diff --git a/go.mod b/go.mod index 43a35447..2d572238 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25.5 require ( github.com/XSAM/otelsql v0.41.0 + github.com/cespare/xxhash/v2 v2.3.0 github.com/go-chi/chi/v5 v5.2.4 github.com/go-redsync/redsync/v4 v4.15.0 github.com/go-sql-driver/mysql v1.9.3 @@ -49,7 +50,6 @@ require ( github.com/BurntSushi/toml v1.6.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -77,6 +77,7 @@ require ( github.com/prometheus/otlptranslator v1.0.0 // indirect github.com/prometheus/procfs v0.19.2 // indirect github.com/rs/xid v1.6.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tinylib/msgp v1.6.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect diff --git a/go.sum b/go.sum index 9f9dd94f..27fb2db4 100644 --- a/go.sum +++ b/go.sum @@ -139,6 +139,8 @@ github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -212,12 +214,8 @@ golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= diff --git a/nix/packages/ncps/default.nix b/nix/packages/ncps/default.nix index 9ff2a894..cd2f021e 100644 --- a/nix/packages/ncps/default.nix +++ b/nix/packages/ncps/default.nix @@ -15,7 +15,7 @@ version = if tag != "" then tag else rev; - vendorHash = "sha256-nnt4HIG4Fs7RhHjVb7mYJ39UgvFKc46Cu42cURMmr1s="; + vendorHash = "sha256-VxIMr0QgsvkZpTe+fvGNF+cT3xfAa0m22q31z+Rf+Ds="; in pkgs.buildGoModule { name = "ncps-${shortRev}"; diff --git a/pkg/nixcacheindex/base32.go b/pkg/nixcacheindex/base32.go new file mode 100644 index 00000000..50bf3312 --- /dev/null +++ b/pkg/nixcacheindex/base32.go @@ -0,0 +1,85 @@ +package nixcacheindex + +import ( + "fmt" + "math/big" + "strings" +) + +const ( + // Alphabet is the Nix Base32 alphabet. + // Note: 'e', 'o', 'u', 't' are excluded to avoid offensive words. + Alphabet = "0123456789abcdfghijklmnpqrsvwxyz" + + // HashLength is the length of a Nix store path hash in base32 characters. + HashLength = 32 + + // HashBits is the number of bits in a full store path hash (160). + HashBits = 160 +) + +//nolint:gochecknoglobals +var alphabetMap map[rune]int64 + +//nolint:gochecknoinits +func init() { + alphabetMap = make(map[rune]int64) + for i, c := range Alphabet { + alphabetMap[c] = int64(i) + } +} + +// ParseHash parses a 32-character Nix base32 string into a big.Int. +// The string is interpreted as a Big-Endian 160-bit unsigned integer. +// This means the first character is the most significant. +func ParseHash(s string) (*big.Int, error) { + if len(s) != HashLength { + return nil, fmt.Errorf("%w: expected %d, got %d", ErrInvalidHashLength, HashLength, len(s)) + } + + result := new(big.Int) + + for _, char := range s { + val, ok := alphabetMap[char] + if !ok { + return nil, fmt.Errorf("%w: %q", ErrInvalidHashChar, char) + } + + // result = (result << 5) | val + result.Lsh(result, 5) + result.Or(result, big.NewInt(val)) + } + + return result, nil +} + +// FormatHash formats a big.Int into a 32-character Nix base32 string. +// The integer is treated as Big-Endian. +func FormatHash(i *big.Int) string { + if i == nil { + return strings.Repeat("0", HashLength) + } + + // Work with a copy since we'll act on it + n := new(big.Int).Set(i) + + // Create a buffer for 32 characters + chars := make([]byte, HashLength) + + // Extract 5 bits at a time from right to left (least significant first) + // But we fill the string from right to left too, so it matches Big-Endian + // i.e. last 5 bits of integer -> last char of string + + mask := big.NewInt(0x1f) // 5 ones + + for idx := HashLength - 1; idx >= 0; idx-- { + // val = n & 0x1f + val := new(big.Int).And(n, mask) + chars[idx] = Alphabet[val.Int64()] + + // n = n >> 5 + n.Rsh(n, 5) + } + + return string(chars) +} diff --git a/pkg/nixcacheindex/base32_test.go b/pkg/nixcacheindex/base32_test.go new file mode 100644 index 00000000..9dd06b72 --- /dev/null +++ b/pkg/nixcacheindex/base32_test.go @@ -0,0 +1,101 @@ +package nixcacheindex_test + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/nixcacheindex" +) + +func TestParseHash(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want *big.Int + wantErr bool + }{ + { + name: "Zero hash", + input: "00000000000000000000000000000000", + want: big.NewInt(0), + }, + { + name: "One (last bit set)", + input: "00000000000000000000000000000001", + want: big.NewInt(1), + }, + { + // RFC Example: "100...000" maps to 2^155 + // First char '1' (value 1) is most significant 5 bits + name: "2^155 (first bit set)", + input: "10000000000000000000000000000000", + want: new(big.Int).Exp(big.NewInt(2), big.NewInt(155), nil), + }, + { + // RFC Example: "010...000" maps to 2^150 + // Second char '1' (value 1) shifted left by (32-2)*5 = 30*5 = 150 + name: "2^150 (second char set)", + input: "01000000000000000000000000000000", + want: new(big.Int).Exp(big.NewInt(2), big.NewInt(150), nil), + }, + { + // RFC Example had a typo saying g=16, but g is 15 in the 0-indexed alphabet. + // 0-9 (10), a-d (4), f (1), g (1) -> 10+4+1 = 15. + name: "Max single char (g)", + input: "g0000000000000000000000000000000", + want: new(big.Int).Mul(big.NewInt(15), new(big.Int).Exp(big.NewInt(2), big.NewInt(155), nil)), + }, + { + // Max value: all z's + name: "Max value", + input: "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz", + want: new(big.Int).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(160), nil), big.NewInt(1)), + }, + { + name: "Invalid length (short)", + input: "000", + wantErr: true, + }, + { + name: "Invalid length (long)", + input: "000000000000000000000000000000000", + wantErr: true, + }, + { + name: "Invalid character", + input: "0000000000000000000000000000000e", // 'e' is not in alphabet + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := nixcacheindex.ParseHash(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, 0, tt.want.Cmp(got), "expected %s, got %s", tt.want, got) + + // Verify round trip + formatted := nixcacheindex.FormatHash(got) + assert.Equal(t, tt.input, formatted, "FormatHash mismatch") + } + }) + } +} + +func TestFormatHash(t *testing.T) { + t.Parallel() + + // Focus on round-trip property for random big ints within range + // But since we covered round trip in TestParseHash, we just add explicit nil check + assert.Equal(t, "00000000000000000000000000000000", nixcacheindex.FormatHash(nil)) +} diff --git a/pkg/nixcacheindex/client.go b/pkg/nixcacheindex/client.go new file mode 100644 index 00000000..4454bf91 --- /dev/null +++ b/pkg/nixcacheindex/client.go @@ -0,0 +1,311 @@ +package nixcacheindex + +import ( + "bytes" + "context" + "fmt" + "io" + "strings" + "sync" + + "github.com/klauspost/compress/zstd" + "github.com/rs/zerolog" +) + +// Result of a query. +type Result int + +const ( + DefiniteMiss Result = iota + DefiniteHit + ProbableHit +) + +func (r Result) String() string { + switch r { + case DefiniteMiss: + return "DEFINITE_MISS" + case DefiniteHit: + return "DEFINITE_HIT" + case ProbableHit: + return "PROBABLE_HIT" + default: + return fmt.Sprintf("Result(%d)", int(r)) + } +} + +// Fetcher abstraction for retrieving files (e.g., HTTP, local file system). +type Fetcher interface { + Fetch(path string) (io.ReadCloser, error) +} + +// Client for querying the binary cache index. +type Client struct { + fetcher Fetcher + manifest *Manifest + + shardCacheMu sync.Mutex + shardCache map[string]*ShardReader +} + +// NewClient creates a new client. +func NewClient(fetcher Fetcher) *Client { + return &Client{ + fetcher: fetcher, + shardCache: make(map[string]*ShardReader), + } +} + +// LoadManifest fetches and parses the manifest. +func (c *Client) LoadManifest() error { + r, err := c.fetcher.Fetch(ManifestPath) + if err != nil { + return err + } + defer r.Close() + + m, err := LoadManifest(r) + if err != nil { + return err + } + + c.manifest = m + + return nil +} + +// Query checks if the cache contains the given hash. +// hashStr is the 32-character base32 hash. +func (c *Client) Query(ctx context.Context, hashStr string) (Result, error) { + if c.manifest == nil { + if err := c.LoadManifest(); err != nil { + return DefiniteMiss, err + } + } + + // 1. Check Journal (Layer 1) + // Iterate through active segments. + // RFC: "Check journal for recent mutations" + // "segments_to_compact = get_segments_older_than..." + // Need to know WHICH segments to check. + // Manifest says `journal.current_segment`. + // RFC says: "Writer appends to current... Segments older than retention...". + // So we should check `current_segment` and previous segments? + // RFC Section 7 Step 2: "FOR segment IN manifest.journal.segments". + // Manifest JSON example doesn't have "segments" list. It has `current_segment`, `segment_duration`, `retention`. + // We infer the segments: [current, current-duration, current-2*duration ... up to retention]. + + current := c.manifest.Journal.CurrentSegment + duration := int64(c.manifest.Journal.SegmentDurationSeconds) + count := c.manifest.Journal.RetentionCount + + // We check newest to oldest? Order doesn't matter for "set" semantics, but strictly speaking + // if something was Added then Deleted, order matters. + // Journal text format preserves order. + // Across files? + // "Segments older... are archived". + // If we have strict linearization, we should check segments in chronological order? + // And within a segment, order matters. + // "Lines beginning with - indicate deletions". + // "On artifact push: Append +hash". + // "On GC: Append -hash". + // If we see -hash (deleted), we return DEFINITE_MISS. + // If we see +hash (added), we return PROBABLE_HIT. + // What if -hash comes AFTER +hash? (deleted). + // What if +hash comes AFTER -hash? (re-added? unlikely for immutable store path, but possible if GC'd then re-pushed). + // So we need to process all journal entries in Chronological order to find the FINAL state? + // Or just look for ANY usage? + // RFC Query Algo: + // "IF "-" + target IN journal: RETURN MISS" + // "IF "+" + target IN journal: RETURN PROBABLE_HIT" + // This implies checking all journals together? Or order? + // If I see a deletion 5 mins ago, and addition 1 min ago. It's present. + // If checking in Reverse Chronological (Newest first): + // If I see +hash: It's present (ignoring older deletion). Return PROBABLE_HIT. + // If I see -hash: It's deleted (ignoring older addition). Return DEFINITE_MISS. + // So Reverse Chronological check is correct and efficient. + + // Generate segment timestamps (start times) + // Current is start time of current segment. + + segments := make([]int64, 0, count+1) + for i := 0; i < count+1; i++ { // +1 for current? "retention_count ... segments retained before archival". + // Usually retention doesn't include current? + // RFC Example 1: `retention_count: 24`, `current`: 1705147200. + // Files: `1705147200.log` (current). Previous ones... + // We will check `current` and `count` previous ones. + t := current - int64(i)*duration + segments = append(segments, t) + } + + // Check segments (Newest first) + targetHash := hashStr // We look for strict string match? RFC says yes "line[1:]". + + for _, segTime := range segments { + path := fmt.Sprintf("%s%d.log", c.manifest.Urls.JournalBase, segTime) + + // Fetch journal + // Note: fetcher should handle caching or 404s (empty journal?) + // Fetcher.Fetch returns error if not found? + // If 404, we assume empty/ignore? + // RFC says "Fetch journal...". + // In reality, some segments might not exist if no writes happened? + // Or if rotation logic is strict. + + rc, err := c.fetcher.Fetch(path) + if err != nil { + // If journal segment is missing, we assume no mutations in that window. + zerolog.Ctx(ctx).Debug().Err(err).Str("path", path). + Msg("journal segment missing, assuming no mutations in this window") + + continue + } + defer rc.Close() + + entries, err := ParseJournal(rc) + if err != nil { + // Bad journal. Ignore? + return DefiniteMiss, fmt.Errorf("failed to parse journal %s: %w", path, err) + } + + // Search entries in Reverse Order (Newest lines are at bottom) + for i := len(entries) - 1; i >= 0; i-- { + e := entries[i] + if e.Hash == targetHash { + if e.Op == OpDelete { + return DefiniteMiss, nil + } + + if e.Op == OpAdd { + return ProbableHit, nil + } + } + } + } + + // 2. Check Shards (Layer 2) + // Determine Shard Path + // prefix = target_hash[0:depth] + depth := c.manifest.Sharding.Depth + + var prefix string + if depth > 0 && len(hashStr) >= depth { + prefix = hashStr[:depth] + } + + epoch := c.manifest.Epoch.Current + + // Path: /nix-cache-index/shards//.idx + // If depth=0, prefix is usually "root" or empty? + // RFC Example 1: `depth: 0`. Path: `shards/3/root.idx`. + // RFC 5: "For sharding.depth = 0 ... shards/42/root.idx". + // For depth > 0: `shards/42/b6.idx`. + + var shardName string + if depth == 0 { + shardName = "root" + } else { + shardName = prefix + } + + shardPath := fmt.Sprintf("%s%d/%s.idx.zst", c.manifest.Urls.ShardsBase, epoch, shardName) + + c.shardCacheMu.Lock() + shardReader, ok := c.shardCache[shardPath] + c.shardCacheMu.Unlock() + + if ok { + return c.queryShard(shardReader, hashStr) + } + + rc, err := c.fetcher.Fetch(shardPath) + if err == nil { + defer rc.Close() + + return c.processShardResponse(shardPath, rc, hashStr) + } + + // Shard missing? + // RFC 9.2: "If shard fetch returns 404 AND epoch.previous exists... retry previous". + if c.manifest.Epoch.Previous == 0 { + return DefiniteMiss, nil // Missing shard -> Miss + } + + prevEpoch := c.manifest.Epoch.Previous + shardPath = fmt.Sprintf("%s%d/%s.idx.zst", c.manifest.Urls.ShardsBase, prevEpoch, shardName) + + c.shardCacheMu.Lock() + shardReader, ok = c.shardCache[shardPath] + c.shardCacheMu.Unlock() + + if ok { + return c.queryShard(shardReader, hashStr) + } + + rc, err = c.fetcher.Fetch(shardPath) + if err != nil { + return DefiniteMiss, err // Both missing -> Miss (or error) + } + defer rc.Close() + + return c.processShardResponse(shardPath, rc, hashStr) +} + +func (c *Client) processShardResponse(shardPath string, rc io.ReadCloser, hashStr string) (Result, error) { + // We assume Fetcher returns a Seekable stream needed for ReadShard? + // fetcher.Fetch returns io.ReadCloser. + // ReadShard needs io.ReadSeeker. + // If fetcher is HTTP, it might not be seekable. + // We might need to ReadAll into memory. + // For shards (small/medium), this is fine (hundreds of KB). + // For large shards (1MB+), memory is still fine. + + // Buffer it. + // Note: This is an optimization point (Range requests). + // For now, read all. + var reader io.Reader = rc + if strings.HasSuffix(shardPath, ".zst") { + zstdReader, err := zstd.NewReader(rc) + if err != nil { + return DefiniteMiss, fmt.Errorf("failed to create zstd reader for %s: %w", shardPath, err) + } + defer zstdReader.Close() + + reader = zstdReader + } + + data, err := io.ReadAll(reader) + if err != nil { + return DefiniteMiss, err + } + + shardReader, err := ReadShard(bytes.NewReader(data)) + if err != nil { + return DefiniteMiss, err + } + + c.shardCacheMu.Lock() + c.shardCache[shardPath] = shardReader + c.shardCacheMu.Unlock() + + return c.queryShard(shardReader, hashStr) +} + +func (c *Client) queryShard(shardReader *ShardReader, hashStr string) (Result, error) { + // Parse Hash + h, err := ParseHash(hashStr) + if err != nil { + return DefiniteMiss, fmt.Errorf("invalid hash: %w", err) + } + + hit, err := shardReader.Contains(h) + if err != nil { + return DefiniteMiss, err + } + + if hit { + return DefiniteHit, nil + } + + return DefiniteMiss, nil +} diff --git a/pkg/nixcacheindex/client_test.go b/pkg/nixcacheindex/client_test.go new file mode 100644 index 00000000..5585c507 --- /dev/null +++ b/pkg/nixcacheindex/client_test.go @@ -0,0 +1,194 @@ +package nixcacheindex_test + +import ( + "bytes" + "context" + "fmt" + "io" + "math/big" + "testing" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/nixcacheindex" +) + +// MockFetcher. +type MockFetcher struct { + mock.Mock + files map[string][]byte + fetchCalls int +} + +func (m *MockFetcher) Fetch(path string) (io.ReadCloser, error) { + m.fetchCalls++ + if data, ok := m.files[path]; ok { + return io.NopCloser(bytes.NewReader(data)), nil + } + + return nil, fmt.Errorf("%w: %s", nixcacheindex.ErrShardNotFound, path) +} + +func TestClientQuery_EndToEnd(t *testing.T) { + t.Parallel() + + // Setup + // 1. Manifest + manifest := nixcacheindex.NewManifest() + manifest.Sharding.Depth = 0 // Single root shard for simplicity + manifest.Epoch.Current = 10 + manifest.Journal.CurrentSegment = 1000 + manifest.Journal.SegmentDurationSeconds = 100 + manifest.Journal.RetentionCount = 1 + // Use mock URLs + manifest.Urls.JournalBase = "https://mock/journal/" + manifest.Urls.ShardsBase = "https://mock/shards/" + + var manifestBuf bytes.Buffer + require.NoError(t, manifest.Write(&manifestBuf)) + + // 2. Journal + // Current segment (1000). Contains +hashA, -hashB + hashA := "0000000000000000000000000000000a" // Added + hashB := "0000000000000000000000000000000b" // Deleted + + journalEntries := []nixcacheindex.JournalEntry{ + {Op: nixcacheindex.OpAdd, Hash: hashA}, + {Op: nixcacheindex.OpDelete, Hash: hashB}, + } + + var journalBuf bytes.Buffer + require.NoError(t, nixcacheindex.WriteJournal(&journalBuf, journalEntries)) + + // 3. Shard (Epoch 10, root) + // Contains hashC + hashC := "0000000000000000000000000000000c" + bnC, _ := nixcacheindex.ParseHash(hashC) + + // Also contains hashB (which was deleted in journal, so journal should take precedence) + bnB, _ := nixcacheindex.ParseHash(hashB) + + hashes := []*big.Int{bnB, bnC} // Sorted? B < C. Yes. + + var shardBuf bytes.Buffer + + params := nixcacheindex.Encoding{Parameter: 4, HashBits: 160, PrefixBits: 0} + require.NoError(t, nixcacheindex.WriteShard(&shardBuf, hashes, params)) + + // Compress the shard data + var compressedShardBuf bytes.Buffer + + enc, err := zstd.NewWriter(&compressedShardBuf) + require.NoError(t, err) + _, err = enc.Write(shardBuf.Bytes()) + require.NoError(t, err) + require.NoError(t, enc.Close()) + + // 4. Mock Files + mockFiles := map[string][]byte{ + nixcacheindex.ManifestPath: manifestBuf.Bytes(), + "https://mock/journal/1000.log": journalBuf.Bytes(), + "https://mock/journal/900.log": nil, // Previous segment empty/missing + "https://mock/shards/10/root.idx.zst": compressedShardBuf.Bytes(), + } + + fetcher := &MockFetcher{files: mockFiles} + client := nixcacheindex.NewClient(fetcher) + + // Test Cases + + // Case 1: HashA (Added in Journal) -> ProbableHit + res, err := client.Query(context.Background(), hashA) + require.NoError(t, err) + assert.Equal(t, nixcacheindex.ProbableHit, res, "HashA should be ProbableHit") + + // Case 2: HashB (Deleted in Journal) -> DefiniteMiss + // Even though it is in Shard! + res, err = client.Query(context.Background(), hashB) + require.NoError(t, err) + assert.Equal(t, nixcacheindex.DefiniteMiss, res, "HashB should be DefiniteMiss (deleted in journal)") + + // Case 3: HashC (In Shard, not in Journal) -> DefiniteHit + res, err = client.Query(context.Background(), hashC) + require.NoError(t, err) + assert.Equal(t, nixcacheindex.DefiniteHit, res, "HashC should be DefiniteHit") + + // Case 4: HashD (Missing) -> DefiniteMiss + hashD := "0000000000000000000000000000000d" + res, err = client.Query(context.Background(), hashD) + require.NoError(t, err) + assert.Equal(t, nixcacheindex.DefiniteMiss, res, "HashD should be DefiniteMiss") +} + +func TestClientQuery_Caching(t *testing.T) { + t.Parallel() + + // Setup + manifest := nixcacheindex.NewManifest() + manifest.Sharding.Depth = 0 + manifest.Epoch.Current = 1 + manifest.Journal.RetentionCount = 0 // Minimize journal fetches + manifest.Urls.ShardsBase = "https://mock/shards/" + + var manifestBuf bytes.Buffer + require.NoError(t, manifest.Write(&manifestBuf)) + + hashA := "0000000000000000000000000000000a" + bnA, _ := nixcacheindex.ParseHash(hashA) + hashes := []*big.Int{bnA} + + var shardBuf bytes.Buffer + + params := nixcacheindex.Encoding{Parameter: 4, HashBits: 160, PrefixBits: 0} + require.NoError(t, nixcacheindex.WriteShard(&shardBuf, hashes, params)) + + mockFiles := map[string][]byte{ + nixcacheindex.ManifestPath: manifestBuf.Bytes(), + "https://mock/shards/1/root.idx.zst": shardBuf.Bytes(), + } + // Note: client.go checks fmt.Sprintf("%s%d/%s.idx.zst", ...), so the path MUST end in .zst + // even if we don't actually compress it in the mock for this test (zstd reader might error if not valid zstd) + // Let's compress it to be safe. + + var compressedShardBuf bytes.Buffer + + enc, err := zstd.NewWriter(&compressedShardBuf) + require.NoError(t, err) + _, err = enc.Write(shardBuf.Bytes()) + require.NoError(t, err) + require.NoError(t, enc.Close()) + + mockFiles["https://mock/shards/1/root.idx.zst"] = compressedShardBuf.Bytes() + + fetcher := &MockFetcher{files: mockFiles} + client := nixcacheindex.NewClient(fetcher) + + // First query: should fetch shard + res, err := client.Query(context.Background(), hashA) + require.NoError(t, err) + assert.Equal(t, nixcacheindex.DefiniteHit, res) + // 1 (manifest) + 1 (current journal) + 1 (shard) = 3 + assert.Equal(t, 3, fetcher.fetchCalls) + + // Second query for same hash: should use cache for shard + res, err = client.Query(context.Background(), hashA) + require.NoError(t, err) + assert.Equal(t, nixcacheindex.DefiniteHit, res) + // No new manifest fetch (cached in Client.manifest if already loaded), + // but Query currently re-fetches manifest if manifest is nil. + // Wait, LoadManifest is only called if c.manifest == nil. + // But it re-fetches journal segments every time. + // So 3 + 1 (journal) = 4 + assert.Equal(t, 4, fetcher.fetchCalls, "Should not have fetched shard again") + + // Third query for different hash in same shard: should use cache for shard + hashB := "0000000000000000000000000000000b" + res, err = client.Query(context.Background(), hashB) + require.NoError(t, err) + assert.Equal(t, nixcacheindex.DefiniteMiss, res) + // 4 + 1 (journal) = 5 + assert.Equal(t, 5, fetcher.fetchCalls, "Should not have fetched shard again") +} diff --git a/pkg/nixcacheindex/errors.go b/pkg/nixcacheindex/errors.go new file mode 100644 index 00000000..2301e25c --- /dev/null +++ b/pkg/nixcacheindex/errors.go @@ -0,0 +1,20 @@ +package nixcacheindex + +import "errors" + +var ( + // ErrInvalidHashLength is returned when a hash string has an incorrect length. + ErrInvalidHashLength = errors.New("invalid hash length") + // ErrInvalidHashChar is returned when a hash string contains invalid characters. + ErrInvalidHashChar = errors.New("invalid character in hash") + // ErrInvalidMagic is returned when a shard file has an incorrect magic number. + ErrInvalidMagic = errors.New("invalid magic number") + // ErrEmptyShard is returned when trying to write a shard with no hashes. + ErrEmptyShard = errors.New("cannot write empty shard") + // ErrManifestNotFound is returned when the manifest cannot be fetched. + ErrManifestNotFound = errors.New("manifest not found") + // ErrShardNotFound is returned when a shard cannot be fetched. + ErrShardNotFound = errors.New("shard not found") + // ErrInvalidJournalOp is returned when a journal operation is invalid. + ErrInvalidJournalOp = errors.New("invalid journal operation") +) diff --git a/pkg/nixcacheindex/journal.go b/pkg/nixcacheindex/journal.go new file mode 100644 index 00000000..11287d42 --- /dev/null +++ b/pkg/nixcacheindex/journal.go @@ -0,0 +1,102 @@ +package nixcacheindex + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +// JournalOp represents a journal operation (add or delete). +type JournalOp int + +const ( + OpAdd JournalOp = iota + OpDelete +) + +// JournalEntry is a single entry in the journal. +type JournalEntry struct { + Op JournalOp + Hash string // 32-char Nix base32 hash +} + +// ParseJournal parses journal entries from an io.Reader. +// Format is line-based: +// + +// - +// Empty lines are ignored. +func ParseJournal(r io.Reader) ([]JournalEntry, error) { + scanner := bufio.NewScanner(r) + + var entries []JournalEntry + + lineNum := 0 + for scanner.Scan() { + lineNum++ + + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + if len(line) != HashLength+1 { + return nil, fmt.Errorf("%w: line %d: got %d (expected %d)", ErrInvalidHashLength, lineNum, len(line), HashLength+1) + } + + opChar := line[0] + hash := line[1:] + + var op JournalOp + + switch opChar { + case '+': + op = OpAdd + case '-': + op = OpDelete + default: + return nil, fmt.Errorf("%w: line %d: %q", ErrInvalidJournalOp, lineNum, opChar) + } + + // Validate hash characters roughly (checked fully if we ParseHash later) + // For now just length is checked above. + + entries = append(entries, JournalEntry{ + Op: op, + Hash: hash, + }) + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return entries, nil +} + +// WriteJournal writes journal entries to an io.Writer. +func WriteJournal(w io.Writer, entries []JournalEntry) error { + for _, entry := range entries { + var opChar string + + switch entry.Op { + case OpAdd: + opChar = "+" + case OpDelete: + opChar = "-" + default: + return fmt.Errorf("%w: %v", ErrInvalidJournalOp, entry.Op) + } + + if len(entry.Hash) != HashLength { + return fmt.Errorf("%w: %d (expected %d)", ErrInvalidHashLength, len(entry.Hash), HashLength) + } + + _, err := fmt.Fprintf(w, "%s%s\n", opChar, entry.Hash) + if err != nil { + return err + } + } + + return nil +} diff --git a/pkg/nixcacheindex/journal_test.go b/pkg/nixcacheindex/journal_test.go new file mode 100644 index 00000000..9ffc8e96 --- /dev/null +++ b/pkg/nixcacheindex/journal_test.go @@ -0,0 +1,83 @@ +package nixcacheindex_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/nixcacheindex" +) + +func TestJournalRoundTrip(t *testing.T) { + t.Parallel() + + entries := []nixcacheindex.JournalEntry{ + {Op: nixcacheindex.OpAdd, Hash: "b6gvzjyb2pg0kjfwn6a6llj3k1bq6dwi"}, + {Op: nixcacheindex.OpAdd, Hash: "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"}, + {Op: nixcacheindex.OpDelete, Hash: "x9y8z7w6v5u4t3s2r1q0p9o8n7m6l5k4"}, + } + + var buf bytes.Buffer + + err := nixcacheindex.WriteJournal(&buf, entries) + require.NoError(t, err) + + parsed, err := nixcacheindex.ParseJournal(&buf) + require.NoError(t, err) + + assert.Equal(t, entries, parsed) +} + +func TestParseJournal_Errors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantErr string + }{ + { + name: "Invalid Op", + input: "*b6gvzjyb2pg0kjfwn6a6llj3k1bq6dwi", + wantErr: "invalid journal operation", + }, + { + name: "Short Line", + input: "+short", + wantErr: "invalid hash length", + }, + { + name: "Long Line", + input: "+b6gvzjyb2pg0kjfwn6a6llj3k1bq6dwi1", + wantErr: "invalid hash length", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + entries, err := nixcacheindex.ParseJournal(strings.NewReader(tt.input)) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + assert.Nil(t, entries) + }) + } +} + +func TestWriteJournal_Errors(t *testing.T) { + t.Parallel() + + entries := []nixcacheindex.JournalEntry{ + {Op: nixcacheindex.OpAdd, Hash: "short"}, // Invalid hash + } + + var buf bytes.Buffer + + err := nixcacheindex.WriteJournal(&buf, entries) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid hash length") +} diff --git a/pkg/nixcacheindex/manifest.go b/pkg/nixcacheindex/manifest.go new file mode 100644 index 00000000..4a0dc72f --- /dev/null +++ b/pkg/nixcacheindex/manifest.go @@ -0,0 +1,131 @@ +package nixcacheindex + +import ( + "encoding/json" + "io" + "os" + "time" +) + +// ManifestPath is the well-known path for the cache index manifest. +const ManifestPath = "/nix-cache-index/manifest.json" + +// Manifest describes the index topology. +type Manifest struct { + Version int `json:"version"` + Format string `json:"format"` + CreatedAt time.Time `json:"created_at"` //nolint:tagliatelle // RFC 0195 + ItemCount int64 `json:"item_count"` //nolint:tagliatelle // RFC 0195 + Sharding Sharding `json:"sharding"` + Encoding Encoding `json:"encoding"` + Urls Urls `json:"urls"` + Journal Journal `json:"journal"` + Epoch Epoch `json:"epoch"` + Deltas Deltas `json:"deltas"` +} + +// Sharding configuration. +type Sharding struct { + Depth int `json:"depth"` + Alphabet string `json:"alphabet"` +} + +// Encoding configuration for shards. +type Encoding struct { + Type string `json:"type"` // e.g. "golomb-rice" + Parameter int `json:"parameter"` // Golomb parameter k (M = 2^k) + HashBits int `json:"hash_bits"` //nolint:tagliatelle // RFC 0195 + PrefixBits int `json:"prefix_bits"` //nolint:tagliatelle // RFC 0195 +} + +// Urls configuration. +type Urls struct { + JournalBase string `json:"journal_base"` //nolint:tagliatelle // RFC 0195 + ShardsBase string `json:"shards_base"` //nolint:tagliatelle // RFC 0195 + DeltasBase string `json:"deltas_base"` //nolint:tagliatelle // RFC 0195 +} + +// Journal configuration. +type Journal struct { + CurrentSegment int64 `json:"current_segment"` //nolint:tagliatelle // RFC 0195 + SegmentDurationSeconds int `json:"segment_duration_seconds"` //nolint:tagliatelle // RFC 0195 + RetentionCount int `json:"retention_count"` //nolint:tagliatelle // RFC 0195 +} + +// Epoch information. +type Epoch struct { + Current int64 `json:"current"` + Previous int64 `json:"previous,omitempty"` +} + +// Deltas configuration. +type Deltas struct { + Enabled bool `json:"enabled"` + OldestBase int64 `json:"oldest_base"` //nolint:tagliatelle // RFC 0195 + Compression string `json:"compression"` // "none", "gzip", "zstd" +} + +// NewManifest creates a default manifest. +func NewManifest() *Manifest { + return &Manifest{ + Version: 1, + Format: "hlssi", + CreatedAt: time.Now().UTC(), + Sharding: Sharding{ + Depth: 2, + Alphabet: Alphabet, + }, + Encoding: Encoding{ + Type: "golomb-rice", + Parameter: 8, + HashBits: HashBits, + PrefixBits: 10, + }, + Urls: Urls{ + JournalBase: "https://cache.example.com/nix-cache-index/journal/", + ShardsBase: "https://cache.example.com/nix-cache-index/shards/", + DeltasBase: "https://cache.example.com/nix-cache-index/deltas/", + }, + Journal: Journal{ + CurrentSegment: time.Now().Unix(), + SegmentDurationSeconds: 300, + RetentionCount: 12, + }, + Epoch: Epoch{ + Current: 1, + }, + Deltas: Deltas{ + Enabled: true, + Compression: "zstd", + }, + } +} + +// LoadManifest reads a manifest from an io.Reader. +func LoadManifest(r io.Reader) (*Manifest, error) { + var m Manifest + if err := json.NewDecoder(r).Decode(&m); err != nil { + return nil, err + } + + return &m, nil +} + +// LoadManifestFromFile reads a manifest from a file. +func LoadManifestFromFile(path string) (*Manifest, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + return LoadManifest(f) +} + +// WriteManifest writes the manifest to an io.Writer. +func (m *Manifest) Write(w io.Writer) error { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + + return enc.Encode(m) +} diff --git a/pkg/nixcacheindex/manifest_test.go b/pkg/nixcacheindex/manifest_test.go new file mode 100644 index 00000000..32ee430d --- /dev/null +++ b/pkg/nixcacheindex/manifest_test.go @@ -0,0 +1,86 @@ +package nixcacheindex_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/nixcacheindex" +) + +func TestManifestSerialization(t *testing.T) { + t.Parallel() + + // RFC Example (approximate) + jsonStr := `{ + "version": 1, + "format": "hlssi", + "created_at": "2026-01-13T12:00:00Z", + "item_count": 1200000000, + "sharding": { + "depth": 2, + "alphabet": "0123456789abcdfghijklmnpqrsvwxyz" + }, + "encoding": { + "type": "golomb-rice", + "parameter": 8, + "hash_bits": 160, + "prefix_bits": 10 + }, + "urls": { + "journal_base": "https://cache.example.com/nix-cache-index/journal/", + "shards_base": "https://cache.example.com/nix-cache-index/shards/", + "deltas_base": "https://cache.example.com/nix-cache-index/deltas/" + }, + "journal": { + "current_segment": 1705147200, + "segment_duration_seconds": 300, + "retention_count": 12 + }, + "epoch": { + "current": 42, + "previous": 41 + }, + "deltas": { + "enabled": true, + "oldest_base": 35, + "compression": "zstd" + } +}` + + m, err := nixcacheindex.LoadManifest(strings.NewReader(jsonStr)) + require.NoError(t, err) + + assert.Equal(t, 1, m.Version) + assert.Equal(t, "hlssi", m.Format) + assert.Equal(t, int64(1200000000), m.ItemCount) + assert.Equal(t, 2, m.Sharding.Depth) + assert.Equal(t, "golomb-rice", m.Encoding.Type) + assert.Equal(t, "https://cache.example.com/nix-cache-index/journal/", m.Urls.JournalBase) + assert.Equal(t, 42, int(m.Epoch.Current)) // Cast to int for comparison convenience if needed + assert.Equal(t, 41, int(m.Epoch.Previous)) + + // Test Serialization + var buf bytes.Buffer + + err = m.Write(&buf) + require.NoError(t, err) + + // Read back + m2, err := nixcacheindex.LoadManifest(&buf) + require.NoError(t, err) + assert.Equal(t, m, m2) +} + +func TestNewManifest(t *testing.T) { + t.Parallel() + + m := nixcacheindex.NewManifest() + assert.Equal(t, 1, m.Version) + assert.Equal(t, "hlssi", m.Format) + assert.Positive(t, m.CreatedAt.Unix()) + assert.NotEmpty(t, m.Urls.JournalBase) +} diff --git a/pkg/nixcacheindex/shard.go b/pkg/nixcacheindex/shard.go new file mode 100644 index 00000000..85b723f4 --- /dev/null +++ b/pkg/nixcacheindex/shard.go @@ -0,0 +1,392 @@ +package nixcacheindex + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" + "math/big" + "sort" + + "github.com/cespare/xxhash/v2" + + "github.com/kalbasit/ncps/pkg/golomb" +) + +const ( + // MagicNumber is "NIXIDX01" in Little-Endian. + // N (4E) I (49) X (58) I (49) D (44) X (58) 0 (30) 1 (31) -> 0x313058444958494E. + MagicNumber = 0x313058444958494E + + // SparseIndexInterval is the number of items between sparse index entries. + SparseIndexInterval = 256 + + // HeaderSize is the fixed size of the shard header. + HeaderSize = 64 + + // SparseEntrySize is the size of a sparse index entry (20 + 8 = 28 bytes). + SparseEntrySize = 28 +) + +// ShardHeader represents the fixed-size header of a shard file. +type ShardHeader struct { + Magic uint64 + ItemCount uint64 + GolombK uint8 + HashSuffixBits uint8 + SparseIndexOffset uint64 + SparseIndexCount uint64 + Checksum uint64 // XXH64 of encoded data + Reserved [22]byte +} + +// SparseIndexEntry represents an entry in the sparse index. +type SparseIndexEntry struct { + HashSuffix *big.Int // 20 bytes (160 bits), store path suffix + Offset uint64 // 8 bytes, offset into encoded data +} + +// ShardReader facilitates reading a shard. +type ShardReader struct { + r io.ReadSeeker + Header ShardHeader + SparseIndex []SparseIndexEntry + Params Encoding +} + +// Helper to write Little-Endian uint64. +func writeUint64(w io.Writer, v uint64) error { + return binary.Write(w, binary.LittleEndian, v) +} + +// Helper to write Little-Endian uint8. +func writeUint8(w io.Writer, v uint8) error { + return binary.Write(w, binary.LittleEndian, v) +} + +// WriteShard writes a shard to w given a list of sorted hashes. +// hashes must be sorted numerically (big-endian). +// params defines the encoding parameters (k, prefix bits). +func WriteShard(w io.Writer, hashes []*big.Int, params Encoding) error { + if len(hashes) == 0 { + return ErrEmptyShard + } + + if params.Parameter < 0 || params.Parameter >= 64 { + return fmt.Errorf("%w: %d", golomb.ErrInvalidGolombK, params.Parameter) + } + + if params.PrefixBits < 0 || params.HashBits < params.PrefixBits { + return fmt.Errorf("%w: prefix/hash bits %d/%d", golomb.ErrInvalidEncodingParams, params.PrefixBits, params.HashBits) + } + + // 1. Prepare buffers + var ( + encodedData bytes.Buffer + sparseIndex []SparseIndexEntry + ) + + // Golomb encoder + ge, err := golomb.NewEncoder(&encodedData, params.Parameter) + if err != nil { + return err + } + + // Mask for stripping prefix + // Hash is HashBits length. Prefix is PrefixBits. Suffix is HashBits - PrefixBits. + suffixBits := params.HashBits - params.PrefixBits + mask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), uint(suffixBits)), big.NewInt(1)) //nolint:gosec + + var prevSuffix *big.Int + + for i, h := range hashes { + // Strip prefix to get suffix + suffix := new(big.Int).And(h, mask) + + if i == 0 { //nolint:nestif + // First hash: raw suffix bits. + if err := ge.BitWriter.WriteBigIntBits(suffix, suffixBits); err != nil { + return err + } + // Flush to byte boundary for valid Sparse Index Offset + if err := ge.Flush(); err != nil { + return err + } + // Offset is current buffer length (which points to byte immediately after hash[0]) + sparseIndex = append(sparseIndex, SparseIndexEntry{ + HashSuffix: new(big.Int).Set(suffix), + Offset: uint64(encodedData.Len()), //nolint:gosec + }) + + prevSuffix = suffix + } else { + // Delta = suffix - prev + delta := new(big.Int).Sub(suffix, prevSuffix) + + // Encode Delta + if err := ge.EncodeBig(delta); err != nil { + return err + } + + prevSuffix = suffix + + // If this is a start of a new block (e.g. 256), we flush AFTER writing it, + // and record the offset for the NEXT item to be efficiently seekable? + // NO. Sparse Entry 'j' corresponds to item `j*256`. + // So Entry 1 is for item 256. + // The entry must store `hash[256]` AND `offset` where decoding starts for `hash[257]`. + // So we process item 256, THEN Flush, THEN record Entry 1. + + if i%SparseIndexInterval == 0 { + if err := ge.Flush(); err != nil { + return err + } + + sparseIndex = append(sparseIndex, SparseIndexEntry{ + HashSuffix: new(big.Int).Set(suffix), + Offset: uint64(encodedData.Len()), //nolint:gosec + }) + } + } + } + // Flush final bits (though last block might not be aligned, that's fine, file ends). + if err := ge.Flush(); err != nil { + return err + } + + // 2. Compute Checksum + encodedBytes := encodedData.Bytes() + checksum := xxhash.Sum64(encodedBytes) + + // 3. Write Header + h := ShardHeader{ + Magic: MagicNumber, + ItemCount: uint64(len(hashes)), + GolombK: uint8(uint(params.Parameter)), //nolint:gosec + HashSuffixBits: uint8(uint(suffixBits)), //nolint:gosec + SparseIndexOffset: uint64(HeaderSize), + SparseIndexCount: uint64(len(sparseIndex)), + Checksum: checksum, + } + + // Write Header Fields + if err := writeUint64(w, h.Magic); err != nil { + return err + } + + if err := writeUint64(w, h.ItemCount); err != nil { + return err + } + + if err := writeUint8(w, h.GolombK); err != nil { + return err + } + + if err := writeUint8(w, h.HashSuffixBits); err != nil { + return err + } + + if err := writeUint64(w, h.SparseIndexOffset); err != nil { + return err + } + + if err := writeUint64(w, h.SparseIndexCount); err != nil { + return err + } + + if err := writeUint64(w, h.Checksum); err != nil { + return err + } + + if _, err := w.Write(h.Reserved[:]); err != nil { + return err + } + + // 4. Write Sparse Index + for _, entry := range sparseIndex { + // Hash (20 bytes, Big Endian as per RFC "20 bytes, big-endian integer") + // My WriteBigIntBits writes Big Endian bits? + // WriteBigIntBits writes bits. `Write` probably expects bytes. + // "Hash stored is ... 160-bit big-endian integer in 20 bytes". + // I should convert BigInt to 20 bytes. + b := entry.HashSuffix.Bytes() + // Pad to 20 bytes if needed + pad := 20 - len(b) + if pad < 0 { + // Should not happen if HashSuffixBits <= 160 + return fmt.Errorf("%w: %d bytes", ErrInvalidHashLength, len(b)) + } + + if pad > 0 { + if _, err := w.Write(make([]byte, pad)); err != nil { + return err + } + } + + if _, err := w.Write(b); err != nil { + return err + } + + // Offset (8 bytes Little Endian) + if err := writeUint64(w, entry.Offset); err != nil { + return err + } + } + + // 5. Write Encoded Data + if _, err := w.Write(encodedBytes); err != nil { + return err + } + + return nil +} + +// ReadShard opens and parses a shard. +func ReadShard(r io.ReadSeeker) (*ShardReader, error) { + // Read Header + var h ShardHeader + + b := make([]byte, HeaderSize) + if _, err := io.ReadFull(r, b); err != nil { + return nil, err + } + + buf := bytes.NewReader(b) + + if err := binary.Read(buf, binary.LittleEndian, &h.Magic); err != nil { + return nil, err + } + + if h.Magic != MagicNumber { + return nil, fmt.Errorf("%w: %x", ErrInvalidMagic, h.Magic) + } + + if err := binary.Read(buf, binary.LittleEndian, &h.ItemCount); err != nil { + return nil, err + } + + if err := binary.Read(buf, binary.LittleEndian, &h.GolombK); err != nil { + return nil, err + } + + if err := binary.Read(buf, binary.LittleEndian, &h.HashSuffixBits); err != nil { + return nil, err + } + + if err := binary.Read(buf, binary.LittleEndian, &h.SparseIndexOffset); err != nil { + return nil, err + } + + if err := binary.Read(buf, binary.LittleEndian, &h.SparseIndexCount); err != nil { + return nil, err + } + + if err := binary.Read(buf, binary.LittleEndian, &h.Checksum); err != nil { + return nil, err + } + + // Load Sparse Index + if _, err := r.Seek(int64(h.SparseIndexOffset), io.SeekStart); err != nil { //nolint:gosec + return nil, err + } + + sparseIndex := make([]SparseIndexEntry, h.SparseIndexCount) + for i := 0; i < int(h.SparseIndexCount); i++ { //nolint:gosec + // Read 20 byte hash + hashBytes := make([]byte, 20) + if _, err := io.ReadFull(r, hashBytes); err != nil { + return nil, err + } + + sparseIndex[i].HashSuffix = new(big.Int).SetBytes(hashBytes) + + // Read 8 byte offset + if err := binary.Read(r, binary.LittleEndian, &sparseIndex[i].Offset); err != nil { + return nil, err + } + } + + // Verify Checksum? RFC says "Clients SHOULD verify checksum". + // We'll skip for now to save IO or do it? + // Reading whole encoded data into memory might be heavy. Let's do it if caller requests or assume lazy? + // For "ReadShard", usually means "Open". + // We won't verify checksum here to verify "Contains" performance unless we want strictness. + + return &ShardReader{ + r: r, + Header: h, + SparseIndex: sparseIndex, + Params: Encoding{ + Parameter: int(h.GolombK), + HashBits: 160, // Assuming standard + PrefixBits: 160 - int(h.HashSuffixBits), + }, + }, nil +} + +// Contains checks if the shard contains the given hash. +// The hash must match the shard's prefix (which is implicit, we only check suffix). +func (sr *ShardReader) Contains(hash *big.Int) (bool, error) { + // Strip prefix + bits := uint(sr.Params.HashBits - sr.Params.PrefixBits) //nolint:gosec + mask := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), bits), big.NewInt(1)) + targetSuffix := new(big.Int).And(hash, mask) + + // Binary Search Sparse Index + // Find the largest entry <= targetSuffix + idx := sort.Search(len(sr.SparseIndex), func(i int) bool { + return sr.SparseIndex[i].HashSuffix.Cmp(targetSuffix) > 0 + }) + // idx is the first entry > target. + // So bracket is idx-1. + + if idx == 0 { + // All entries > target? + // Check if first entry IS target? (Since Search condition was >) + // No, `Search` returns first index where f(i) is true. + // If index 0 > target, then target is smaller than smallest sparse entry? + // But entry 0 is the smallest element (element 0). + // So target < element 0. + return false, nil + } + + bracketIdx := idx - 1 + startEntry := sr.SparseIndex[bracketIdx] + + // Optimization: check if startEntry IS target + if startEntry.HashSuffix.Cmp(targetSuffix) == 0 { + return true, nil + } + + // Decode from startEntry + // Seek to Encoded Data Start + Offset + // Encoded Data Start is after Sparse Index. + encodedDataStart := sr.Header.SparseIndexOffset + sr.Header.SparseIndexCount*SparseEntrySize + seekPos := int64(encodedDataStart + startEntry.Offset) //nolint:gosec + + if _, err := sr.r.Seek(seekPos, io.SeekStart); err != nil { + return false, err + } + + // Golomb Decoder + gd := golomb.NewDecoder(bufio.NewReader(sr.r), sr.Params.Parameter) + + currentHash := new(big.Int).Set(startEntry.HashSuffix) + + // Loop until current >= target + for currentHash.Cmp(targetSuffix) < 0 { + delta, err := gd.DecodeBig() + if err != nil { + if err == io.EOF { + return false, nil + } + + return false, err + } + + currentHash.Add(currentHash, delta) + } + + return currentHash.Cmp(targetSuffix) == 0, nil +} diff --git a/pkg/nixcacheindex/shard_test.go b/pkg/nixcacheindex/shard_test.go new file mode 100644 index 00000000..b57b5eab --- /dev/null +++ b/pkg/nixcacheindex/shard_test.go @@ -0,0 +1,108 @@ +package nixcacheindex_test + +import ( + "bytes" + "math/big" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/nixcacheindex" +) + +func TestShardReadWrite(t *testing.T) { + t.Parallel() + + // Generate some sorted hashes + var hashes []*big.Int + + start := big.NewInt(1000) + for i := 0; i < 500; i++ { + // Add some gaps + h := new(big.Int).Add(start, big.NewInt(int64(i*10+(i%5)))) // deterministic gaps + hashes = append(hashes, h) + } + + params := nixcacheindex.Encoding{ + Parameter: 4, // k=4 + HashBits: 160, + PrefixBits: 0, // Depth 0 + } + + var buf bytes.Buffer + + err := nixcacheindex.WriteShard(&buf, hashes, params) + require.NoError(t, err) + + // Read back + r := bytes.NewReader(buf.Bytes()) + sr, err := nixcacheindex.ReadShard(r) + require.NoError(t, err) + + assert.Equal(t, uint64(500), sr.Header.ItemCount) + assert.Equal(t, uint8(4), sr.Header.GolombK) + + // Sparse Index Count: 500 items. Intervals 256. 0, 256. -> 2 entries. + assert.Equal(t, uint64(2), sr.Header.SparseIndexCount) + + // Verify Contains + for _, h := range hashes { + contains, err := sr.Contains(h) + require.NoError(t, err) + assert.True(t, contains, "Shard should contain %s", h) + } + + // Verify Missing + missing := big.NewInt(999) + contains, err := sr.Contains(missing) + require.NoError(t, err) + assert.False(t, contains, "Shard should not contain %s", missing) + + missing2 := big.NewInt(1005) // In between 1000 and 1011 (gap 11) + contains, err = sr.Contains(missing2) + require.NoError(t, err) + assert.False(t, contains) +} + +func TestWriteShard_Empty(t *testing.T) { + t.Parallel() + + err := nixcacheindex.WriteShard(&bytes.Buffer{}, nil, nixcacheindex.Encoding{}) + assert.Error(t, err) +} + +func TestShardSparseAlignment(t *testing.T) { + t.Parallel() + + // Create enough items to trigger sparse index ( > 256) + count := 300 + + hashes := make([]*big.Int, count) + for i := 0; i < count; i++ { + hashes[i] = big.NewInt(int64(i)) + } + + params := nixcacheindex.Encoding{Parameter: 8, HashBits: 160, PrefixBits: 0} + + var buf bytes.Buffer + + err := nixcacheindex.WriteShard(&buf, hashes, params) + require.NoError(t, err) + + sr, err := nixcacheindex.ReadShard(bytes.NewReader(buf.Bytes())) + require.NoError(t, err) + + // Check sparse index entries + require.Len(t, sr.SparseIndex, 2) + assert.Equal(t, 0, sr.SparseIndex[0].HashSuffix.Cmp(big.NewInt(0)), "Entry 0 should be 0") + assert.Equal(t, 0, sr.SparseIndex[1].HashSuffix.Cmp(big.NewInt(256)), "Entry 1 should be 256") + + // Verify lookups around the boundary + ok, _ := sr.Contains(big.NewInt(255)) + assert.True(t, ok) + ok, _ = sr.Contains(big.NewInt(256)) + assert.True(t, ok) + ok, _ = sr.Contains(big.NewInt(257)) + assert.True(t, ok) +}