diff --git a/docs/modules/ROOT/pages/ec_oci_image_manifests.adoc b/docs/modules/ROOT/pages/ec_oci_image_manifests.adoc new file mode 100644 index 000000000..aaee30eea --- /dev/null +++ b/docs/modules/ROOT/pages/ec_oci_image_manifests.adoc @@ -0,0 +1,20 @@ += ec.oci.image_manifests + +Fetch Image Manifests from an OCI registry in parallel. + +== Usage + + manifests = ec.oci.image_manifests(refs: set[string]) + +== Parameters + +* `refs` (`set[string]`): set of OCI image references + +== Return + +`manifests` (`object`): object mapping refs to their Image Manifest objects + +The object contains dynamic attributes. +The attributes are of `string` type and represent the OCI image reference. +The values are of `object, os: string, os.features: array, os.version: string, variant: string>, size: number, urls: array>, layers: array, os: string, os.features: array, os.version: string, variant: string>, size: number, urls: array>>, mediaType: string, schemaVersion: number, subject: object, os: string, os.features: array, os.version: string, variant: string>, size: number, urls: array>>` type and hold the Image Manifest object. + diff --git a/docs/modules/ROOT/pages/rego_builtins.adoc b/docs/modules/ROOT/pages/rego_builtins.adoc index 4bb827626..cbbfaacb4 100644 --- a/docs/modules/ROOT/pages/rego_builtins.adoc +++ b/docs/modules/ROOT/pages/rego_builtins.adoc @@ -18,6 +18,8 @@ information. |Fetch an Image Index from an OCI registry. |xref:ec_oci_image_manifest.adoc[ec.oci.image_manifest] |Fetch an Image Manifest from an OCI registry. +|xref:ec_oci_image_manifests.adoc[ec.oci.image_manifests] +|Fetch Image Manifests from an OCI registry in parallel. |xref:ec_purl_is_valid.adoc[ec.purl.is_valid] |Determine whether or not a given PURL is valid. |xref:ec_purl_parse.adoc[ec.purl.parse] diff --git a/docs/modules/ROOT/partials/rego_nav.adoc b/docs/modules/ROOT/partials/rego_nav.adoc index 6abfed01e..1335fbeab 100644 --- a/docs/modules/ROOT/partials/rego_nav.adoc +++ b/docs/modules/ROOT/partials/rego_nav.adoc @@ -4,6 +4,7 @@ ** xref:ec_oci_image_files.adoc[ec.oci.image_files] ** xref:ec_oci_image_index.adoc[ec.oci.image_index] ** xref:ec_oci_image_manifest.adoc[ec.oci.image_manifest] +** xref:ec_oci_image_manifests.adoc[ec.oci.image_manifests] ** xref:ec_purl_is_valid.adoc[ec.purl.is_valid] ** xref:ec_purl_parse.adoc[ec.purl.parse] ** xref:ec_sigstore_verify_attestation.adoc[ec.sigstore.verify_attestation] diff --git a/internal/rego/oci/oci.go b/internal/rego/oci/oci.go index eced53a59..13c22ecc5 100644 --- a/internal/rego/oci/oci.go +++ b/internal/rego/oci/oci.go @@ -27,6 +27,9 @@ import ( "errors" "fmt" "io" + "runtime" + "sync" + "sync/atomic" "github.com/google/go-containerregistry/pkg/name" v1 "github.com/google/go-containerregistry/pkg/v1" @@ -35,6 +38,8 @@ import ( "github.com/open-policy-agent/opa/v1/topdown/builtins" "github.com/open-policy-agent/opa/v1/types" log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/singleflight" "k8s.io/client-go/util/retry" "github.com/conforma/cli/internal/fetchers/oci/files" @@ -42,11 +47,12 @@ import ( ) const ( - ociBlobName = "ec.oci.blob" - ociDescriptorName = "ec.oci.descriptor" - ociImageManifestName = "ec.oci.image_manifest" - ociImageFilesName = "ec.oci.image_files" - ociImageIndexName = "ec.oci.image_index" + ociBlobName = "ec.oci.blob" + ociDescriptorName = "ec.oci.descriptor" + ociImageManifestName = "ec.oci.image_manifest" + ociImageManifestsBatchName = "ec.oci.image_manifests" + ociImageFilesName = "ec.oci.image_files" + ociImageIndexName = "ec.oci.image_index" ) func registerOCIBlob() { @@ -212,6 +218,78 @@ func registerOCIImageManifest() { }) } +func registerOCIImageManifestsBatch() { + platform := types.NewObject( + []*types.StaticProperty{ + {Key: "architecture", Value: types.S}, + {Key: "os", Value: types.S}, + {Key: "os.version", Value: types.S}, + {Key: "os.features", Value: types.NewArray([]types.Type{types.S}, nil)}, + {Key: "variant", Value: types.S}, + {Key: "features", Value: types.NewArray([]types.Type{types.S}, nil)}, + }, + nil, + ) + + annotations := types.NewObject(nil, types.NewDynamicProperty(types.S, types.S)) + + descriptor := types.NewObject( + []*types.StaticProperty{ + {Key: "mediaType", Value: types.S}, + {Key: "size", Value: types.N}, + {Key: "digest", Value: types.S}, + {Key: "data", Value: types.S}, + {Key: "urls", Value: types.NewArray( + []types.Type{types.S}, nil, + )}, + {Key: "annotations", Value: annotations}, + {Key: "platform", Value: platform}, + {Key: "artifactType", Value: types.S}, + }, + nil, + ) + + manifest := types.NewObject( + []*types.StaticProperty{ + {Key: "schemaVersion", Value: types.N}, + {Key: "mediaType", Value: types.S}, + {Key: "config", Value: descriptor}, + {Key: "layers", Value: types.NewArray( + []types.Type{descriptor}, nil, + )}, + {Key: "annotations", Value: annotations}, + {Key: "subject", Value: descriptor}, + }, + nil, + ) + + // Return type is an object mapping ref strings to manifests + resultType := types.NewObject(nil, types.NewDynamicProperty( + types.Named("ref", types.S).Description("the OCI image reference"), + types.Named("manifest", manifest).Description("the Image Manifest object"), + )) + + decl := rego.Function{ + Name: ociImageManifestsBatchName, + Decl: types.NewFunction( + types.Args( + types.Named("refs", types.NewSet(types.S)).Description("set of OCI image references"), + ), + types.Named("manifests", resultType).Description("object mapping refs to their Image Manifest objects"), + ), + Memoize: true, + Nondeterministic: true, + } + + rego.RegisterBuiltin1(&decl, ociImageManifestsBatch) + ast.RegisterBuiltin(&ast.Builtin{ + Name: decl.Name, + Description: "Fetch Image Manifests from an OCI registry in parallel.", + Decl: decl.Decl, + Nondeterministic: decl.Nondeterministic, + }) +} + func registerOCIImageFiles() { filesObject := types.NewObject( nil, @@ -309,69 +387,93 @@ func ociBlob(bctx rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { logger.Error("input is not a string") return nil, nil } - logger = logger.WithField("ref", string(uri)) - logger.Debug("Starting blob retrieval") + refStr := string(uri) + logger = logger.WithField("ref", refStr) - ref, err := name.NewDigest(string(uri)) - if err != nil { - logger.WithFields(log.Fields{ - "action": "new digest", - "error": err, - }).Error("failed to create new digest") - return nil, nil + // Check cache first (fast path) + if cached, found := blobCache.Load(refStr); found { + logger.Debug("Blob served from cache") + return cached.(*ast.Term), nil } - rawLayer, err := oci.NewClient(bctx.Context).Layer(ref) - if err != nil { - logger.WithFields(log.Fields{ - "action": "fetch layer", - "error": err, - }).Error("failed to fetch OCI layer") - return nil, nil - } + // Use singleflight to prevent thundering herd - only one goroutine fetches per key + result, err, _ := blobFlight.Do(refStr, func() (any, error) { + // Double-check cache inside singleflight (another goroutine may have populated it) + if cached, found := blobCache.Load(refStr); found { + logger.Debug("Blob served from cache (after singleflight)") + return cached, nil + } + logger.Debug("Starting blob retrieval") + + ref, err := name.NewDigest(refStr) + if err != nil { + logger.WithFields(log.Fields{ + "action": "new digest", + "error": err, + }).Error("failed to create new digest") + return nil, nil //nolint:nilerr // intentional: return nil to signal failure without OPA error + } - layer, err := rawLayer.Uncompressed() - if err != nil { - logger.WithFields(log.Fields{ - "action": "uncompress layer", - "error": err, - }).Error("failed to uncompress OCI layer") - return nil, nil - } - defer layer.Close() + rawLayer, err := oci.NewClient(bctx.Context).Layer(ref) + if err != nil { + logger.WithFields(log.Fields{ + "action": "fetch layer", + "error": err, + }).Error("failed to fetch OCI layer") + return nil, nil //nolint:nilerr + } + + layer, err := rawLayer.Uncompressed() + if err != nil { + logger.WithFields(log.Fields{ + "action": "uncompress layer", + "error": err, + }).Error("failed to uncompress OCI layer") + return nil, nil //nolint:nilerr + } + defer layer.Close() + + // TODO: Other algorithms are technically supported, e.g. sha512. However, support for those is + // not complete in the go-containerregistry library, e.g. name.NewDigest throws an error if + // sha256 is not used. This is good for now, but may need revisiting later. + hasher := sha256.New() + reader := io.TeeReader(layer, hasher) + + var blob bytes.Buffer + if _, err := io.Copy(&blob, reader); err != nil { + logger.WithFields(log.Fields{ + "action": "copy buffer", + "error": err, + }).Error("failed to copy data into buffer") + return nil, nil //nolint:nilerr + } - // TODO: Other algorithms are technically supported, e.g. sha512. However, support for those is - // not complete in the go-containerregistry library, e.g. name.NewDigest throws an error if - // sha256 is not used. This is good for now, but may need revisiting later. - hasher := sha256.New() - reader := io.TeeReader(layer, hasher) + sum := fmt.Sprintf("sha256:%x", hasher.Sum(nil)) + // io.LimitReader truncates the layer if it exceeds its limit. The condition below catches this + // scenario in order to avoid unexpected behavior caused by partial data being returned. + if sum != ref.DigestStr() { + logger.WithFields(log.Fields{ + "action": "verify digest", + "computed_digest": sum, + "expected_digest": ref.DigestStr(), + }).Error("computed digest does not match expected digest") + return nil, nil + } - var blob bytes.Buffer - if _, err := io.Copy(&blob, reader); err != nil { logger.WithFields(log.Fields{ - "action": "copy buffer", - "error": err, - }).Error("failed to copy data into buffer") - return nil, nil - } + "action": "complete", + "digest": sum, + }).Debug("Successfully retrieved blob") - sum := fmt.Sprintf("sha256:%x", hasher.Sum(nil)) - // io.LimitReader truncates the layer if it exceeds its limit. The condition below catches this - // scenario in order to avoid unexpected behavior caused by partial data being returned. - if sum != ref.DigestStr() { - logger.WithFields(log.Fields{ - "action": "verify digest", - "computed_digest": sum, - "expected_digest": ref.DigestStr(), - }).Error("computed digest does not match expected digest") + term := ast.StringTerm(blob.String()) + blobCache.Store(refStr, term) + return term, nil + }) + + if err != nil || result == nil { return nil, nil } - - logger.WithFields(log.Fields{ - "action": "complete", - "digest": sum, - }).Debug("Successfully retrieved blob") - return ast.StringTerm(blob.String()), nil + return result.(*ast.Term), nil } func ociDescriptor(bctx rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { @@ -382,29 +484,52 @@ func ociDescriptor(bctx rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { logger.Error("input is not a string") return nil, nil } - logger = logger.WithField("input_ref", string(uriValue)) - logger.Debug("Starting descriptor retrieval") + refStr := string(uriValue) + logger = logger.WithField("input_ref", refStr) - client := oci.NewClient(bctx.Context) - - uri, ref, err := resolveIfNeeded(client, string(uriValue)) - if err != nil { - logger.WithField("action", "resolveIfNeeded").Error(err) - return nil, nil + // Check cache first (fast path) + if cached, found := descriptorCache.Load(refStr); found { + logger.Debug("Descriptor served from cache") + return cached.(*ast.Term), nil } - logger = logger.WithField("ref", uri) - descriptor, err := client.Head(ref) - if err != nil { - logger.WithFields(log.Fields{ - "action": "fetch head", - "error": err, - }).Error("failed to fetch image descriptor") + // Use singleflight to prevent thundering herd + result, err, _ := descriptorFlight.Do(refStr, func() (any, error) { + // Double-check cache inside singleflight + if cached, found := descriptorCache.Load(refStr); found { + logger.Debug("Descriptor served from cache (after singleflight)") + return cached, nil + } + logger.Debug("Starting descriptor retrieval") + + client := oci.NewClient(bctx.Context) + + uri, ref, err := resolveIfNeeded(client, refStr) + if err != nil { + logger.WithField("action", "resolveIfNeeded").Error(err) + return nil, nil //nolint:nilerr + } + logger.WithField("ref", uri).Debug("Resolved reference") + + descriptor, err := client.Head(ref) + if err != nil { + logger.WithFields(log.Fields{ + "action": "fetch head", + "error": err, + }).Error("failed to fetch image descriptor") + return nil, nil //nolint:nilerr + } + + logger.Debug("Successfully retrieved descriptor") + term := newDescriptorTerm(*descriptor) + descriptorCache.Store(refStr, term) + return term, nil + }) + + if err != nil || result == nil { return nil, nil } - - logger.Debug("Successfully retrieved descriptor") - return newDescriptorTerm(*descriptor), nil + return result.(*ast.Term), nil } func ociImageManifest(bctx rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { @@ -415,64 +540,267 @@ func ociImageManifest(bctx rego.BuiltinContext, a *ast.Term) (*ast.Term, error) logger.Error("input is not a string") return nil, nil } - logger = logger.WithField("input_ref", string(uriValue)) - logger.Debug("Starting image manifest retrieval") + refStr := string(uriValue) + logger = logger.WithField("input_ref", refStr) - client := oci.NewClient(bctx.Context) + // Check cache first (fast path) + if cached, found := manifestCache.Load(refStr); found { + logger.Debug("Image manifest served from cache") + return cached.(*ast.Term), nil + } - uri, ref, err := resolveIfNeeded(client, string(uriValue)) - if err != nil { - logger.WithField("action", "resolveIfNeeded").Error(err) + // Use singleflight to prevent thundering herd + result, err, _ := manifestFlight.Do(refStr, func() (any, error) { + // Double-check cache inside singleflight + if cached, found := manifestCache.Load(refStr); found { + logger.Debug("Image manifest served from cache (after singleflight)") + return cached, nil + } + logger.Debug("Starting image manifest retrieval") + + client := oci.NewClient(bctx.Context) + + uri, ref, err := resolveIfNeeded(client, refStr) + if err != nil { + logger.WithField("action", "resolveIfNeeded").Error(err) + return nil, nil //nolint:nilerr + } + logger.WithField("ref", uri).Debug("Resolved reference") + + var image v1.Image + err = retry.OnError(retry.DefaultRetry, func(_ error) bool { return true }, func() error { + image, err = client.Image(ref) + return err + }) + if err != nil { + logger.WithFields(log.Fields{ + "action": "fetch image", + "error": err, + }).Error("failed to fetch image") + return nil, nil //nolint:nilerr + } + + manifest, err := image.Manifest() + if err != nil { + logger.WithFields(log.Fields{ + "action": "fetch manifest", + "error": err, + }).Error("failed to fetch manifest") + return nil, nil //nolint:nilerr + } + + if manifest == nil { + logger.Error("manifest is nil") + return nil, nil + } + + layers := []*ast.Term{} + for _, layer := range manifest.Layers { + layers = append(layers, newDescriptorTerm(layer)) + } + + manifestTerms := [][2]*ast.Term{ + ast.Item(ast.StringTerm("schemaVersion"), ast.NumberTerm(json.Number(fmt.Sprintf("%d", manifest.SchemaVersion)))), + ast.Item(ast.StringTerm("mediaType"), ast.StringTerm(string(manifest.MediaType))), + ast.Item(ast.StringTerm("config"), newDescriptorTerm(manifest.Config)), + ast.Item(ast.StringTerm("layers"), ast.ArrayTerm(layers...)), + ast.Item(ast.StringTerm("annotations"), newAnnotationsTerm(manifest.Annotations)), + } + + if s := manifest.Subject; s != nil { + manifestTerms = append(manifestTerms, ast.Item(ast.StringTerm("subject"), newDescriptorTerm(*s))) + } + + logger.Debug("Successfully retrieved image manifest") + term := ast.ObjectTerm(manifestTerms...) + manifestCache.Store(refStr, term) + return term, nil + }) + + if err != nil || result == nil { return nil, nil } - logger = logger.WithField("ref", uri) + return result.(*ast.Term), nil +} - var image v1.Image - err = retry.OnError(retry.DefaultRetry, func(_ error) bool { return true }, func() error { - image, err = client.Image(ref) - return err - }) +// manifestResult holds the result of fetching a single manifest +type manifestResult struct { + ref string + manifest *ast.Term +} + +// maxParallelManifestFetches limits concurrent manifest fetches to avoid overwhelming registries. +// Defaults to GOMAXPROCS * 4, which provides good parallelism while being respectful of resources. +var maxParallelManifestFetches = runtime.GOMAXPROCS(0) * 4 + +// Package-level caches for OCI operations. +// OPA's Memoize only works within a single Eval() call, but we validate multiple +// images in separate Eval() calls. These caches persist for the lifetime of the process. +// All caches are keyed by the ref string (or ref+paths for image files). +// +// We use singleflight.Group alongside sync.Map to prevent thundering herd: +// - sync.Map stores the cached results +// - singleflight.Group ensures only one goroutine fetches a given key at a time +var ( + blobCache sync.Map // map[string]*ast.Term - for ociBlob + blobFlight singleflight.Group // deduplicates concurrent blob fetches + descriptorCache sync.Map // map[string]*ast.Term - for ociDescriptor + descriptorFlight singleflight.Group + manifestCache sync.Map // map[string]*ast.Term - for ociImageManifest + manifestFlight singleflight.Group + imageFilesCache sync.Map // map[string]*ast.Term - for ociImageFiles (key: ref+pathsHash) + imageFilesFlight singleflight.Group + imageIndexCache sync.Map // map[string]*ast.Term - for ociImageIndex + imageIndexFlight singleflight.Group +) + +// batchCallCounter tracks how many times ociImageManifestsBatch is called (for debugging) +var batchCallCounter uint64 + +// ClearCaches clears all package-level caches. This is primarily used for testing +// to ensure tests don't interfere with each other via cached values. +func ClearCaches() { + blobCache = sync.Map{} + descriptorCache = sync.Map{} + manifestCache = sync.Map{} + imageFilesCache = sync.Map{} + imageIndexCache = sync.Map{} +} + +func ociImageManifestsBatch(bctx rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { + callNum := atomic.AddUint64(&batchCallCounter, 1) + logger := log.WithField("function", ociImageManifestsBatchName) + + refsSet, err := builtins.SetOperand(a.Value, 1) if err != nil { logger.WithFields(log.Fields{ - "action": "fetch image", + "action": "convert refs", "error": err, - }).Error("failed to fetch image") + }).Error("failed to convert refs to set operand") return nil, nil } - manifest, err := image.Manifest() + // Collect all ref terms and check cache + var uncachedTerms []*ast.Term + cachedResults := make(map[string]*ast.Term) + + err = refsSet.Iter(func(refTerm *ast.Term) error { + refStr, ok := refTerm.Value.(ast.String) + if !ok { + return fmt.Errorf("ref is not a string: %#v", refTerm) + } + ref := string(refStr) + + // Check cache first + if cached, found := manifestCache.Load(ref); found { + cachedResults[ref] = cached.(*ast.Term) + } else { + uncachedTerms = append(uncachedTerms, refTerm) + } + return nil + }) if err != nil { logger.WithFields(log.Fields{ - "action": "fetch manifest", + "action": "iterate refs", "error": err, - }).Error("failed to fetch manifest") + }).Error("failed iterating refs") return nil, nil } - if manifest == nil { - logger.Error("manifest is nil") - return nil, nil + totalRefs := len(cachedResults) + len(uncachedTerms) + logger.WithFields(log.Fields{ + "call_number": callNum, + "total_refs": totalRefs, + "cached_refs": len(cachedResults), + "uncached_refs": len(uncachedTerms), + "concurrency": maxParallelManifestFetches, + }).Debug("Starting parallel image manifest retrieval with caching") + + if totalRefs == 0 { + return ast.ObjectTerm(), nil + } + + // Build result from cached entries + resultTerms := make([][2]*ast.Term, 0, totalRefs) + for ref, manifest := range cachedResults { + resultTerms = append(resultTerms, ast.Item(ast.StringTerm(ref), manifest)) + } + + // If everything was cached, return early + if len(uncachedTerms) == 0 { + logger.WithField("success_count", len(resultTerms)).Debug("All manifests served from cache") + return ast.ObjectTerm(resultTerms...), nil } - layers := []*ast.Term{} - for _, layer := range manifest.Layers { - layers = append(layers, newDescriptorTerm(layer)) + // Fetch uncached refs in parallel + g, ctx := errgroup.WithContext(bctx.Context) + g.SetLimit(maxParallelManifestFetches) + + results := make(chan manifestResult, len(uncachedTerms)) + + bctxWithCancel := rego.BuiltinContext{ + Context: ctx, + Cancel: bctx.Cancel, + Runtime: bctx.Runtime, + Time: bctx.Time, + Seed: bctx.Seed, + Metrics: bctx.Metrics, + Location: bctx.Location, + Tracers: bctx.Tracers, } - manifestTerms := [][2]*ast.Term{ - ast.Item(ast.StringTerm("schemaVersion"), ast.NumberTerm(json.Number(fmt.Sprintf("%d", manifest.SchemaVersion)))), - ast.Item(ast.StringTerm("mediaType"), ast.StringTerm(string(manifest.MediaType))), - ast.Item(ast.StringTerm("config"), newDescriptorTerm(manifest.Config)), - ast.Item(ast.StringTerm("layers"), ast.ArrayTerm(layers...)), - ast.Item(ast.StringTerm("annotations"), newAnnotationsTerm(manifest.Annotations)), + for _, refTerm := range uncachedTerms { + term := refTerm + g.Go(func() error { + select { + case <-ctx.Done(): + return nil + default: + } + + ref := string(term.Value.(ast.String)) + manifest, err := ociImageManifest(bctxWithCancel, term) + if err != nil { + logger.WithFields(log.Fields{ + "ref": ref, + "error": err, + }).Error("failed to fetch manifest in batch") + results <- manifestResult{ref: ref, manifest: nil} + return nil + } + + // Store in cache (even nil results to avoid re-fetching failures) + if manifest != nil { + manifestCache.Store(ref, manifest) + } + + results <- manifestResult{ref: ref, manifest: manifest} + return nil + }) } - if s := manifest.Subject; s != nil { - manifestTerms = append(manifestTerms, ast.Item(ast.StringTerm("subject"), newDescriptorTerm(*s))) + go func() { + _ = g.Wait() + close(results) + }() + + // Collect newly fetched results + var mu sync.Mutex + for result := range results { + if result.manifest != nil { + mu.Lock() + resultTerms = append(resultTerms, ast.Item(ast.StringTerm(result.ref), result.manifest)) + mu.Unlock() + } } - logger.Debug("Successfully retrieved image manifest") - return ast.ObjectTerm(manifestTerms...), nil + logger.WithFields(log.Fields{ + "success_count": len(resultTerms), + "from_cache": len(cachedResults), + "newly_fetched": len(resultTerms) - len(cachedResults), + }).Debug("Completed parallel image manifest retrieval") + + return ast.ObjectTerm(resultTerms...), nil } func ociImageFiles(bctx rego.BuiltinContext, refTerm *ast.Term, pathsTerm *ast.Term) (*ast.Term, error) { @@ -483,64 +811,96 @@ func ociImageFiles(bctx rego.BuiltinContext, refTerm *ast.Term, pathsTerm *ast.T logger.Error("input ref is not a string") return nil, nil } - logger = logger.WithField("ref", string(uri)) - logger.Debug("Starting image files extraction") + refStr := string(uri) + logger = logger.WithField("ref", refStr) - ref, err := name.NewDigest(string(uri)) - if err != nil { - logger.WithFields(log.Fields{ - "action": "new digest", - "error": err, - }).Error("failed to create new digest") + if pathsTerm == nil { + logger.Error("paths term is nil") return nil, nil } - pathsArray, err := builtins.ArrayOperand(pathsTerm.Value, 1) - if err != nil { - logger.WithFields(log.Fields{ - "action": "convert paths", - "error": err, - }).Error("failed to convert paths to array operand") - return nil, nil + // Build cache key from ref + paths (hash the paths for a stable key) + pathsHash := fmt.Sprintf("%x", sha256.Sum256([]byte(pathsTerm.String())))[:12] + cacheKey := refStr + ":" + pathsHash + + // Check cache first (fast path) + if cached, found := imageFilesCache.Load(cacheKey); found { + logger.Debug("Image files served from cache") + return cached.(*ast.Term), nil } - var extractors []files.Extractor - err = pathsArray.Iter(func(pathTerm *ast.Term) error { - pathString, ok := pathTerm.Value.(ast.String) - if !ok { - return fmt.Errorf("path is not a string: %#v", pathTerm) + // Use singleflight to prevent thundering herd + result, err, _ := imageFilesFlight.Do(cacheKey, func() (any, error) { + // Double-check cache inside singleflight + if cached, found := imageFilesCache.Load(cacheKey); found { + logger.Debug("Image files served from cache (after singleflight)") + return cached, nil + } + logger.Debug("Starting image files extraction") + + ref, err := name.NewDigest(refStr) + if err != nil { + logger.WithFields(log.Fields{ + "action": "new digest", + "error": err, + }).Error("failed to create new digest") + return nil, nil //nolint:nilerr } - extractors = append(extractors, files.PathExtractor{Path: string(pathString)}) - return nil - }) - if err != nil { - logger.WithFields(log.Fields{ - "action": "iterate paths", - "error": err, - }).Error("failed iterating paths") - return nil, nil - } - files, err := files.ImageFiles(bctx.Context, ref, extractors) - if err != nil { - logger.WithFields(log.Fields{ - "action": "extract files", - "error": err, - }).Error("failed to extract image files") - return nil, nil - } + pathsArray, err := builtins.ArrayOperand(pathsTerm.Value, 1) + if err != nil { + logger.WithFields(log.Fields{ + "action": "convert paths", + "error": err, + }).Error("failed to convert paths to array operand") + return nil, nil //nolint:nilerr + } - filesValue, err := ast.InterfaceToValue(files) - if err != nil { - logger.WithFields(log.Fields{ - "action": "convert files", - "error": err, - }).Error("failed to convert files object to value") + var extractors []files.Extractor + err = pathsArray.Iter(func(pathTerm *ast.Term) error { + pathString, ok := pathTerm.Value.(ast.String) + if !ok { + return fmt.Errorf("path is not a string: %#v", pathTerm) + } + extractors = append(extractors, files.PathExtractor{Path: string(pathString)}) + return nil + }) + if err != nil { + logger.WithFields(log.Fields{ + "action": "iterate paths", + "error": err, + }).Error("failed iterating paths") + return nil, nil //nolint:nilerr + } + + filesResult, err := files.ImageFiles(bctx.Context, ref, extractors) + if err != nil { + logger.WithFields(log.Fields{ + "action": "extract files", + "error": err, + }).Error("failed to extract image files") + return nil, nil //nolint:nilerr + } + + filesValue, err := ast.InterfaceToValue(filesResult) + if err != nil { + logger.WithFields(log.Fields{ + "action": "convert files", + "error": err, + }).Error("failed to convert files object to value") + return nil, nil //nolint:nilerr + } + + logger.Debug("Successfully extracted image files") + term := ast.NewTerm(filesValue) + imageFilesCache.Store(cacheKey, term) + return term, nil + }) + + if err != nil || result == nil { return nil, nil } - - logger.Debug("Successfully extracted image files") - return ast.NewTerm(filesValue), nil + return result.(*ast.Term), nil } func ociImageIndex(bctx rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { @@ -551,59 +911,82 @@ func ociImageIndex(bctx rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { logger.Error("input is not a string") return nil, nil } - logger = logger.WithField("input_ref", string(uriValue)) - logger.Debug("Starting image index retrieval") - - client := oci.NewClient(bctx.Context) + refStr := string(uriValue) + logger = logger.WithField("input_ref", refStr) - uri, ref, err := resolveIfNeeded(client, string(uriValue)) - if err != nil { - logger.WithField("action", "resolveIfNeeded").Error(err) - return nil, nil + // Check cache first (fast path) + if cached, found := imageIndexCache.Load(refStr); found { + logger.Debug("Image index served from cache") + return cached.(*ast.Term), nil } - logger = logger.WithField("ref", uri) - imageIndex, err := client.Index(ref) - if err != nil { - logger.WithFields(log.Fields{ - "action": "fetch image index", - "error": err, - }).Error("failed to fetch image index") - return nil, nil - } + // Use singleflight to prevent thundering herd + result, err, _ := imageIndexFlight.Do(refStr, func() (any, error) { + // Double-check cache inside singleflight + if cached, found := imageIndexCache.Load(refStr); found { + logger.Debug("Image index served from cache (after singleflight)") + return cached, nil + } + logger.Debug("Starting image index retrieval") - indexManifest, err := imageIndex.IndexManifest() - if err != nil { - logger.WithFields(log.Fields{ - "action": "fetch index manifest", - "error": err, - }).Error("failed to fetch index manifest") - return nil, nil - } + client := oci.NewClient(bctx.Context) - if indexManifest == nil { - logger.Error("index manifest is nil") - return nil, nil - } + uri, ref, err := resolveIfNeeded(client, refStr) + if err != nil { + logger.WithField("action", "resolveIfNeeded").Error(err) + return nil, nil //nolint:nilerr + } + logger.WithField("ref", uri).Debug("Resolved reference") + + imageIndex, err := client.Index(ref) + if err != nil { + logger.WithFields(log.Fields{ + "action": "fetch image index", + "error": err, + }).Error("failed to fetch image index") + return nil, nil //nolint:nilerr + } - manifestTerms := []*ast.Term{} - for _, manifest := range indexManifest.Manifests { - manifestTerms = append(manifestTerms, newDescriptorTerm(manifest)) - } + indexManifest, err := imageIndex.IndexManifest() + if err != nil { + logger.WithFields(log.Fields{ + "action": "fetch index manifest", + "error": err, + }).Error("failed to fetch index manifest") + return nil, nil //nolint:nilerr + } - imageIndexTerms := [][2]*ast.Term{ - ast.Item(ast.StringTerm("schemaVersion"), ast.NumberTerm(json.Number(fmt.Sprintf("%d", indexManifest.SchemaVersion)))), - ast.Item(ast.StringTerm("mediaType"), ast.StringTerm(string(indexManifest.MediaType))), - ast.Item(ast.StringTerm("manifests"), ast.ArrayTerm(manifestTerms...)), - ast.Item(ast.StringTerm("annotations"), newAnnotationsTerm(indexManifest.Annotations)), - } + if indexManifest == nil { + logger.Error("index manifest is nil") + return nil, nil + } - if s := indexManifest.Subject; s != nil { - imageIndexTerms = append(imageIndexTerms, ast.Item(ast.StringTerm("subject"), newDescriptorTerm(*s))) - } + manifestTerms := []*ast.Term{} + for _, manifest := range indexManifest.Manifests { + manifestTerms = append(manifestTerms, newDescriptorTerm(manifest)) + } + + imageIndexTerms := [][2]*ast.Term{ + ast.Item(ast.StringTerm("schemaVersion"), ast.NumberTerm(json.Number(fmt.Sprintf("%d", indexManifest.SchemaVersion)))), + ast.Item(ast.StringTerm("mediaType"), ast.StringTerm(string(indexManifest.MediaType))), + ast.Item(ast.StringTerm("manifests"), ast.ArrayTerm(manifestTerms...)), + ast.Item(ast.StringTerm("annotations"), newAnnotationsTerm(indexManifest.Annotations)), + } + + if s := indexManifest.Subject; s != nil { + imageIndexTerms = append(imageIndexTerms, ast.Item(ast.StringTerm("subject"), newDescriptorTerm(*s))) + } - logger.Debug("Successfully retrieved image index") - return ast.ObjectTerm(imageIndexTerms...), nil + logger.Debug("Successfully retrieved image index") + term := ast.ObjectTerm(imageIndexTerms...) + imageIndexCache.Store(refStr, term) + return term, nil + }) + + if err != nil || result == nil { + return nil, nil + } + return result.(*ast.Term), nil } func newPlatformTerm(p v1.Platform) *ast.Term { @@ -702,5 +1085,6 @@ func init() { registerOCIDescriptor() registerOCIImageFiles() registerOCIImageManifest() + registerOCIImageManifestsBatch() registerOCIImageIndex() } diff --git a/internal/rego/oci/oci_test.go b/internal/rego/oci/oci_test.go index aa220e108..2cd93ccd3 100644 --- a/internal/rego/oci/oci_test.go +++ b/internal/rego/oci/oci_test.go @@ -21,6 +21,7 @@ package oci import ( "context" "errors" + "fmt" "testing" "github.com/gkampitakis/go-snaps/snaps" @@ -39,6 +40,9 @@ import ( ) func TestOCIBlob(t *testing.T) { + t.Cleanup(ClearCaches) + ClearCaches() // Clear before test to avoid interference from previous tests + cases := []struct { name string data string @@ -86,6 +90,8 @@ func TestOCIBlob(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { + ClearCaches() // Clear cache before each subtest + client := fake.FakeClient{} if c.remoteErr != nil { client.On("Layer", mock.Anything, mock.Anything).Return(nil, c.remoteErr) @@ -111,6 +117,9 @@ func TestOCIBlob(t *testing.T) { } func TestOCIDescriptorManifest(t *testing.T) { + t.Cleanup(ClearCaches) + ClearCaches() + cases := []struct { name string ref *ast.Term @@ -224,6 +233,8 @@ func TestOCIDescriptorManifest(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { + ClearCaches() // Clear cache before each subtest + client := fake.FakeClient{} if c.headErr != nil { client.On("Head", mock.Anything).Return(nil, c.headErr) @@ -251,6 +262,9 @@ func TestOCIDescriptorManifest(t *testing.T) { } func TestOCIDescriptorErrors(t *testing.T) { + t.Cleanup(ClearCaches) + ClearCaches() + cases := []struct { name string ref *ast.Term @@ -271,6 +285,8 @@ func TestOCIDescriptorErrors(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { + ClearCaches() // Clear cache before each subtest + client := fake.FakeClient{} client.On("Head", mock.Anything, mock.Anything).Return(nil, errors.New("expected")) ctx := oci.WithClient(context.Background(), &client) @@ -284,6 +300,9 @@ func TestOCIDescriptorErrors(t *testing.T) { } func TestOCIImageManifest(t *testing.T) { + t.Cleanup(ClearCaches) + ClearCaches() + cases := []struct { name string ref *ast.Term @@ -467,6 +486,8 @@ func TestOCIImageManifest(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { + ClearCaches() // Clear cache before each subtest + client := fake.FakeClient{} if c.imageErr != nil { client.On("Image", mock.Anything, mock.Anything).Return(nil, c.imageErr) @@ -495,7 +516,180 @@ func TestOCIImageManifest(t *testing.T) { } } +func TestOCIImageManifestsBatch(t *testing.T) { + t.Cleanup(ClearCaches) + ClearCaches() + + minimalManifest := &v1.Manifest{ + SchemaVersion: 2, + MediaType: types.OCIManifestSchema1, + Config: v1.Descriptor{ + MediaType: types.OCIConfigJSON, + Size: 123, + Digest: v1.Hash{ + Algorithm: "sha256", + Hex: "4e388ab32b10dc8dbc7e28144f552830adc74787c1e2c0824032078a79f227fb", + }, + }, + Layers: []v1.Descriptor{ + { + MediaType: types.OCILayer, + Size: 9999, + Digest: v1.Hash{ + Algorithm: "sha256", + Hex: "325392e8dd2826a53a9a35b7a7f8d71683cd27ebc2c73fee85dab673bc909b67", + }, + }, + }, + } + + cases := []struct { + name string + refs *ast.Term + manifest *v1.Manifest + manifestErr error + wantErr bool + wantCount int + wantKeys []string + }{ + { + name: "single ref success", + refs: ast.NewTerm(ast.NewSet( + ast.StringTerm("registry.local/spam:latest@sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b"), + )), + manifest: minimalManifest, + wantCount: 1, + wantKeys: []string{"registry.local/spam:latest@sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b"}, + }, + { + name: "multiple refs success", + refs: ast.NewTerm(ast.NewSet( + ast.StringTerm("registry.local/img1:latest@sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b"), + ast.StringTerm("registry.local/img2:latest@sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b"), + ast.StringTerm("registry.local/img3:latest@sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b"), + )), + manifest: minimalManifest, + wantCount: 3, + }, + { + name: "empty set", + refs: ast.NewTerm(ast.NewSet()), + manifest: minimalManifest, + wantCount: 0, + }, + { + name: "invalid input type", + refs: ast.StringTerm("not-a-set"), + wantErr: true, + }, + { + name: "non-string ref in set", + refs: ast.NewTerm(ast.NewSet( + ast.IntNumberTerm(42), + )), + wantErr: true, + }, + { + name: "manifest fetch error excludes ref from result", + refs: ast.NewTerm(ast.NewSet( + ast.StringTerm("registry.local/spam:latest@sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b"), + )), + manifestErr: errors.New("fetch error"), + wantCount: 0, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + ClearCaches() // Clear cache before each subtest + + client := fake.FakeClient{} + if c.manifestErr != nil { + imageManifest := v1fake.FakeImage{} + imageManifest.ManifestReturns(nil, c.manifestErr) + client.On("Image", mock.Anything, mock.Anything).Return(&imageManifest, nil) + } else { + imageManifest := v1fake.FakeImage{} + imageManifest.ManifestReturns(c.manifest, nil) + client.On("Image", mock.Anything, mock.Anything).Return(&imageManifest, nil) + } + + ctx := oci.WithClient(context.Background(), &client) + bctx := rego.BuiltinContext{Context: ctx} + + got, err := ociImageManifestsBatch(bctx, c.refs) + require.NoError(t, err) + + if c.wantErr { + require.Nil(t, got) + } else { + require.NotNil(t, got) + obj, ok := got.Value.(ast.Object) + require.True(t, ok, "result should be an object") + require.Equal(t, c.wantCount, obj.Len(), "unexpected number of results") + + if len(c.wantKeys) > 0 { + for _, key := range c.wantKeys { + val := obj.Get(ast.StringTerm(key)) + require.NotNil(t, val, "expected key %s not found", key) + } + } + } + }) + } +} + +func TestOCIImageManifestsBatchConcurrency(t *testing.T) { + t.Cleanup(ClearCaches) + ClearCaches() + + // Save and restore the original value + original := maxParallelManifestFetches + defer func() { maxParallelManifestFetches = original }() + + // Set a low concurrency limit for testing + maxParallelManifestFetches = 2 + + minimalManifest := &v1.Manifest{ + SchemaVersion: 2, + MediaType: types.OCIManifestSchema1, + Config: v1.Descriptor{ + MediaType: types.OCIConfigJSON, + Size: 123, + Digest: v1.Hash{ + Algorithm: "sha256", + Hex: "4e388ab32b10dc8dbc7e28144f552830adc74787c1e2c0824032078a79f227fb", + }, + }, + Layers: []v1.Descriptor{}, + } + + // Create more refs than the concurrency limit to test bounded concurrency + refsSet := ast.NewSet() + for i := 0; i < 10; i++ { + refsSet.Add(ast.StringTerm(fmt.Sprintf("registry.local/img%d:latest@sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b", i))) + } + + client := fake.FakeClient{} + imageManifest := v1fake.FakeImage{} + imageManifest.ManifestReturns(minimalManifest, nil) + client.On("Image", mock.Anything, mock.Anything).Return(&imageManifest, nil) + + ctx := oci.WithClient(context.Background(), &client) + bctx := rego.BuiltinContext{Context: ctx} + + got, err := ociImageManifestsBatch(bctx, ast.NewTerm(refsSet)) + require.NoError(t, err) + require.NotNil(t, got) + + obj, ok := got.Value.(ast.Object) + require.True(t, ok) + require.Equal(t, 10, obj.Len(), "all refs should be processed") +} + func TestOCIImageFiles(t *testing.T) { + t.Cleanup(ClearCaches) + ClearCaches() image, err := crane.Image(map[string][]byte{ "autoexec.bat": []byte(`@ECHO OFF`), @@ -549,6 +743,8 @@ func TestOCIImageFiles(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { + ClearCaches() // Clear cache before each subtest + client := fake.FakeClient{} if c.remoteErr != nil { client.On("Image", mock.Anything).Return(nil, c.remoteErr) @@ -572,6 +768,9 @@ func TestOCIImageFiles(t *testing.T) { } func TestOCIImageIndex(t *testing.T) { + t.Cleanup(ClearCaches) + ClearCaches() + cases := []struct { name string ref *ast.Term @@ -728,6 +927,8 @@ func TestOCIImageIndex(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { + ClearCaches() // Clear cache before each subtest + client := fake.FakeClient{} if c.indexErr != nil { @@ -765,6 +966,7 @@ func TestFunctionsRegistered(t *testing.T) { ociDescriptorName, ociImageFilesName, ociImageManifestName, + ociImageManifestsBatchName, ociImageIndexName, } for _, name := range names {