Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,15 @@ object CometConf extends ShimCometConf {
createExecEnabledConfig("hashJoin", defaultValue = true)
val COMET_EXEC_SORT_MERGE_JOIN_ENABLED: ConfigEntry[Boolean] =
createExecEnabledConfig("sortMergeJoin", defaultValue = true)
val COMET_EXEC_SMJ_USE_NATIVE: ConfigEntry[Boolean] =
conf("spark.comet.exec.sortMergeJoin.useNative")
.category(CATEGORY_EXEC)
.doc(
"When true, use Comet's native sort merge join implementation. " +
"When false, use DataFusion's SortMergeJoinExec. " +
"This is useful for benchmarking the two implementations.")
.booleanConf
.createWithDefault(true)
val COMET_EXEC_AGGREGATE_ENABLED: ConfigEntry[Boolean] =
createExecEnabledConfig("aggregate", defaultValue = true)
val COMET_EXEC_COLLECT_LIMIT_ENABLED: ConfigEntry[Boolean] =
Expand Down
6 changes: 5 additions & 1 deletion native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ struct ExecutionContext {
pub memory_pool_config: MemoryPoolConfig,
/// Whether to log memory usage on each call to execute_plan
pub tracing_enabled: bool,
/// Spark configuration map passed from JVM
pub spark_config: HashMap<String, String>,
}

/// Accept serialized query plan and return the address of the native query plan.
Expand Down Expand Up @@ -327,6 +329,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
explain_native,
memory_pool_config,
tracing_enabled,
spark_config,
});

Ok(Box::into_raw(exec_context) as i64)
Expand Down Expand Up @@ -545,7 +548,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
let start = Instant::now();
let planner =
PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition)
.with_exec_id(exec_context_id);
.with_exec_id(exec_context_id)
.with_spark_config(exec_context.spark_config.clone());
let (scans, shuffle_scans, root_op) = planner.create_plan(
&exec_context.spark_plan,
&mut exec_context.input_sources.clone(),
Expand Down
267 changes: 267 additions & 0 deletions native/core/src/execution/joins/buffered_batch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Buffered batch management for the sort merge join operator.
//!
//! [`BufferedMatchGroup`] holds all rows from the buffered (right) side that
//! share the current join key. When memory is tight, individual batches are
//! spilled to Arrow IPC files on disk and reloaded on demand.

use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::record_batch::RecordBatch;
use datafusion::common::utils::memory::get_record_batch_memory_size;
use datafusion::common::{DataFusionError, Result};
use datafusion::execution::disk_manager::RefCountedTempFile;
use datafusion::execution::memory_pool::MemoryReservation;
use datafusion::physical_expr::PhysicalExprRef;
use datafusion::physical_plan::spill::SpillManager;

use super::metrics::SortMergeJoinMetrics;

/// State of a single buffered batch — either held in memory or spilled to disk.
#[derive(Debug)]
enum BatchState {
/// The batch is available in memory.
InMemory(RecordBatch),
/// The batch has been spilled to an Arrow IPC file.
Spilled(RefCountedTempFile),
}

/// A single batch in a [`BufferedMatchGroup`].
///
/// Tracks the batch data (in-memory or spilled), pre-evaluated join key arrays,
/// row count, estimated memory size, and per-row match flags for outer joins.
#[derive(Debug)]
pub(super) struct BufferedBatch {
/// The batch data, either in memory or spilled to disk.
state: BatchState,
/// Pre-evaluated join key column arrays. `None` when the batch has been spilled.
#[allow(dead_code)]
join_arrays: Option<Vec<ArrayRef>>,
/// Number of rows in this batch (cached so we don't need the batch to know).
pub num_rows: usize,
/// Estimated memory footprint in bytes (batch + join arrays).
pub size_estimate: usize,
/// For full/right outer joins: tracks which rows have been matched.
matched: Option<Vec<bool>>,
}

impl BufferedBatch {
/// Mark a buffered row as matched (for full outer join tracking).
pub fn mark_matched(&mut self, row_idx: usize) {
if let Some(ref mut matched) = self.matched {
matched[row_idx] = true;
}
}

/// Iterate over unmatched row indices.
pub fn unmatched_indices(&self) -> impl Iterator<Item = usize> + '_ {
self.matched.as_ref().into_iter().flat_map(|m| {
m.iter()
.enumerate()
.filter(|(_, &matched)| !matched)
.map(|(idx, _)| idx)
})
}

/// Create a new in-memory buffered batch.
///
/// `full_outer` controls whether per-row match tracking is allocated.
fn new_in_memory(batch: RecordBatch, join_arrays: Vec<ArrayRef>, full_outer: bool) -> Self {
let num_rows = batch.num_rows();
let mut size_estimate = get_record_batch_memory_size(&batch);
for arr in &join_arrays {
size_estimate += arr.get_array_memory_size();
}
let matched = if full_outer {
Some(vec![false; num_rows])
} else {
None
};
Self {
state: BatchState::InMemory(batch),
join_arrays: Some(join_arrays),
num_rows,
size_estimate,
matched,
}
}

/// Return the batch. If it was spilled, read it back from disk via the spill manager.
pub fn get_batch(&self, spill_manager: &SpillManager) -> Result<RecordBatch> {
match &self.state {
BatchState::InMemory(batch) => Ok(batch.clone()),
BatchState::Spilled(file) => {
let reader = spill_manager.read_spill_as_stream(file.clone(), None)?;
let batches = tokio::task::block_in_place(|| {
let rt = tokio::runtime::Handle::current();
rt.block_on(async {
use futures::StreamExt;
let mut stream = reader;
let mut batches = Vec::new();
while let Some(batch) = stream.next().await {
batches.push(batch?);
}
Ok::<_, DataFusionError>(batches)
})
})?;
// A single batch was spilled per file, but concatenate just in case.
if batches.len() == 1 {
Ok(batches.into_iter().next().unwrap())
} else {
arrow::compute::concat_batches(&Arc::clone(spill_manager.schema()), &batches)
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))
}
}
}
}

/// Return join key arrays. If in memory, returns the cached arrays directly.
/// If spilled, deserializes the batch and re-evaluates the join expressions.
#[allow(dead_code)]
pub fn get_join_arrays(
&self,
spill_manager: &SpillManager,
join_exprs: &[PhysicalExprRef],
) -> Result<Vec<ArrayRef>> {
if let Some(ref arrays) = self.join_arrays {
return Ok(arrays.clone());
}
// Spilled — reload and re-evaluate
let batch = self.get_batch(spill_manager)?;
evaluate_join_keys(&batch, join_exprs)
}
}

/// A group of buffered batches that share the same join key values.
///
/// Batches may be held in memory or spilled to disk when the memory reservation
/// cannot accommodate them. When spilled, they can be loaded back on demand via
/// the spill manager.
#[derive(Debug)]
pub(super) struct BufferedMatchGroup {
/// All batches in this match group.
pub batches: Vec<BufferedBatch>,
/// Total number of rows across all batches.
pub num_rows: usize,
/// Total estimated memory usage of in-memory batches.
pub memory_size: usize,
}

impl BufferedMatchGroup {
/// Create a new empty match group.
pub fn new() -> Self {
Self {
batches: Vec::new(),
num_rows: 0,
memory_size: 0,
}
}

/// Add a batch to this match group.
///
/// First attempts to grow the memory reservation to hold the batch in memory.
/// If that fails, the batch is spilled to disk via the spill manager and the
/// spill metrics are updated accordingly.
pub fn add_batch(
&mut self,
batch: RecordBatch,
join_arrays: Vec<ArrayRef>,
full_outer: bool,
reservation: &mut MemoryReservation,
spill_manager: &SpillManager,
metrics: &SortMergeJoinMetrics,
) -> Result<()> {
let buffered = BufferedBatch::new_in_memory(batch.clone(), join_arrays, full_outer);
let size = buffered.size_estimate;
let num_rows = buffered.num_rows;

if reservation.try_grow(size).is_ok() {
// Fits in memory
self.memory_size += size;
self.num_rows += num_rows;
self.batches.push(buffered);
} else {
// Spill to disk
let spill_file = spill_manager
.spill_record_batch_and_finish(&[batch], "SortMergeJoin buffered batch")?;
match spill_file {
Some(file) => {
metrics.spill_count.add(1);
metrics.spilled_bytes.add(
std::fs::metadata(file.path())
.map(|m| m.len() as usize)
.unwrap_or(0),
);
metrics.spilled_rows.add(num_rows);
let matched = if full_outer {
Some(vec![false; num_rows])
} else {
None
};
self.num_rows += num_rows;
self.batches.push(BufferedBatch {
state: BatchState::Spilled(file),
join_arrays: None,
num_rows,
size_estimate: 0, // not consuming memory
matched,
});
}
None => {
// Empty batch, nothing to do
}
}
}
Ok(())
}

/// Clear all batches and release the memory reservation.
pub fn clear(&mut self, reservation: &mut MemoryReservation) {
self.batches.clear();
reservation.shrink(self.memory_size);
self.num_rows = 0;
self.memory_size = 0;
}

/// Get a batch by index. If the batch was spilled, it is read back from disk.
pub fn get_batch(&self, batch_idx: usize, spill_manager: &SpillManager) -> Result<RecordBatch> {
self.batches[batch_idx].get_batch(spill_manager)
}

/// Returns `true` if this group contains no batches.
pub fn is_empty(&self) -> bool {
self.batches.is_empty()
}
}

/// Evaluate join key physical expressions against a record batch and return the
/// resulting column arrays.
pub(super) fn evaluate_join_keys(
batch: &RecordBatch,
join_exprs: &[PhysicalExprRef],
) -> Result<Vec<ArrayRef>> {
join_exprs
.iter()
.map(|expr| {
expr.evaluate(batch)
.and_then(|cv| cv.into_array(batch.num_rows()))
})
.collect()
}
Loading
Loading