@@ -21,6 +21,8 @@ import (
2121 "errors"
2222 "fmt"
2323 "io"
24+ "sync"
25+ "sync/atomic"
2426 "testing"
2527
2628 "github.com/apache/arrow-go/v18/arrow"
@@ -484,3 +486,251 @@ type flightStreamWriter struct{}
484486func (f * flightStreamWriter ) Send (data * flight.FlightData ) error { return nil }
485487
486488var _ flight.DataStreamWriter = (* flightStreamWriter )(nil )
489+
490+ // callbackRecordReader wraps a record reader and invokes a callback on each Next() call.
491+ // It tracks whether batches are properly released and the reader itself is released.
492+ type callbackRecordReader struct {
493+ mem memory.Allocator
494+ schema * arrow.Schema
495+ numBatches int
496+ currentBatch atomic.Int32
497+ onNext func (batchIndex int ) // callback invoked before returning from Next()
498+ released atomic.Bool
499+ batchesCreated atomic.Int32
500+ totalRetains atomic.Int32
501+ totalReleases atomic.Int32
502+ createdBatches []arrow.RecordBatch // track all created batches for cleanup
503+ mu sync.Mutex
504+ }
505+
506+ func newCallbackRecordReader (mem memory.Allocator , schema * arrow.Schema , numBatches int , onNext func (int )) * callbackRecordReader {
507+ return & callbackRecordReader {
508+ mem : mem ,
509+ schema : schema ,
510+ numBatches : numBatches ,
511+ onNext : onNext ,
512+ }
513+ }
514+
515+ func (r * callbackRecordReader ) Schema () * arrow.Schema {
516+ return r .schema
517+ }
518+
519+ func (r * callbackRecordReader ) Next () bool {
520+ current := r .currentBatch .Load ()
521+ if int (current ) >= r .numBatches {
522+ return false
523+ }
524+ r .currentBatch .Add (1 )
525+
526+ if r .onNext != nil {
527+ r .onNext (int (current ))
528+ }
529+
530+ return true
531+ }
532+
533+ func (r * callbackRecordReader ) RecordBatch () arrow.RecordBatch {
534+ bldr := array .NewInt64Builder (r .mem )
535+ defer bldr .Release ()
536+
537+ currentBatch := r .currentBatch .Load ()
538+ bldr .AppendValues ([]int64 {int64 (currentBatch )}, nil )
539+ arr := bldr .NewArray ()
540+
541+ rec := array .NewRecordBatch (r .schema , []arrow.Array {arr }, 1 )
542+ arr .Release ()
543+
544+ tracked := & trackedRecordBatch {
545+ RecordBatch : rec ,
546+ onRetain : func () {
547+ r .totalRetains .Add (1 )
548+ },
549+ onRelease : func () {
550+ r .totalReleases .Add (1 )
551+ },
552+ }
553+
554+ r .mu .Lock ()
555+ r .createdBatches = append (r .createdBatches , tracked )
556+ r .mu .Unlock ()
557+
558+ r .batchesCreated .Add (1 )
559+ return tracked
560+ }
561+
562+ func (r * callbackRecordReader ) ReleaseAll () {
563+ r .mu .Lock ()
564+ defer r .mu .Unlock ()
565+ for _ , batch := range r .createdBatches {
566+ batch .Release ()
567+ }
568+ r .createdBatches = nil
569+ }
570+
571+ func (r * callbackRecordReader ) Retain () {}
572+
573+ func (r * callbackRecordReader ) Release () {
574+ r .released .Store (true )
575+ }
576+
577+ func (r * callbackRecordReader ) Record () arrow.RecordBatch {
578+ return r .RecordBatch ()
579+ }
580+
581+ func (r * callbackRecordReader ) Err () error {
582+ return nil
583+ }
584+
585+ // trackedRecordBatch wraps a RecordBatch to track Retain/Release calls.
586+ type trackedRecordBatch struct {
587+ arrow.RecordBatch
588+ onRetain func ()
589+ onRelease func ()
590+ }
591+
592+ func (t * trackedRecordBatch ) Retain () {
593+ if t .onRetain != nil {
594+ t .onRetain ()
595+ }
596+ t .RecordBatch .Retain ()
597+ }
598+
599+ func (t * trackedRecordBatch ) Release () {
600+ if t .onRelease != nil {
601+ t .onRelease ()
602+ }
603+ t .RecordBatch .Release ()
604+ }
605+
606+ func TestStreamChunksFromReader_OK (t * testing.T ) {
607+ mem := memory .NewCheckedAllocator (memory .DefaultAllocator )
608+ defer mem .AssertSize (t , 0 )
609+
610+ schema := arrow .NewSchema ([]arrow.Field {{Name : "value" , Type : arrow .PrimitiveTypes .Int64 }}, nil )
611+
612+ rdr := newCallbackRecordReader (mem , schema , 5 , nil )
613+ defer rdr .ReleaseAll ()
614+
615+ ch := make (chan flight.StreamChunk , 5 )
616+
617+ ctx := context .Background ()
618+
619+ go flight .StreamChunksFromReader (ctx , rdr , ch )
620+
621+ var chunksReceived int
622+ for chunk := range ch {
623+ if chunk .Err != nil {
624+ t .Errorf ("unexpected error chunk: %v" , chunk .Err )
625+ continue
626+ }
627+ if chunk .Data != nil {
628+ chunksReceived ++
629+ chunk .Data .Release ()
630+ }
631+ }
632+
633+ require .Equal (t , 5 , chunksReceived , "should receive all 5 batches" )
634+ require .True (t , rdr .released .Load (), "reader should be released" )
635+
636+ }
637+
638+ // TestStreamChunksFromReader_HandlesCancellation verifies that context cancellation
639+ // causes StreamChunksFromReader to exit cleanly and release the reader.
640+ func TestStreamChunksFromReader_HandlesCancellation (t * testing.T ) {
641+ ctx , cancel := context .WithCancel (context .Background ())
642+ defer cancel ()
643+
644+ mem := memory .DefaultAllocator
645+ schema := arrow .NewSchema ([]arrow.Field {{Name : "value" , Type : arrow .PrimitiveTypes .Int64 }}, nil )
646+
647+ rdr := newCallbackRecordReader (mem , schema , 10 , nil )
648+ defer rdr .ReleaseAll ()
649+ ch := make (chan flight.StreamChunk ) // unbuffered channel
650+
651+ go flight .StreamChunksFromReader (ctx , rdr , ch )
652+
653+ chunksReceived := 0
654+ for chunk := range ch {
655+ if chunk .Data != nil {
656+ chunksReceived ++
657+ chunk .Data .Release ()
658+ }
659+
660+ // Cancel context after 2 batches (simulating server detecting client disconnect)
661+ if chunksReceived == 2 {
662+ cancel ()
663+ }
664+ }
665+
666+ // After canceling context, StreamChunksFromReader exits and closes the channel.
667+ // The for-range loop above exits when the channel closes.
668+ // By the time we reach here, the channel is closed, which means StreamChunksFromReader's
669+ // defer stack has already executed, so the reader must be released.
670+
671+ require .True (t , rdr .released .Load (), "reader must be released when context is canceled" )
672+
673+ }
674+
675+ // TestStreamChunksFromReader_CancellationReleasesBatches verifies that batches are
676+ // properly tracked and demonstrates memory leaks without cleanup, then proves cleanup fixes it.
677+ func TestStreamChunksFromReader_CancellationReleasesBatches (t * testing.T ) {
678+ mem := memory .NewCheckedAllocator (memory .DefaultAllocator )
679+
680+ schema := arrow .NewSchema ([]arrow.Field {{Name : "value" , Type : arrow .PrimitiveTypes .Int64 }}, nil )
681+
682+ ctx , cancel := context .WithCancel (context .Background ())
683+ defer cancel ()
684+
685+ // Create reader that will produce 10 batches, but we'll cancel after 3
686+ reader := newCallbackRecordReader (mem , schema , 10 , func (batchIndex int ) {
687+ if batchIndex == 2 {
688+ cancel ()
689+ }
690+ })
691+
692+ ch := make (chan flight.StreamChunk , 5 )
693+
694+ // Start streaming
695+ go flight .StreamChunksFromReader (ctx , reader , ch )
696+
697+ // Consume chunks until channel closes
698+ var chunksReceived int
699+ for chunk := range ch {
700+ if chunk .Err != nil {
701+ t .Errorf ("unexpected error chunk: %v" , chunk .Err )
702+ continue
703+ }
704+ if chunk .Data != nil {
705+ chunksReceived ++
706+ chunk .Data .Release ()
707+ }
708+ }
709+
710+ // Verify the reader was released
711+ require .True (t , reader .released .Load (), "reader should be released" )
712+
713+ // We should have received at most 3-4 chunks (depending on timing)
714+ // The important part is we didn't receive all 10
715+ require .LessOrEqual (t , chunksReceived , 4 , "should not receive all 10 chunks, got %d" , chunksReceived )
716+ require .Greater (t , chunksReceived , 0 , "should receive at least 1 chunk" )
717+
718+ // Check that Retain and Release don't balance - proving there's a leak without manual cleanup
719+ retains := reader .totalRetains .Load ()
720+ releases := reader .totalReleases .Load ()
721+ batchesCreated := reader .batchesCreated .Load ()
722+
723+ // Each batch starts with refcount=1, then StreamChunksFromReader calls Retain() (refcount=2)
724+ // For sent batches: we call Release() (refcount=1), batch still has initial ref
725+ // For unsent batches due to cancellation: they keep refcount=1 from creation
726+ // So we expect: releases < retains + batchesCreated
727+ require .Less (t , releases , retains + batchesCreated ,
728+ "without cleanup, releases should be less than retains+created: retains=%d, releases=%d, created=%d" ,
729+ retains , releases , batchesCreated )
730+
731+ // Now manually release all created batches to show proper cleanup fixes the leak
732+ reader .ReleaseAll ()
733+
734+ // After cleanup, memory should be freed
735+ mem .AssertSize (t , 0 )
736+ }
0 commit comments