From dff286a570a675dbf74ce56e2954a2d6c8ad17bf Mon Sep 17 00:00:00 2001 From: Wael Nasreddine Date: Thu, 15 Jan 2026 01:40:29 -0800 Subject: [PATCH] feat: Implement Nix Binary Cache Index Protocol This implements the Nix Binary Cache Index Protocol as outlined in RFC 0195. The index provides a way for clients to efficiently determine if a NAR info hash exists in the cache without querying the server for every hash. Key changes: - Added FileStore interface and implementations for local and S3 storage. - Added GetAllNarInfos database query and implementation. - Implemented nixcacheindex package for manifest, shard, and delta management. - Wired up the index feature into pkg/cache and pkg/server. - Added --experimental-cache-index flag to enable the feature. - Implemented background index generation using cron. --- db/query.mysql.sql | 4 + db/query.postgres.sql | 4 + db/query.sqlite.sql | 4 + pkg/cache/cache.go | 67 ++++++++ pkg/cache/cache_distributed_test.go | 3 + pkg/cache/cache_internal_test.go | 2 +- pkg/cache/cache_test.go | 11 +- pkg/cache/export_test.go | 7 + pkg/cache/healthcheck/healthcheck_test.go | 2 +- pkg/cache/index.go | 148 ++++++++++++++++ pkg/cache/index_test.go | 104 ++++++++++++ pkg/database/mysqldb/querier.go | 7 +- pkg/database/mysqldb/query.mysql.sql.go | 36 +++- pkg/database/postgresdb/querier.go | 7 +- pkg/database/postgresdb/query.postgres.sql.go | 36 +++- pkg/database/querier.go | 1 + pkg/database/sqlitedb/querier.go | 7 +- pkg/database/sqlitedb/query.sqlite.sql.go | 36 +++- pkg/database/wrapper_mysql.go | 10 ++ pkg/database/wrapper_postgres.go | 10 ++ pkg/database/wrapper_sqlite.go | 10 ++ pkg/ncps/serve.go | 30 ++-- pkg/server/server.go | 50 ++++++ pkg/server/server_internal_test.go | 2 +- pkg/server/server_test.go | 11 +- pkg/storage/local/local.go | 159 +++++++++++++++++- pkg/storage/local/local_test.go | 42 +++++ pkg/storage/s3/s3.go | 122 ++++++++++++++ pkg/storage/store.go | 16 ++ 29 files changed, 912 insertions(+), 36 deletions(-) create mode 100644 pkg/cache/export_test.go create mode 100644 pkg/cache/index.go create mode 100644 pkg/cache/index_test.go diff --git a/db/query.mysql.sql b/db/query.mysql.sql index 7e95790e..f21ac88d 100644 --- a/db/query.mysql.sql +++ b/db/query.mysql.sql @@ -164,3 +164,7 @@ SELECT nf.* FROM nar_files nf LEFT JOIN narinfo_nar_files ninf ON nf.id = ninf.nar_file_id WHERE ninf.narinfo_id IS NULL; + +-- name: GetAllNarInfos :many +SELECT hash +FROM narinfos; diff --git a/db/query.postgres.sql b/db/query.postgres.sql index eb0f814c..20281173 100644 --- a/db/query.postgres.sql +++ b/db/query.postgres.sql @@ -168,3 +168,7 @@ SELECT nf.* FROM nar_files nf LEFT JOIN narinfo_nar_files ninf ON nf.id = ninf.nar_file_id WHERE ninf.narinfo_id IS NULL; + +-- name: GetAllNarInfos :many +SELECT hash +FROM narinfos; diff --git a/db/query.sqlite.sql b/db/query.sqlite.sql index eb8a5f4b..d4bc8c9f 100644 --- a/db/query.sqlite.sql +++ b/db/query.sqlite.sql @@ -168,3 +168,7 @@ SELECT nf.* FROM nar_files nf LEFT JOIN narinfo_nar_files ninf ON nf.id = ninf.nar_file_id WHERE ninf.narinfo_id IS NULL; + +-- name: GetAllNarInfos :many +SELECT hash +FROM narinfos; diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index f359c8b5..cd01c0f1 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -30,6 +30,7 @@ import ( "github.com/kalbasit/ncps/pkg/database" "github.com/kalbasit/ncps/pkg/lock" "github.com/kalbasit/ncps/pkg/nar" + "github.com/kalbasit/ncps/pkg/nixcacheindex" "github.com/kalbasit/ncps/pkg/storage" ) @@ -137,10 +138,17 @@ type Cache struct { configStore storage.ConfigStore narInfoStore storage.NarInfoStore narStore storage.NarStore + fileStore storage.FileStore // Should the cache sign the narinfos? shouldSignNarinfo bool + // Should the cache generate the experimental cache index? + experimentalCacheIndex bool + experimentalCacheIndexHTTPS bool + indexClient *nixcacheindex.Client + indexGenerationJobID cron.EntryID + // recordAgeIgnoreTouch represents the duration at which a record is // considered up to date and a touch is not invoked. This helps avoid // repetitive touching of records in the database which are causing `database @@ -216,6 +224,13 @@ func (ds *downloadState) getError() error { return ds.downloadError } +// Fetch implements the nixcacheindex.Fetcher interface. +func (c *Cache) Fetch(ctx context.Context, path string) (io.ReadCloser, error) { + _, rc, err := c.fileStore.GetFile(ctx, path) + + return rc, err +} + // New returns a new Cache. func New( ctx context.Context, @@ -225,6 +240,7 @@ func New( configStore storage.ConfigStore, narInfoStore storage.NarInfoStore, narStore storage.NarStore, + fileStore storage.FileStore, secretKeyPath string, downloadLocker lock.Locker, cacheLocker lock.RWLocker, @@ -238,6 +254,7 @@ func New( configStore: configStore, narInfoStore: narInfoStore, narStore: narStore, + fileStore: fileStore, shouldSignNarinfo: true, downloadLocker: downloadLocker, cacheLocker: cacheLocker, @@ -276,6 +293,9 @@ func New( c.processHealthChanges(ctx, healthChangeCh) }) + c.cron = cron.New() + c.cron.Start() + return c, nil } @@ -376,6 +396,40 @@ func (c *Cache) GetHealthChecker() *healthcheck.HealthChecker { return c.healthC // SetCacheSignNarinfo configure ncps to sign or not sign narinfos. func (c *Cache) SetCacheSignNarinfo(shouldSignNarinfo bool) { c.shouldSignNarinfo = shouldSignNarinfo } +// SetExperimentalCacheIndex configure ncps to generate the cache index or not. +func (c *Cache) SetExperimentalCacheIndex(experimentalCacheIndex bool) { + c.experimentalCacheIndex = experimentalCacheIndex + + if !c.experimentalCacheIndex { + c.indexClient = nil + + if c.indexGenerationJobID != 0 { + c.cron.Remove(c.indexGenerationJobID) + c.indexGenerationJobID = 0 + } + + return + } + + c.indexClient = nixcacheindex.NewClient(c.baseContext, c) + + if c.indexGenerationJobID != 0 { + return + } + + jobID, err := c.cron.AddFunc("@hourly", c.generateIndex) + if err != nil { + zerolog.Ctx(c.baseContext).Error().Err(err).Msg("failed to schedule cache index generation") + + return + } + + c.indexGenerationJobID = jobID +} + +// SetExperimentalCacheIndexHTTPS configure ncps to use HTTPS for the cache index. +func (c *Cache) SetExperimentalCacheIndexHTTPS(https bool) { c.experimentalCacheIndexHTTPS = https } + // SetMaxSize sets the maxsize of the cache. This will be used by the LRU // cronjob to automatically clean-up the store. func (c *Cache) SetMaxSize(maxSize uint64) { c.maxSize = maxSize } @@ -1087,6 +1141,19 @@ func (c *Cache) GetNarInfo(ctx context.Context, hash string) (*narinfo.NarInfo, } } + if c.indexClient != nil { + status, err := c.indexClient.Query(ctx, hash) + if err != nil { + // Don't fail the request if the index lookup fails, just log it and proceed + zerolog.Ctx(ctx).Warn().Err(err).Str("hash", hash).Msg("cache index query failed") + } else if status == nixcacheindex.DefiniteMiss { + // Avoid checking upstream if we know it's a definite miss + zerolog.Ctx(ctx).Debug().Str("hash", hash).Msg("cache index says definite miss") + + return storage.ErrNotFound + } + } + metricAttrs = append(metricAttrs, attribute.String("result", "miss"), attribute.String("status", "success"), diff --git a/pkg/cache/cache_distributed_test.go b/pkg/cache/cache_distributed_test.go index 640c0c25..560e262e 100644 --- a/pkg/cache/cache_distributed_test.go +++ b/pkg/cache/cache_distributed_test.go @@ -130,6 +130,7 @@ func TestDistributedDownloadDeduplication(t *testing.T) { sharedStore, sharedStore, sharedStore, + sharedStore, "", downloadLocker, cacheLocker, @@ -247,6 +248,7 @@ func TestDistributedConcurrentReads(t *testing.T) { sharedStore, sharedStore, sharedStore, + sharedStore, "", downloadLocker, cacheLocker, @@ -288,6 +290,7 @@ func TestDistributedConcurrentReads(t *testing.T) { sharedStore, sharedStore, sharedStore, + sharedStore, "", downloadLocker, cacheLocker, diff --git a/pkg/cache/cache_internal_test.go b/pkg/cache/cache_internal_test.go index 8fafc13d..a914a902 100644 --- a/pkg/cache/cache_internal_test.go +++ b/pkg/cache/cache_internal_test.go @@ -69,7 +69,7 @@ func setupTestCache(t *testing.T) (*Cache, func()) { downloadLocker := locklocal.NewLocker() cacheLocker := locklocal.NewRWLocker() - c, err := New(newContext(), cacheName, db, localStore, localStore, localStore, "", + c, err := New(newContext(), cacheName, db, localStore, localStore, localStore, localStore, "", downloadLocker, cacheLocker, downloadLockTTL, cacheLockTTL) if err != nil { cleanup() diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index abe7fbb0..cbb3dde2 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -47,12 +47,13 @@ func newTestCache( configStore storage.ConfigStore, narInfoStore storage.NarInfoStore, narStore storage.NarStore, + fileStore storage.FileStore, secretKeyPath string, ) (*cache.Cache, error) { downloadLocker := locklocal.NewLocker() cacheLocker := locklocal.NewRWLocker() - return cache.New(ctx, hostName, db, configStore, narInfoStore, narStore, secretKeyPath, + return cache.New(ctx, hostName, db, configStore, narInfoStore, narStore, fileStore, secretKeyPath, downloadLocker, cacheLocker, 5*time.Minute, 30*time.Minute) } @@ -83,7 +84,7 @@ func setupTestCache(t *testing.T) (*cache.Cache, database.Querier, *local.Store, db, localStore, dir, cleanup := setupTestComponents(t) - c, err := newTestCache(newContext(), cacheName, db, localStore, localStore, localStore, "") + c, err := newTestCache(newContext(), cacheName, db, localStore, localStore, localStore, localStore, "") require.NoError(t, err) return c, db, localStore, dir, cleanup @@ -129,7 +130,7 @@ func TestNew(t *testing.T) { db, localStore, _, cleanup := setupTestComponents(t) defer cleanup() - _, err := newTestCache(newContext(), tt.hostname, db, localStore, localStore, localStore, "") + _, err := newTestCache(newContext(), tt.hostname, db, localStore, localStore, localStore, localStore, "") if tt.wantErr != nil { assert.ErrorIs(t, err, tt.wantErr) } else { @@ -180,7 +181,7 @@ func TestNew(t *testing.T) { require.NoError(t, skFile.Close()) - c, err := newTestCache(newContext(), cacheName, db, localStore, localStore, localStore, skFile.Name()) + c, err := newTestCache(newContext(), cacheName, db, localStore, localStore, localStore, localStore, skFile.Name()) require.NoError(t, err) // Verify key is NOT in local store @@ -207,7 +208,7 @@ func TestNew(t *testing.T) { err = localStore.PutSecretKey(newContext(), sk) require.NoError(t, err) - c, err := newTestCache(newContext(), cacheName, db, localStore, localStore, localStore, "") + c, err := newTestCache(newContext(), cacheName, db, localStore, localStore, localStore, localStore, "") require.NoError(t, err) // Verify key is NOT in local store anymore diff --git a/pkg/cache/export_test.go b/pkg/cache/export_test.go new file mode 100644 index 00000000..4ad5ad81 --- /dev/null +++ b/pkg/cache/export_test.go @@ -0,0 +1,7 @@ +package cache + +import "context" + +func (c *Cache) GenerateIndexForTest(ctx context.Context) error { + return c.doGenerateIndex(ctx) +} diff --git a/pkg/cache/healthcheck/healthcheck_test.go b/pkg/cache/healthcheck/healthcheck_test.go index dce47cc4..c894956a 100644 --- a/pkg/cache/healthcheck/healthcheck_test.go +++ b/pkg/cache/healthcheck/healthcheck_test.go @@ -53,7 +53,7 @@ func TestHealthCheck(t *testing.T) { downloadLocker := locklocal.NewLocker() cacheLocker := locklocal.NewRWLocker() - c, err := cache.New(newContext(), cacheName, db, localStore, localStore, localStore, "", + c, err := cache.New(newContext(), cacheName, db, localStore, localStore, localStore, localStore, "", downloadLocker, cacheLocker, 5*time.Minute, 30*time.Minute) require.NoError(t, err) diff --git a/pkg/cache/index.go b/pkg/cache/index.go new file mode 100644 index 00000000..b287c232 --- /dev/null +++ b/pkg/cache/index.go @@ -0,0 +1,148 @@ +package cache + +import ( + "bytes" + "context" + "fmt" + "math/big" + "sort" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/rs/zerolog" + + "github.com/kalbasit/ncps/pkg/nixcacheindex" +) + +// generateIndex rebuilds the entire cache index (manifest and shards) from the database. +func (c *Cache) generateIndex() { + ctx := c.baseContext + logger := zerolog.Ctx(ctx).With().Str("component", "index_generator").Logger() + ctx = logger.WithContext(ctx) + + logger.Info().Msg("starting cache index generation") + + start := time.Now() + + if err := c.doGenerateIndex(ctx); err != nil { + logger.Error().Err(err).Msg("failed to generate cache index") + } else { + logger.Info().Dur("duration", time.Since(start)).Msg("cache index generation completed successfully") + } +} + +func (c *Cache) doGenerateIndex(ctx context.Context) error { + // 1. Fetch all NarInfo hashes from DB + rows, err := c.db.GetAllNarInfos(ctx) + if err != nil { + return fmt.Errorf("failed to fetch all narinfos: %w", err) + } + + // 2. Prepare Manifest + // For now, we recreate a fresh manifest. + // TODO: Load existing manifest to preserve journal/version? + // But we are doing a full rebuild, so a fresh manifest might be appropriate. + manifest := nixcacheindex.NewManifest() + + scheme := "http" + if c.experimentalCacheIndexHTTPS { + scheme = "https" + } + + baseURL := fmt.Sprintf("%s://%s/nix-cache-index/", scheme, c.hostName) + manifest.Urls.JournalBase = baseURL + "journal/" + manifest.Urls.ShardsBase = baseURL + "shards/" + manifest.Urls.DeltasBase = baseURL + "deltas/" + + // 3. Process Hashes + type Item struct { + Big *big.Int + Str string + } + + items := make([]Item, 0, len(rows)) + for _, hash := range rows { + h, err := nixcacheindex.ParseHash(hash) + if err != nil { + zerolog.Ctx(ctx).Warn().Str("hash", hash).Err(err).Msg("skipping invalid hash") + + continue + } + + items = append(items, Item{Big: h, Str: hash}) + } + + // 4. Sort Hashes + sort.Slice(items, func(i, j int) bool { + return items[i].Big.Cmp(items[j].Big) < 0 + }) + + manifest.ItemCount = int64(len(items)) + + // 5. Build Shards + shards := make(map[string][]*big.Int) + depth := manifest.Sharding.Depth + + if depth == 0 { + list := make([]*big.Int, len(items)) + for i, item := range items { + list[i] = item.Big + } + + shards["root"] = list + } else { + for _, item := range items { + if len(item.Str) < depth { + continue + } + + prefix := item.Str[:depth] + shards[prefix] = append(shards[prefix], item.Big) + } + } + + // 6. Write Shards to Store + epoch := manifest.Epoch.Current + + for name, list := range shards { + var buf bytes.Buffer + + // Encode + params := manifest.Encoding + // IMPORTANT: nixcacheindex.WriteShard expects hashes to be sorted. They are. + + // Compress with zstd and write shard directly to the compressor. + enc, err := zstd.NewWriter(&buf) + if err != nil { + return fmt.Errorf("failed to create zstd writer: %w", err) + } + + writeErr := nixcacheindex.WriteShard(enc, list, params) + closeErr := enc.Close() + + if writeErr != nil { + return fmt.Errorf("failed to write shard %s: %w", name, writeErr) + } + + if closeErr != nil { + return fmt.Errorf("failed to close zstd writer for shard %s: %w", name, closeErr) + } + + path := fmt.Sprintf("/nix-cache-index/shards/%d/%s.idx.zst", epoch, name) + if _, err := c.fileStore.PutFile(ctx, path, &buf); err != nil { + return fmt.Errorf("failed to store shard file %s: %w", path, err) + } + } + + // 7. Write Manifest + var buf bytes.Buffer + if err := manifest.Write(&buf); err != nil { + return fmt.Errorf("failed to encode manifest: %w", err) + } + + if _, err := c.fileStore.PutFile(ctx, nixcacheindex.ManifestPath, &buf); err != nil { + return fmt.Errorf("failed to store manifest: %w", err) + } + + return nil +} diff --git a/pkg/cache/index_test.go b/pkg/cache/index_test.go new file mode 100644 index 00000000..d21b66c5 --- /dev/null +++ b/pkg/cache/index_test.go @@ -0,0 +1,104 @@ +package cache_test + +import ( + "encoding/json" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/database" + "github.com/kalbasit/ncps/pkg/nixcacheindex" + "github.com/kalbasit/ncps/pkg/storage/local" + "github.com/kalbasit/ncps/testdata" + "github.com/kalbasit/ncps/testhelper" +) + +func TestGenerateIndex(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "cache-index-test-") + require.NoError(t, err) + + defer os.RemoveAll(dir) + + dbFile := filepath.Join(dir, "db.sqlite") + testhelper.CreateMigrateDatabase(t, dbFile) + + db, err := database.Open("sqlite:"+dbFile, nil) + require.NoError(t, err) + + localStore, err := local.New(newContext(), dir) + require.NoError(t, err) + + c, err := newTestCache(newContext(), cacheName, db, localStore, localStore, localStore, localStore, "") + require.NoError(t, err) + + // 1. Insert some NarInfos + ctx := newContext() + + // Nar1 + err = c.PutNarInfo(ctx, testdata.Nar1.NarInfoHash, io.NopCloser(strings.NewReader(testdata.Nar1.NarInfoText))) + require.NoError(t, err) + + // Nar2 + err = c.PutNarInfo(ctx, testdata.Nar2.NarInfoHash, io.NopCloser(strings.NewReader(testdata.Nar2.NarInfoText))) + require.NoError(t, err) + + // 2. Trigger Generation + err = c.GenerateIndexForTest(ctx) + require.NoError(t, err) + + // 3. Verify Manifest exists + manifestPath := filepath.Join(dir, "store", "nix-cache-index", "manifest.json") + assert.FileExists(t, manifestPath) + + f, err := os.Open(manifestPath) + require.NoError(t, err) + + defer f.Close() + + var m nixcacheindex.Manifest + + err = json.NewDecoder(f).Decode(&m) + require.NoError(t, err) + + assert.Equal(t, int64(2), m.ItemCount) + assert.Equal(t, 1, m.Version) + + // Verify URLs + assert.Equal(t, "http://cache.example.com/nix-cache-index/journal/", m.Urls.JournalBase) + assert.Equal(t, "http://cache.example.com/nix-cache-index/shards/", m.Urls.ShardsBase) + assert.Equal(t, "http://cache.example.com/nix-cache-index/deltas/", m.Urls.DeltasBase) + + // 4. Verify Shards exist and are compressed + shardsDir := filepath.Join(dir, "store", "nix-cache-index", "shards", "1") + entries, err := os.ReadDir(shardsDir) + require.NoError(t, err) + assert.NotEmpty(t, entries) + + // Check for .zst extension + found := false + + for _, e := range entries { + if strings.HasSuffix(e.Name(), ".idx.zst") { + found = true + + break + } + } + + assert.True(t, found, "Expected to find .idx.zst shard file") + + // 5. Test Query + c.SetExperimentalCacheIndex(true) + + // Check that we can fetch the manifest via the cache interface + rc, err := c.Fetch(ctx, nixcacheindex.ManifestPath) + require.NoError(t, err) + rc.Close() +} diff --git a/pkg/database/mysqldb/querier.go b/pkg/database/mysqldb/querier.go index 8f5cf4ed..4de0d2b7 100644 --- a/pkg/database/mysqldb/querier.go +++ b/pkg/database/mysqldb/querier.go @@ -70,6 +70,11 @@ type Querier interface { // FROM narinfo_nar_files // ) DeleteOrphanedNarInfos(ctx context.Context) (int64, error) + //GetAllNarInfos + // + // SELECT hash + // FROM narinfos + GetAllNarInfos(ctx context.Context) ([]string, error) //GetConfigByID // // SELECT id, `key`, value, created_at, updated_at @@ -106,7 +111,7 @@ type Querier interface { // SELECT COALESCE(SUM(nf.file_size), 0) // FROM nar_files nf // WHERE nf.id IN ( - // SELECT nnf.nar_file_id + // SELECT DISTINCT nnf.nar_file_id // FROM narinfo_nar_files nnf // INNER JOIN narinfos ni2 ON nnf.narinfo_id = ni2.id // WHERE ni2.last_accessed_at < ni1.last_accessed_at diff --git a/pkg/database/mysqldb/query.mysql.sql.go b/pkg/database/mysqldb/query.mysql.sql.go index f5eabadc..04f12814 100644 --- a/pkg/database/mysqldb/query.mysql.sql.go +++ b/pkg/database/mysqldb/query.mysql.sql.go @@ -198,6 +198,38 @@ func (q *Queries) DeleteOrphanedNarInfos(ctx context.Context) (int64, error) { return result.RowsAffected() } +const getAllNarInfos = `-- name: GetAllNarInfos :many +SELECT hash +FROM narinfos +` + +// GetAllNarInfos +// +// SELECT hash +// FROM narinfos +func (q *Queries) GetAllNarInfos(ctx context.Context) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getAllNarInfos) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var hash string + if err := rows.Scan(&hash); err != nil { + return nil, err + } + items = append(items, hash) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getConfigByID = `-- name: GetConfigByID :one SELECT id, ` + "`" + `key` + "`" + `, value, created_at, updated_at FROM config @@ -308,7 +340,7 @@ WHERE ( SELECT COALESCE(SUM(nf.file_size), 0) FROM nar_files nf WHERE nf.id IN ( - SELECT nnf.nar_file_id + SELECT DISTINCT nnf.nar_file_id FROM narinfo_nar_files nnf INNER JOIN narinfos ni2 ON nnf.narinfo_id = ni2.id WHERE ni2.last_accessed_at < ni1.last_accessed_at @@ -328,7 +360,7 @@ WHERE ( // SELECT COALESCE(SUM(nf.file_size), 0) // FROM nar_files nf // WHERE nf.id IN ( -// SELECT nnf.nar_file_id +// SELECT DISTINCT nnf.nar_file_id // FROM narinfo_nar_files nnf // INNER JOIN narinfos ni2 ON nnf.narinfo_id = ni2.id // WHERE ni2.last_accessed_at < ni1.last_accessed_at diff --git a/pkg/database/postgresdb/querier.go b/pkg/database/postgresdb/querier.go index ae0762fc..a181ea8b 100644 --- a/pkg/database/postgresdb/querier.go +++ b/pkg/database/postgresdb/querier.go @@ -72,6 +72,11 @@ type Querier interface { // FROM narinfo_nar_files // ) DeleteOrphanedNarInfos(ctx context.Context) (int64, error) + //GetAllNarInfos + // + // SELECT hash + // FROM narinfos + GetAllNarInfos(ctx context.Context) ([]string, error) //GetConfigByID // // SELECT id, key, value, created_at, updated_at @@ -108,7 +113,7 @@ type Querier interface { // SELECT COALESCE(SUM(nf.file_size), 0) // FROM nar_files nf // WHERE nf.id IN ( - // SELECT nnf.nar_file_id + // SELECT DISTINCT nnf.nar_file_id // FROM narinfo_nar_files nnf // INNER JOIN narinfos ni2 ON nnf.narinfo_id = ni2.id // WHERE ni2.last_accessed_at < ni1.last_accessed_at diff --git a/pkg/database/postgresdb/query.postgres.sql.go b/pkg/database/postgresdb/query.postgres.sql.go index b2a6ee17..8af69a9e 100644 --- a/pkg/database/postgresdb/query.postgres.sql.go +++ b/pkg/database/postgresdb/query.postgres.sql.go @@ -233,6 +233,38 @@ func (q *Queries) DeleteOrphanedNarInfos(ctx context.Context) (int64, error) { return result.RowsAffected() } +const getAllNarInfos = `-- name: GetAllNarInfos :many +SELECT hash +FROM narinfos +` + +// GetAllNarInfos +// +// SELECT hash +// FROM narinfos +func (q *Queries) GetAllNarInfos(ctx context.Context) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getAllNarInfos) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var hash string + if err := rows.Scan(&hash); err != nil { + return nil, err + } + items = append(items, hash) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getConfigByID = `-- name: GetConfigByID :one SELECT id, key, value, created_at, updated_at FROM config @@ -343,7 +375,7 @@ WHERE ( SELECT COALESCE(SUM(nf.file_size), 0) FROM nar_files nf WHERE nf.id IN ( - SELECT nnf.nar_file_id + SELECT DISTINCT nnf.nar_file_id FROM narinfo_nar_files nnf INNER JOIN narinfos ni2 ON nnf.narinfo_id = ni2.id WHERE ni2.last_accessed_at < ni1.last_accessed_at @@ -363,7 +395,7 @@ WHERE ( // SELECT COALESCE(SUM(nf.file_size), 0) // FROM nar_files nf // WHERE nf.id IN ( -// SELECT nnf.nar_file_id +// SELECT DISTINCT nnf.nar_file_id // FROM narinfo_nar_files nnf // INNER JOIN narinfos ni2 ON nnf.narinfo_id = ni2.id // WHERE ni2.last_accessed_at < ni1.last_accessed_at diff --git a/pkg/database/querier.go b/pkg/database/querier.go index 2d691407..bf1e3d0f 100644 --- a/pkg/database/querier.go +++ b/pkg/database/querier.go @@ -20,6 +20,7 @@ type Querier interface { DeleteOrphanedNarInfos(ctx context.Context) (int64, error) GetConfigByID(ctx context.Context, id int64) (Config, error) GetConfigByKey(ctx context.Context, key string) (Config, error) + GetAllNarInfos(ctx context.Context) ([]string, error) GetLeastUsedNarFiles(ctx context.Context, fileSize uint64) ([]NarFile, error) GetLeastUsedNarInfos(ctx context.Context, fileSize uint64) ([]NarInfo, error) GetNarFileByHash(ctx context.Context, hash string) (NarFile, error) diff --git a/pkg/database/sqlitedb/querier.go b/pkg/database/sqlitedb/querier.go index 6f23bcd5..207f2d1f 100644 --- a/pkg/database/sqlitedb/querier.go +++ b/pkg/database/sqlitedb/querier.go @@ -72,6 +72,11 @@ type Querier interface { // FROM narinfo_nar_files // ) DeleteOrphanedNarInfos(ctx context.Context) (int64, error) + //GetAllNarInfos + // + // SELECT hash + // FROM narinfos + GetAllNarInfos(ctx context.Context) ([]string, error) //GetConfigByID // // SELECT id, "key", value, created_at, updated_at @@ -108,7 +113,7 @@ type Querier interface { // SELECT COALESCE(SUM(nf.file_size), 0) // FROM nar_files nf // WHERE nf.id IN ( - // SELECT nnf.nar_file_id + // SELECT DISTINCT nnf.nar_file_id // FROM narinfo_nar_files nnf // INNER JOIN narinfos ni2 ON nnf.narinfo_id = ni2.id // WHERE ni2.last_accessed_at < ni1.last_accessed_at diff --git a/pkg/database/sqlitedb/query.sqlite.sql.go b/pkg/database/sqlitedb/query.sqlite.sql.go index 2fa550eb..aed17832 100644 --- a/pkg/database/sqlitedb/query.sqlite.sql.go +++ b/pkg/database/sqlitedb/query.sqlite.sql.go @@ -233,6 +233,38 @@ func (q *Queries) DeleteOrphanedNarInfos(ctx context.Context) (int64, error) { return result.RowsAffected() } +const getAllNarInfos = `-- name: GetAllNarInfos :many +SELECT hash +FROM narinfos +` + +// GetAllNarInfos +// +// SELECT hash +// FROM narinfos +func (q *Queries) GetAllNarInfos(ctx context.Context) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getAllNarInfos) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var hash string + if err := rows.Scan(&hash); err != nil { + return nil, err + } + items = append(items, hash) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getConfigByID = `-- name: GetConfigByID :one SELECT id, "key", value, created_at, updated_at FROM config @@ -343,7 +375,7 @@ WHERE ( SELECT COALESCE(SUM(nf.file_size), 0) FROM nar_files nf WHERE nf.id IN ( - SELECT nnf.nar_file_id + SELECT DISTINCT nnf.nar_file_id FROM narinfo_nar_files nnf INNER JOIN narinfos ni2 ON nnf.narinfo_id = ni2.id WHERE ni2.last_accessed_at < ni1.last_accessed_at @@ -363,7 +395,7 @@ WHERE ( // SELECT COALESCE(SUM(nf.file_size), 0) // FROM nar_files nf // WHERE nf.id IN ( -// SELECT nnf.nar_file_id +// SELECT DISTINCT nnf.nar_file_id // FROM narinfo_nar_files nnf // INNER JOIN narinfos ni2 ON nnf.narinfo_id = ni2.id // WHERE ni2.last_accessed_at < ni1.last_accessed_at diff --git a/pkg/database/wrapper_mysql.go b/pkg/database/wrapper_mysql.go index 695f3cec..6ed1c1bf 100644 --- a/pkg/database/wrapper_mysql.go +++ b/pkg/database/wrapper_mysql.go @@ -155,6 +155,16 @@ func (w *mysqlWrapper) GetConfigByKey(ctx context.Context, key string) (Config, return Config(res), nil } +func (w *mysqlWrapper) GetAllNarInfos(ctx context.Context) ([]string, error) { + res, err := w.adapter.GetAllNarInfos(ctx) + if err != nil { + return nil, err + } + + // Return Slice of Primitives (direct match) + return res, nil +} + func (w *mysqlWrapper) GetLeastUsedNarFiles(ctx context.Context, fileSize uint64) ([]NarFile, error) { res, err := w.adapter.GetLeastUsedNarFiles(ctx, fileSize) if err != nil { diff --git a/pkg/database/wrapper_postgres.go b/pkg/database/wrapper_postgres.go index 0f40ffd7..b1d69589 100644 --- a/pkg/database/wrapper_postgres.go +++ b/pkg/database/wrapper_postgres.go @@ -140,6 +140,16 @@ func (w *postgresWrapper) GetConfigByKey(ctx context.Context, key string) (Confi return Config(res), nil } +func (w *postgresWrapper) GetAllNarInfos(ctx context.Context) ([]string, error) { + res, err := w.adapter.GetAllNarInfos(ctx) + if err != nil { + return nil, err + } + + // Return Slice of Primitives (direct match) + return res, nil +} + func (w *postgresWrapper) GetLeastUsedNarFiles(ctx context.Context, fileSize uint64) ([]NarFile, error) { res, err := w.adapter.GetLeastUsedNarFiles(ctx, fileSize) if err != nil { diff --git a/pkg/database/wrapper_sqlite.go b/pkg/database/wrapper_sqlite.go index 9f4bb185..9653ecf5 100644 --- a/pkg/database/wrapper_sqlite.go +++ b/pkg/database/wrapper_sqlite.go @@ -140,6 +140,16 @@ func (w *sqliteWrapper) GetConfigByKey(ctx context.Context, key string) (Config, return Config(res), nil } +func (w *sqliteWrapper) GetAllNarInfos(ctx context.Context) ([]string, error) { + res, err := w.adapter.GetAllNarInfos(ctx) + if err != nil { + return nil, err + } + + // Return Slice of Primitives (direct match) + return res, nil +} + func (w *sqliteWrapper) GetLeastUsedNarFiles(ctx context.Context, fileSize uint64) ([]NarFile, error) { res, err := w.adapter.GetLeastUsedNarFiles(ctx, fileSize) if err != nil { diff --git a/pkg/ncps/serve.go b/pkg/ncps/serve.go index 65485e3c..8c56b371 100644 --- a/pkg/ncps/serve.go +++ b/pkg/ncps/serve.go @@ -255,6 +255,11 @@ func serveCommand( Usage: "Enable the use of the experimental binary cache index", Sources: flagSources("experimental.cache-index", "EXPERIMENTAL_CACHE_INDEX"), }, + &cli.BoolFlag{ + Name: "experimental-cache-index-https", + Usage: "Use HTTPS for the experimental binary cache index URLs", + Sources: flagSources("experimental.cache-index-https", "EXPERIMENTAL_CACHE_INDEX_HTTPS"), + }, // Redis Configuration (optional - for distributed locking in HA deployments) &cli.StringSliceFlag{ @@ -723,7 +728,7 @@ func getUpstreamCaches(ctx context.Context, cmd *cli.Command, netrcData *netrc.N func getStorageBackend( ctx context.Context, cmd *cli.Command, -) (storage.ConfigStore, storage.NarInfoStore, storage.NarStore, error) { +) (storage.ConfigStore, storage.NarInfoStore, storage.NarStore, storage.FileStore, error) { // Handle backward compatibility for cache-data-path (deprecated) deprecatedDataPath := cmd.String("cache-data-path") localDataPath := cmd.String("cache-storage-local") @@ -745,7 +750,7 @@ func getStorageBackend( switch { case localDataPath != "" && s3Bucket != "": - return nil, nil, nil, ErrStorageConflict + return nil, nil, nil, nil, ErrStorageConflict case localDataPath != "": return createLocalStorage(ctx, localDataPath) @@ -754,7 +759,7 @@ func getStorageBackend( return createS3Storage(ctx, cmd) default: - return nil, nil, nil, ErrStorageConfigRequired + return nil, nil, nil, nil, ErrStorageConfigRequired } } @@ -762,22 +767,22 @@ func getStorageBackend( func createLocalStorage( ctx context.Context, dataPath string, -) (storage.ConfigStore, storage.NarInfoStore, storage.NarStore, error) { +) (storage.ConfigStore, storage.NarInfoStore, storage.NarStore, storage.FileStore, error) { localStore, err := localstorage.New(ctx, dataPath) if err != nil { - return nil, nil, nil, fmt.Errorf("error creating a new local store at %q: %w", dataPath, err) + return nil, nil, nil, nil, fmt.Errorf("error creating a new local store at %q: %w", dataPath, err) } zerolog.Ctx(ctx).Info().Str("path", dataPath).Msg("using local storage") - return localStore, localStore, localStore, nil + return localStore, localStore, localStore, localStore, nil } //nolint:staticcheck // deprecated: migration support func createS3Storage( ctx context.Context, cmd *cli.Command, -) (storage.ConfigStore, storage.NarInfoStore, storage.NarStore, error) { +) (storage.ConfigStore, storage.NarInfoStore, storage.NarStore, storage.FileStore, error) { s3Bucket := cmd.String("cache-storage-s3-bucket") s3Endpoint := cmd.String("cache-storage-s3-endpoint") s3AccessKeyID := cmd.String("cache-storage-s3-access-key-id") @@ -785,7 +790,7 @@ func createS3Storage( s3ForcePathStyle := cmd.Bool("cache-storage-s3-force-path-style") if s3Endpoint == "" || s3AccessKeyID == "" || s3SecretAccessKey == "" { - return nil, nil, nil, ErrS3ConfigIncomplete + return nil, nil, nil, nil, ErrS3ConfigIncomplete } ctx = zerolog.Ctx(ctx). @@ -809,12 +814,12 @@ func createS3Storage( s3Store, err := s3.New(ctx, s3Cfg) if err != nil { - return nil, nil, nil, fmt.Errorf("error creating a new S3 store: %w", err) + return nil, nil, nil, nil, fmt.Errorf("error creating a new S3 store: %w", err) } zerolog.Ctx(ctx).Info().Msg("using S3 storage") - return s3Store, s3Store, s3Store, nil + return s3Store, s3Store, s3Store, s3Store, nil } func createDatabaseQuerier(cmd *cli.Command) (database.Querier, error) { @@ -849,7 +854,7 @@ func createCache( rwLocker lock.RWLocker, ucs []*upstream.Cache, ) (*cache.Cache, error) { - configStore, narInfoStore, narStore, err := getStorageBackend(ctx, cmd) + configStore, narInfoStore, narStore, fileStore, err := getStorageBackend(ctx, cmd) if err != nil { return nil, err } @@ -861,6 +866,7 @@ func createCache( configStore, narInfoStore, narStore, + fileStore, cmd.String("cache-secret-key-path"), locker, rwLocker, @@ -873,6 +879,8 @@ func createCache( c.SetTempDir(cmd.String("cache-temp-path")) c.SetCacheSignNarinfo(cmd.Bool("cache-sign-narinfo")) + c.SetExperimentalCacheIndex(cmd.Bool("experimental-cache-index")) + c.SetExperimentalCacheIndexHTTPS(cmd.Bool("experimental-cache-index-https")) c.AddUpstreamCaches(ctx, ucs...) // Trigger the health-checker to speed-up the boot but do not wait for the check to complete. diff --git a/pkg/server/server.go b/pkg/server/server.go index 343a9bfc..ab5185e9 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -34,6 +34,7 @@ const ( routeNarCompression = "/nar/{hash:[a-z0-9]+}.nar.{compression:*}" routeNarInfo = "/{hash:[a-z0-9]+}.narinfo" routeCacheInfo = "/nix-cache-info" + routeCacheIndex = "/nix-cache-index/*" routeCachePublicKey = "/pubkey" contentLength = "Content-Length" @@ -146,6 +147,8 @@ func (s *Server) createRouter() { s.router.Put(routeNar, s.putNar) s.router.Delete(routeNar, s.deleteNar) + s.router.Get(routeCacheIndex, s.getCacheIndex) + // Add Prometheus metrics endpoint if gatherer is configured if prometheusGatherer != nil { s.router.Get("/metrics", promhttp.HandlerFor(prometheusGatherer, promhttp.HandlerOpts{}).ServeHTTP) @@ -610,3 +613,50 @@ func (s *Server) deleteNar(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) }).ServeHTTP(w, r) } + +func (s *Server) getCacheIndex(w http.ResponseWriter, r *http.Request) { + path := "/nix-cache-index/" + chi.URLParam(r, "*") + + ctx, span := tracer.Start( + r.Context(), + "server.getCacheIndex", + trace.WithSpanKind(trace.SpanKindServer), + trace.WithAttributes( + attribute.String("path", path), + ), + ) + defer span.End() + + r = r.WithContext( + zerolog.Ctx(ctx). + With(). + Str("path", path). + Logger(). + WithContext(ctx)) + + rc, err := s.cache.Fetch(r.Context(), path) + if err != nil { + if errors.Is(err, storage.ErrNotFound) { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + + return + } + + zerolog.Ctx(r.Context()). + Error(). + Err(err). + Msg("error fetching cache index file") + + http.Error(w, err.Error(), http.StatusInternalServerError) + + return + } + defer rc.Close() + + if _, err := io.Copy(w, rc); err != nil { + zerolog.Ctx(r.Context()). + Error(). + Err(err). + Msg("error writing cache index file to response") + } +} diff --git a/pkg/server/server_internal_test.go b/pkg/server/server_internal_test.go index 7a4124be..6a1a58d2 100644 --- a/pkg/server/server_internal_test.go +++ b/pkg/server/server_internal_test.go @@ -41,7 +41,7 @@ func TestSetDeletePermitted(t *testing.T) { downloadLocker := locklocal.NewLocker() cacheLocker := locklocal.NewRWLocker() - c, err := cache.New(newContext(), cacheName, db, localStore, localStore, localStore, "", + c, err := cache.New(newContext(), cacheName, db, localStore, localStore, localStore, localStore, "", downloadLocker, cacheLocker, 5*time.Minute, 30*time.Minute) require.NoError(t, err) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index f72dfd5f..98420b28 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -41,11 +41,12 @@ func newTestCache( configStore storage.ConfigStore, narInfoStore storage.NarInfoStore, narStore storage.NarStore, + fileStore storage.FileStore, ) (*cache.Cache, error) { downloadLocker := locklocal.NewLocker() cacheLocker := locklocal.NewRWLocker() - return cache.New(ctx, cacheName, db, configStore, narInfoStore, narStore, "", + return cache.New(ctx, cacheName, db, configStore, narInfoStore, narStore, fileStore, "", downloadLocker, cacheLocker, 5*time.Minute, 30*time.Minute) } @@ -74,7 +75,7 @@ func TestServeHTTP(t *testing.T) { localStore, err := local.New(newContext(), dir) require.NoError(t, err) - c, err := newTestCache(newContext(), db, localStore, localStore, localStore) + c, err := newTestCache(newContext(), db, localStore, localStore, localStore, localStore) require.NoError(t, err) c.AddUpstreamCaches(newContext(), uc) @@ -121,7 +122,7 @@ func TestServeHTTP(t *testing.T) { localStore, err := local.New(newContext(), dir) require.NoError(t, err) - c, err := newTestCache(newContext(), db, localStore, localStore, localStore) + c, err := newTestCache(newContext(), db, localStore, localStore, localStore, localStore) require.NoError(t, err) c.AddUpstreamCaches(newContext(), uc) @@ -261,7 +262,7 @@ func TestServeHTTP(t *testing.T) { localStore, err := local.New(newContext(), dir) require.NoError(t, err) - c, err := newTestCache(newContext(), db, localStore, localStore, localStore) + c, err := newTestCache(newContext(), db, localStore, localStore, localStore, localStore) require.NoError(t, err) c.AddUpstreamCaches(newContext(), uc) @@ -378,7 +379,7 @@ func TestServeHTTP(t *testing.T) { localStore, err := local.New(newContext(), dir) require.NoError(t, err) - c, err := newTestCache(newContext(), db, localStore, localStore, localStore) + c, err := newTestCache(newContext(), db, localStore, localStore, localStore, localStore) require.NoError(t, err) c.AddUpstreamCaches(newContext(), uc) diff --git a/pkg/storage/local/local.go b/pkg/storage/local/local.go index 43e0267c..6d15af8c 100644 --- a/pkg/storage/local/local.go +++ b/pkg/storage/local/local.go @@ -8,6 +8,7 @@ import ( "io/fs" "os" "path/filepath" + "strings" "github.com/nix-community/go-nix/pkg/narinfo" "github.com/nix-community/go-nix/pkg/narinfo/signature" @@ -379,9 +380,161 @@ func (s *Store) DeleteNar(ctx context.Context, narURL nar.URL) error { return nil } -func (s *Store) configPath() string { return filepath.Join(s.path, "config") } -func (s *Store) secretKeyPath() string { return filepath.Join(s.configPath(), "cache.key") } -func (s *Store) storePath() string { return filepath.Join(s.path, "store") } +// HasFile returns true if the store has the file at the given path. +func (s *Store) HasFile(ctx context.Context, path string) bool { + filePath, err := s.sanitizePath(path) + if err != nil { + return false + } + + _, span := tracer.Start( + ctx, + "local.HasFile", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes( + attribute.String("path", path), + attribute.String("file_path", filePath), + ), + ) + defer span.End() + + _, err = os.Stat(filePath) + + return err == nil +} + +// GetFile returns the file from the store at the given path. +// NOTE: The caller must close the returned io.ReadCloser! +func (s *Store) GetFile(ctx context.Context, path string) (int64, io.ReadCloser, error) { + filePath, err := s.sanitizePath(path) + if err != nil { + return 0, nil, err + } + + _, span := tracer.Start( + ctx, + "local.GetFile", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes( + attribute.String("path", path), + attribute.String("file_path", filePath), + ), + ) + defer span.End() + + info, err := os.Stat(filePath) + if err != nil { + if os.IsNotExist(err) { + return 0, nil, storage.ErrNotFound + } + + return 0, nil, fmt.Errorf("error stating the file %q: %w", filePath, err) + } + + f, err := os.Open(filePath) + if err != nil { + return 0, nil, fmt.Errorf("error opening the file %q: %w", filePath, err) + } + + return info.Size(), f, nil +} + +// PutFile puts the file in the store at the given path. +func (s *Store) PutFile(ctx context.Context, path string, body io.Reader) (int64, error) { + filePath, err := s.sanitizePath(path) + if err != nil { + return 0, err + } + + _, span := tracer.Start( + ctx, + "local.PutFile", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes( + attribute.String("path", path), + attribute.String("file_path", filePath), + ), + ) + defer span.End() + + if err := os.MkdirAll(filepath.Dir(filePath), dirMode); err != nil { + return 0, fmt.Errorf("error creating the directories for %q: %w", filePath, err) + } + + f, err := os.CreateTemp(s.storeTMPPath(), filepath.Base(path)+"-*") + if err != nil { + return 0, fmt.Errorf("error creating the temporary file: %w", err) + } + + written, err := io.Copy(f, body) + if err != nil { + f.Close() + os.Remove(f.Name()) + + return 0, fmt.Errorf("error writing to the temporary file: %w", err) + } + + if err := f.Close(); err != nil { + return 0, fmt.Errorf("error closing the temporary file: %w", err) + } + + if err := os.Rename(f.Name(), filePath); err != nil { + return 0, fmt.Errorf("error moving the file to %q: %w", filePath, err) + } + + if err := os.Chmod(filePath, fileMode); err != nil { + return 0, fmt.Errorf("error changing mode of %q: %w", filePath, err) + } + + return written, nil +} + +// DeleteFile deletes the file from the store at the given path. +func (s *Store) DeleteFile(ctx context.Context, path string) error { + filePath, err := s.sanitizePath(path) + if err != nil { + return err + } + + _, span := tracer.Start( + ctx, + "local.DeleteFile", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes( + attribute.String("path", path), + attribute.String("file_path", filePath), + ), + ) + defer span.End() + + if err := os.Remove(filePath); err != nil { + if os.IsNotExist(err) { + return storage.ErrNotFound + } + + return fmt.Errorf("error deleting file %q: %w", filePath, err) + } + + return nil +} + +func (s *Store) configPath() string { return filepath.Join(s.path, "config") } +func (s *Store) secretKeyPath() string { return filepath.Join(s.configPath(), "cache.key") } +func (s *Store) storePath() string { return filepath.Join(s.path, "store") } + +func (s *Store) sanitizePath(path string) (string, error) { + // Sanitize path to prevent traversal. + relativePath := strings.TrimPrefix(path, "/") + filePath := filepath.Join(s.storePath(), relativePath) + + // Final check to ensure the path is within the store directory. + if !strings.HasPrefix(filePath, s.storePath()) { + return "", storage.ErrNotFound + } + + return filePath, nil +} + func (s *Store) storeNarInfoPath() string { return filepath.Join(s.storePath(), "narinfo") } func (s *Store) storeNarPath() string { return filepath.Join(s.storePath(), "nar") } func (s *Store) storeTMPPath() string { return filepath.Join(s.storePath(), "tmp") } diff --git a/pkg/storage/local/local_test.go b/pkg/storage/local/local_test.go index 39a2b42d..74123d7d 100644 --- a/pkg/storage/local/local_test.go +++ b/pkg/storage/local/local_test.go @@ -797,3 +797,45 @@ func newContext() context.Context { New(io.Discard). WithContext(context.Background()) } + +func TestPathTraversal(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "cache-path-") + require.NoError(t, err) + + defer os.RemoveAll(dir) + + ctx := newContext() + s, err := local.New(ctx, dir) + require.NoError(t, err) + + maliciousPath := "../../../etc/passwd" + + t.Run("HasFile", func(t *testing.T) { + t.Parallel() + + assert.False(t, s.HasFile(ctx, maliciousPath)) + }) + + t.Run("GetFile", func(t *testing.T) { + t.Parallel() + + _, _, err := s.GetFile(ctx, maliciousPath) + assert.ErrorIs(t, err, storage.ErrNotFound) + }) + + t.Run("PutFile", func(t *testing.T) { + t.Parallel() + + _, err := s.PutFile(ctx, maliciousPath, strings.NewReader("malicious content")) + assert.ErrorIs(t, err, storage.ErrNotFound) + }) + + t.Run("DeleteFile", func(t *testing.T) { + t.Parallel() + + err := s.DeleteFile(ctx, maliciousPath) + assert.ErrorIs(t, err, storage.ErrNotFound) + }) +} diff --git a/pkg/storage/s3/s3.go b/pkg/storage/s3/s3.go index 3f9941b0..16171ddb 100644 --- a/pkg/storage/s3/s3.go +++ b/pkg/storage/s3/s3.go @@ -501,6 +501,128 @@ func (s *Store) DeleteNar(ctx context.Context, narURL nar.URL) error { return nil } +// HasFile returns true if the store has the file at the given path. +func (s *Store) HasFile(ctx context.Context, path string) bool { + key := "store/" + strings.TrimPrefix(path, "/") + + _, span := tracer.Start( + ctx, + "s3.HasFile", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes( + attribute.String("path", path), + attribute.String("key", key), + ), + ) + defer span.End() + + _, err := s.client.StatObject(ctx, s.bucket, key, minio.StatObjectOptions{}) + + return err == nil +} + +// GetFile returns the file from the store at the given path. +// NOTE: The caller must close the returned io.ReadCloser! +func (s *Store) GetFile(ctx context.Context, path string) (int64, io.ReadCloser, error) { + key := "store/" + strings.TrimPrefix(path, "/") + + _, span := tracer.Start( + ctx, + "s3.GetFile", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes( + attribute.String("path", path), + attribute.String("key", key), + ), + ) + defer span.End() + + obj, err := s.client.GetObject(ctx, s.bucket, key, minio.GetObjectOptions{}) + if err != nil { + return 0, nil, fmt.Errorf("error getting file from S3: %w", err) + } + + // Get object info for size + info, err := obj.Stat() + if err != nil { + obj.Close() + + errResp := minio.ToErrorResponse(err) + if errResp.Code == s3NoSuchKey { + return 0, nil, storage.ErrNotFound + } + + return 0, nil, fmt.Errorf("error getting file info from S3: %w", err) + } + + return info.Size, obj, nil +} + +// PutFile puts the file in the store at the given path. +func (s *Store) PutFile(ctx context.Context, path string, body io.Reader) (int64, error) { + key := "store/" + strings.TrimPrefix(path, "/") + + _, span := tracer.Start( + ctx, + "s3.PutFile", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes( + attribute.String("path", path), + attribute.String("key", key), + ), + ) + defer span.End() + + // Put the file + info, err := s.client.PutObject( + ctx, + s.bucket, + key, + body, + -1, // unknown size + minio.PutObjectOptions{}, + ) + if err != nil { + return 0, fmt.Errorf("error putting file to S3: %w", err) + } + + return info.Size, nil +} + +// DeleteFile deletes the file from the store at the given path. +func (s *Store) DeleteFile(ctx context.Context, path string) error { + key := "store/" + strings.TrimPrefix(path, "/") + + _, span := tracer.Start( + ctx, + "s3.DeleteFile", + trace.WithSpanKind(trace.SpanKindInternal), + trace.WithAttributes( + attribute.String("path", path), + attribute.String("key", key), + ), + ) + defer span.End() + + // Check if key exists + _, err := s.client.StatObject(ctx, s.bucket, key, minio.StatObjectOptions{}) + if err != nil { + errResp := minio.ToErrorResponse(err) + if errResp.Code == s3NoSuchKey { + return storage.ErrNotFound + } + + return fmt.Errorf("error checking if file exists: %w", err) + } + + err = s.client.RemoveObject(ctx, s.bucket, key, minio.RemoveObjectOptions{}) + if err != nil { + return fmt.Errorf("error deleting file from S3: %w", err) + } + + return nil +} + // Helper methods for key generation. func (s *Store) secretKeyPath() string { return "config/cache.key" diff --git a/pkg/storage/store.go b/pkg/storage/store.go index b5a228ec..8d5a01d7 100644 --- a/pkg/storage/store.go +++ b/pkg/storage/store.go @@ -72,3 +72,19 @@ type NarStore interface { // DeleteNar deletes the nar from the store. DeleteNar(ctx context.Context, narURL nar.URL) error } + +// FileStore represents a store capable of storing arbitrary files. +type FileStore interface { + // PutFile puts the file in the store at the given path. + PutFile(ctx context.Context, path string, body io.Reader) (int64, error) + + // GetFile returns the file from the store at the given path. + // NOTE: The caller must close the returned io.ReadCloser! + GetFile(ctx context.Context, path string) (int64, io.ReadCloser, error) + + // HasFile returns true if the store has the file at the given path. + HasFile(ctx context.Context, path string) bool + + // DeleteFile deletes the file from the store at the given path. + DeleteFile(ctx context.Context, path string) error +}