Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type config struct {
Enabled bool `envconfig:"ENABLED"`
AuthDisabled bool `envconfig:"AUTH_DISABLED"`
Path string `envconfig:"PATH" default:"/tmp"`
Expiration time.Duration `envconfig:"EXPIRATION" default:"10s"`
Expiration time.Duration `envconfig:"EXPIRATION" default:"1m"`
} `envconfig:"CACHE_"`
Github github.Config `envconfig:"GITHUB_"`
Modules modules.Config `envconfig:"MODULES_"`
Expand All @@ -37,10 +37,21 @@ func main() {

log := slog.New(slog.NewTextHandler(os.Stdout, nil))

var cache modules.KeyValueStore
if cfg.Cache.Enabled {
c := mcache.New[string, []string](cfg.Cache.Expiration)
if cfg.Cache.Expiration > 0 {
// Start a cleanup loop to remove expired items from the cache periodically.
stop := mcache.StartCleanupLoop(c, cfg.Cache.Expiration)
defer stop()
}
cache = c
}

var repo modules.Repository
repo = github.New(cfg.Github, &http.Client{
Timeout: 5 * time.Second,
})
}, cache)

if cfg.Cache.Enabled {
log.Info("enabling cache", "path", cfg.Cache.Path, "expiration", cfg.Cache.Expiration, "authDisabled", cfg.Cache.AuthDisabled)
Expand All @@ -49,7 +60,7 @@ func main() {
}
repo = modules.NewCache(
repo,
mcache.New[string, []string](cfg.Cache.Expiration),
cache,
modules.StoreInPath(cfg.Cache.Path),
log,
cfg.Cache.AuthDisabled,
Expand Down
72 changes: 60 additions & 12 deletions pkg/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"regexp"
"slices"
"strings"
"time"

"github.com/reMarkable/orbit/pkg/auth"
)
Expand All @@ -36,16 +37,39 @@ type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}

func New(cfg Config, c HTTPClient) *Service {
// TagCache caches the aggregated tag names of a repository, keyed by
// "owner/repo", to avoid re-paginating all tags on every request.
type TagCache interface {
Get(key string) ([]string, bool)
Set(key string, value []string, d ...time.Duration)
}

// noopTagCache is a cache that never stores anything. It satisfies the same Get/Set
// interface as Cache and can be used to disable caching
type noopTagCache struct{}

// Get always reports a miss.
func (noopTagCache) Get(string) ([]string, bool) { return nil, false }

// Set discards the value.
func (noopTagCache) Set(string, []string, ...time.Duration) {}

func New(cfg Config, c HTTPClient, cache TagCache) *Service {
if cache == nil {
cache = noopTagCache{}
}

return &Service{
cfg: cfg,
client: c,
cache: cache,
}
}

type Service struct {
cfg Config
client HTTPClient
cache TagCache
}

// https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#list-repository-tags
Expand All @@ -55,10 +79,34 @@ func (s *Service) ListVersions(ctx context.Context, system, repo, module string)
return nil, err
}

tags, err := s.listTags(ctx, owner, repo)
if err != nil {
return nil, err
}

prefix := module + "/"
versions := []string{}
for _, name := range tags {
if strings.HasPrefix(name, prefix) {
versions = append(versions, strings.TrimPrefix(name, prefix))
}
}
return versions, nil
}

// listTags returns the names of all tags in the repository, fetching every page
// from the GitHub API. When a cache is configured, results are cached per
// repository to avoid re-paginating all tags on subsequent requests within the
// cache expiration window.
func (s *Service) listTags(ctx context.Context, owner, repo string) ([]string, error) {
key := owner + "/" + repo
if tags, ok := s.cache.Get(key); ok {
return tags, nil
}

var (
page = 1
prefix = module + "/"
versions = []string{}
page = 1
tags = []string{}
)
for {
uri := fmt.Sprintf("repos/%s/%s/tags?per_page=%d&page=%d", owner, repo, tagsPerPage, page)
Expand All @@ -67,10 +115,10 @@ func (s *Service) ListVersions(ctx context.Context, system, repo, module string)
return nil, err
}

var tags []struct {
var batch []struct {
Name string `json:"name"`
}
err = json.NewDecoder(res).Decode(&tags)
err = json.NewDecoder(res).Decode(&batch)
cerr := res.Close()
if cerr != nil {
return nil, fmt.Errorf("closing response: %w", cerr)
Expand All @@ -80,18 +128,18 @@ func (s *Service) ListVersions(ctx context.Context, system, repo, module string)
return nil, fmt.Errorf("decoding response: %w", err)
}

for _, tag := range tags {
if strings.HasPrefix(tag.Name, prefix) {
versions = append(versions, strings.TrimPrefix(tag.Name, prefix))
}
for _, tag := range batch {
tags = append(tags, tag.Name)
}

if len(tags) < tagsPerPage {
if len(batch) < tagsPerPage {
break
}
page++
}
return versions, nil

s.cache.Set(key, tags)
return tags, nil
}

func (s *Service) ProxyDownload(ctx context.Context, system, repo, module, version string, w io.Writer) error {
Expand Down
102 changes: 100 additions & 2 deletions pkg/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import (
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"net/http"
"testing"

"github.com/reMarkable/orbit/pkg/mcache"
)

type mockHTTPClient struct {
Expand Down Expand Up @@ -40,7 +43,7 @@ func TestService_ListVersions(t *testing.T) {
cfg := Config{
OrgMappings: map[string]string{"test-system": "test-org"},
}
service := New(cfg, mockClient)
service := New(cfg, mockClient, mcache.New[string, []string](mcache.NoExpiration))

versions, err := service.ListVersions(context.Background(), "test-system", "test-repo", "module")
if err != nil {
Expand Down Expand Up @@ -97,7 +100,7 @@ func TestService_ProxyDownload(t *testing.T) {
cfg := Config{
OrgMappings: map[string]string{"test-system": "test-org"},
}
service := New(cfg, mockClient)
service := New(cfg, mockClient, mcache.New[string, []string](mcache.NoExpiration))

var buf bytes.Buffer
err := service.ProxyDownload(context.Background(), "test-system", "test-repo", "module", "v1.0.0", &buf)
Expand Down Expand Up @@ -126,3 +129,98 @@ func TestService_ProxyDownload(t *testing.T) {
t.Errorf("expected tarball to contain %q, but it was %q", expectedContent, tarBuf.Bytes())
}
}

func TestService_ListVersions_CachesTagsPerRepo(t *testing.T) {
var calls int
mockClient := &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
if req.URL.Path == "/repos/test-org/test-repo/tags" {
calls++
body := `[
{"name": "module/v1.0.0"},
{"name": "other/v2.0.0"}
]`
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader([]byte(body))),
}, nil
}
return nil, errors.New("unexpected request")
},
}

cfg := Config{
OrgMappings: map[string]string{"test-system": "test-org"},
}
service := New(cfg, mockClient, mcache.New[string, []string](mcache.NoExpiration))

// First call for "module" should hit the API.
if _, err := service.ListVersions(context.Background(), "test-system", "test-repo", "module"); err != nil {
t.Fatalf("unexpected error: %v", err)
}

// Second call for a different module in the same repo should be served from
// the cache without hitting the API again.
other, err := service.ListVersions(context.Background(), "test-system", "test-repo", "other")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if calls != 1 {
t.Fatalf("expected tags endpoint to be called once, got %d", calls)
}

if len(other) != 1 || other[0] != "v2.0.0" {
t.Errorf("expected cached tags to yield [v2.0.0], got %v", other)
}
}

func TestService_ListVersions_Pagination(t *testing.T) {
// Build a first page with exactly tagsPerPage entries to force a second request.
var firstPage bytes.Buffer
firstPage.WriteString("[")
for i := 0; i < tagsPerPage; i++ {
if i > 0 {
firstPage.WriteString(",")
}
fmt.Fprintf(&firstPage, `{"name": "filler/v0.0.%d"}`, i)
}
firstPage.WriteString("]")

mockClient := &mockHTTPClient{
doFunc: func(req *http.Request) (*http.Response, error) {
if req.URL.Path != "/repos/test-org/test-repo/tags" {
return nil, errors.New("unexpected request")
}
switch req.URL.Query().Get("page") {
case "1":
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(firstPage.Bytes())),
}, nil
case "2":
body := `[{"name": "module/v1.0.0"}]`
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader([]byte(body))),
}, nil
default:
return nil, errors.New("unexpected page")
}
},
}

cfg := Config{
OrgMappings: map[string]string{"test-system": "test-org"},
}
service := New(cfg, mockClient, mcache.New[string, []string](mcache.NoExpiration))

versions, err := service.ListVersions(context.Background(), "test-system", "test-repo", "module")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(versions) != 1 || versions[0] != "v1.0.0" {
t.Errorf("expected [v1.0.0] across pages, got %v", versions)
}
}
Loading