Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 255 additions & 40 deletions datafusion/physical-plan/src/sorts/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
// under the License.

use crate::spill::get_record_batch_memory_size;
use arrow::array::ArrayRef;
use arrow::compute::interleave;
use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::MemoryReservation;
use log::warn;
use std::any::Any;
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::sync::Arc;

#[derive(Debug, Copy, Clone, Default)]
Expand All @@ -40,9 +45,24 @@ pub struct BatchBuilder {
/// Maintain a list of [`RecordBatch`] and their corresponding stream
batches: Vec<(usize, RecordBatch)>,

/// Accounts for memory used by buffered batches
/// Accounts for memory used by buffered batches.
///
/// May include pre-reserved bytes (from `sort_spill_reservation_bytes`)
/// that were transferred via [`MemoryReservation::take()`] to prevent
/// starvation when concurrent sort partitions compete for pool memory.
reservation: MemoryReservation,

/// Tracks the actual memory used by buffered batches (not including
/// pre-reserved bytes). This allows [`Self::push_batch`] to skip pool
/// allocation requests when the pre-reserved bytes cover the batch.
batches_mem_used: usize,

/// The initial reservation size at construction time. When the reservation
/// is pre-loaded with `sort_spill_reservation_bytes` (via `take()`), this
/// records that amount so we never shrink below it, maintaining the
/// anti-starvation guarantee throughout the merge.
initial_reservation: usize,

/// The current [`BatchCursor`] for each stream
cursors: Vec<BatchCursor>,

Expand All @@ -59,19 +79,26 @@ impl BatchBuilder {
batch_size: usize,
reservation: MemoryReservation,
) -> Self {
let initial_reservation = reservation.size();
Self {
schema,
batches: Vec::with_capacity(stream_count * 2),
cursors: vec![BatchCursor::default(); stream_count],
indices: Vec::with_capacity(batch_size),
reservation,
batches_mem_used: 0,
initial_reservation,
}
}

/// Append a new batch in `stream_idx`
pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> {
self.reservation
.try_grow(get_record_batch_memory_size(&batch))?;
let size = get_record_batch_memory_size(&batch);
self.batches_mem_used += size;
// Only request additional memory from the pool when actual batch
// usage exceeds the current reservation (which may include
// pre-reserved bytes from sort_spill_reservation_bytes).
try_grow_reservation_to_at_least(&mut self.reservation, self.batches_mem_used)?;
let batch_idx = self.batches.len();
self.batches.push((stream_idx, batch));
self.cursors[stream_idx] = BatchCursor {
Expand Down Expand Up @@ -104,53 +131,241 @@ impl BatchBuilder {
&self.schema
}

/// Try to interleave all columns using the given index slice.
fn try_interleave_columns(
&self,
indices: &[(usize, usize)],
) -> Result<Vec<ArrayRef>> {
(0..self.schema.fields.len())
.map(|column_idx| {
let arrays: Vec<_> = self
.batches
.iter()
.map(|(_, batch)| batch.column(column_idx).as_ref())
.collect();
recover_offset_overflow_from_panic(|| interleave(&arrays, indices))
})
.collect::<Result<Vec<_>>>()
}

/// Builds a record batch from the first `rows_to_emit` buffered rows.
fn finish_record_batch(
&mut self,
rows_to_emit: usize,
columns: Vec<ArrayRef>,
) -> Result<RecordBatch> {
// Remove consumed indices, keeping any remaining for the next call.
self.indices.drain(..rows_to_emit);

// Only clean up fully-consumed batches when all indices are drained,
// because remaining indices may still reference earlier batches.
if self.indices.is_empty() {
// New cursors are only created once the previous cursor for the stream
// is finished. This means all remaining rows from all but the last batch
// for each stream have been yielded to the newly created record batch
//
// We can therefore drop all but the last batch for each stream
let mut batch_idx = 0;
let mut retained = 0;
self.batches.retain(|(stream_idx, batch)| {
let stream_cursor = &mut self.cursors[*stream_idx];
let retain = stream_cursor.batch_idx == batch_idx;
batch_idx += 1;

if retain {
stream_cursor.batch_idx = retained;
retained += 1;
} else {
self.batches_mem_used -= get_record_batch_memory_size(batch);
}
retain
});
}

// Release excess memory back to the pool, but never shrink below
// initial_reservation to maintain the anti-starvation guarantee
// for the merge phase.
let target = self.batches_mem_used.max(self.initial_reservation);
if self.reservation.size() > target {
self.reservation.shrink(self.reservation.size() - target);
}

RecordBatch::try_new(Arc::clone(&self.schema), columns).map_err(Into::into)
}

/// Drains the in_progress row indexes, and builds a new RecordBatch from them
///
/// Will then drop any batches for which all rows have been yielded to the output
/// Will then drop any batches for which all rows have been yielded to the output.
/// If an offset overflow occurs (e.g. string/list offsets exceed i32::MAX),
/// retries with progressively fewer rows until it succeeds.
///
/// Returns `None` if no pending rows
pub fn build_record_batch(&mut self) -> Result<Option<RecordBatch>> {
if self.is_empty() {
return Ok(None);
}

let columns = (0..self.schema.fields.len())
.map(|column_idx| {
let arrays: Vec<_> = self
.batches
.iter()
.map(|(_, batch)| batch.column(column_idx).as_ref())
.collect();
Ok(interleave(&arrays, &self.indices)?)
})
.collect::<Result<Vec<_>>>()?;

self.indices.clear();

// New cursors are only created once the previous cursor for the stream
// is finished. This means all remaining rows from all but the last batch
// for each stream have been yielded to the newly created record batch
//
// We can therefore drop all but the last batch for each stream
let mut batch_idx = 0;
let mut retained = 0;
self.batches.retain(|(stream_idx, batch)| {
let stream_cursor = &mut self.cursors[*stream_idx];
let retain = stream_cursor.batch_idx == batch_idx;
batch_idx += 1;

if retain {
stream_cursor.batch_idx = retained;
retained += 1;
let (rows_to_emit, columns) =
retry_interleave(self.indices.len(), self.indices.len(), |rows_to_emit| {
self.try_interleave_columns(&self.indices[..rows_to_emit])
})?;

Ok(Some(self.finish_record_batch(rows_to_emit, columns)?))
}
}

/// Try to grow `reservation` so it covers at least `needed` bytes.
///
/// When a reservation has been pre-loaded with bytes (e.g. via
/// [`MemoryReservation::take()`]), this avoids redundant pool
/// allocations: if the reservation already covers `needed`, this is
/// a no-op; otherwise only the deficit is requested from the pool.
pub(crate) fn try_grow_reservation_to_at_least(
reservation: &mut MemoryReservation,
needed: usize,
) -> Result<()> {
if needed > reservation.size() {
reservation.try_grow(needed - reservation.size())?;
}
Ok(())
}

/// Returns true if the error is an Arrow offset overflow.
fn is_offset_overflow(e: &DataFusionError) -> bool {
matches!(
e,
DataFusionError::ArrowError(boxed, _)
if matches!(boxed.as_ref(), ArrowError::OffsetOverflowError(_))
)
}

fn offset_overflow_error() -> DataFusionError {
DataFusionError::ArrowError(Box::new(ArrowError::OffsetOverflowError(0)), None)
}

fn recover_offset_overflow_from_panic<T, F>(f: F) -> Result<T>
where
F: FnOnce() -> std::result::Result<T, ArrowError>,
{
// Arrow's interleave can panic on i32 offset overflow with
// `.expect("overflow")` / `.expect("offset overflow")`.
// Catch only those specific panics so the caller can retry
// with fewer rows while unrelated defects still unwind.
match catch_unwind(AssertUnwindSafe(f)) {
Ok(result) => Ok(result?),
Err(panic_payload) => {
if is_arrow_offset_overflow_panic(panic_payload.as_ref()) {
Err(offset_overflow_error())
} else {
std::panic::resume_unwind(panic_payload);
}
}
}
}

fn retry_interleave<T, F>(
mut rows_to_emit: usize,
total_rows: usize,
mut interleave: F,
) -> Result<(usize, T)>
where
F: FnMut(usize) -> Result<T>,
{
loop {
match interleave(rows_to_emit) {
Ok(value) => return Ok((rows_to_emit, value)),
Err(e) if is_offset_overflow(&e) => {
rows_to_emit /= 2;
if rows_to_emit == 0 {
return Err(e);
}
warn!(
"Interleave offset overflow with {total_rows} rows, retrying with {rows_to_emit}"
);
}
Err(e) => return Err(e),
}
}
}

fn panic_message(payload: &(dyn Any + Send)) -> Option<&str> {
if let Some(msg) = payload.downcast_ref::<&str>() {
return Some(msg);
}
if let Some(msg) = payload.downcast_ref::<String>() {
return Some(msg.as_str());
}
None
}

/// Returns true if a caught panic payload matches the Arrow offset overflows
/// raised by interleave's offset builders.
fn is_arrow_offset_overflow_panic(payload: &(dyn Any + Send)) -> bool {
matches!(panic_message(payload), Some("overflow" | "offset overflow"))
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::error::ArrowError;

#[test]
fn test_retry_interleave_halves_rows_until_success() {
let mut attempts = Vec::new();

let (rows_to_emit, result) = retry_interleave(4, 4, |rows_to_emit| {
attempts.push(rows_to_emit);
if rows_to_emit > 1 {
Err(offset_overflow_error())
} else {
self.reservation.shrink(get_record_batch_memory_size(batch));
Ok("ok")
}
retain
});
})
.unwrap();

assert_eq!(rows_to_emit, 1);
assert_eq!(result, "ok");
assert_eq!(attempts, vec![4, 2, 1]);
}

#[test]
fn test_recover_offset_overflow_from_panic() {
let error = recover_offset_overflow_from_panic(
|| -> std::result::Result<(), ArrowError> { panic!("offset overflow") },
)
.unwrap_err();

assert!(is_offset_overflow(&error));
}

#[test]
fn test_recover_offset_overflow_from_panic_rethrows_unrelated_panics() {
let panic_payload = catch_unwind(AssertUnwindSafe(|| {
let _ = recover_offset_overflow_from_panic(
|| -> std::result::Result<(), ArrowError> { panic!("capacity overflow") },
);
}));

assert!(panic_payload.is_err());
}

#[test]
fn test_is_arrow_offset_overflow_panic() {
let overflow = Box::new("overflow") as Box<dyn Any + Send>;
assert!(is_arrow_offset_overflow_panic(overflow.as_ref()));

let offset_overflow =
Box::new(String::from("offset overflow")) as Box<dyn Any + Send>;
assert!(is_arrow_offset_overflow_panic(offset_overflow.as_ref()));

let capacity_overflow = Box::new("capacity overflow") as Box<dyn Any + Send>;
assert!(!is_arrow_offset_overflow_panic(capacity_overflow.as_ref()));

Ok(Some(RecordBatch::try_new(
Arc::clone(&self.schema),
columns,
)?))
let arithmetic_overflow =
Box::new(String::from("attempt to multiply with overflow"))
as Box<dyn Any + Send>;
assert!(!is_arrow_offset_overflow_panic(
arithmetic_overflow.as_ref()
));
}
}
Loading
Loading