diff --git a/cmd/server/main.go b/cmd/server/main.go index 7739b80..497738a 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -24,7 +24,7 @@ type config struct { Enabled bool `envconfig:"ENABLED"` AuthDisabled bool `envconfig:"AUTH_DISABLED"` Path string `envconfig:"PATH" default:"/tmp"` - Expiration time.Duration `envconfig:"EXPIRATION" default:"10s"` + Expiration time.Duration `envconfig:"EXPIRATION" default:"1m"` } `envconfig:"CACHE_"` Github github.Config `envconfig:"GITHUB_"` Modules modules.Config `envconfig:"MODULES_"` @@ -37,10 +37,21 @@ func main() { log := slog.New(slog.NewTextHandler(os.Stdout, nil)) + var cache modules.KeyValueStore + if cfg.Cache.Enabled { + c := mcache.New[string, []string](cfg.Cache.Expiration) + if cfg.Cache.Expiration > 0 { + // Start a cleanup loop to remove expired items from the cache periodically. + stop := mcache.StartCleanupLoop(c, cfg.Cache.Expiration) + defer stop() + } + cache = c + } + var repo modules.Repository repo = github.New(cfg.Github, &http.Client{ Timeout: 5 * time.Second, - }) + }, cache) if cfg.Cache.Enabled { log.Info("enabling cache", "path", cfg.Cache.Path, "expiration", cfg.Cache.Expiration, "authDisabled", cfg.Cache.AuthDisabled) @@ -49,7 +60,7 @@ func main() { } repo = modules.NewCache( repo, - mcache.New[string, []string](cfg.Cache.Expiration), + cache, modules.StoreInPath(cfg.Cache.Path), log, cfg.Cache.AuthDisabled, diff --git a/pkg/github/github.go b/pkg/github/github.go index 07794b6..81fdd32 100644 --- a/pkg/github/github.go +++ b/pkg/github/github.go @@ -16,6 +16,7 @@ import ( "regexp" "slices" "strings" + "time" "github.com/reMarkable/orbit/pkg/auth" ) @@ -36,16 +37,39 @@ type HTTPClient interface { Do(req *http.Request) (*http.Response, error) } -func New(cfg Config, c HTTPClient) *Service { +// TagCache caches the aggregated tag names of a repository, keyed by +// "owner/repo", to avoid re-paginating all tags on every request. +type TagCache interface { + Get(key string) ([]string, bool) + Set(key string, value []string, d ...time.Duration) +} + +// noopTagCache is a cache that never stores anything. It satisfies the same Get/Set +// interface as Cache and can be used to disable caching +type noopTagCache struct{} + +// Get always reports a miss. +func (noopTagCache) Get(string) ([]string, bool) { return nil, false } + +// Set discards the value. +func (noopTagCache) Set(string, []string, ...time.Duration) {} + +func New(cfg Config, c HTTPClient, cache TagCache) *Service { + if cache == nil { + cache = noopTagCache{} + } + return &Service{ cfg: cfg, client: c, + cache: cache, } } type Service struct { cfg Config client HTTPClient + cache TagCache } // https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#list-repository-tags @@ -55,10 +79,34 @@ func (s *Service) ListVersions(ctx context.Context, system, repo, module string) return nil, err } + tags, err := s.listTags(ctx, owner, repo) + if err != nil { + return nil, err + } + + prefix := module + "/" + versions := []string{} + for _, name := range tags { + if strings.HasPrefix(name, prefix) { + versions = append(versions, strings.TrimPrefix(name, prefix)) + } + } + return versions, nil +} + +// listTags returns the names of all tags in the repository, fetching every page +// from the GitHub API. When a cache is configured, results are cached per +// repository to avoid re-paginating all tags on subsequent requests within the +// cache expiration window. +func (s *Service) listTags(ctx context.Context, owner, repo string) ([]string, error) { + key := owner + "/" + repo + if tags, ok := s.cache.Get(key); ok { + return tags, nil + } + var ( - page = 1 - prefix = module + "/" - versions = []string{} + page = 1 + tags = []string{} ) for { uri := fmt.Sprintf("repos/%s/%s/tags?per_page=%d&page=%d", owner, repo, tagsPerPage, page) @@ -67,10 +115,10 @@ func (s *Service) ListVersions(ctx context.Context, system, repo, module string) return nil, err } - var tags []struct { + var batch []struct { Name string `json:"name"` } - err = json.NewDecoder(res).Decode(&tags) + err = json.NewDecoder(res).Decode(&batch) cerr := res.Close() if cerr != nil { return nil, fmt.Errorf("closing response: %w", cerr) @@ -80,18 +128,18 @@ func (s *Service) ListVersions(ctx context.Context, system, repo, module string) return nil, fmt.Errorf("decoding response: %w", err) } - for _, tag := range tags { - if strings.HasPrefix(tag.Name, prefix) { - versions = append(versions, strings.TrimPrefix(tag.Name, prefix)) - } + for _, tag := range batch { + tags = append(tags, tag.Name) } - if len(tags) < tagsPerPage { + if len(batch) < tagsPerPage { break } page++ } - return versions, nil + + s.cache.Set(key, tags) + return tags, nil } func (s *Service) ProxyDownload(ctx context.Context, system, repo, module, version string, w io.Writer) error { diff --git a/pkg/github/github_test.go b/pkg/github/github_test.go index 5b87016..d93e0b0 100644 --- a/pkg/github/github_test.go +++ b/pkg/github/github_test.go @@ -6,9 +6,12 @@ import ( "compress/gzip" "context" "errors" + "fmt" "io" "net/http" "testing" + + "github.com/reMarkable/orbit/pkg/mcache" ) type mockHTTPClient struct { @@ -40,7 +43,7 @@ func TestService_ListVersions(t *testing.T) { cfg := Config{ OrgMappings: map[string]string{"test-system": "test-org"}, } - service := New(cfg, mockClient) + service := New(cfg, mockClient, mcache.New[string, []string](mcache.NoExpiration)) versions, err := service.ListVersions(context.Background(), "test-system", "test-repo", "module") if err != nil { @@ -97,7 +100,7 @@ func TestService_ProxyDownload(t *testing.T) { cfg := Config{ OrgMappings: map[string]string{"test-system": "test-org"}, } - service := New(cfg, mockClient) + service := New(cfg, mockClient, mcache.New[string, []string](mcache.NoExpiration)) var buf bytes.Buffer err := service.ProxyDownload(context.Background(), "test-system", "test-repo", "module", "v1.0.0", &buf) @@ -126,3 +129,98 @@ func TestService_ProxyDownload(t *testing.T) { t.Errorf("expected tarball to contain %q, but it was %q", expectedContent, tarBuf.Bytes()) } } + +func TestService_ListVersions_CachesTagsPerRepo(t *testing.T) { + var calls int + mockClient := &mockHTTPClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + if req.URL.Path == "/repos/test-org/test-repo/tags" { + calls++ + body := `[ + {"name": "module/v1.0.0"}, + {"name": "other/v2.0.0"} + ]` + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(body))), + }, nil + } + return nil, errors.New("unexpected request") + }, + } + + cfg := Config{ + OrgMappings: map[string]string{"test-system": "test-org"}, + } + service := New(cfg, mockClient, mcache.New[string, []string](mcache.NoExpiration)) + + // First call for "module" should hit the API. + if _, err := service.ListVersions(context.Background(), "test-system", "test-repo", "module"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Second call for a different module in the same repo should be served from + // the cache without hitting the API again. + other, err := service.ListVersions(context.Background(), "test-system", "test-repo", "other") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if calls != 1 { + t.Fatalf("expected tags endpoint to be called once, got %d", calls) + } + + if len(other) != 1 || other[0] != "v2.0.0" { + t.Errorf("expected cached tags to yield [v2.0.0], got %v", other) + } +} + +func TestService_ListVersions_Pagination(t *testing.T) { + // Build a first page with exactly tagsPerPage entries to force a second request. + var firstPage bytes.Buffer + firstPage.WriteString("[") + for i := 0; i < tagsPerPage; i++ { + if i > 0 { + firstPage.WriteString(",") + } + fmt.Fprintf(&firstPage, `{"name": "filler/v0.0.%d"}`, i) + } + firstPage.WriteString("]") + + mockClient := &mockHTTPClient{ + doFunc: func(req *http.Request) (*http.Response, error) { + if req.URL.Path != "/repos/test-org/test-repo/tags" { + return nil, errors.New("unexpected request") + } + switch req.URL.Query().Get("page") { + case "1": + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(firstPage.Bytes())), + }, nil + case "2": + body := `[{"name": "module/v1.0.0"}]` + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(body))), + }, nil + default: + return nil, errors.New("unexpected page") + } + }, + } + + cfg := Config{ + OrgMappings: map[string]string{"test-system": "test-org"}, + } + service := New(cfg, mockClient, mcache.New[string, []string](mcache.NoExpiration)) + + versions, err := service.ListVersions(context.Background(), "test-system", "test-repo", "module") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(versions) != 1 || versions[0] != "v1.0.0" { + t.Errorf("expected [v1.0.0] across pages, got %v", versions) + } +}