diff --git a/pkg/golomb/golomb.go b/pkg/golomb/golomb.go new file mode 100644 index 00000000..a7e7d171 --- /dev/null +++ b/pkg/golomb/golomb.go @@ -0,0 +1,344 @@ +package golomb + +import ( + "bufio" + "errors" + "fmt" + "io" + "math/big" +) + +var ( + // ErrInvalidGolombK is returned when the Golomb parameter k is invalid. + ErrInvalidGolombK = errors.New("invalid golomb parameter k") + // ErrInvalidEncodingParams is returned when encoding parameters are invalid. + ErrInvalidEncodingParams = errors.New("invalid encoding parameters") +) + +// BitWriter writes bits to an io.Writer. +type BitWriter struct { + w io.ByteWriter + currentByte byte + bitCount uint8 // bits used in currentByte +} + +// NewBitWriter creates a new BitWriter on top of an io.Writer. +// The underlying writer must support WriteByte (e.g., bufio.Writer). +func NewBitWriter(w io.ByteWriter) *BitWriter { + return &BitWriter{w: w} +} + +// WriteBit writes a single bit (0 or 1). +func (bw *BitWriter) WriteBit(bit bool) error { + if bit { + bw.currentByte |= (1 << (7 - bw.bitCount)) + } + + bw.bitCount++ + if bw.bitCount == 8 { + if err := bw.w.WriteByte(bw.currentByte); err != nil { + return err + } + + bw.currentByte = 0 + bw.bitCount = 0 + } + + return nil +} + +// WriteBits writes the n least significant bits of val, most significant first. +func (bw *BitWriter) WriteBits(val uint64, n int) error { + for i := n - 1; i >= 0; i-- { + bit := (val >> i) & 1 + if err := bw.WriteBit(bit == 1); err != nil { + return err + } + } + + return nil +} + +// Flush writes any pending bits to the underlying writer, padded with zeros. +func (bw *BitWriter) Flush() error { + if bw.bitCount > 0 { + return bw.w.WriteByte(bw.currentByte) + } + + return nil +} + +// BitReader reads bits from an io.ByteReader. +type BitReader struct { + r io.ByteReader + currentByte byte + bitCount uint8 // bits available in currentByte +} + +// NewBitReader creates a new BitReader on top of an io.ByteReader. +func NewBitReader(r io.ByteReader) *BitReader { + return &BitReader{r: r} +} + +// ReadBit reads a single bit. +func (br *BitReader) ReadBit() (bool, error) { + if br.bitCount == 0 { + b, err := br.r.ReadByte() + if err != nil { + return false, err + } + + br.currentByte = b + br.bitCount = 8 + } + + bit := (br.currentByte >> (br.bitCount - 1)) & 1 + br.bitCount-- + + return bit == 1, nil +} + +// ReadBits reads n bits and returns them as a uint64. +func (br *BitReader) ReadBits(n int) (uint64, error) { + var val uint64 + + for i := 0; i < n; i++ { + bit, err := br.ReadBit() + if err != nil { + return 0, err + } + + val <<= 1 + if bit { + val |= 1 + } + } + + return val, nil +} + +// Encoder encodes integers using Golomb-Rice coding. +type Encoder struct { + BitWriter *BitWriter + flusher interface { + Flush() error + } + k int // Parameter k where M = 2^k + m uint64 // M = 2^k +} + +// NewEncoder creates a new encoder. +func NewEncoder(w io.Writer, k int) (*Encoder, error) { + if k < 0 || k >= 64 { + return nil, fmt.Errorf("%w: %d, must be in range [0, 63]", ErrInvalidGolombK, k) + } + + // Ensure w implements io.ByteWriter, wrap in bufio if not + var bw io.ByteWriter + + var flusher interface { + Flush() error + } + + if ibw, ok := w.(io.ByteWriter); ok { + bw = ibw + } else { + bufW := bufio.NewWriter(w) + bw = bufW + flusher = bufW + } + + return &Encoder{ + BitWriter: NewBitWriter(bw), + flusher: flusher, + k: k, + m: 1 << k, + }, nil +} + +// Encode encodes a value d. +func (ge *Encoder) Encode(d uint64) error { + q := d >> ge.k // d / M + r := d & (ge.m - 1) // d % M + + // Unary encode q: q ones followed by a zero + for i := uint64(0); i < q; i++ { + if err := ge.BitWriter.WriteBit(true); err != nil { + return err + } + } + + if err := ge.BitWriter.WriteBit(false); err != nil { + return err + } + + // Binary encode r: k bits + return ge.BitWriter.WriteBits(r, ge.k) +} + +// Flush flushes the underlying bit writer and any buffered writer. +func (ge *Encoder) Flush() error { + if err := ge.BitWriter.Flush(); err != nil { + return err + } + + if ge.flusher != nil { + return ge.flusher.Flush() + } + + return nil +} + +// WriteBigIntBits writes the n least significant bits of val, most significant first. +func (bw *BitWriter) WriteBigIntBits(val *big.Int, n int) error { + // We want to write bit (n-1) down to 0. + for i := n - 1; i >= 0; i-- { + bit := val.Bit(i) // .Bit(i) returns the bit at position i (0 is LSB) + if err := bw.WriteBit(bit == 1); err != nil { + return err + } + } + + return nil +} + +// ReadBigIntBits reads n bits and returns them as a big.Int. +func (br *BitReader) ReadBigIntBits(n int) (*big.Int, error) { + val := new(big.Int) + + for i := 0; i < n; i++ { + bit, err := br.ReadBit() + if err != nil { + return nil, err + } + + val.Lsh(val, 1) // Shift left + + if bit { + val.SetBit(val, 0, 1) // Set LSB to 1 + } + } + + return val, nil +} + +// Decoder decodes integers using Golomb-Rice coding. +type Decoder struct { + br *BitReader + k int + m uint64 +} + +// NewDecoder creates a new decoder. +func NewDecoder(r io.ByteReader, k int) *Decoder { + return &Decoder{ + br: NewBitReader(r), + k: k, + m: 1 << k, + } +} + +// Decode decodes a value. +func (gd *Decoder) Decode() (uint64, error) { + // Decode unary q: count ones until zero + var q uint64 + + for { + bit, err := gd.br.ReadBit() + if err != nil { + return 0, err + } + + if !bit { + break + } + + q++ + } + + // Decode binary r: k bits + r, err := gd.br.ReadBits(gd.k) + if err != nil { + return 0, err + } + + return q*gd.m + r, nil +} + +// EncodeBig encodes a big.Int delta. +func (ge *Encoder) EncodeBig(d *big.Int) error { + // q = d >> k + // r = d & (m - 1) <-- m is 2^k, so this is d & ( (1< 0 { + if err := ge.BitWriter.WriteBit(true); err != nil { + return err + } + + currQ.Sub(currQ, one) + } + + // Write zero delimiter + if err := ge.BitWriter.WriteBit(false); err != nil { + return err + } + + // Write r (k bits). r is the lowest k bits of d. + // We can use WriteBigIntBits on d directly, taking 'k' bits. + return ge.BitWriter.WriteBigIntBits(d, ge.k) +} + +// DecodeBig decodes a value as big.Int. +func (gd *Decoder) DecodeBig() (*big.Int, error) { + // Decode unary q + q := new(big.Int) + one := big.NewInt(1) + + for { + bit, err := gd.br.ReadBit() + if err != nil { + return nil, err + } + + if !bit { + break + } + + q.Add(q, one) + } + + // Decode binary r: k bits + r, err := gd.br.ReadBigIntBits(gd.k) + if err != nil { + return nil, err + } + + // d = q * M + r + // M = 2^k + // d = (q << k) | r (since r < 2^k) + + d := new(big.Int).Lsh(q, uint(gd.k)) //nolint:gosec + d.Or(d, r) + + return d, nil +} diff --git a/pkg/golomb/golomb_test.go b/pkg/golomb/golomb_test.go new file mode 100644 index 00000000..23332dc8 --- /dev/null +++ b/pkg/golomb/golomb_test.go @@ -0,0 +1,136 @@ +package golomb_test + +import ( + "bytes" + "math/big" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/kalbasit/ncps/pkg/golomb" +) + +func TestBitReaderWrite(t *testing.T) { + t.Parallel() + + var buf bytes.Buffer + + bw := golomb.NewBitWriter(&buf) + + // Write 1 0 1 1 0 0 0 1 + // Byte: 10110001 = 0xB1 + assert.NoError(t, bw.WriteBit(true)) + assert.NoError(t, bw.WriteBit(false)) + assert.NoError(t, bw.WriteBits(0b11, 2)) + assert.NoError(t, bw.WriteBits(0b00, 2)) + assert.NoError(t, bw.WriteBits(0b0, 1)) + assert.NoError(t, bw.WriteBits(0b1, 1)) + + assert.NoError(t, bw.Flush()) + assert.Equal(t, []byte{0xB1}, buf.Bytes()) + + // Read back + br := golomb.NewBitReader(&buf) + + b, err := br.ReadBit() + require.NoError(t, err) + assert.True(t, b) + + b, err = br.ReadBit() + require.NoError(t, err) + assert.False(t, b) + + val, err := br.ReadBits(2) + require.NoError(t, err) + assert.Equal(t, uint64(3), val) // 11 + + val, err = br.ReadBits(4) + require.NoError(t, err) + assert.Equal(t, uint64(1), val) // 0001 +} + +func TestGolombRoundTrip(t *testing.T) { + t.Parallel() + + k := 3 // M = 8 + values := []uint64{0, 1, 7, 8, 9, 100, 12345} + + var buf bytes.Buffer + + enc, err := golomb.NewEncoder(&buf, k) + require.NoError(t, err) + + for _, v := range values { + err := enc.Encode(v) + require.NoError(t, err) + } + + require.NoError(t, enc.Flush()) + + dec := golomb.NewDecoder(&buf, k) + for i, want := range values { + got, err := dec.Decode() + require.NoError(t, err, "failed to decode value at index %d", i) + assert.Equal(t, want, got, "value mismatched at index %d", i) + } +} + +func TestGolombExample(t *testing.T) { + t.Parallel() + + // Example from RFC: + // Delta d = 1000 + // M=256, k=8 + // q = 3, r = 232 + // Expect: 1110 (unary 3) | 11101000 (binary 232) + // Total 12 bits: 1110 1110 1000 ... + var buf bytes.Buffer + + enc, err := golomb.NewEncoder(&buf, 8) + require.NoError(t, err) + err = enc.Encode(1000) + require.NoError(t, err) + + _ = enc.Flush() + + // 1110 1110 1000 0000 (padded) -> EE 80 + decodedBytes := buf.Bytes() + require.Len(t, decodedBytes, 2) + assert.Equal(t, byte(0xEE), decodedBytes[0]) + assert.Equal(t, byte(0x80), decodedBytes[1]) +} + +func TestGolombBigIntRoundTrip(t *testing.T) { + t.Parallel() + + k := 60 // Use a large k to keep q small, otherwise unary encoding of 2^64 takes forever + // Values that might exceed uint64 + values := []*big.Int{ + big.NewInt(0), + big.NewInt(123), + new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 64), big.NewInt(1)), // 2^64 - 1 + new(big.Int).Lsh(big.NewInt(1), 64), // 2^64 + new(big.Int).Lsh(big.NewInt(12345), 70), + } + + var buf bytes.Buffer + + enc, err := golomb.NewEncoder(&buf, k) + require.NoError(t, err) + + for _, v := range values { + err := enc.EncodeBig(v) + require.NoError(t, err) + } + + require.NoError(t, enc.Flush()) + + dec := golomb.NewDecoder(&buf, k) + for i, want := range values { + got, err := dec.DecodeBig() + require.NoError(t, err, "failed to decode big.Int value at index %d", i) + assert.Equal(t, 0, want.Cmp(got), + "big.Int value mismatched at index %d, want: %s, got: %s", i, want.String(), got.String()) + } +}