Skip to content

Commit 2d0962e

Browse files
arnoldwakimawakim
andauthored
fix(flight): make StreamChunksFromReader ctx aware and cancellation-safe (#615)
### Rationale for this change `StreamChunksFromReader` previously did not observe context cancellation. As a result, if a client disconnected early, the reader could continue producing data indefinitely, potentially blocking on channel sends, leaking `RecordBatch` objects, leaking the reader, and consuming unbounded memory and CPU (this observation triggered this PR). This fix ensures that data streaming promptly stops when the client disconnects. ### What changes are included in this PR? - `StreamChunksFromReader` now accepts a `context.Context`. - Tiny change was made to `DoGet`, to ensure it continues to work with the context-aware `StreamChunksFromReader`. ### Are these changes tested? - To be removed from description: the tests are bit tricky to write, similar to that of #437. Maybe @zeroshade has suggestions? ### Are there any user-facing changes? - `StreamChunksFromReader` now accepts a `context.Context`. --------- Co-authored-by: awakim <arnold.wakim@gmail.com>
1 parent 7566828 commit 2d0962e

File tree

4 files changed

+289
-19
lines changed

4 files changed

+289
-19
lines changed

arrow/flight/flight_test.go

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"errors"
2222
"fmt"
2323
"io"
24+
"sync"
25+
"sync/atomic"
2426
"testing"
2527

2628
"github.com/apache/arrow-go/v18/arrow"
@@ -484,3 +486,251 @@ type flightStreamWriter struct{}
484486
func (f *flightStreamWriter) Send(data *flight.FlightData) error { return nil }
485487

486488
var _ flight.DataStreamWriter = (*flightStreamWriter)(nil)
489+
490+
// callbackRecordReader wraps a record reader and invokes a callback on each Next() call.
491+
// It tracks whether batches are properly released and the reader itself is released.
492+
type callbackRecordReader struct {
493+
mem memory.Allocator
494+
schema *arrow.Schema
495+
numBatches int
496+
currentBatch atomic.Int32
497+
onNext func(batchIndex int) // callback invoked before returning from Next()
498+
released atomic.Bool
499+
batchesCreated atomic.Int32
500+
totalRetains atomic.Int32
501+
totalReleases atomic.Int32
502+
createdBatches []arrow.RecordBatch // track all created batches for cleanup
503+
mu sync.Mutex
504+
}
505+
506+
func newCallbackRecordReader(mem memory.Allocator, schema *arrow.Schema, numBatches int, onNext func(int)) *callbackRecordReader {
507+
return &callbackRecordReader{
508+
mem: mem,
509+
schema: schema,
510+
numBatches: numBatches,
511+
onNext: onNext,
512+
}
513+
}
514+
515+
func (r *callbackRecordReader) Schema() *arrow.Schema {
516+
return r.schema
517+
}
518+
519+
func (r *callbackRecordReader) Next() bool {
520+
current := r.currentBatch.Load()
521+
if int(current) >= r.numBatches {
522+
return false
523+
}
524+
r.currentBatch.Add(1)
525+
526+
if r.onNext != nil {
527+
r.onNext(int(current))
528+
}
529+
530+
return true
531+
}
532+
533+
func (r *callbackRecordReader) RecordBatch() arrow.RecordBatch {
534+
bldr := array.NewInt64Builder(r.mem)
535+
defer bldr.Release()
536+
537+
currentBatch := r.currentBatch.Load()
538+
bldr.AppendValues([]int64{int64(currentBatch)}, nil)
539+
arr := bldr.NewArray()
540+
541+
rec := array.NewRecordBatch(r.schema, []arrow.Array{arr}, 1)
542+
arr.Release()
543+
544+
tracked := &trackedRecordBatch{
545+
RecordBatch: rec,
546+
onRetain: func() {
547+
r.totalRetains.Add(1)
548+
},
549+
onRelease: func() {
550+
r.totalReleases.Add(1)
551+
},
552+
}
553+
554+
r.mu.Lock()
555+
r.createdBatches = append(r.createdBatches, tracked)
556+
r.mu.Unlock()
557+
558+
r.batchesCreated.Add(1)
559+
return tracked
560+
}
561+
562+
func (r *callbackRecordReader) ReleaseAll() {
563+
r.mu.Lock()
564+
defer r.mu.Unlock()
565+
for _, batch := range r.createdBatches {
566+
batch.Release()
567+
}
568+
r.createdBatches = nil
569+
}
570+
571+
func (r *callbackRecordReader) Retain() {}
572+
573+
func (r *callbackRecordReader) Release() {
574+
r.released.Store(true)
575+
}
576+
577+
func (r *callbackRecordReader) Record() arrow.RecordBatch {
578+
return r.RecordBatch()
579+
}
580+
581+
func (r *callbackRecordReader) Err() error {
582+
return nil
583+
}
584+
585+
// trackedRecordBatch wraps a RecordBatch to track Retain/Release calls.
586+
type trackedRecordBatch struct {
587+
arrow.RecordBatch
588+
onRetain func()
589+
onRelease func()
590+
}
591+
592+
func (t *trackedRecordBatch) Retain() {
593+
if t.onRetain != nil {
594+
t.onRetain()
595+
}
596+
t.RecordBatch.Retain()
597+
}
598+
599+
func (t *trackedRecordBatch) Release() {
600+
if t.onRelease != nil {
601+
t.onRelease()
602+
}
603+
t.RecordBatch.Release()
604+
}
605+
606+
func TestStreamChunksFromReader_OK(t *testing.T) {
607+
mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
608+
defer mem.AssertSize(t, 0)
609+
610+
schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type: arrow.PrimitiveTypes.Int64}}, nil)
611+
612+
rdr := newCallbackRecordReader(mem, schema, 5, nil)
613+
defer rdr.ReleaseAll()
614+
615+
ch := make(chan flight.StreamChunk, 5)
616+
617+
ctx := context.Background()
618+
619+
go flight.StreamChunksFromReader(ctx, rdr, ch)
620+
621+
var chunksReceived int
622+
for chunk := range ch {
623+
if chunk.Err != nil {
624+
t.Errorf("unexpected error chunk: %v", chunk.Err)
625+
continue
626+
}
627+
if chunk.Data != nil {
628+
chunksReceived++
629+
chunk.Data.Release()
630+
}
631+
}
632+
633+
require.Equal(t, 5, chunksReceived, "should receive all 5 batches")
634+
require.True(t, rdr.released.Load(), "reader should be released")
635+
636+
}
637+
638+
// TestStreamChunksFromReader_HandlesCancellation verifies that context cancellation
639+
// causes StreamChunksFromReader to exit cleanly and release the reader.
640+
func TestStreamChunksFromReader_HandlesCancellation(t *testing.T) {
641+
ctx, cancel := context.WithCancel(context.Background())
642+
defer cancel()
643+
644+
mem := memory.DefaultAllocator
645+
schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type: arrow.PrimitiveTypes.Int64}}, nil)
646+
647+
rdr := newCallbackRecordReader(mem, schema, 10, nil)
648+
defer rdr.ReleaseAll()
649+
ch := make(chan flight.StreamChunk) // unbuffered channel
650+
651+
go flight.StreamChunksFromReader(ctx, rdr, ch)
652+
653+
chunksReceived := 0
654+
for chunk := range ch {
655+
if chunk.Data != nil {
656+
chunksReceived++
657+
chunk.Data.Release()
658+
}
659+
660+
// Cancel context after 2 batches (simulating server detecting client disconnect)
661+
if chunksReceived == 2 {
662+
cancel()
663+
}
664+
}
665+
666+
// After canceling context, StreamChunksFromReader exits and closes the channel.
667+
// The for-range loop above exits when the channel closes.
668+
// By the time we reach here, the channel is closed, which means StreamChunksFromReader's
669+
// defer stack has already executed, so the reader must be released.
670+
671+
require.True(t, rdr.released.Load(), "reader must be released when context is canceled")
672+
673+
}
674+
675+
// TestStreamChunksFromReader_CancellationReleasesBatches verifies that batches are
676+
// properly tracked and demonstrates memory leaks without cleanup, then proves cleanup fixes it.
677+
func TestStreamChunksFromReader_CancellationReleasesBatches(t *testing.T) {
678+
mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
679+
680+
schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type: arrow.PrimitiveTypes.Int64}}, nil)
681+
682+
ctx, cancel := context.WithCancel(context.Background())
683+
defer cancel()
684+
685+
// Create reader that will produce 10 batches, but we'll cancel after 3
686+
reader := newCallbackRecordReader(mem, schema, 10, func(batchIndex int) {
687+
if batchIndex == 2 {
688+
cancel()
689+
}
690+
})
691+
692+
ch := make(chan flight.StreamChunk, 5)
693+
694+
// Start streaming
695+
go flight.StreamChunksFromReader(ctx, reader, ch)
696+
697+
// Consume chunks until channel closes
698+
var chunksReceived int
699+
for chunk := range ch {
700+
if chunk.Err != nil {
701+
t.Errorf("unexpected error chunk: %v", chunk.Err)
702+
continue
703+
}
704+
if chunk.Data != nil {
705+
chunksReceived++
706+
chunk.Data.Release()
707+
}
708+
}
709+
710+
// Verify the reader was released
711+
require.True(t, reader.released.Load(), "reader should be released")
712+
713+
// We should have received at most 3-4 chunks (depending on timing)
714+
// The important part is we didn't receive all 10
715+
require.LessOrEqual(t, chunksReceived, 4, "should not receive all 10 chunks, got %d", chunksReceived)
716+
require.Greater(t, chunksReceived, 0, "should receive at least 1 chunk")
717+
718+
// Check that Retain and Release don't balance - proving there's a leak without manual cleanup
719+
retains := reader.totalRetains.Load()
720+
releases := reader.totalReleases.Load()
721+
batchesCreated := reader.batchesCreated.Load()
722+
723+
// Each batch starts with refcount=1, then StreamChunksFromReader calls Retain() (refcount=2)
724+
// For sent batches: we call Release() (refcount=1), batch still has initial ref
725+
// For unsent batches due to cancellation: they keep refcount=1 from creation
726+
// So we expect: releases < retains + batchesCreated
727+
require.Less(t, releases, retains+batchesCreated,
728+
"without cleanup, releases should be less than retains+created: retains=%d, releases=%d, created=%d",
729+
retains, releases, batchesCreated)
730+
731+
// Now manually release all created batches to show proper cleanup fixes the leak
732+
reader.ReleaseAll()
733+
734+
// After cleanup, memory should be freed
735+
mem.AssertSize(t, 0)
736+
}

arrow/flight/flightsql/example/sqlite_server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ func (s *SQLiteFlightSQLServer) DoGetTables(ctx context.Context, cmd flightsql.G
354354
}
355355

356356
schema := rdr.Schema()
357-
go flight.StreamChunksFromReader(rdr, ch)
357+
go flight.StreamChunksFromReader(ctx, rdr, ch)
358358
return schema, ch, nil
359359
}
360360

@@ -485,7 +485,7 @@ func doGetQuery(ctx context.Context, mem memory.Allocator, db dbQueryCtx, query
485485
}
486486

487487
ch := make(chan flight.StreamChunk)
488-
go flight.StreamChunksFromReader(rdr, ch)
488+
go flight.StreamChunksFromReader(ctx, rdr, ch)
489489
return schema, ch, nil
490490
}
491491

arrow/flight/flightsql/server.go

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ func (b *BaseServer) GetFlightInfoSqlInfo(_ context.Context, _ GetSqlInfo, desc
381381
}
382382

383383
// DoGetSqlInfo returns a flight stream containing the list of sqlinfo results
384-
func (b *BaseServer) DoGetSqlInfo(_ context.Context, cmd GetSqlInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) {
384+
func (b *BaseServer) DoGetSqlInfo(ctx context.Context, cmd GetSqlInfo) (*arrow.Schema, <-chan flight.StreamChunk, error) {
385385
if b.Alloc == nil {
386386
b.Alloc = memory.DefaultAllocator
387387
}
@@ -430,7 +430,7 @@ func (b *BaseServer) DoGetSqlInfo(_ context.Context, cmd GetSqlInfo) (*arrow.Sch
430430
}
431431

432432
// StreamChunksFromReader will call release on the reader when done
433-
go flight.StreamChunksFromReader(rdr, ch)
433+
go flight.StreamChunksFromReader(ctx, rdr, ch)
434434
return schema_ref.SqlInfo, ch, nil
435435
}
436436

@@ -927,19 +927,24 @@ func (f *flightSqlServer) DoGet(request *flight.Ticket, stream flight.FlightServ
927927
wr := flight.NewRecordWriter(stream, ipc.WithSchema(sc))
928928
defer wr.Close()
929929

930-
for chunk := range cc {
931-
if chunk.Err != nil {
932-
return chunk.Err
933-
}
934-
935-
wr.SetFlightDescriptor(chunk.Desc)
936-
if err = wr.WriteWithAppMetadata(chunk.Data, chunk.AppMetadata); err != nil {
937-
return err
930+
for {
931+
select {
932+
case <-stream.Context().Done():
933+
return stream.Context().Err()
934+
case chunk, ok := <-cc:
935+
if !ok {
936+
return nil
937+
}
938+
if chunk.Err != nil {
939+
return chunk.Err
940+
}
941+
wr.SetFlightDescriptor(chunk.Desc)
942+
if err := wr.WriteWithAppMetadata(chunk.Data, chunk.AppMetadata); err != nil {
943+
return err
944+
}
945+
chunk.Data.Release()
938946
}
939-
chunk.Data.Release()
940947
}
941-
942-
return err
943948
}
944949

945950
type putMetadataWriter struct {

arrow/flight/record_batch_reader.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package flight
1818

1919
import (
2020
"bytes"
21+
"context"
2122
"errors"
2223
"fmt"
2324
"io"
@@ -212,24 +213,38 @@ type haserr interface {
212213

213214
// StreamChunksFromReader is a convenience function to populate a channel
214215
// from a record reader. It is intended to be run using a separate goroutine
215-
// by calling `go flight.StreamChunksFromReader(rdr, ch)`.
216+
// by calling `go flight.StreamChunksFromReader(ctx, rdr, ch)`.
216217
//
217218
// If the record reader panics, an error chunk will get sent on the channel.
218219
//
219220
// This will close the channel and release the reader when it completes.
220-
func StreamChunksFromReader(rdr array.RecordReader, ch chan<- StreamChunk) {
221+
func StreamChunksFromReader(ctx context.Context, rdr array.RecordReader, ch chan<- StreamChunk) {
221222
defer close(ch)
222223
defer func() {
223224
if err := recover(); err != nil {
224-
ch <- StreamChunk{Err: utils.FormatRecoveredError("panic while reading", err)}
225+
select {
226+
case ch <- StreamChunk{Err: utils.FormatRecoveredError("panic while reading", err)}:
227+
case <-ctx.Done():
228+
}
225229
}
226230
}()
227231

228232
defer rdr.Release()
229233
for rdr.Next() {
234+
select {
235+
case <-ctx.Done():
236+
return
237+
default:
238+
}
239+
230240
rec := rdr.RecordBatch()
231241
rec.Retain()
232-
ch <- StreamChunk{Data: rec}
242+
select {
243+
case ch <- StreamChunk{Data: rec}:
244+
case <-ctx.Done():
245+
rec.Release()
246+
return
247+
}
233248
}
234249

235250
if e, ok := rdr.(haserr); ok {

0 commit comments

Comments
 (0)