Skip to content

Commit c76c245

Browse files
committed
Review feedback
1 parent c875c0b commit c76c245

File tree

3 files changed

+110
-23
lines changed

3 files changed

+110
-23
lines changed

block/internal/da/client.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func (c *client) Submit(ctx context.Context, data [][]byte, _ float64, namespace
113113
c.logger.Debug().
114114
Int("original_size", len(raw)).
115115
Int("compressed_size", len(compressed)).
116-
Float64("ratio", float64(len(compressed))/float64(len(raw))).
116+
Float64("ratio", float64(len(compressed))/float64(max(len(raw), 1))).
117117
Int("level", int(compLevel)).
118118
Msg("compressed blob for DA submission")
119119

@@ -311,13 +311,14 @@ func (c *client) Retrieve(ctx context.Context, height uint64, namespace []byte)
311311
data := make([]datypes.Blob, len(blobs))
312312
for i, b := range blobs {
313313
ids[i] = blobrpc.MakeID(height, b.Commitment)
314-
decompressed, decompErr := da.Decompress(b.Data())
314+
decompressed, decompErr := da.Decompress(ctx, b.Data())
315315
if decompErr != nil {
316316
return datypes.ResultRetrieve{
317317
BaseResult: datypes.BaseResult{
318-
Code: datypes.StatusError,
319-
Message: fmt.Sprintf("decompress blob %d at height %d: %v", i, height, decompErr),
320-
Height: height,
318+
Code: datypes.StatusError,
319+
Message: fmt.Sprintf("decompress blob %d at height %d: %v", i, height, decompErr),
320+
Height: height,
321+
Timestamp: blockTime,
321322
},
322323
}
323324
}
@@ -399,7 +400,7 @@ func (c *client) Get(ctx context.Context, ids []datypes.ID, namespace []byte) ([
399400
if b == nil {
400401
continue
401402
}
402-
decompressed, decompErr := da.Decompress(b.Data())
403+
decompressed, decompErr := da.Decompress(ctx, b.Data())
403404
if decompErr != nil {
404405
return nil, fmt.Errorf("decompress blob: %w", decompErr)
405406
}

pkg/da/compression.go

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package da
22

33
import (
4+
"context"
5+
"errors"
46
"fmt"
7+
"time"
58

69
"github.com/klauspost/compress/zstd"
710
)
@@ -10,6 +13,20 @@ import (
1013
// ASCII "ZSTD" = 0x5A 0x53 0x54 0x44.
1114
var magic = []byte{0x5A, 0x53, 0x54, 0x44}
1215

16+
// maxDecompressedSize is the maximum allowed decompressed output size.
17+
// Matches the WithDecoderMaxMemory cap and provides early rejection
18+
// by inspecting the zstd frame header before allocating anything.
19+
const maxDecompressedSize = 7 * 1024 * 1024 // 7 MiB
20+
21+
// decompressTimeout is the hard wall-clock cap on a single decompression.
22+
// This guards against CPU-based decompression bombs (crafted inputs that
23+
// are slow to decode) independently of the caller's context deadline.
24+
const decompressTimeout = 500 * time.Millisecond
25+
26+
// ErrDecompressedSizeExceeded is returned when the zstd frame header
27+
// declares a decompressed size that exceeds the allowed limit.
28+
var ErrDecompressedSizeExceeded = errors.New("compression: declared decompressed size exceeds limit")
29+
1330
// CompressionLevel controls the speed/ratio trade-off for blob compression.
1431
type CompressionLevel int
1532

@@ -44,8 +61,9 @@ func init() {
4461
}
4562
encoders[i] = enc
4663
}
64+
const maxDecoderMemory = 7 * 1024 * 1024 // 7 MiB cap
4765
var err error
48-
decoder, err = zstd.NewReader(nil)
66+
decoder, err = zstd.NewReader(nil, zstd.WithDecoderMaxMemory(maxDecoderMemory))
4967
if err != nil {
5068
panic(fmt.Sprintf("compression: create zstd decoder: %v", err))
5169
}
@@ -72,21 +90,51 @@ func Compress(data []byte, level CompressionLevel) ([]byte, error) {
7290
return result, nil
7391
}
7492

93+
// decodeResult carries the output of a DecodeAll goroutine.
94+
type decodeResult struct {
95+
data []byte
96+
err error
97+
}
98+
7599
// Decompress decompresses data that was compressed with Compress.
76100
// If the data does not have the magic prefix, it is returned as-is
77101
// (backward-compatible passthrough for uncompressed blobs).
78-
func Decompress(data []byte) ([]byte, error) {
102+
func Decompress(ctx context.Context, data []byte) ([]byte, error) {
79103
if !IsCompressed(data) {
80104
return data, nil
81105
}
82106

83-
// Strip magic prefix and decompress
84-
decompressed, err := decoder.DecodeAll(data[len(magic):], nil)
85-
if err != nil {
86-
return nil, fmt.Errorf("compression: zstd decompress: %w", err)
107+
payload := data[len(magic):]
108+
109+
// Layer 1: Parse frame header to check declared decompressed size
110+
// before allocating anything. This is a zero-cost upfront rejection.
111+
var hdr zstd.Header
112+
if err := hdr.Decode(payload); err == nil && hdr.HasFCS {
113+
if hdr.FrameContentSize > maxDecompressedSize {
114+
return nil, fmt.Errorf("%w: %d bytes declared, %d allowed",
115+
ErrDecompressedSizeExceeded, hdr.FrameContentSize, maxDecompressedSize)
116+
}
87117
}
88118

89-
return decompressed, nil
119+
// Layer 3: Apply the shorter of caller deadline and our hard cap.
120+
ctx, cancel := context.WithTimeout(ctx, decompressTimeout)
121+
defer cancel()
122+
123+
ch := make(chan decodeResult, 1)
124+
go func() {
125+
out, err := decoder.DecodeAll(payload, nil)
126+
ch <- decodeResult{data: out, err: err}
127+
}()
128+
129+
select {
130+
case res := <-ch:
131+
if res.err != nil {
132+
return nil, fmt.Errorf("zstd decompress: %w", res.err)
133+
}
134+
return res.data, nil
135+
case <-ctx.Done():
136+
return nil, fmt.Errorf("zstd decompress timeout: %w", ctx.Err())
137+
}
90138
}
91139

92140
// IsCompressed reports whether data starts with the compression magic prefix.

pkg/da/compression_test.go

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ package da
22

33
import (
44
"bytes"
5+
"context"
56
"crypto/rand"
67
"testing"
78

9+
"github.com/klauspost/compress/zstd"
810
"github.com/stretchr/testify/assert"
911
"github.com/stretchr/testify/require"
1012
)
@@ -39,7 +41,7 @@ func TestCompressDecompress_RoundTrip(t *testing.T) {
3941

4042
assert.True(t, IsCompressed(compressed), "compressed data should have magic prefix")
4143

42-
decompressed, err := Decompress(compressed)
44+
decompressed, err := Decompress(context.Background(), compressed)
4345
require.NoError(t, err)
4446

4547
assert.Equal(t, tt.data, decompressed, "round-trip should preserve data")
@@ -67,7 +69,7 @@ func TestCompress_AllLevelsRoundTrip(t *testing.T) {
6769

6870
sizes = append(sizes, len(compressed))
6971

70-
decompressed, err := Decompress(compressed)
72+
decompressed, err := Decompress(context.Background(), compressed)
7173
require.NoError(t, err)
7274
assert.Equal(t, data, decompressed)
7375

@@ -94,32 +96,32 @@ func TestCompress_Empty(t *testing.T) {
9496
func TestDecompress_UncompressedPassthrough(t *testing.T) {
9597
// Data without magic prefix should pass through unchanged
9698
raw := []byte("this is uncompressed protobuf data")
97-
result, err := Decompress(raw)
99+
result, err := Decompress(context.Background(), raw)
98100
require.NoError(t, err)
99101
assert.Equal(t, raw, result)
100102
}
101103

102104
func TestDecompress_Empty(t *testing.T) {
103-
result, err := Decompress(nil)
105+
result, err := Decompress(context.Background(), nil)
104106
require.NoError(t, err)
105107
assert.Nil(t, result)
106108

107-
result, err = Decompress([]byte{})
109+
result, err = Decompress(context.Background(), []byte{})
108110
require.NoError(t, err)
109111
assert.Empty(t, result)
110112
}
111113

112114
func TestDecompress_ShortData(t *testing.T) {
113115
// Data shorter than magic prefix should pass through
114-
result, err := Decompress([]byte{0x5A, 0x53})
116+
result, err := Decompress(context.Background(), []byte{0x5A, 0x53})
115117
require.NoError(t, err)
116118
assert.Equal(t, []byte{0x5A, 0x53}, result)
117119
}
118120

119121
func TestDecompress_CorruptCompressedData(t *testing.T) {
120122
// Magic prefix followed by invalid zstd data
121123
corrupt := append([]byte{0x5A, 0x53, 0x54, 0x44}, []byte("not valid zstd")...)
122-
_, err := Decompress(corrupt)
124+
_, err := Decompress(context.Background(), corrupt)
123125
assert.Error(t, err, "should fail on corrupt compressed data")
124126
}
125127

@@ -163,14 +165,14 @@ func TestCompress_RandomDataStillWorks(t *testing.T) {
163165
compressed, err := Compress(data, LevelFastest)
164166
require.NoError(t, err)
165167

166-
decompressed, err := Decompress(compressed)
168+
decompressed, err := Decompress(context.Background(), compressed)
167169
require.NoError(t, err)
168170
assert.Equal(t, data, decompressed)
169171
}
170172

171173
func TestDecompress_DataStartingWithMagicButUncompressed(t *testing.T) {
172174
fakeCompressed := append([]byte{0x5A, 0x53, 0x54, 0x44}, bytes.Repeat([]byte{0x00}, 100)...)
173-
_, err := Decompress(fakeCompressed)
175+
_, err := Decompress(context.Background(), fakeCompressed)
174176
assert.Error(t, err, "data starting with magic but containing invalid zstd should error")
175177
}
176178

@@ -180,7 +182,43 @@ func TestCompress_InvalidLevel(t *testing.T) {
180182
compressed, err := Compress(data, CompressionLevel(99))
181183
require.NoError(t, err)
182184

183-
decompressed, err := Decompress(compressed)
185+
decompressed, err := Decompress(context.Background(), compressed)
184186
require.NoError(t, err)
185187
assert.Equal(t, data, decompressed)
186188
}
189+
190+
func TestDecompress_ContextCanceled(t *testing.T) {
191+
data := []byte("test data for context cancellation")
192+
compressed, err := Compress(data, LevelDefault)
193+
require.NoError(t, err)
194+
195+
// Pre-canceled context should cause Decompress to return an error.
196+
ctx, cancel := context.WithCancel(context.Background())
197+
cancel()
198+
199+
_, err = Decompress(ctx, compressed)
200+
assert.Error(t, err, "decompress with canceled context should fail")
201+
assert.ErrorIs(t, err, context.Canceled)
202+
}
203+
204+
func TestDecompress_OversizedFrameRejected(t *testing.T) {
205+
// Build a fake zstd frame header that declares 100 MiB decompressed size.
206+
// This should be rejected by the frame header pre-check before any
207+
// decompression occurs.
208+
hdr := zstd.Header{
209+
SingleSegment: true,
210+
HasFCS: true,
211+
FrameContentSize: 100 * 1024 * 1024, // 100 MiB — way over the 7 MiB limit
212+
}
213+
frame, err := hdr.AppendTo(nil)
214+
require.NoError(t, err)
215+
216+
// Prepend our custom magic prefix
217+
blob := make([]byte, len(magic)+len(frame))
218+
copy(blob, magic)
219+
copy(blob[len(magic):], frame)
220+
221+
_, err = Decompress(context.Background(), blob)
222+
assert.Error(t, err, "should reject blob declaring oversized decompressed output")
223+
assert.ErrorIs(t, err, ErrDecompressedSizeExceeded)
224+
}

0 commit comments

Comments
 (0)