Skip to content
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
10f9113
Tracing putTag
hweawer Jan 18, 2026
ba94f3c
Nginx and logs in functions
hweawer Jan 19, 2026
bba1aae
Tracing executor tasks
hweawer Jan 19, 2026
32bb846
Revert mocks
hweawer Jan 20, 2026
c56651d
Fix storage mock
hweawer Jan 20, 2026
e480321
Partially revert
hweawer Jan 20, 2026
fceac88
Revert mock
hweawer Jan 20, 2026
299db87
Update
hweawer Jan 20, 2026
c267118
Remove changes to executor
hweawer Jan 26, 2026
ce87a86
Separate DB operations
hweawer Jan 26, 2026
ed5ea03
DecodeString
hweawer Jan 26, 2026
ff8f723
Fix typo
hweawer Jan 26, 2026
0441e8a
Database persistence for writeback trace context
hweawer Jan 26, 2026
98587d6
feat(origin): add tracing to blob upload handlers
hweawer Jan 26, 2026
25aac66
Remove context import
hweawer Jan 26, 2026
d432533
otel middleware
hweawer Jan 26, 2026
f2c8b22
feat(tracing): propagate trace context to writeback tasks in build-in…
hweawer Jan 26, 2026
e0847ed
Update tagstore tests
hweawer Jan 26, 2026
8f71806
feat(tracing): add trace-aware logging to writeback executor
hweawer Jan 26, 2026
9ecc727
feat: add SendTracingContext option to httputil for OpenTelemetry tra…
hweawer Jan 26, 2026
16ad6d1
feat: client side tracing upload path
hweawer Jan 26, 2026
f1c7884
fix: gomock
hweawer Jan 26, 2026
2f4f623
feat(origin): pass trace headers in the nginx
hweawer Jan 26, 2026
507411f
feat: Add tracing to origin DownloadBlob operation
hweawer Jan 26, 2026
8ae75b4
fix: Fix tests
hweawer Jan 26, 2026
b27beb7
feat(proxy): Tracing prefetch api
hweawer Jan 28, 2026
4aa8e3d
Merge branch 'master' into feat/tracing-proxy-prefetch
hweawer Mar 9, 2026
767ae7c
fix: fix
hweawer Mar 9, 2026
b4d9847
fix: fix
hweawer Mar 9, 2026
8bedd68
fix: fix err
hweawer Mar 9, 2026
5b535ed
Update proxy/proxyserver/prefetch.go
hweawer Mar 9, 2026
ce80978
fix: fix comments
hweawer Mar 9, 2026
45b9950
fix context
hweawer Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 113 additions & 54 deletions proxy/proxyserver/prefetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ import (
"github.com/uber/kraken/origin/blobclient"
"github.com/uber/kraken/utils/httputil"
"github.com/uber/kraken/utils/log"
"go.uber.org/zap"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)

// Constants for prefetch status.
Expand All @@ -43,6 +47,7 @@ type PrefetchHandler struct {
metrics tally.Scope
getManifestLatency tally.Histogram
getTagLatency tally.Histogram
tracer trace.Tracer
}

// blobInfo holds digest and size information for a blob.
Expand All @@ -53,24 +58,21 @@ type blobInfo struct {

// Request and response payloads.
type prefetchBody struct {
Tag string `json:"tag"`
TraceId string `json:"trace_id"`
Tag string `json:"tag"`
}
Comment thread
hweawer marked this conversation as resolved.

type prefetchResponse struct {
Tag string `json:"tag"`
Prefetched bool `json:"prefetched"`
Status string `json:"status"`
Message string `json:"message"`
TraceId string `json:"trace_id"`
}

type prefetchError struct {
Error string `json:"error"`
Prefetched bool `json:"prefetched"`
Status string `json:"status"`
Message string `json:"message"`
TraceId string `json:"trace_id,omitempty"`
}

type TagParser interface {
Expand Down Expand Up @@ -121,28 +123,27 @@ func NewPrefetchHandler(
metrics: m,
getManifestLatency: m.Histogram("download_manifest_latency", tally.MustMakeExponentialDurationBuckets(1*time.Second, 2, 12)),
getTagLatency: m.Histogram("get_tag_latency", tally.MustMakeExponentialDurationBuckets(100*time.Millisecond, 2, 10)),
tracer: otel.Tracer("kraken-proxy-prefetch"),
}
}

// newPrefetchSuccessResponse constructs a successful response.
func newPrefetchSuccessResponse(tag, msg, traceId string) *prefetchResponse {
func newPrefetchSuccessResponse(tag, msg string) *prefetchResponse {
return &prefetchResponse{
Tag: tag,
Prefetched: true,
Status: StatusSuccess,
Message: msg,
TraceId: traceId,
}
}

// newPrefetchError constructs an error response.
func newPrefetchError(status int, msg, traceId string) *prefetchError {
func newPrefetchError(status int, msg string) *prefetchError {
return &prefetchError{
Error: http.StatusText(status),
Prefetched: false,
Status: StatusFailure,
Message: msg,
TraceId: traceId,
}
}

Expand All @@ -164,16 +165,16 @@ func writeJSON(w http.ResponseWriter, status int, payload interface{}) {
}
}

func writeBadRequestError(w http.ResponseWriter, msg, traceId string) {
writeJSON(w, http.StatusBadRequest, newPrefetchError(http.StatusBadRequest, msg, traceId))
func writeBadRequestError(w http.ResponseWriter, msg string) {
writeJSON(w, http.StatusBadRequest, newPrefetchError(http.StatusBadRequest, msg))
}

func writeInternalError(w http.ResponseWriter, msg, traceId string) {
writeJSON(w, http.StatusInternalServerError, newPrefetchError(http.StatusInternalServerError, msg, traceId))
func writeInternalError(w http.ResponseWriter, msg string) {
writeJSON(w, http.StatusInternalServerError, newPrefetchError(http.StatusInternalServerError, msg))
}

func writePrefetchResponse(w http.ResponseWriter, tag, msg, traceId string) {
writeJSON(w, http.StatusOK, newPrefetchSuccessResponse(tag, msg, traceId))
func writePrefetchResponse(w http.ResponseWriter, tag, msg string) {
writeJSON(w, http.StatusOK, newPrefetchSuccessResponse(tag, msg))
}

// HandleV1 processes the prefetch request.
Expand All @@ -184,7 +185,7 @@ func (ph *PrefetchHandler) HandleV1(w http.ResponseWriter, r *http.Request) {
}

ph.metrics.Counter("initiated").Inc(1)
writePrefetchResponse(w, input.tag, "prefetching initiated successfully", input.traceID)
writePrefetchResponse(w, input.tag, "prefetching initiated successfully")

if ph.v1Synchronous {
ph.downloadBlobs(input)
Expand All @@ -195,88 +196,124 @@ func (ph *PrefetchHandler) HandleV1(w http.ResponseWriter, r *http.Request) {
}

type prefetchInput struct {
ctx context.Context
blobs []blobInfo
namespace string
logger *zap.SugaredLogger
tag string
traceID string
}

// preparePrefetch parses the request, calls build-index to get the image manifest SHA,
// downloads the manifest(s) from the origin cluster, parses them, and returns the blobs layers to prefetch.
// If an error occurs, preparePrefetch returns the appropriate HTTP response.
func (ph *PrefetchHandler) preparePrefetch(w http.ResponseWriter, r *http.Request) (res *prefetchInput, errOccurred bool) {
ctx, span := ph.tracer.Start(r.Context(), "prefetch.prepare",
trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(
attribute.String("component", "proxy-prefetch"),
attribute.String("operation", "prepare_prefetch"),
),
Comment thread
hweawer marked this conversation as resolved.
)
defer span.End()
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preparePrefetch returns ctx that was created from ph.tracer.Start(...), but the span is ended via defer span.End() before the caller starts prefetch.download_blobs / prefetch.trigger_prefetch. This makes those later spans children of an already-ended parent span, producing confusing timelines in traces. Consider returning a context whose active span is still appropriate for downstream work (e.g., return r.Context() / the otelhttp span context, or create a higher-level span in HandleV1/HandleV2 that stays open for the synchronous portion and use child spans for prepare/trigger/download).

Suggested change
defer span.End()
defer span.End()
// Use the original request context for downstream work to avoid returning
// a context tied to a span that will be ended before it is used.
ctx = r.Context()

Copilot uses AI. Check for mistakes.

ph.metrics.Counter("requests").Inc(1)
var reqBody prefetchBody
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
writeBadRequestError(w, fmt.Sprintf("failed to decode request body: %s", err), "")
log.With("error", err).Error("Failed to decode request body")
span.RecordError(err)
span.SetStatus(codes.Error, "failed to decode request body")
writeBadRequestError(w, fmt.Sprintf("failed to decode request body: %s", err))
log.WithTraceContext(ctx).With("error", err).Error("Failed to decode request body")
return nil, true
}
logger := log.
With("trace_id", reqBody.TraceId).
With("image_tag", reqBody.Tag)

span.SetAttributes(attribute.String("image.tag", reqBody.Tag))

namespace, tag, err := ph.tagParser.ParseTag(reqBody.Tag)
if err != nil {
writeBadRequestError(w, fmt.Sprintf("tag: %s, invalid tag format: %s", reqBody.Tag, err), reqBody.TraceId)
span.RecordError(err)
span.SetStatus(codes.Error, "invalid tag format")
writeBadRequestError(w, fmt.Sprintf("tag: %s, invalid tag format: %s", reqBody.Tag, err))
return nil, true
}

span.SetAttributes(
attribute.String("image.namespace", namespace),
attribute.String("image.name", tag),
)

tagRequest := url.QueryEscape(fmt.Sprintf("%s/%s", namespace, tag))
startTime := time.Now()
digest, err := ph.tagClient.Get(tagRequest)
if err != nil {
ph.metrics.Counter("get_tag_error").Inc(1)
logger.With("error", err).Error("Failed to get manifest tag")
writeInternalError(w, fmt.Sprintf("tag request: %s, failed to get tag: %s", tagRequest, err), reqBody.TraceId)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to get manifest tag")
log.WithTraceContext(ctx).With("error", err).Error("Failed to get manifest tag")
writeInternalError(w, fmt.Sprintf("tag request: %s, failed to get tag: %s", tagRequest, err))
return nil, true
}
ph.getTagLatency.RecordDuration(time.Since(startTime))
logger.Infof("Namespace: %s, Tag: %s", namespace, tag)
span.SetAttributes(attribute.String("manifest.digest", digest.Hex()))
log.WithTraceContext(ctx).Infof("Namespace: %s, Tag: %s", namespace, tag)

buf := &bytes.Buffer{}
startTime = time.Now()
if err := ph.clusterClient.DownloadBlob(context.Background(), namespace, digest, buf); err != nil {
if err := ph.clusterClient.DownloadBlob(ctx, namespace, digest, buf); err != nil {
ph.metrics.Counter("download_manifest_error").Inc(1)
logger.With("error", err).Error("Failed to download manifest blob")
writeInternalError(w, fmt.Sprintf("error downloading manifest blob: %s", err), reqBody.TraceId)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to download manifest")
log.WithTraceContext(ctx).With("error", err).Error("Failed to download manifest blob")
writeInternalError(w, fmt.Sprintf("error downloading manifest blob: %s", err))
return nil, true
}
ph.getManifestLatency.RecordDuration(time.Since(startTime))

// Process manifest (ManifestList or single Manifest)
blobs, err := ph.processManifest(logger, namespace, buf.Bytes())
blobs, err := ph.processManifest(ctx, namespace, buf.Bytes())
if err != nil {
writeInternalError(w, fmt.Sprintf("failed to process manifest: %s", err), reqBody.TraceId)
span.RecordError(err)
span.SetStatus(codes.Error, "failed to process manifest")
writeInternalError(w, fmt.Sprintf("failed to process manifest: %s", err))
return nil, true
}

span.SetAttributes(attribute.Int("blobs.count", len(blobs)))
span.SetStatus(codes.Ok, "prepare completed")

return &prefetchInput{
ctx: ctx,
blobs: blobs,
namespace: namespace,
logger: logger,
tag: tag,
traceID: reqBody.TraceId,
}, false
}

// downloadBlobs downloads blobs in parallel.
func (ph *PrefetchHandler) downloadBlobs(input *prefetchInput) {
ctx, span := ph.tracer.Start(input.ctx, "prefetch.download_blobs",
trace.WithAttributes(
attribute.String("component", "proxy-prefetch"),
attribute.String("operation", "download_blobs"),
attribute.String("image.namespace", input.namespace),
attribute.String("image.tag", input.tag),
attribute.Int("blobs.total", len(input.blobs)),
),
)
defer span.End()

var wg sync.WaitGroup
var mu sync.Mutex
var errList []error

for _, b := range input.blobs {
if ph.shouldSkipPrefetch(b, input.logger) {
if ph.shouldSkipPrefetch(input.ctx, b) {
continue
}
Comment thread
hweawer marked this conversation as resolved.

wg.Add(1)
go func(blob blobInfo) {
defer wg.Done()
blobStart := time.Now()
err := ph.clusterClient.DownloadBlob(context.Background(), input.namespace, blob.digest, io.Discard)
err := ph.clusterClient.DownloadBlob(ctx, input.namespace, blob.digest, io.Discard)
Comment thread
hweawer marked this conversation as resolved.
Outdated
blobDuration := time.Since(blobStart)
ph.metrics.Timer("blob_download_time").Record(blobDuration)
ph.metrics.Counter("bytes_downloaded").Inc(blob.size)
Expand All @@ -297,16 +334,20 @@ func (ph *PrefetchHandler) downloadBlobs(input *prefetchInput) {

if len(errList) > 0 {
ph.metrics.Counter("failed").Inc(1)
span.RecordError(errors.Join(errList...))
span.SetStatus(codes.Error, fmt.Sprintf("%d blob downloads failed", len(errList)))
for _, err := range errList {
input.logger.With("error", err).Error("Error downloading blob")
log.WithTraceContext(input.ctx).With("error", err).Error("Error downloading blob")
Comment thread
hweawer marked this conversation as resolved.
Outdated
Comment thread
hweawer marked this conversation as resolved.
Outdated
}
} else {
span.SetStatus(codes.Ok, "all blobs downloaded")
}
}

// Skip blobs that are outside the size range [min, max]
func (ph *PrefetchHandler) shouldSkipPrefetch(b blobInfo, logger *zap.SugaredLogger) bool {
func (ph *PrefetchHandler) shouldSkipPrefetch(ctx context.Context, b blobInfo) bool {
if b.size < ph.minBlobSizeBytes {
logger.With(
log.WithTraceContext(ctx).With(
"digest", b.digest,
"size", b.size,
"min_threshold", ph.minBlobSizeBytes,
Expand All @@ -315,7 +356,7 @@ func (ph *PrefetchHandler) shouldSkipPrefetch(b blobInfo, logger *zap.SugaredLog
return true
}
if b.size > ph.maxBlobSizeBytes {
logger.With(
log.WithTraceContext(ctx).With(
"digest", b.digest,
"size", b.size,
"max_threshold", ph.maxBlobSizeBytes,
Expand All @@ -327,34 +368,34 @@ func (ph *PrefetchHandler) shouldSkipPrefetch(b blobInfo, logger *zap.SugaredLog
}

// processManifest handles both ManifestLists and single Manifests.
func (ph *PrefetchHandler) processManifest(logger *zap.SugaredLogger, namespace string, manifestBytes []byte) ([]blobInfo, error) {
func (ph *PrefetchHandler) processManifest(ctx context.Context, namespace string, manifestBytes []byte) ([]blobInfo, error) {
// Attempt to process as a manifest list.
blobs, err := ph.tryProcessManifestList(logger, namespace, manifestBytes)
blobs, err := ph.tryProcessManifestList(ctx, namespace, manifestBytes)
if err == nil && len(blobs) > 0 {
return blobs, nil
}

// Fallback to single manifest.
var manifest schema2.Manifest
if err := json.NewDecoder(bytes.NewReader(manifestBytes)).Decode(&manifest); err != nil {
logger.With("namespace", namespace).Errorf("Failed to parse single manifest: %v", err)
log.WithTraceContext(ctx).With("namespace", namespace).Errorf("Failed to parse single manifest: %v", err)
return nil, fmt.Errorf("invalid single manifest: %w", err)
}
return ph.processLayers(manifest.Layers)
}

// tryProcessManifestList attempts to decode a manifest list.
func (ph *PrefetchHandler) tryProcessManifestList(logger *zap.SugaredLogger, namespace string, manifestBytes []byte) ([]blobInfo, error) {
func (ph *PrefetchHandler) tryProcessManifestList(ctx context.Context, namespace string, manifestBytes []byte) ([]blobInfo, error) {
var manifestList manifestlist.ManifestList
if err := json.NewDecoder(bytes.NewReader(manifestBytes)).Decode(&manifestList); err != nil || len(manifestList.Manifests) == 0 {
return nil, fmt.Errorf("not a valid manifest list")
}
logger.With("namespace", namespace).Info("Processing manifest list")
return ph.processManifestList(logger, namespace, manifestList)
log.WithTraceContext(ctx).With("namespace", namespace).Info("Processing manifest list")
return ph.processManifestList(ctx, namespace, manifestList)
}

// processManifestList processes a manifest list.
func (ph *PrefetchHandler) processManifestList(logger *zap.SugaredLogger, namespace string, manifestList manifestlist.ManifestList) ([]blobInfo, error) {
func (ph *PrefetchHandler) processManifestList(ctx context.Context, namespace string, manifestList manifestlist.ManifestList) ([]blobInfo, error) {
var allBlobs []blobInfo
for _, descriptor := range manifestList.Manifests {
manifestDigestHex := descriptor.Digest.Hex()
Expand All @@ -364,9 +405,9 @@ func (ph *PrefetchHandler) processManifestList(logger *zap.SugaredLogger, namesp
}
buf := &bytes.Buffer{}
startTime := time.Now()
if err := ph.clusterClient.DownloadBlob(context.Background(), namespace, digest, buf); err != nil {
if err := ph.clusterClient.DownloadBlob(ctx, namespace, digest, buf); err != nil {
ph.metrics.Counter("download_manifest_error").Inc(1)
logger.With("error", err).Error("Failed to download manifest blob")
log.WithTraceContext(ctx).With("error", err).Error("Failed to download manifest blob")
continue
}
ph.getManifestLatency.RecordDuration(time.Since(startTime))
Expand Down Expand Up @@ -413,23 +454,36 @@ func (ph *PrefetchHandler) HandleV2(w http.ResponseWriter, r *http.Request) {

err := ph.triggerPrefetchBlobs(input)
if err != nil {
writeInternalError(w, fmt.Sprintf("failed to trigger image prefetch: %s", err), input.traceID)
input.logger.Errorf("Failed to trigger image prefetch")
writeInternalError(w, fmt.Sprintf("failed to trigger image prefetch: %s", err))
log.WithTraceContext(input.ctx).Errorf("Failed to trigger image prefetch")
return
}

ph.metrics.Counter("initiated").Inc(1)
writePrefetchResponse(w, input.tag, "prefetching initiated successfully", input.traceID)
writePrefetchResponse(w, input.tag, "prefetching initiated successfully")
}

// triggerPrefetchBlobs triggers a blob prefetch for all blobs in parallel.
func (ph *PrefetchHandler) triggerPrefetchBlobs(input *prefetchInput) error {
ctx, span := ph.tracer.Start(input.ctx, "prefetch.trigger_prefetch",
trace.WithAttributes(
attribute.String("component", "proxy-prefetch"),
attribute.String("operation", "trigger_prefetch"),
attribute.String("image.namespace", input.namespace),
attribute.String("image.tag", input.tag),
attribute.Int("blobs.total", len(input.blobs)),
),
)
defer span.End()

_ = ctx // PrefetchBlob doesn't accept context yet
Comment thread
hweawer marked this conversation as resolved.
Outdated

var wg sync.WaitGroup
var mu sync.Mutex
var errList []error

for _, b := range input.blobs {
if ph.shouldSkipPrefetch(b, input.logger) {
if ph.shouldSkipPrefetch(input.ctx, b) {
Comment thread
hweawer marked this conversation as resolved.
Outdated
continue
Comment thread
hweawer marked this conversation as resolved.
Outdated
}

Expand All @@ -447,7 +501,12 @@ func (ph *PrefetchHandler) triggerPrefetchBlobs(input *prefetchInput) error {
wg.Wait()

if len(errList) != 0 {
return fmt.Errorf("at least one layer could not be prefetched: %w", errors.Join(errList...))
err := fmt.Errorf("at least one layer could not be prefetched: %w", errors.Join(errList...))
span.RecordError(err)
span.SetStatus(codes.Error, fmt.Sprintf("%d prefetch requests failed", len(errList)))
return err
}

span.SetStatus(codes.Ok, "all prefetch requests triggered")
return nil
}
Loading
Loading