From 35449a5ea7df4c9eece7f4242db79ef4c020a193 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 06:46:15 -0600 Subject: [PATCH 01/12] feat: scaffold CometSortMergeJoinExec module structure --- native/core/src/execution/joins/mod.rs | 20 +++ .../src/execution/joins/sort_merge_join.rs | 154 ++++++++++++++++++ native/core/src/execution/mod.rs | 1 + 3 files changed, 175 insertions(+) create mode 100644 native/core/src/execution/joins/mod.rs create mode 100644 native/core/src/execution/joins/sort_merge_join.rs diff --git a/native/core/src/execution/joins/mod.rs b/native/core/src/execution/joins/mod.rs new file mode 100644 index 0000000000..b460f09ed6 --- /dev/null +++ b/native/core/src/execution/joins/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod sort_merge_join; + +pub(crate) use sort_merge_join::CometSortMergeJoinExec; diff --git a/native/core/src/execution/joins/sort_merge_join.rs b/native/core/src/execution/joins/sort_merge_join.rs new file mode 100644 index 0000000000..d215b0facd --- /dev/null +++ b/native/core/src/execution/joins/sort_merge_join.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::Formatter; +use std::sync::Arc; + +use arrow::compute::SortOptions; +use arrow::datatypes::SchemaRef; +use datafusion::common::Result; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::joins::utils::{build_join_schema, check_join_is_valid, JoinFilter}; +use datafusion::physical_plan::joins::JoinOn; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, +}; +use datafusion::common::NullEquality; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::JoinType; + +/// A Comet-specific sort merge join operator that replaces DataFusion's +/// `SortMergeJoinExec` with Spark-compatible semantics. +#[derive(Debug)] +pub(crate) struct CometSortMergeJoinExec { + left: Arc, + right: Arc, + join_on: JoinOn, + join_filter: Option, + join_type: JoinType, + sort_options: Vec, + null_equality: NullEquality, + schema: SchemaRef, + properties: PlanProperties, + metrics: ExecutionPlanMetricsSet, +} + +impl CometSortMergeJoinExec { + /// Create a new `CometSortMergeJoinExec`. + pub fn try_new( + left: Arc, + right: Arc, + join_on: JoinOn, + join_filter: Option, + join_type: JoinType, + sort_options: Vec, + null_equality: NullEquality, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + check_join_is_valid(&left_schema, &right_schema, &join_on)?; + + let (schema, _column_indices) = + build_join_schema(&left_schema, &right_schema, &join_type); + let schema = Arc::new(schema); + + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(left.properties().output_partitioning().partition_count()), + EmissionType::Incremental, + Boundedness::Bounded, + ); + + Ok(Self { + left, + right, + join_on, + join_filter, + join_type, + sort_options, + null_equality, + schema, + properties, + metrics: ExecutionPlanMetricsSet::default(), + }) + } +} + +impl DisplayAs for CometSortMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CometSortMergeJoinExec: join_type={:?}", self.join_type) + } + DisplayFormatType::TreeRender => unimplemented!(), + } + } +} + +impl ExecutionPlan for CometSortMergeJoinExec { + fn name(&self) -> &str { + "CometSortMergeJoinExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(CometSortMergeJoinExec::try_new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.join_on.clone(), + self.join_filter.clone(), + self.join_type, + self.sort_options.clone(), + self.null_equality, + )?)) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("SortMergeJoinStream not yet implemented") + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index f556fce41c..01aedfc62b 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -20,6 +20,7 @@ pub mod columnar_to_row; pub mod expressions; pub mod jni_api; pub(crate) mod metrics; +pub(crate) mod joins; pub mod operators; pub(crate) mod planner; pub mod serde; From 0d9fca1425a7030d50fc43aab72c2f440c64dac6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 06:49:38 -0600 Subject: [PATCH 02/12] feat: add SortMergeJoinMetrics matching CometMetricNode.scala --- native/core/src/execution/joins/metrics.rs | 56 ++++++++++++++++++++++ native/core/src/execution/joins/mod.rs | 1 + 2 files changed, 57 insertions(+) create mode 100644 native/core/src/execution/joins/metrics.rs diff --git a/native/core/src/execution/joins/metrics.rs b/native/core/src/execution/joins/metrics.rs new file mode 100644 index 0000000000..3248a7c949 --- /dev/null +++ b/native/core/src/execution/joins/metrics.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::physical_plan::metrics::{ + Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, Time, +}; + +/// Metrics for CometSortMergeJoinExec, matching CometMetricNode.scala definitions. +#[derive(Debug, Clone)] +pub(super) struct SortMergeJoinMetrics { + pub input_rows: Count, + pub input_batches: Count, + pub output_rows: Count, + pub output_batches: Count, + pub join_time: Time, + pub peak_mem_used: Gauge, + pub spill_count: Count, + pub spilled_bytes: Count, + pub spilled_rows: Count, +} + +impl SortMergeJoinMetrics { + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + input_rows: MetricBuilder::new(metrics).counter("input_rows", partition), + input_batches: MetricBuilder::new(metrics).counter("input_batches", partition), + output_rows: MetricBuilder::new(metrics).output_rows(partition), + output_batches: MetricBuilder::new(metrics).counter("output_batches", partition), + join_time: MetricBuilder::new(metrics).subset_time("join_time", partition), + peak_mem_used: MetricBuilder::new(metrics).gauge("peak_mem_used", partition), + spill_count: MetricBuilder::new(metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition), + spilled_rows: MetricBuilder::new(metrics).counter("spilled_rows", partition), + } + } + + pub fn update_peak_mem(&self, current_mem: usize) { + if current_mem > self.peak_mem_used.value() { + self.peak_mem_used.set(current_mem); + } + } +} diff --git a/native/core/src/execution/joins/mod.rs b/native/core/src/execution/joins/mod.rs index b460f09ed6..f0d41aa286 100644 --- a/native/core/src/execution/joins/mod.rs +++ b/native/core/src/execution/joins/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod metrics; mod sort_merge_join; pub(crate) use sort_merge_join::CometSortMergeJoinExec; From 2204b23bc14131c5be2e2c943bfd90a756d01f62 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 06:53:36 -0600 Subject: [PATCH 03/12] feat: add OutputBuilder for join batch materialization Add OutputBuilder that accumulates matched/null-joined index pairs during the sort merge join's Joining state and materializes them into Arrow RecordBatches. Includes BufferedMatchGroup stub with spill support and join filter evaluation module. --- .../src/execution/joins/buffered_batch.rs | 257 ++++++++++++ native/core/src/execution/joins/filter.rs | 232 +++++++++++ native/core/src/execution/joins/mod.rs | 3 + .../src/execution/joins/output_builder.rs | 394 ++++++++++++++++++ 4 files changed, 886 insertions(+) create mode 100644 native/core/src/execution/joins/buffered_batch.rs create mode 100644 native/core/src/execution/joins/filter.rs create mode 100644 native/core/src/execution/joins/output_builder.rs diff --git a/native/core/src/execution/joins/buffered_batch.rs b/native/core/src/execution/joins/buffered_batch.rs new file mode 100644 index 0000000000..fa8c7970c9 --- /dev/null +++ b/native/core/src/execution/joins/buffered_batch.rs @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Buffered batch management for the sort merge join operator. +//! +//! [`BufferedMatchGroup`] holds all rows from the buffered (right) side that +//! share the current join key. When memory is tight, individual batches are +//! spilled to Arrow IPC files on disk and reloaded on demand. + +use arrow::array::ArrayRef; +use arrow::record_batch::RecordBatch; +use datafusion::common::utils::memory::get_record_batch_memory_size; +use datafusion::common::{DataFusionError, Result}; +use datafusion::execution::disk_manager::RefCountedTempFile; +use datafusion::execution::memory_pool::MemoryReservation; +use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_plan::spill::SpillManager; + +use super::metrics::SortMergeJoinMetrics; + +/// State of a single buffered batch — either held in memory or spilled to disk. +#[derive(Debug)] +enum BatchState { + /// The batch is available in memory. + InMemory(RecordBatch), + /// The batch has been spilled to an Arrow IPC file. + Spilled(RefCountedTempFile), +} + +/// A single batch in a [`BufferedMatchGroup`]. +/// +/// Tracks the batch data (in-memory or spilled), pre-evaluated join key arrays, +/// row count, estimated memory size, and per-row match flags for outer joins. +#[derive(Debug)] +pub(super) struct BufferedBatch { + /// The batch data, either in memory or spilled to disk. + state: BatchState, + /// Pre-evaluated join key column arrays. `None` when the batch has been spilled. + join_arrays: Option>, + /// Number of rows in this batch (cached so we don't need the batch to know). + pub num_rows: usize, + /// Estimated memory footprint in bytes (batch + join arrays). + pub size_estimate: usize, + /// For full/right outer joins: tracks which rows have been matched. + /// `None` for inner/left joins where unmatched tracking is not needed. + pub matched: Option>, +} + +impl BufferedBatch { + /// Create a new in-memory buffered batch. + /// + /// `full_outer` controls whether per-row match tracking is allocated. + fn new_in_memory( + batch: RecordBatch, + join_arrays: Vec, + full_outer: bool, + ) -> Self { + let num_rows = batch.num_rows(); + let mut size_estimate = get_record_batch_memory_size(&batch); + for arr in &join_arrays { + size_estimate += arr.get_array_memory_size(); + } + let matched = if full_outer { + Some(vec![false; num_rows]) + } else { + None + }; + Self { + state: BatchState::InMemory(batch), + join_arrays: Some(join_arrays), + num_rows, + size_estimate, + matched, + } + } + + /// Return the batch. If it was spilled, read it back from disk via the spill manager. + pub fn get_batch(&self, spill_manager: &SpillManager) -> Result { + match &self.state { + BatchState::InMemory(batch) => Ok(batch.clone()), + BatchState::Spilled(file) => { + let reader = + spill_manager.read_spill_as_stream(file.clone(), None)?; + let rt = tokio::runtime::Handle::current(); + let batches = rt.block_on(async { + use futures::StreamExt; + let mut stream = reader; + let mut batches = Vec::new(); + while let Some(batch) = stream.next().await { + batches.push(batch?); + } + Ok::<_, DataFusionError>(batches) + })?; + // A single batch was spilled per file, but concatenate just in case. + if batches.len() == 1 { + Ok(batches.into_iter().next().unwrap()) + } else { + arrow::compute::concat_batches( + &spill_manager.schema().clone(), + &batches, + ) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + } + } + } + } + + /// Return join key arrays. If in memory, returns the cached arrays directly. + /// If spilled, deserializes the batch and re-evaluates the join expressions. + pub fn get_join_arrays( + &self, + spill_manager: &SpillManager, + join_exprs: &[PhysicalExprRef], + ) -> Result> { + if let Some(ref arrays) = self.join_arrays { + return Ok(arrays.clone()); + } + // Spilled — reload and re-evaluate + let batch = self.get_batch(spill_manager)?; + evaluate_join_keys(&batch, join_exprs) + } +} + +/// A group of buffered batches that share the same join key values. +/// +/// Batches may be held in memory or spilled to disk when the memory reservation +/// cannot accommodate them. When spilled, they can be loaded back on demand via +/// the spill manager. +#[derive(Debug)] +pub(super) struct BufferedMatchGroup { + /// All batches in this match group. + pub batches: Vec, + /// Total number of rows across all batches. + pub num_rows: usize, + /// Total estimated memory usage of in-memory batches. + pub memory_size: usize, +} + +impl BufferedMatchGroup { + /// Create a new empty match group. + pub fn new() -> Self { + Self { + batches: Vec::new(), + num_rows: 0, + memory_size: 0, + } + } + + /// Add a batch to this match group. + /// + /// First attempts to grow the memory reservation to hold the batch in memory. + /// If that fails, the batch is spilled to disk via the spill manager and the + /// spill metrics are updated accordingly. + pub fn add_batch( + &mut self, + batch: RecordBatch, + join_arrays: Vec, + full_outer: bool, + reservation: &mut MemoryReservation, + spill_manager: &SpillManager, + metrics: &SortMergeJoinMetrics, + ) -> Result<()> { + let buffered = BufferedBatch::new_in_memory(batch.clone(), join_arrays, full_outer); + let size = buffered.size_estimate; + let num_rows = buffered.num_rows; + + if reservation.try_grow(size).is_ok() { + // Fits in memory + self.memory_size += size; + self.num_rows += num_rows; + self.batches.push(buffered); + } else { + // Spill to disk + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "SortMergeJoin buffered batch")?; + match spill_file { + Some(file) => { + metrics.spill_count.add(1); + metrics.spilled_bytes.add( + std::fs::metadata(file.path()) + .map(|m| m.len() as usize) + .unwrap_or(0), + ); + metrics.spilled_rows.add(num_rows); + let matched = if full_outer { + Some(vec![false; num_rows]) + } else { + None + }; + self.num_rows += num_rows; + self.batches.push(BufferedBatch { + state: BatchState::Spilled(file), + join_arrays: None, + num_rows, + size_estimate: 0, // not consuming memory + matched, + }); + } + None => { + // Empty batch, nothing to do + } + } + } + Ok(()) + } + + /// Clear all batches and release the memory reservation. + pub fn clear(&mut self, reservation: &mut MemoryReservation) { + self.batches.clear(); + reservation.shrink(self.memory_size); + self.num_rows = 0; + self.memory_size = 0; + } + + /// Get a batch by index. If the batch was spilled, it is read back from disk. + pub fn get_batch( + &self, + batch_idx: usize, + spill_manager: &SpillManager, + ) -> Result { + self.batches[batch_idx].get_batch(spill_manager) + } + + /// Returns `true` if this group contains no batches. + pub fn is_empty(&self) -> bool { + self.batches.is_empty() + } +} + +/// Evaluate join key physical expressions against a record batch and return the +/// resulting column arrays. +pub(super) fn evaluate_join_keys( + batch: &RecordBatch, + join_exprs: &[PhysicalExprRef], +) -> Result> { + join_exprs + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|cv| cv.into_array(batch.num_rows())) + }) + .collect() +} diff --git a/native/core/src/execution/joins/filter.rs b/native/core/src/execution/joins/filter.rs new file mode 100644 index 0000000000..b90944e685 --- /dev/null +++ b/native/core/src/execution/joins/filter.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Join filter evaluation with corrected masks for outer, semi, and anti joins. +//! +//! In outer joins, if all candidate pairs for a streamed row fail the filter, +//! the streamed row must be null-joined (not dropped). Semi joins emit a +//! streamed row if ANY pair passes. Anti joins emit if NO pair passes. This +//! module groups filter results by streamed row to implement these semantics. + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, BooleanArray, RecordBatch, UInt32Array}; +use arrow::compute::take; +use datafusion::common::{internal_err, JoinSide, Result}; +use datafusion::logical_expr::JoinType; +use datafusion::physical_plan::joins::utils::JoinFilter; + +use super::output_builder::JoinIndex; + +/// Result of applying a join filter to a set of candidate pairs. +pub(super) struct FilteredOutput { + /// Pairs that passed the filter (or were selected for null-join in outer). + pub passed_indices: Vec, + /// Streamed row indices that had no passing pair and should be null-joined + /// (applies to outer and anti joins). + pub streamed_null_joins: Vec, + /// (batch_idx, buffered_idx) pairs that passed the filter, used for + /// tracking matched buffered rows in full outer joins. + pub buffered_matched: Vec<(usize, usize)>, +} + +/// Evaluate a join filter on candidate pairs and return corrected results +/// based on the join type. +/// +/// `pair_indices` contains the candidate pairs as `JoinIndex` values. +/// `candidate_batch` is the intermediate batch built for filter evaluation. +pub(super) fn apply_join_filter( + filter: &JoinFilter, + candidate_batch: &RecordBatch, + pair_indices: &[JoinIndex], + join_type: &JoinType, +) -> Result { + // Evaluate the filter expression on the candidate batch + let filter_result = filter + .expression() + .evaluate(candidate_batch)? + .into_array(candidate_batch.num_rows())?; + + let mask = filter_result + .as_any() + .downcast_ref::() + .expect("join filter expression must return BooleanArray"); + + match join_type { + JoinType::Inner => Ok(apply_inner_filter(mask, pair_indices)), + JoinType::Left | JoinType::Right => Ok(apply_outer_filter(mask, pair_indices)), + JoinType::Full => Ok(apply_full_outer_filter(mask, pair_indices)), + JoinType::LeftSemi | JoinType::RightSemi => Ok(apply_semi_filter(mask, pair_indices)), + JoinType::LeftAnti | JoinType::RightAnti => Ok(apply_anti_filter(mask, pair_indices)), + _ => Ok(apply_inner_filter(mask, pair_indices)), + } +} + +/// Build the intermediate batch used for filter evaluation. +/// +/// For each column in the filter's `column_indices`, we take the appropriate +/// rows from either the streamed or buffered batch using the provided index +/// arrays. +pub(super) fn build_filter_candidate_batch( + filter: &JoinFilter, + streamed_batch: &RecordBatch, + buffered_batch: &RecordBatch, + streamed_indices: &UInt32Array, + buffered_indices: &UInt32Array, +) -> Result { + let columns: Vec = filter + .column_indices() + .iter() + .map(|col_idx| { + let (batch, indices) = match col_idx.side { + JoinSide::Left => (streamed_batch, streamed_indices), + JoinSide::Right => (buffered_batch, buffered_indices), + }; + let column = batch.column(col_idx.index); + Ok(take(column.as_ref(), indices, None)?) + }) + .collect::>>()?; + + Ok(RecordBatch::try_new( + Arc::clone(filter.schema()), + columns, + )?) +} + +/// Returns true if the mask value at `i` is true and not null. +#[inline] +fn mask_passed(mask: &BooleanArray, i: usize) -> bool { + mask.value(i) && !mask.is_null(i) +} + +/// Inner join: keep rows where mask is true (and not null). +fn apply_inner_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + let passed_indices: Vec = indices + .iter() + .enumerate() + .filter(|(i, _)| mask_passed(mask, *i)) + .map(|(_, idx)| *idx) + .collect(); + + let buffered_matched = passed_indices + .iter() + .map(|idx| (idx.batch_idx, idx.buffered_idx)) + .collect(); + + FilteredOutput { + passed_indices, + streamed_null_joins: Vec::new(), + buffered_matched, + } +} + +/// Outer join (Left/Right): group by streamed_idx. If any pair passes for a +/// streamed row, keep those passing pairs. If none pass, add streamed row to +/// null-joins. +fn apply_outer_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + let mut groups: HashMap> = HashMap::new(); + for (i, idx) in indices.iter().enumerate() { + groups.entry(idx.streamed_idx).or_default().push((i, *idx)); + } + + let mut passed_indices = Vec::new(); + let mut streamed_null_joins = Vec::new(); + let mut buffered_matched = Vec::new(); + + for (streamed_idx, pairs) in &groups { + let passing: Vec = pairs + .iter() + .filter(|(i, _)| mask_passed(mask, *i)) + .map(|(_, idx)| *idx) + .collect(); + + if passing.is_empty() { + streamed_null_joins.push(*streamed_idx); + } else { + for idx in &passing { + buffered_matched.push((idx.batch_idx, idx.buffered_idx)); + } + passed_indices.extend(passing); + } + } + + FilteredOutput { + passed_indices, + streamed_null_joins, + buffered_matched, + } +} + +/// Full outer join: same grouping logic as outer, but buffered tracking is +/// done via the matched bitvector on BufferedBatch (caller uses +/// `buffered_matched`). +fn apply_full_outer_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + // Same logic as outer — the caller handles buffered-side null joins + // via the BufferedBatch matched bitvector. + apply_outer_filter(mask, indices) +} + +/// Semi join: group by streamed_idx. If any pair passes for a streamed row, +/// emit one JoinIndex for that row (the first passing pair). +fn apply_semi_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + let mut groups: HashMap> = HashMap::new(); + for (i, idx) in indices.iter().enumerate() { + groups.entry(idx.streamed_idx).or_default().push((i, *idx)); + } + + let mut passed_indices = Vec::new(); + let mut buffered_matched = Vec::new(); + + for (_streamed_idx, pairs) in &groups { + if let Some((_, idx)) = pairs.iter().find(|(i, _)| mask_passed(mask, *i)) { + passed_indices.push(*idx); + buffered_matched.push((idx.batch_idx, idx.buffered_idx)); + } + } + + FilteredOutput { + passed_indices, + streamed_null_joins: Vec::new(), + buffered_matched, + } +} + +/// Anti join: group by streamed_idx. If no pair passes for a streamed row, +/// add it to streamed_null_joins. +fn apply_anti_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutput { + let mut groups: HashMap> = HashMap::new(); + for (i, idx) in indices.iter().enumerate() { + groups.entry(idx.streamed_idx).or_default().push(i); + } + + let mut streamed_null_joins = Vec::new(); + + for (streamed_idx, mask_indices) in &groups { + let any_passed = mask_indices.iter().any(|i| mask_passed(mask, *i)); + + if !any_passed { + streamed_null_joins.push(*streamed_idx); + } + } + + FilteredOutput { + passed_indices: Vec::new(), + streamed_null_joins, + buffered_matched: Vec::new(), + } +} diff --git a/native/core/src/execution/joins/mod.rs b/native/core/src/execution/joins/mod.rs index f0d41aa286..be0f28dac2 100644 --- a/native/core/src/execution/joins/mod.rs +++ b/native/core/src/execution/joins/mod.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod buffered_batch; +mod filter; mod metrics; +mod output_builder; mod sort_merge_join; pub(crate) use sort_merge_join::CometSortMergeJoinExec; diff --git a/native/core/src/execution/joins/output_builder.rs b/native/core/src/execution/joins/output_builder.rs new file mode 100644 index 0000000000..5a857309dd --- /dev/null +++ b/native/core/src/execution/joins/output_builder.rs @@ -0,0 +1,394 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Output batch builder for the sort merge join operator. +//! +//! The [`OutputBuilder`] accumulates matched and null-joined index pairs during +//! the join's Joining state and materializes them into Arrow [`RecordBatch`]es +//! during the OutputReady state. + +use std::sync::Arc; + +use arrow::array::{new_null_array, ArrayRef, RecordBatch, UInt32Array}; +use arrow::compute::kernels::concat::concat; +use arrow::compute::kernels::take::take; +use arrow::datatypes::SchemaRef; +use datafusion::common::{JoinType, Result}; +use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_plan::spill::SpillManager; + +use super::buffered_batch::BufferedMatchGroup; + +/// An index pair representing a matched row from the streamed and buffered sides. +#[derive(Debug, Clone, Copy)] +pub(super) struct JoinIndex { + /// Row index in the current streamed batch. + pub streamed_idx: usize, + /// Index of the buffered batch within the match group. + pub batch_idx: usize, + /// Row index within the buffered batch. + pub buffered_idx: usize, +} + +/// Accumulates join output indices and materializes them into Arrow record batches. +/// +/// During the join process, matched pairs and null-joined rows are recorded as +/// index tuples. When enough indices have been accumulated (reaching the target +/// batch size), they are materialized into a `RecordBatch` by gathering columns +/// from the streamed and buffered sides using Arrow's `take` kernel. +pub(super) struct OutputBuilder { + /// Schema of the output record batch. + output_schema: SchemaRef, + /// Schema of the streamed (left) side. + streamed_schema: SchemaRef, + /// Schema of the buffered (right) side. + buffered_schema: SchemaRef, + /// The type of join being performed. + join_type: JoinType, + /// Target number of rows per output batch. + target_batch_size: usize, + /// Matched pairs: (streamed_idx, batch_idx, buffered_idx). + indices: Vec, + /// Streamed row indices that need a null buffered counterpart (left outer, left anti). + streamed_null_joins: Vec, + /// Buffered row indices that need a null streamed counterpart (full outer). + /// Each entry is (batch_idx, row_idx). + buffered_null_joins: Vec<(usize, usize)>, +} + +impl OutputBuilder { + /// Create a new `OutputBuilder`. + pub fn new( + output_schema: SchemaRef, + streamed_schema: SchemaRef, + buffered_schema: SchemaRef, + join_type: JoinType, + target_batch_size: usize, + ) -> Self { + Self { + output_schema, + streamed_schema, + buffered_schema, + join_type, + target_batch_size, + indices: Vec::new(), + streamed_null_joins: Vec::new(), + buffered_null_joins: Vec::new(), + } + } + + /// Record a matched pair between streamed and buffered rows. + pub fn add_match(&mut self, streamed_idx: usize, batch_idx: usize, buffered_idx: usize) { + self.indices.push(JoinIndex { + streamed_idx, + batch_idx, + buffered_idx, + }); + } + + /// Record a streamed row that needs a null buffered counterpart + /// (used for outer joins and anti joins). + pub fn add_streamed_null_join(&mut self, streamed_idx: usize) { + self.streamed_null_joins.push(streamed_idx); + } + + /// Record a buffered row that needs a null streamed counterpart + /// (used for full outer joins). + pub fn add_buffered_null_join(&mut self, batch_idx: usize, buffered_idx: usize) { + self.buffered_null_joins.push((batch_idx, buffered_idx)); + } + + /// Return the total number of pending output rows. + pub fn pending_count(&self) -> usize { + self.indices.len() + self.streamed_null_joins.len() + self.buffered_null_joins.len() + } + + /// Returns `true` if the pending row count has reached or exceeded the target batch size. + pub fn should_flush(&self) -> bool { + self.pending_count() >= self.target_batch_size + } + + /// Returns `true` if there are any pending output rows. + pub fn has_pending(&self) -> bool { + self.pending_count() > 0 + } + + /// Materialize the accumulated indices into a [`RecordBatch`]. + /// + /// For `LeftSemi` and `LeftAnti` joins, only streamed columns are included. + /// For all other join types, columns from both sides are concatenated. + /// + /// After building, all accumulated indices are cleared. + pub fn build( + &mut self, + streamed_batch: &RecordBatch, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager, + _buffered_join_exprs: &[PhysicalExprRef], + ) -> Result { + let result = match self.join_type { + JoinType::LeftSemi | JoinType::LeftAnti => self.build_semi_anti(streamed_batch), + _ => self.build_full(streamed_batch, match_group, spill_manager), + }; + + // Clear all accumulated indices after building + self.indices.clear(); + self.streamed_null_joins.clear(); + self.buffered_null_joins.clear(); + + result + } + + /// Build output for LeftSemi/LeftAnti joins (streamed columns only). + fn build_semi_anti(&self, streamed_batch: &RecordBatch) -> Result { + // For semi/anti joins, we only output streamed rows from matched pairs + // and streamed null joins. No buffered columns. + let indices: Vec = self + .indices + .iter() + .map(|idx| idx.streamed_idx as u32) + .chain(self.streamed_null_joins.iter().map(|&idx| idx as u32)) + .collect(); + + let indices_array = UInt32Array::from(indices); + + let columns: Vec = streamed_batch + .columns() + .iter() + .map(|col| take(col.as_ref(), &indices_array, None).map_err(Into::into)) + .collect::>()?; + + Ok(RecordBatch::try_new( + Arc::clone(&self.output_schema), + columns, + )?) + } + + /// Build output for all other join types (both streamed and buffered columns). + fn build_full( + &self, + streamed_batch: &RecordBatch, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager, + ) -> Result { + let streamed_columns = self.build_streamed_columns(streamed_batch)?; + let buffered_columns = + self.build_buffered_columns(match_group, spill_manager)?; + + let mut columns = streamed_columns; + columns.extend(buffered_columns); + + Ok(RecordBatch::try_new( + Arc::clone(&self.output_schema), + columns, + )?) + } + + /// Build the streamed side columns by gathering rows using the `take` kernel. + /// + /// The order is: matched pairs, streamed null joins, then None for buffered null joins. + fn build_streamed_columns(&self, streamed_batch: &RecordBatch) -> Result> { + let total_rows = self.pending_count(); + let num_buffered_nulls = self.buffered_null_joins.len(); + + // Build indices: matched streamed_idx, then streamed_null_join indices, then None + let indices: Vec> = self + .indices + .iter() + .map(|idx| Some(idx.streamed_idx as u32)) + .chain(self.streamed_null_joins.iter().map(|&idx| Some(idx as u32))) + .chain(std::iter::repeat_n(None, num_buffered_nulls)) + .collect(); + + debug_assert_eq!(indices.len(), total_rows); + + let indices_array = UInt32Array::from(indices); + + streamed_batch + .columns() + .iter() + .map(|col| take(col.as_ref(), &indices_array, None).map_err(Into::into)) + .collect() + } + + /// Build the buffered side columns by gathering rows from buffered batches. + /// + /// For matched pairs: take from buffered batches (loading spilled ones as needed). + /// For streamed null joins: null arrays. + /// For buffered null joins: take from buffered batches. + fn build_buffered_columns( + &self, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager, + ) -> Result> { + let num_buffered_cols = self.buffered_schema.fields().len(); + let num_streamed_nulls = self.streamed_null_joins.len(); + + // Build one column at a time + (0..num_buffered_cols) + .map(|col_idx| { + self.build_single_buffered_column( + col_idx, + match_group, + spill_manager, + num_streamed_nulls, + ) + }) + .collect() + } + + /// Build a single buffered column by concatenating parts from matched pairs, + /// null arrays for streamed null joins, and parts from buffered null joins. + fn build_single_buffered_column( + &self, + col_idx: usize, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager, + num_streamed_nulls: usize, + ) -> Result { + let data_type = self.buffered_schema.field(col_idx).data_type(); + let mut parts: Vec = Vec::new(); + + // Part 1: Matched pairs — gather from buffered batches + if !self.indices.is_empty() { + let matched_part = + self.take_buffered_matched(col_idx, match_group, spill_manager)?; + parts.push(matched_part); + } + + // Part 2: Streamed null joins — null arrays for buffered side + if num_streamed_nulls > 0 { + parts.push(new_null_array(data_type, num_streamed_nulls)); + } + + // Part 3: Buffered null joins — gather from buffered batches + if !self.buffered_null_joins.is_empty() { + let null_join_part = + self.take_buffered_null_joins(col_idx, match_group, spill_manager)?; + parts.push(null_join_part); + } + + if parts.is_empty() { + return Ok(new_null_array(data_type, 0)); + } + + if parts.len() == 1 { + return Ok(parts.into_iter().next().unwrap()); + } + + let part_refs: Vec<&dyn arrow::array::Array> = + parts.iter().map(|a| a.as_ref()).collect(); + Ok(concat(&part_refs)?) + } + + /// Take values from buffered batches for matched pairs. + fn take_buffered_matched( + &self, + col_idx: usize, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager, + ) -> Result { + // Group indices by batch_idx to minimize batch lookups + // For simplicity, we build per-index and concatenate, but a more + // optimized version would group by batch_idx. + + // Simple approach: gather one at a time using take with single-element indices + // A more efficient approach groups by batch_idx, but this is correct and clear. + let mut parts: Vec = Vec::new(); + + // Group consecutive indices by batch_idx for efficiency + let mut current_batch_idx: Option = None; + let mut current_row_indices: Vec = Vec::new(); + + for join_idx in &self.indices { + if current_batch_idx == Some(join_idx.batch_idx) { + current_row_indices.push(join_idx.buffered_idx as u32); + } else { + // Flush previous group + if let Some(batch_idx) = current_batch_idx { + let batch = + match_group.batches[batch_idx].get_batch(spill_manager)?; + let col = batch.column(col_idx); + let row_indices = UInt32Array::from(std::mem::take(&mut current_row_indices)); + parts.push(take(col.as_ref(), &row_indices, None)?); + } + current_batch_idx = Some(join_idx.batch_idx); + current_row_indices.push(join_idx.buffered_idx as u32); + } + } + + // Flush last group + if let Some(batch_idx) = current_batch_idx { + let batch = match_group.batches[batch_idx].get_batch(spill_manager)?; + let col = batch.column(col_idx); + let row_indices = UInt32Array::from(current_row_indices); + parts.push(take(col.as_ref(), &row_indices, None)?); + } + + if parts.len() == 1 { + return Ok(parts.into_iter().next().unwrap()); + } + + let part_refs: Vec<&dyn arrow::array::Array> = + parts.iter().map(|a| a.as_ref()).collect(); + Ok(concat(&part_refs)?) + } + + /// Take values from buffered batches for buffered null joins (full outer). + fn take_buffered_null_joins( + &self, + col_idx: usize, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager, + ) -> Result { + let mut parts: Vec = Vec::new(); + + // Group consecutive entries by batch_idx + let mut current_batch_idx: Option = None; + let mut current_row_indices: Vec = Vec::new(); + + for &(batch_idx, row_idx) in &self.buffered_null_joins { + if current_batch_idx == Some(batch_idx) { + current_row_indices.push(row_idx as u32); + } else { + if let Some(bi) = current_batch_idx { + let batch = match_group.batches[bi].get_batch(spill_manager)?; + let col = batch.column(col_idx); + let row_indices = UInt32Array::from(std::mem::take(&mut current_row_indices)); + parts.push(take(col.as_ref(), &row_indices, None)?); + } + current_batch_idx = Some(batch_idx); + current_row_indices.push(row_idx as u32); + } + } + + if let Some(bi) = current_batch_idx { + let batch = match_group.batches[bi].get_batch(spill_manager)?; + let col = batch.column(col_idx); + let row_indices = UInt32Array::from(current_row_indices); + parts.push(take(col.as_ref(), &row_indices, None)?); + } + + if parts.len() == 1 { + return Ok(parts.into_iter().next().unwrap()); + } + + let part_refs: Vec<&dyn arrow::array::Array> = + parts.iter().map(|a| a.as_ref()).collect(); + Ok(concat(&part_refs)?) + } +} From 6798bfd9e29444f5bc4cbca6ed86f5e2ec3f5e8f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 06:54:07 -0600 Subject: [PATCH 04/12] feat: add join filter evaluation with corrected masks for outer/semi/anti Handle JoinSide::None variant in build_filter_candidate_batch to fix compilation with DataFusion 52.4.0 which includes a None variant in the JoinSide enum. --- native/core/src/execution/joins/filter.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/native/core/src/execution/joins/filter.rs b/native/core/src/execution/joins/filter.rs index b90944e685..e2fdeb312f 100644 --- a/native/core/src/execution/joins/filter.rs +++ b/native/core/src/execution/joins/filter.rs @@ -96,6 +96,11 @@ pub(super) fn build_filter_candidate_batch( let (batch, indices) = match col_idx.side { JoinSide::Left => (streamed_batch, streamed_indices), JoinSide::Right => (buffered_batch, buffered_indices), + JoinSide::None => { + return internal_err!( + "unexpected JoinSide::None in join filter column index" + ); + } }; let column = batch.column(col_idx.index); Ok(take(column.as_ref(), indices, None)?) From e91b50f18ddddfe948182c22b79e48077e5a29cf Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 07:04:29 -0600 Subject: [PATCH 05/12] feat: implement SortMergeJoinStream state machine and wire up execute() Implement the streaming sort merge join state machine that drives two sorted input streams, compares join keys, and produces joined output batches. The state machine handles all join types (inner, left/right outer, full outer, semi, anti) with key-reuse optimization via RowConverter, multi-batch match group collection, null key handling, optional join filter evaluation, and memory-aware spilling. --- native/core/src/execution/joins/mod.rs | 1 + .../src/execution/joins/sort_merge_join.rs | 77 +- .../execution/joins/sort_merge_join_stream.rs | 1067 +++++++++++++++++ 3 files changed, 1137 insertions(+), 8 deletions(-) create mode 100644 native/core/src/execution/joins/sort_merge_join_stream.rs diff --git a/native/core/src/execution/joins/mod.rs b/native/core/src/execution/joins/mod.rs index be0f28dac2..ce8c0195dd 100644 --- a/native/core/src/execution/joins/mod.rs +++ b/native/core/src/execution/joins/mod.rs @@ -20,5 +20,6 @@ mod filter; mod metrics; mod output_builder; mod sort_merge_join; +mod sort_merge_join_stream; pub(crate) use sort_merge_join::CometSortMergeJoinExec; diff --git a/native/core/src/execution/joins/sort_merge_join.rs b/native/core/src/execution/joins/sort_merge_join.rs index d215b0facd..50a8530ad3 100644 --- a/native/core/src/execution/joins/sort_merge_join.rs +++ b/native/core/src/execution/joins/sort_merge_join.rs @@ -21,19 +21,23 @@ use std::sync::Arc; use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; -use datafusion::common::Result; +use datafusion::common::{NullEquality, Result}; +use datafusion::execution::memory_pool::MemoryConsumer; +use datafusion::execution::TaskContext; +use datafusion::logical_expr::JoinType; use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::joins::utils::{build_join_schema, check_join_is_valid, JoinFilter}; use datafusion::physical_plan::joins::JoinOn; -use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet, SpillMetrics}; +use datafusion::physical_plan::spill::SpillManager; use datafusion::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, }; -use datafusion::common::NullEquality; -use datafusion::execution::TaskContext; -use datafusion::logical_expr::JoinType; + +use super::metrics::SortMergeJoinMetrics; +use super::sort_merge_join_stream::SortMergeJoinStream; /// A Comet-specific sort merge join operator that replaces DataFusion's /// `SortMergeJoinExec` with Spark-compatible semantics. @@ -142,10 +146,67 @@ impl ExecutionPlan for CometSortMergeJoinExec { fn execute( &self, - _partition: usize, - _context: Arc, + partition: usize, + context: Arc, ) -> Result { - unimplemented!("SortMergeJoinStream not yet implemented") + // Determine streamed/buffered assignment based on join type. + // RightOuter: right is streamed, left is buffered. + // All others: left is streamed, right is buffered. + let (streamed_child, buffered_child, streamed_join_exprs, buffered_join_exprs) = + match self.join_type { + JoinType::Right => ( + Arc::clone(&self.right), + Arc::clone(&self.left), + self.join_on.iter().map(|(_, r)| Arc::clone(r)).collect::>(), + self.join_on.iter().map(|(l, _)| Arc::clone(l)).collect::>(), + ), + _ => ( + Arc::clone(&self.left), + Arc::clone(&self.right), + self.join_on.iter().map(|(l, _)| Arc::clone(l)).collect::>(), + self.join_on.iter().map(|(_, r)| Arc::clone(r)).collect::>(), + ), + }; + + let streamed_schema = streamed_child.schema(); + let buffered_schema = buffered_child.schema(); + + let streamed_input = streamed_child.execute(partition, Arc::clone(&context))?; + let buffered_input = buffered_child.execute(partition, Arc::clone(&context))?; + + // Create memory reservation. + let reservation = MemoryConsumer::new("CometSortMergeJoin") + .with_can_spill(true) + .register(context.memory_pool()); + + // Create spill manager. + let spill_metrics = SpillMetrics::new(&self.metrics, partition); + let spill_manager = SpillManager::new( + context.runtime_env(), + spill_metrics, + Arc::clone(&buffered_schema), + ); + + let metrics = SortMergeJoinMetrics::new(&self.metrics, partition); + let target_batch_size = context.session_config().batch_size(); + + Ok(Box::pin(SortMergeJoinStream::try_new( + Arc::clone(&self.schema), + streamed_schema, + buffered_schema, + self.join_type, + self.null_equality, + self.join_filter.clone(), + self.sort_options.clone(), + streamed_input, + buffered_input, + streamed_join_exprs, + buffered_join_exprs, + reservation, + spill_manager, + metrics, + target_batch_size, + )?)) } fn metrics(&self) -> Option { diff --git a/native/core/src/execution/joins/sort_merge_join_stream.rs b/native/core/src/execution/joins/sort_merge_join_stream.rs new file mode 100644 index 0000000000..f1a905d86f --- /dev/null +++ b/native/core/src/execution/joins/sort_merge_join_stream.rs @@ -0,0 +1,1067 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Streaming state machine for the sort merge join operator. +//! +//! The [`SortMergeJoinStream`] drives two sorted input streams (streamed and +//! buffered), compares join keys, collects matching buffered rows into a +//! [`BufferedMatchGroup`], and produces joined output batches via the +//! [`OutputBuilder`]. + +use std::cmp::Ordering; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::{ArrayRef, RecordBatch, UInt32Array}; +use arrow::compute::SortOptions; +use arrow::datatypes::SchemaRef; +use arrow::row::{OwnedRow, RowConverter, SortField}; +use datafusion::common::{NullEquality, Result}; +use datafusion::execution::memory_pool::MemoryReservation; +use datafusion::logical_expr::JoinType; +use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_plan::joins::utils::{compare_join_arrays, JoinFilter}; +use datafusion::physical_plan::spill::SpillManager; +use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; + +use futures::{Stream, StreamExt}; + +use super::buffered_batch::{evaluate_join_keys, BufferedMatchGroup}; +use super::filter::{apply_join_filter, build_filter_candidate_batch}; +use super::metrics::SortMergeJoinMetrics; +use super::output_builder::{JoinIndex, OutputBuilder}; + +/// States of the sort merge join state machine. +#[derive(Debug, PartialEq, Eq)] +enum JoinState { + /// Need to poll the next streamed row. + PollStreamed, + /// Need to poll the next buffered batch. + PollBuffered, + /// Initial state: decide what to poll next. + Init, + /// Compare the current streamed key with the current buffered key. + Comparing, + /// Collecting more buffered batches into the match group (key spans batches). + CollectingBuffered, + /// Produce join output for the current streamed row against the match group. + Joining, + /// Flush accumulated output. + OutputReady, + /// Drain unmatched rows after one side is exhausted. + DrainUnmatched, + /// No more output. + Exhausted, +} + +/// A streaming sort merge join that merges two sorted inputs by join keys. +pub(super) struct SortMergeJoinStream { + /// The type of join (Inner, Left, Right, Full, LeftSemi, LeftAnti). + join_type: JoinType, + /// How nulls compare during key matching. + null_equality: NullEquality, + /// Optional post-join filter. + join_filter: Option, + /// Sort options for each join key column. + sort_options: Vec, + + /// The streamed (driving) input. + streamed_input: SendableRecordBatchStream, + /// The buffered (probe) input. + buffered_input: SendableRecordBatchStream, + /// Expressions to evaluate join keys on the streamed side. + streamed_join_exprs: Vec, + /// Expressions to evaluate join keys on the buffered side. + buffered_join_exprs: Vec, + + /// Current streamed batch. + streamed_batch: Option, + /// Pre-evaluated join key arrays for the current streamed batch. + streamed_join_arrays: Option>, + /// Current row index within the streamed batch. + streamed_idx: usize, + /// Whether the streamed input is exhausted. + streamed_exhausted: bool, + + /// Pending buffered batch (batch + join arrays) not yet consumed. + buffered_pending: Option<(RecordBatch, Vec)>, + /// Whether the buffered input is exhausted. + buffered_exhausted: bool, + /// The current match group of buffered rows sharing the same join key. + match_group: BufferedMatchGroup, + + /// Converts join keys to comparable row format for key-reuse optimization. + row_converter: RowConverter, + /// Cached key of the previous streamed row (for key-reuse detection). + cached_streamed_key: Option, + + /// Accumulates output indices and builds result batches. + output_builder: OutputBuilder, + /// Schema of the output record batches. + output_schema: SchemaRef, + + /// Memory reservation for buffered data. + reservation: MemoryReservation, + /// Manages spilling buffered batches to disk when memory is tight. + spill_manager: SpillManager, + /// Metrics for this join operator. + metrics: SortMergeJoinMetrics, + + /// Current state of the state machine. + state: JoinState, +} + +impl SortMergeJoinStream { + /// Create a new `SortMergeJoinStream`. + #[allow(clippy::too_many_arguments)] + pub fn try_new( + output_schema: SchemaRef, + streamed_schema: SchemaRef, + buffered_schema: SchemaRef, + join_type: JoinType, + null_equality: NullEquality, + join_filter: Option, + sort_options: Vec, + streamed_input: SendableRecordBatchStream, + buffered_input: SendableRecordBatchStream, + streamed_join_exprs: Vec, + buffered_join_exprs: Vec, + reservation: MemoryReservation, + spill_manager: SpillManager, + metrics: SortMergeJoinMetrics, + target_batch_size: usize, + ) -> Result { + // Build SortFields from the streamed join key data types. + let sort_fields: Vec = streamed_join_exprs + .iter() + .zip(sort_options.iter()) + .map(|(expr, opts)| { + let dt = expr.data_type(&streamed_schema)?; + Ok(SortField::new_with_options(dt, *opts)) + }) + .collect::>()?; + + let row_converter = RowConverter::new(sort_fields) + .map_err(|e| datafusion::common::DataFusionError::ArrowError(Box::new(e), None))?; + + let output_builder = OutputBuilder::new( + Arc::clone(&output_schema), + streamed_schema, + buffered_schema, + join_type, + target_batch_size, + ); + + Ok(Self { + join_type, + null_equality, + join_filter, + sort_options, + streamed_input, + buffered_input, + streamed_join_exprs, + buffered_join_exprs, + streamed_batch: None, + streamed_join_arrays: None, + streamed_idx: 0, + streamed_exhausted: false, + buffered_pending: None, + buffered_exhausted: false, + match_group: BufferedMatchGroup::new(), + row_converter, + cached_streamed_key: None, + output_builder, + output_schema, + reservation, + spill_manager, + metrics, + state: JoinState::Init, + }) + } + + /// Drive the state machine, returning `Poll::Ready(Some(batch))` when a + /// batch is available, `Poll::Ready(None)` when done, or `Poll::Pending` + /// if waiting on input. + fn poll_next_inner( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match self.state { + JoinState::Init => { + // Decide what to poll based on current state. + if self.streamed_batch.is_none() && !self.streamed_exhausted { + self.state = JoinState::PollStreamed; + } else if self.buffered_pending.is_none() && !self.buffered_exhausted { + self.state = JoinState::PollBuffered; + } else if self.streamed_exhausted && self.buffered_exhausted { + self.state = JoinState::DrainUnmatched; + } else if self.streamed_exhausted { + self.state = JoinState::DrainUnmatched; + } else if self.buffered_exhausted { + // Streamed has data but buffered is done. + self.state = JoinState::Comparing; + } else { + self.state = JoinState::Comparing; + } + } + + JoinState::PollStreamed => { + match self.streamed_input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + // Skip empty batches. + continue; + } + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); + let join_arrays = + evaluate_join_keys(&batch, &self.streamed_join_exprs)?; + self.streamed_batch = Some(batch); + self.streamed_join_arrays = Some(join_arrays); + self.streamed_idx = 0; + // Now ensure we have buffered data too. + if self.buffered_pending.is_none() && !self.buffered_exhausted { + self.state = JoinState::PollBuffered; + } else { + self.state = JoinState::Comparing; + } + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.streamed_exhausted = true; + self.state = JoinState::DrainUnmatched; + } + Poll::Pending => return Poll::Pending, + } + } + + JoinState::PollBuffered => { + match self.buffered_input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + continue; + } + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); + let join_arrays = + evaluate_join_keys(&batch, &self.buffered_join_exprs)?; + self.buffered_pending = Some((batch, join_arrays)); + if self.streamed_batch.is_some() { + self.state = JoinState::Comparing; + } else if !self.streamed_exhausted { + self.state = JoinState::PollStreamed; + } else { + self.state = JoinState::DrainUnmatched; + } + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.buffered_exhausted = true; + if self.streamed_batch.is_some() { + self.state = JoinState::Comparing; + } else { + self.state = JoinState::DrainUnmatched; + } + } + Poll::Pending => return Poll::Pending, + } + } + + JoinState::Comparing => { + // We have a streamed row. Compare its key against the + // buffered key (first row of buffered_pending). + let streamed_idx = self.streamed_idx; + + // Check if the streamed key has nulls. + let streamed_has_null = self + .streamed_join_arrays + .as_ref() + .unwrap() + .iter() + .any(|a| a.is_null(streamed_idx)); + + // For inner/semi joins, skip null keys entirely. + if streamed_has_null + && self.null_equality == NullEquality::NullEqualsNothing + { + match self.join_type { + JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi => { + self.advance_streamed()?; + self.determine_next_state(); + continue; + } + JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftAnti + | JoinType::RightAnti => { + self.output_builder + .add_streamed_null_join(streamed_idx); + self.advance_streamed()?; + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + continue; + } + _ => { + self.advance_streamed()?; + self.determine_next_state(); + continue; + } + } + } + + // Check if the streamed key matches the cached key (reuse + // the existing match group). + if self.try_reuse_match_group()? { + self.state = JoinState::Joining; + continue; + } + + // Clear old match group. + self.match_group.clear(&mut self.reservation); + self.cached_streamed_key = None; + + if self.buffered_exhausted && self.buffered_pending.is_none() { + // No buffered data at all. Streamed row is unmatched. + self.emit_streamed_unmatched(streamed_idx); + self.advance_streamed()?; + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + continue; + } + + // Compare streamed key with buffered key. + let ordering = { + let streamed_arrays = + self.streamed_join_arrays.as_ref().unwrap(); + let (_buffered_batch, buffered_arrays) = + self.buffered_pending.as_ref().unwrap(); + compare_join_arrays( + streamed_arrays, + streamed_idx, + buffered_arrays, + 0, + &self.sort_options, + self.null_equality, + )? + }; + + match ordering { + Ordering::Less => { + // Streamed key < buffered key: streamed row has no match. + self.emit_streamed_unmatched(streamed_idx); + self.advance_streamed()?; + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + } + Ordering::Greater => { + // Streamed key > buffered key: advance buffered. + self.buffered_pending = None; + if self.buffered_exhausted { + self.emit_streamed_unmatched(streamed_idx); + self.advance_streamed()?; + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + } else { + self.state = JoinState::PollBuffered; + } + } + Ordering::Equal => { + // Keys match. Build the match group. + let needs_more = self.build_match_group()?; + self.cache_streamed_key()?; + if needs_more { + self.state = JoinState::CollectingBuffered; + } else { + self.state = JoinState::Joining; + } + } + } + } + + JoinState::CollectingBuffered => { + // We consumed an entire buffered batch into the match group + // and need to check if more buffered rows have the same key. + if self.buffered_exhausted { + // No more buffered data. Match group is complete. + self.state = JoinState::Joining; + continue; + } + match self.buffered_input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + continue; + } + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); + let join_arrays = + evaluate_join_keys(&batch, &self.buffered_join_exprs)?; + self.buffered_pending = Some((batch, join_arrays)); + let needs_more = self.build_match_group()?; + if needs_more { + // Still consuming; keep collecting. + continue; + } + self.state = JoinState::Joining; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.buffered_exhausted = true; + self.state = JoinState::Joining; + } + Poll::Pending => return Poll::Pending, + } + } + + JoinState::Joining => { + // Produce join pairs for the current streamed row against + // all rows in the match group. + let streamed_idx = self.streamed_idx; + self.produce_join_pairs(streamed_idx)?; + + self.advance_streamed()?; + + if self.output_builder.should_flush() { + self.state = JoinState::OutputReady; + } else { + self.determine_next_state(); + } + } + + JoinState::OutputReady => { + let batch = self.flush_output()?; + if batch.num_rows() > 0 { + self.metrics.output_rows.add(batch.num_rows()); + self.metrics.output_batches.add(1); + // After flushing, figure out what to do next. + self.determine_next_state(); + return Poll::Ready(Some(Ok(batch))); + } + self.determine_next_state(); + } + + JoinState::DrainUnmatched => { + // Drain remaining streamed rows as null-joined (for outer/anti). + self.drain_remaining()?; + + if self.output_builder.has_pending() { + let batch = self.flush_output()?; + self.state = JoinState::Exhausted; + if batch.num_rows() > 0 { + self.metrics.output_rows.add(batch.num_rows()); + self.metrics.output_batches.add(1); + return Poll::Ready(Some(Ok(batch))); + } + } + self.state = JoinState::Exhausted; + } + + JoinState::Exhausted => { + self.metrics + .update_peak_mem(self.reservation.size()); + return Poll::Ready(None); + } + } + } + } + + /// Determine the next state after flushing output. + fn determine_next_state(&mut self) { + if self.streamed_batch.is_some() + && self.streamed_idx < self.streamed_batch.as_ref().unwrap().num_rows() + { + // More rows in current streamed batch. + self.state = JoinState::Comparing; + } else if !self.streamed_exhausted { + self.state = JoinState::Init; + } else { + self.state = JoinState::DrainUnmatched; + } + } + + /// Advance the streamed side to the next row. If the current batch is + /// exhausted, clear it so we poll the next one. + fn advance_streamed(&mut self) -> Result<()> { + self.streamed_idx += 1; + if let Some(batch) = &self.streamed_batch { + if self.streamed_idx >= batch.num_rows() { + self.streamed_batch = None; + self.streamed_join_arrays = None; + self.streamed_idx = 0; + } + } + Ok(()) + } + + /// Try to reuse the existing match group if the current streamed key + /// matches the cached key. + fn try_reuse_match_group(&mut self) -> Result { + if self.cached_streamed_key.is_none() || self.match_group.is_empty() { + return Ok(false); + } + + let streamed_arrays = self.streamed_join_arrays.as_ref().unwrap(); + let rows = self + .row_converter + .convert_columns(streamed_arrays) + .map_err(|e| datafusion::common::DataFusionError::ArrowError(Box::new(e), None))?; + let current_key = rows.row(self.streamed_idx); + + if let Some(ref cached) = self.cached_streamed_key { + if current_key == cached.row() { + return Ok(true); + } + } + + Ok(false) + } + + /// Cache the current streamed key as an OwnedRow. + fn cache_streamed_key(&mut self) -> Result<()> { + let streamed_arrays = self.streamed_join_arrays.as_ref().unwrap(); + let rows = self + .row_converter + .convert_columns(streamed_arrays) + .map_err(|e| datafusion::common::DataFusionError::ArrowError(Box::new(e), None))?; + self.cached_streamed_key = Some(rows.row(self.streamed_idx).owned()); + Ok(()) + } + + /// Build a match group by collecting all buffered rows with the same key + /// as the current streamed row. + /// + /// Returns `true` if the entire buffered batch was consumed and more data + /// may need to be polled to complete the match group. + fn build_match_group(&mut self) -> Result { + let streamed_arrays = self.streamed_join_arrays.as_ref().unwrap(); + let streamed_idx = self.streamed_idx; + let full_outer = matches!(self.join_type, JoinType::Full); + + // Take the pending buffered batch. + let (batch, arrays) = self.buffered_pending.take().unwrap(); + + // Find how many rows from this batch have the same key. + let boundary = find_key_boundary( + streamed_arrays, + streamed_idx, + &arrays, + &self.sort_options, + self.null_equality, + )?; + + let needs_more = boundary == batch.num_rows(); + + if needs_more { + // Entire batch matches. Add it to the group. + self.match_group.add_batch( + batch, + arrays, + full_outer, + &mut self.reservation, + &self.spill_manager, + &self.metrics, + )?; + // buffered_pending remains None; caller should poll more. + } else { + // Split the batch: rows [0..boundary) match, [boundary..) don't. + if boundary > 0 { + let matching = batch.slice(0, boundary); + let matching_arrays: Vec = + arrays.iter().map(|a| a.slice(0, boundary)).collect(); + self.match_group.add_batch( + matching, + matching_arrays, + full_outer, + &mut self.reservation, + &self.spill_manager, + &self.metrics, + )?; + } + + // Keep the remaining rows as the new pending batch. + let remaining = batch.slice(boundary, batch.num_rows() - boundary); + let remaining_arrays: Vec = arrays + .iter() + .map(|a| a.slice(boundary, a.len() - boundary)) + .collect(); + self.buffered_pending = Some((remaining, remaining_arrays)); + } + + self.metrics.update_peak_mem(self.reservation.size()); + Ok(needs_more) + } + + /// Emit a streamed row as unmatched (null-joined) for outer/anti joins, + /// or skip it for inner/semi joins. + fn emit_streamed_unmatched(&mut self, streamed_idx: usize) { + match self.join_type { + JoinType::Left | JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => { + self.output_builder.add_streamed_null_join(streamed_idx); + } + // For Right outer: the streamed side is actually the right side, + // so unmatched streamed rows need null-joining. + JoinType::Right => { + self.output_builder.add_streamed_null_join(streamed_idx); + } + // Inner and semi joins: unmatched rows are dropped. + _ => {} + } + } + + /// Produce join pairs for the current streamed row against all rows in + /// the match group, applying the join filter if present. + fn produce_join_pairs(&mut self, streamed_idx: usize) -> Result<()> { + if self.match_group.is_empty() { + return Ok(()); + } + + match self.join_type { + JoinType::LeftSemi | JoinType::RightSemi => { + self.produce_semi_pairs(streamed_idx)?; + } + JoinType::LeftAnti | JoinType::RightAnti => { + self.produce_anti_pairs(streamed_idx)?; + } + _ => { + self.produce_standard_pairs(streamed_idx)?; + } + } + Ok(()) + } + + /// Produce pairs for inner/outer joins. + fn produce_standard_pairs(&mut self, streamed_idx: usize) -> Result<()> { + if let Some(ref filter) = self.join_filter { + // Build candidate pairs and apply filter. + let mut pair_indices = Vec::new(); + for (batch_idx, buffered_batch) in self.match_group.batches.iter().enumerate() { + for row_idx in 0..buffered_batch.num_rows { + pair_indices.push(JoinIndex { + streamed_idx, + batch_idx, + buffered_idx: row_idx, + }); + } + } + + if pair_indices.is_empty() { + self.emit_streamed_unmatched(streamed_idx); + return Ok(()); + } + + // Build candidate batch for filter evaluation. + let streamed_batch = self.streamed_batch.as_ref().unwrap(); + let candidate_batch = self.build_filter_batch( + filter, + streamed_batch, + &pair_indices, + )?; + + let filtered = apply_join_filter( + filter, + &candidate_batch, + &pair_indices, + &self.join_type, + )?; + + // Apply filtered results. + for idx in &filtered.passed_indices { + self.output_builder + .add_match(idx.streamed_idx, idx.batch_idx, idx.buffered_idx); + } + for &si in &filtered.streamed_null_joins { + self.output_builder.add_streamed_null_join(si); + } + // Mark matched buffered rows for full outer. + for &(batch_idx, buffered_idx) in &filtered.buffered_matched { + if let Some(ref mut matched) = self.match_group.batches[batch_idx].matched { + matched[buffered_idx] = true; + } + } + } else { + // No filter: all pairs match. + for (batch_idx, buffered_batch) in self.match_group.batches.iter_mut().enumerate() { + for row_idx in 0..buffered_batch.num_rows { + self.output_builder + .add_match(streamed_idx, batch_idx, row_idx); + if let Some(ref mut matched) = buffered_batch.matched { + matched[row_idx] = true; + } + } + } + } + Ok(()) + } + + /// Produce pairs for semi joins: emit the streamed row if any match passes. + fn produce_semi_pairs(&mut self, streamed_idx: usize) -> Result<()> { + if let Some(ref filter) = self.join_filter { + let mut pair_indices = Vec::new(); + for (batch_idx, buffered_batch) in self.match_group.batches.iter().enumerate() { + for row_idx in 0..buffered_batch.num_rows { + pair_indices.push(JoinIndex { + streamed_idx, + batch_idx, + buffered_idx: row_idx, + }); + } + } + + if pair_indices.is_empty() { + return Ok(()); + } + + let streamed_batch = self.streamed_batch.as_ref().unwrap(); + let candidate_batch = + self.build_filter_batch(filter, streamed_batch, &pair_indices)?; + + let filtered = apply_join_filter( + filter, + &candidate_batch, + &pair_indices, + &self.join_type, + )?; + + // Semi: emit the streamed row if any pair passed. + if !filtered.passed_indices.is_empty() { + let idx = &filtered.passed_indices[0]; + self.output_builder + .add_match(idx.streamed_idx, idx.batch_idx, idx.buffered_idx); + } + } else { + // No filter: key match is sufficient for semi join. + if !self.match_group.is_empty() { + self.output_builder.add_match(streamed_idx, 0, 0); + } + } + Ok(()) + } + + /// Produce pairs for anti joins: emit the streamed row if no match passes. + fn produce_anti_pairs(&mut self, streamed_idx: usize) -> Result<()> { + if let Some(ref filter) = self.join_filter { + let mut pair_indices = Vec::new(); + for (batch_idx, buffered_batch) in self.match_group.batches.iter().enumerate() { + for row_idx in 0..buffered_batch.num_rows { + pair_indices.push(JoinIndex { + streamed_idx, + batch_idx, + buffered_idx: row_idx, + }); + } + } + + if pair_indices.is_empty() { + // No buffered matches at all => emit for anti. + self.output_builder.add_streamed_null_join(streamed_idx); + return Ok(()); + } + + let streamed_batch = self.streamed_batch.as_ref().unwrap(); + let candidate_batch = + self.build_filter_batch(filter, streamed_batch, &pair_indices)?; + + let filtered = apply_join_filter( + filter, + &candidate_batch, + &pair_indices, + &self.join_type, + )?; + + // Anti: emit streamed rows that had no passing pair. + for &si in &filtered.streamed_null_joins { + self.output_builder.add_streamed_null_join(si); + } + } else { + // No filter: key match means the streamed row is NOT emitted (anti). + // Do nothing. + } + Ok(()) + } + + /// Build a filter candidate batch for the given pairs. + fn build_filter_batch( + &self, + filter: &JoinFilter, + streamed_batch: &RecordBatch, + pair_indices: &[JoinIndex], + ) -> Result { + // We need to combine rows from potentially multiple buffered batches + // into a single batch for filter evaluation. + // First, build streamed and buffered index arrays. + let streamed_indices: Vec = pair_indices + .iter() + .map(|idx| idx.streamed_idx as u32) + .collect(); + let streamed_idx_array = UInt32Array::from(streamed_indices); + + // For the buffered side, we need to build a single batch containing + // all referenced rows. + let buffered_batch = self.collect_buffered_rows(pair_indices)?; + let buffered_indices: Vec = (0..pair_indices.len() as u32).collect(); + let buffered_idx_array = UInt32Array::from(buffered_indices); + + build_filter_candidate_batch( + filter, + streamed_batch, + &buffered_batch, + &streamed_idx_array, + &buffered_idx_array, + ) + } + + /// Collect all referenced buffered rows into a single batch. + fn collect_buffered_rows( + &self, + pair_indices: &[JoinIndex], + ) -> Result { + if pair_indices.is_empty() { + // Return an empty batch with the correct schema. + let schema = self.spill_manager.schema().clone(); + return Ok(RecordBatch::new_empty(schema)); + } + + // Group indices by batch_idx, then take rows and concatenate. + let schema = self.spill_manager.schema().clone(); + let num_cols = schema.fields().len(); + let mut result_columns: Vec> = vec![Vec::new(); num_cols]; + + let mut current_batch_idx: Option = None; + let mut current_row_indices: Vec = Vec::new(); + + let flush = |batch_idx: usize, + row_indices: &[u32], + result_columns: &mut Vec>, + match_group: &BufferedMatchGroup, + spill_manager: &SpillManager| + -> Result<()> { + let batch = match_group.get_batch(batch_idx, spill_manager)?; + let idx_array = UInt32Array::from(row_indices.to_vec()); + for (col_idx, col_parts) in result_columns.iter_mut().enumerate() { + let col = batch.column(col_idx); + let taken = arrow::compute::take(col.as_ref(), &idx_array, None)?; + col_parts.push(taken); + } + Ok(()) + }; + + for idx in pair_indices { + if current_batch_idx == Some(idx.batch_idx) { + current_row_indices.push(idx.buffered_idx as u32); + } else { + if let Some(bi) = current_batch_idx { + flush( + bi, + ¤t_row_indices, + &mut result_columns, + &self.match_group, + &self.spill_manager, + )?; + current_row_indices.clear(); + } + current_batch_idx = Some(idx.batch_idx); + current_row_indices.push(idx.buffered_idx as u32); + } + } + if let Some(bi) = current_batch_idx { + flush( + bi, + ¤t_row_indices, + &mut result_columns, + &self.match_group, + &self.spill_manager, + )?; + } + + // Concatenate column parts. + let columns: Vec = result_columns + .into_iter() + .map(|parts| { + if parts.len() == 1 { + Ok(parts.into_iter().next().unwrap()) + } else { + let refs: Vec<&dyn arrow::array::Array> = + parts.iter().map(|a| a.as_ref()).collect(); + Ok(arrow::compute::concat(&refs)?) + } + }) + .collect::>()?; + + Ok(RecordBatch::try_new(schema, columns)?) + } + + /// Flush accumulated output indices into a RecordBatch. + fn flush_output(&mut self) -> Result { + let streamed_batch = match &self.streamed_batch { + Some(b) => b.clone(), + None => { + // If the streamed batch has been consumed, we might still + // have pending output from DrainUnmatched (buffered null joins). + // Create an empty streamed batch. + let schema = self.output_builder_streamed_schema(); + RecordBatch::new_empty(schema) + } + }; + + self.output_builder.build( + &streamed_batch, + &self.match_group, + &self.spill_manager, + &self.buffered_join_exprs, + ) + } + + /// Get the streamed schema from the output builder (needed for empty batch creation). + fn output_builder_streamed_schema(&self) -> SchemaRef { + // We can derive this from the output schema and join type, + // but for simplicity use the streamed input's schema. + self.streamed_input.schema() + } + + /// Drain remaining rows after one side is exhausted. + fn drain_remaining(&mut self) -> Result<()> { + match self.join_type { + JoinType::Left | JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => { + // Drain remaining streamed rows as null-joined. + if let Some(batch) = &self.streamed_batch { + let num_rows = batch.num_rows(); + for idx in self.streamed_idx..num_rows { + self.output_builder.add_streamed_null_join(idx); + } + } + } + JoinType::Right => { + // Right outer: streamed side is right, so drain remaining. + if let Some(batch) = &self.streamed_batch { + let num_rows = batch.num_rows(); + for idx in self.streamed_idx..num_rows { + self.output_builder.add_streamed_null_join(idx); + } + } + } + _ => {} + } + + // For full outer: drain unmatched buffered rows. + if matches!(self.join_type, JoinType::Full) { + for (batch_idx, buffered_batch) in self.match_group.batches.iter().enumerate() { + if let Some(ref matched) = buffered_batch.matched { + for (row_idx, &was_matched) in matched.iter().enumerate() { + if !was_matched { + self.output_builder + .add_buffered_null_join(batch_idx, row_idx); + } + } + } + } + } + + // Clear streamed state. + self.streamed_batch = None; + self.streamed_join_arrays = None; + self.streamed_idx = 0; + + Ok(()) + } +} + +/// Find the boundary index in a buffered batch where the key changes relative +/// to the streamed key. Returns the number of rows from the start that have +/// the same key as the streamed row at `streamed_idx`. +/// +/// Uses binary search since the buffered side is sorted. +fn find_key_boundary( + streamed_arrays: &[ArrayRef], + streamed_idx: usize, + buffered_arrays: &[ArrayRef], + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + let num_rows = buffered_arrays[0].len(); + if num_rows == 0 { + return Ok(0); + } + + // Quick check: if the last row also matches, the entire batch is in the group. + let last_cmp = compare_join_arrays( + streamed_arrays, + streamed_idx, + buffered_arrays, + num_rows - 1, + sort_options, + null_equality, + )?; + if last_cmp == Ordering::Equal { + return Ok(num_rows); + } + + // Binary search for the boundary. + let mut lo = 0usize; + let mut hi = num_rows; + while lo < hi { + let mid = lo + (hi - lo) / 2; + let cmp = compare_join_arrays( + streamed_arrays, + streamed_idx, + buffered_arrays, + mid, + sort_options, + null_equality, + )?; + if cmp == Ordering::Equal { + lo = mid + 1; + } else { + hi = mid; + } + } + Ok(lo) +} + +impl Stream for SortMergeJoinStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let join_time = self.metrics.join_time.clone(); + let timer = join_time.timer(); + let result = self.poll_next_inner(cx); + timer.done(); + result + } +} + +impl RecordBatchStream for SortMergeJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.output_schema) + } +} From 175e54f7df0e1d930b15e67310c816fcd17d6eba Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 07:05:43 -0600 Subject: [PATCH 06/12] fix: resolve clippy warnings in sort_merge_join_stream Simplify Init state logic to avoid identical if-branches and use Arc::clone instead of .clone() on ref-counted pointers. --- .../core/src/execution/joins/sort_merge_join_stream.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/native/core/src/execution/joins/sort_merge_join_stream.rs b/native/core/src/execution/joins/sort_merge_join_stream.rs index f1a905d86f..bb95d8d8cf 100644 --- a/native/core/src/execution/joins/sort_merge_join_stream.rs +++ b/native/core/src/execution/joins/sort_merge_join_stream.rs @@ -209,14 +209,10 @@ impl SortMergeJoinStream { self.state = JoinState::PollStreamed; } else if self.buffered_pending.is_none() && !self.buffered_exhausted { self.state = JoinState::PollBuffered; - } else if self.streamed_exhausted && self.buffered_exhausted { - self.state = JoinState::DrainUnmatched; } else if self.streamed_exhausted { self.state = JoinState::DrainUnmatched; - } else if self.buffered_exhausted { - // Streamed has data but buffered is done. - self.state = JoinState::Comparing; } else { + // Have streamed data; compare regardless of buffered state. self.state = JoinState::Comparing; } } @@ -845,12 +841,12 @@ impl SortMergeJoinStream { ) -> Result { if pair_indices.is_empty() { // Return an empty batch with the correct schema. - let schema = self.spill_manager.schema().clone(); + let schema = Arc::clone(self.spill_manager.schema()); return Ok(RecordBatch::new_empty(schema)); } // Group indices by batch_idx, then take rows and concatenate. - let schema = self.spill_manager.schema().clone(); + let schema = Arc::clone(self.spill_manager.schema()); let num_cols = schema.fields().len(); let mut result_columns: Vec> = vec![Vec::new(); num_cols]; From 679a77d618c672c43002141685132f5eb1e86408 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 07:15:27 -0600 Subject: [PATCH 07/12] feat: add useNative config toggle and planner integration for Comet SMJ Adds spark.comet.exec.sortMergeJoin.useNative config (default true) to switch between Comet's native sort merge join and DataFusion's SortMergeJoinExec. Passes spark config through JNI to the native planner for runtime selection. --- .../scala/org/apache/comet/CometConf.scala | 9 ++ native/core/src/execution/jni_api.rs | 6 +- native/core/src/execution/planner.rs | 113 ++++++++++++------ native/core/src/execution/spark_config.rs | 1 + 4 files changed, 91 insertions(+), 38 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 046ccf0b1c..3e58707748 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -280,6 +280,15 @@ object CometConf extends ShimCometConf { createExecEnabledConfig("hashJoin", defaultValue = true) val COMET_EXEC_SORT_MERGE_JOIN_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("sortMergeJoin", defaultValue = true) + val COMET_EXEC_SMJ_USE_NATIVE: ConfigEntry[Boolean] = + conf("spark.comet.exec.sortMergeJoin.useNative") + .category(CATEGORY_EXEC) + .doc( + "When true, use Comet's native sort merge join implementation. " + + "When false, use DataFusion's SortMergeJoinExec. " + + "This is useful for benchmarking the two implementations.") + .booleanConf + .createWithDefault(true) val COMET_EXEC_AGGREGATE_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("aggregate", defaultValue = true) val COMET_EXEC_COLLECT_LIMIT_ENABLED: ConfigEntry[Boolean] = diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index e0a395ebbf..0e02099aed 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -179,6 +179,8 @@ struct ExecutionContext { pub memory_pool_config: MemoryPoolConfig, /// Whether to log memory usage on each call to execute_plan pub tracing_enabled: bool, + /// Spark configuration map passed from JVM + pub spark_config: HashMap, } /// Accept serialized query plan and return the address of the native query plan. @@ -327,6 +329,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( explain_native, memory_pool_config, tracing_enabled, + spark_config, }); Ok(Box::into_raw(exec_context) as i64) @@ -545,7 +548,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let start = Instant::now(); let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) - .with_exec_id(exec_context_id); + .with_exec_id(exec_context_id) + .with_spark_config(exec_context.spark_config.clone()); let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0f96c829e7..74524506d1 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -21,8 +21,10 @@ pub mod expression_registry; pub mod macros; pub mod operator_registry; +use crate::execution::joins::CometSortMergeJoinExec; use crate::execution::operators::init_csv_datasource_exec; use crate::execution::operators::IcebergScanExec; +use crate::execution::spark_config::{SparkConfig, COMET_EXEC_SMJ_USE_NATIVE}; use crate::execution::{ expressions::subquery::Subquery, operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec}, @@ -163,6 +165,7 @@ pub struct PhysicalPlanner { partition: i32, session_ctx: Arc, query_context_registry: Arc, + spark_config: HashMap, } impl Default for PhysicalPlanner { @@ -178,6 +181,7 @@ impl PhysicalPlanner { session_ctx, partition, query_context_registry: datafusion_comet_spark_expr::create_query_context_map(), + spark_config: HashMap::new(), } } @@ -187,6 +191,14 @@ impl PhysicalPlanner { partition: self.partition, session_ctx: Arc::clone(&self.session_ctx), query_context_registry: Arc::clone(&self.query_context_registry), + spark_config: self.spark_config, + } + } + + pub fn with_spark_config(self, spark_config: HashMap) -> Self { + Self { + spark_config, + ..self } } @@ -1625,43 +1637,19 @@ impl PhysicalPlanner { let left = Arc::clone(&join_params.left.native_plan); let right = Arc::clone(&join_params.right.native_plan); - let join = Arc::new(SortMergeJoinExec::try_new( - Arc::clone(&left), - Arc::clone(&right), - join_params.join_on, - join_params.join_filter, - join_params.join_type, - sort_options, - // null doesn't equal to null in Spark join key. If the join key is - // `EqualNullSafe`, Spark will rewrite it during planning. - NullEquality::NullEqualsNothing, - )?); - - if join.filter.is_some() { - // SMJ with join filter produces lots of tiny batches - let coalesce_batches: Arc = - Arc::new(CoalesceBatchesExec::new( - Arc::::clone(&join), - self.session_ctx - .state() - .config_options() - .execution - .batch_size, - )); - Ok(( - scans, - shuffle_scans, - Arc::new(SparkPlan::new_with_additional( - spark_plan.plan_id, - coalesce_batches, - vec![ - Arc::clone(&join_params.left), - Arc::clone(&join_params.right), - ], - vec![join], - )), - )) - } else { + let use_native_smj = self.spark_config.get_bool(COMET_EXEC_SMJ_USE_NATIVE); + + if use_native_smj { + let join: Arc = + Arc::new(CometSortMergeJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + join_params.join_on, + join_params.join_filter, + join_params.join_type, + sort_options, + NullEquality::NullEqualsNothing, + )?); Ok(( scans, shuffle_scans, @@ -1674,6 +1662,57 @@ impl PhysicalPlanner { ], )), )) + } else { + let join = Arc::new(SortMergeJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + join_params.join_on, + join_params.join_filter, + join_params.join_type, + sort_options, + // null doesn't equal to null in Spark join key. If the join key is + // `EqualNullSafe`, Spark will rewrite it during planning. + NullEquality::NullEqualsNothing, + )?); + + if join.filter.is_some() { + // SMJ with join filter produces lots of tiny batches + let coalesce_batches: Arc = + Arc::new(CoalesceBatchesExec::new( + Arc::::clone(&join), + self.session_ctx + .state() + .config_options() + .execution + .batch_size, + )); + Ok(( + scans, + shuffle_scans, + Arc::new(SparkPlan::new_with_additional( + spark_plan.plan_id, + coalesce_batches, + vec![ + Arc::clone(&join_params.left), + Arc::clone(&join_params.right), + ], + vec![join], + )), + )) + } else { + Ok(( + scans, + shuffle_scans, + Arc::new(SparkPlan::new( + spark_plan.plan_id, + join, + vec![ + Arc::clone(&join_params.left), + Arc::clone(&join_params.right), + ], + )), + )) + } } } OpStruct::HashJoin(join) => { diff --git a/native/core/src/execution/spark_config.rs b/native/core/src/execution/spark_config.rs index 277c0eb43b..36f2fd417f 100644 --- a/native/core/src/execution/spark_config.rs +++ b/native/core/src/execution/spark_config.rs @@ -22,6 +22,7 @@ pub(crate) const COMET_DEBUG_ENABLED: &str = "spark.comet.debug.enabled"; pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.native.enabled"; pub(crate) const COMET_MAX_TEMP_DIRECTORY_SIZE: &str = "spark.comet.maxTempDirectorySize"; pub(crate) const COMET_DEBUG_MEMORY: &str = "spark.comet.debug.memory"; +pub(crate) const COMET_EXEC_SMJ_USE_NATIVE: &str = "spark.comet.exec.sortMergeJoin.useNative"; pub(crate) const SPARK_EXECUTOR_CORES: &str = "spark.executor.cores"; pub(crate) trait SparkConfig { From c4632bfc4001c17749d671ef88a7d721104b4419 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 07:21:40 -0600 Subject: [PATCH 08/12] test: add inner join unit tests for CometSortMergeJoinExec Fix two bugs discovered by the tests: - Output builder indices referenced stale streamed batch after it was cleared; defer batch clearing until after flush - Advancing buffered side discarded entire batch instead of advancing past just the first row; use slice to preserve remaining rows --- native/core/src/execution/joins/mod.rs | 3 + .../execution/joins/sort_merge_join_stream.rs | 73 +++++-- native/core/src/execution/joins/tests.rs | 181 ++++++++++++++++++ 3 files changed, 238 insertions(+), 19 deletions(-) create mode 100644 native/core/src/execution/joins/tests.rs diff --git a/native/core/src/execution/joins/mod.rs b/native/core/src/execution/joins/mod.rs index ce8c0195dd..7dc4d73b67 100644 --- a/native/core/src/execution/joins/mod.rs +++ b/native/core/src/execution/joins/mod.rs @@ -23,3 +23,6 @@ mod sort_merge_join; mod sort_merge_join_stream; pub(crate) use sort_merge_join::CometSortMergeJoinExec; + +#[cfg(test)] +mod tests; diff --git a/native/core/src/execution/joins/sort_merge_join_stream.rs b/native/core/src/execution/joins/sort_merge_join_stream.rs index bb95d8d8cf..9dabf64218 100644 --- a/native/core/src/execution/joins/sort_merge_join_stream.rs +++ b/native/core/src/execution/joins/sort_merge_join_stream.rs @@ -332,6 +332,13 @@ impl SortMergeJoinStream { continue; } + // Before clearing the match group, flush any pending output + // that references the current match group's batches. + if self.output_builder.has_pending() { + self.state = JoinState::OutputReady; + continue; + } + // Clear old match group. self.match_group.clear(&mut self.reservation); self.cached_streamed_key = None; @@ -376,9 +383,21 @@ impl SortMergeJoinStream { } } Ordering::Greater => { - // Streamed key > buffered key: advance buffered. - self.buffered_pending = None; - if self.buffered_exhausted { + // Streamed key > buffered key: advance past the + // first buffered row. If the pending batch has + // more rows, slice it; otherwise discard and poll + // the next batch. + let (batch, arrays) = self.buffered_pending.take().unwrap(); + if batch.num_rows() > 1 { + let remaining = batch.slice(1, batch.num_rows() - 1); + let remaining_arrays: Vec = arrays + .iter() + .map(|a| a.slice(1, a.len() - 1)) + .collect(); + self.buffered_pending = Some((remaining, remaining_arrays)); + // Re-compare with the next buffered row. + self.state = JoinState::Comparing; + } else if self.buffered_exhausted { self.emit_streamed_unmatched(streamed_idx); self.advance_streamed()?; if self.output_builder.should_flush() { @@ -489,31 +508,47 @@ impl SortMergeJoinStream { } } - /// Determine the next state after flushing output. + /// Determine the next state after processing a row. + /// + /// When the current streamed batch is exhausted, any pending output must be + /// flushed first (the output builder holds row indices into the batch). + /// After the flush the batch is cleared and we move on to poll new data. fn determine_next_state(&mut self) { - if self.streamed_batch.is_some() - && self.streamed_idx < self.streamed_batch.as_ref().unwrap().num_rows() - { - // More rows in current streamed batch. - self.state = JoinState::Comparing; - } else if !self.streamed_exhausted { + // Check if the current streamed batch is exhausted. + if let Some(batch) = &self.streamed_batch { + if self.streamed_idx < batch.num_rows() { + // More rows in current streamed batch. + self.state = JoinState::Comparing; + return; + } + // Batch exhausted. If there are pending output rows that reference + // indices in this batch we must flush before clearing. + if self.output_builder.has_pending() { + self.state = JoinState::OutputReady; + return; + } + // Safe to clear — no pending references. + self.streamed_batch = None; + self.streamed_join_arrays = None; + self.streamed_idx = 0; + } + + if !self.streamed_exhausted { self.state = JoinState::Init; } else { self.state = JoinState::DrainUnmatched; } } - /// Advance the streamed side to the next row. If the current batch is - /// exhausted, clear it so we poll the next one. + /// Advance the streamed side to the next row. + /// + /// Note: we do NOT clear the streamed batch here even when all rows have + /// been consumed, because the output builder may still hold index references + /// into the batch that need to be materialized during the next flush. + /// The batch is cleared lazily in `determine_next_state` when we transition + /// to polling a new batch. fn advance_streamed(&mut self) -> Result<()> { self.streamed_idx += 1; - if let Some(batch) = &self.streamed_batch { - if self.streamed_idx >= batch.num_rows() { - self.streamed_batch = None; - self.streamed_join_arrays = None; - self.streamed_idx = 0; - } - } Ok(()) } diff --git a/native/core/src/execution/joins/tests.rs b/native/core/src/execution/joins/tests.rs new file mode 100644 index 0000000000..c6a5f41472 --- /dev/null +++ b/native/core/src/execution/joins/tests.rs @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, StringArray}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion::common::{NullEquality, Result}; +use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::datasource::source::DataSourceExec; +use datafusion::logical_expr::JoinType; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use futures::StreamExt; + +use super::CometSortMergeJoinExec; + +fn left_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("l_key", DataType::Int32, true), + Field::new("l_val", DataType::Utf8, true), + ])) +} + +fn right_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("r_key", DataType::Int32, true), + Field::new("r_val", DataType::Utf8, true), + ])) +} + +fn make_sorted_batches( + schema: SchemaRef, + keys: Vec>, + vals: Vec>, +) -> Vec { + vec![RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(keys)), + Arc::new(StringArray::from(vals)), + ], + ) + .unwrap()] +} + +async fn execute_join( + join_type: JoinType, + left_batches: Vec, + right_batches: Vec, +) -> Result> { + let l_schema = left_batches[0].schema(); + let r_schema = right_batches[0].schema(); + + let left = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[left_batches], l_schema, None)?, + ))); + let right = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[right_batches], r_schema, None)?, + ))); + + let on = vec![( + Arc::new(datafusion::physical_expr::expressions::Column::new("l_key", 0)) + as Arc, + Arc::new(datafusion::physical_expr::expressions::Column::new("r_key", 0)) + as Arc, + )]; + + let join = CometSortMergeJoinExec::try_new( + left, + right, + on, + None, + join_type, + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + )?; + + let ctx = SessionContext::new(); + let task_ctx = ctx.task_ctx(); + let stream = join.execute(0, task_ctx)?; + + let mut results = Vec::new(); + let mut stream = stream; + while let Some(batch) = stream.next().await { + results.push(batch?); + } + Ok(results) +} + +fn total_row_count(batches: &[RecordBatch]) -> usize { + batches.iter().map(|b| b.num_rows()).sum() +} + +#[tokio::test] +async fn test_inner_join_basic() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(3), Some(4)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::Inner, left, right).await?; + assert_eq!(total_row_count(&result), 2); + Ok(()) +} + +#[tokio::test] +async fn test_inner_join_with_duplicates() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(1), Some(2)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(1), Some(1), Some(3)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::Inner, left, right).await?; + assert_eq!(total_row_count(&result), 4); + Ok(()) +} + +#[tokio::test] +async fn test_inner_join_null_keys_skipped() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![None, Some(1), Some(2)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![None, Some(1), Some(2)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::Inner, left, right).await?; + assert_eq!(total_row_count(&result), 2); + Ok(()) +} + +#[tokio::test] +async fn test_inner_join_empty_result() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2)], + vec![Some("a"), Some("b")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(3), Some(4)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::Inner, left, right).await?; + assert_eq!(total_row_count(&result), 0); + Ok(()) +} From 1083a5cc44797dd6de88925febd2dc2c854f3919 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 07:26:08 -0600 Subject: [PATCH 09/12] test: add outer, semi, anti, and spill tests for sort merge join Add 7 new test cases covering left/right/full outer joins, left semi/anti joins, null key handling in outer joins, and an inner join under memory pressure that forces spilling. Fix two bugs found by the new tests: - Full/right outer join now correctly emits unmatched buffered rows that remain after the streamed side is exhausted (new DrainBuffered state in the join stream state machine). - Spill read-back no longer panics inside an async runtime by using block_in_place to allow the nested block_on call. --- .../src/execution/joins/buffered_batch.rs | 20 +- .../execution/joins/sort_merge_join_stream.rs | 79 ++++++++ native/core/src/execution/joins/tests.rs | 177 ++++++++++++++++++ 3 files changed, 267 insertions(+), 9 deletions(-) diff --git a/native/core/src/execution/joins/buffered_batch.rs b/native/core/src/execution/joins/buffered_batch.rs index fa8c7970c9..5624f4f2d0 100644 --- a/native/core/src/execution/joins/buffered_batch.rs +++ b/native/core/src/execution/joins/buffered_batch.rs @@ -95,15 +95,17 @@ impl BufferedBatch { BatchState::Spilled(file) => { let reader = spill_manager.read_spill_as_stream(file.clone(), None)?; - let rt = tokio::runtime::Handle::current(); - let batches = rt.block_on(async { - use futures::StreamExt; - let mut stream = reader; - let mut batches = Vec::new(); - while let Some(batch) = stream.next().await { - batches.push(batch?); - } - Ok::<_, DataFusionError>(batches) + let batches = tokio::task::block_in_place(|| { + let rt = tokio::runtime::Handle::current(); + rt.block_on(async { + use futures::StreamExt; + let mut stream = reader; + let mut batches = Vec::new(); + while let Some(batch) = stream.next().await { + batches.push(batch?); + } + Ok::<_, DataFusionError>(batches) + }) })?; // A single batch was spilled per file, but concatenate just in case. if batches.len() == 1 { diff --git a/native/core/src/execution/joins/sort_merge_join_stream.rs b/native/core/src/execution/joins/sort_merge_join_stream.rs index 9dabf64218..257f1bcc1b 100644 --- a/native/core/src/execution/joins/sort_merge_join_stream.rs +++ b/native/core/src/execution/joins/sort_merge_join_stream.rs @@ -65,6 +65,8 @@ enum JoinState { OutputReady, /// Drain unmatched rows after one side is exhausted. DrainUnmatched, + /// Drain remaining buffered rows as null-joined (Full/Right outer). + DrainBuffered, /// No more output. Exhausted, } @@ -487,6 +489,83 @@ impl SortMergeJoinStream { // Drain remaining streamed rows as null-joined (for outer/anti). self.drain_remaining()?; + if self.output_builder.has_pending() { + let batch = self.flush_output()?; + // For Full/Right outer, we may still need to drain buffered rows. + if matches!(self.join_type, JoinType::Full | JoinType::Right) { + self.state = JoinState::DrainBuffered; + } else { + self.state = JoinState::Exhausted; + } + if batch.num_rows() > 0 { + self.metrics.output_rows.add(batch.num_rows()); + self.metrics.output_batches.add(1); + return Poll::Ready(Some(Ok(batch))); + } + } + if matches!(self.join_type, JoinType::Full | JoinType::Right) { + self.state = JoinState::DrainBuffered; + } else { + self.state = JoinState::Exhausted; + } + } + + JoinState::DrainBuffered => { + // For Full/Right outer: emit remaining buffered rows as null-joined. + // First, clear the match group so we can reuse it for pending rows. + self.match_group.clear(&mut self.reservation); + + // Add buffered_pending rows to the match group. + if let Some((batch, arrays)) = self.buffered_pending.take() { + let num_rows = batch.num_rows(); + self.match_group.add_batch( + batch, + arrays, + true, // track matched status + &mut self.reservation, + &self.spill_manager, + &self.metrics, + )?; + // All these rows are unmatched. + for row_idx in 0..num_rows { + self.output_builder.add_buffered_null_join(0, row_idx); + } + } + + // Poll remaining buffered batches. + if !self.buffered_exhausted { + match self.buffered_input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() > 0 { + let num_rows = batch.num_rows(); + let join_arrays = + evaluate_join_keys(&batch, &self.buffered_join_exprs)?; + let batch_idx = self.match_group.batches.len(); + self.match_group.add_batch( + batch, + join_arrays, + true, + &mut self.reservation, + &self.spill_manager, + &self.metrics, + )?; + for row_idx in 0..num_rows { + self.output_builder + .add_buffered_null_join(batch_idx, row_idx); + } + } + // Stay in DrainBuffered to poll more. + continue; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.buffered_exhausted = true; + } + Poll::Pending => return Poll::Pending, + } + } + + // All buffered rows drained. Flush and finish. if self.output_builder.has_pending() { let batch = self.flush_output()?; self.state = JoinState::Exhausted; diff --git a/native/core/src/execution/joins/tests.rs b/native/core/src/execution/joins/tests.rs index c6a5f41472..c5f6bc7de5 100644 --- a/native/core/src/execution/joins/tests.rs +++ b/native/core/src/execution/joins/tests.rs @@ -179,3 +179,180 @@ async fn test_inner_join_empty_result() -> Result<()> { assert_eq!(total_row_count(&result), 0); Ok(()) } + +// --- Outer join tests --- + +#[tokio::test] +async fn test_left_outer_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(4)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::Left, left, right).await?; + assert_eq!(total_row_count(&result), 3); + Ok(()) +} + +#[tokio::test] +async fn test_left_outer_null_keys() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![None, Some(1)], + vec![Some("a"), Some("b")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(1), Some(2)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::Left, left, right).await?; + assert_eq!(total_row_count(&result), 2); + Ok(()) +} + +#[tokio::test] +async fn test_right_outer_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(3)], + vec![Some("a"), Some("b")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::Right, left, right).await?; + assert_eq!(total_row_count(&result), 3); + Ok(()) +} + +#[tokio::test] +async fn test_full_outer_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2)], + vec![Some("a"), Some("b")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(3)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::Full, left, right).await?; + assert_eq!(total_row_count(&result), 3); + Ok(()) +} + +// --- Semi/Anti join tests --- + +#[tokio::test] +async fn test_left_semi_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(3), Some(4)], + vec![Some("x"), Some("y"), Some("z")], + ); + + let result = execute_join(JoinType::LeftSemi, left, right).await?; + assert_eq!(total_row_count(&result), 2); + // Semi join should only output left columns + assert_eq!(result[0].num_columns(), 2); + Ok(()) +} + +#[tokio::test] +async fn test_left_anti_join() -> Result<()> { + let left = make_sorted_batches( + left_schema(), + vec![Some(1), Some(2), Some(3)], + vec![Some("a"), Some("b"), Some("c")], + ); + let right = make_sorted_batches( + right_schema(), + vec![Some(2), Some(4)], + vec![Some("x"), Some("y")], + ); + + let result = execute_join(JoinType::LeftAnti, left, right).await?; + assert_eq!(total_row_count(&result), 2); + Ok(()) +} + +// --- Spill test --- + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_inner_join_with_spill() -> Result<()> { + use datafusion::execution::runtime_env::RuntimeEnvBuilder; + + let l_schema = left_schema(); + let r_schema = right_schema(); + + let left_batches = make_sorted_batches( + l_schema.clone(), + vec![Some(1), Some(1), Some(1), Some(2), Some(2)], + vec![Some("a"), Some("b"), Some("c"), Some("d"), Some("e")], + ); + let right_batches = make_sorted_batches( + r_schema.clone(), + vec![Some(1), Some(1), Some(1), Some(2)], + vec![Some("w"), Some("x"), Some("y"), Some("z")], + ); + + let left_exec = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[left_batches], l_schema, None)?, + ))); + let right_exec = Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(&[right_batches], r_schema, None)?, + ))); + + let on = vec![( + Arc::new(datafusion::physical_expr::expressions::Column::new("l_key", 0)) + as Arc, + Arc::new(datafusion::physical_expr::expressions::Column::new("r_key", 0)) + as Arc, + )]; + + let join = CometSortMergeJoinExec::try_new( + left_exec, + right_exec, + on, + None, + JoinType::Inner, + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + )?; + + let config = datafusion::prelude::SessionConfig::new().with_batch_size(2); + let runtime = Arc::new( + RuntimeEnvBuilder::new() + .with_memory_limit(1024, 1.0) + .build()?, + ); + let ctx = SessionContext::new_with_config_rt(config, runtime); + let task_ctx = ctx.task_ctx(); + let mut stream = join.execute(0, task_ctx)?; + + let mut results = Vec::new(); + while let Some(batch) = stream.next().await { + results.push(batch?); + } + // 3*3 for key=1 + 2*1 for key=2 = 11 + assert_eq!(total_row_count(&results), 11); + Ok(()) +} From c626c65ec098e774553dab61491341ea8e7f64f7 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 07:29:08 -0600 Subject: [PATCH 10/12] chore: fix clippy warnings and format sort merge join code --- .../src/execution/joins/buffered_batch.rs | 26 ++-- native/core/src/execution/joins/filter.rs | 11 +- .../src/execution/joins/output_builder.rs | 19 +-- .../src/execution/joins/sort_merge_join.rs | 27 +++- .../execution/joins/sort_merge_join_stream.rs | 127 +++++++----------- native/core/src/execution/joins/tests.rs | 56 +++++--- 6 files changed, 118 insertions(+), 148 deletions(-) diff --git a/native/core/src/execution/joins/buffered_batch.rs b/native/core/src/execution/joins/buffered_batch.rs index 5624f4f2d0..357d5f7536 100644 --- a/native/core/src/execution/joins/buffered_batch.rs +++ b/native/core/src/execution/joins/buffered_batch.rs @@ -21,6 +21,8 @@ //! share the current join key. When memory is tight, individual batches are //! spilled to Arrow IPC files on disk and reloaded on demand. +use std::sync::Arc; + use arrow::array::ArrayRef; use arrow::record_batch::RecordBatch; use datafusion::common::utils::memory::get_record_batch_memory_size; @@ -50,6 +52,7 @@ pub(super) struct BufferedBatch { /// The batch data, either in memory or spilled to disk. state: BatchState, /// Pre-evaluated join key column arrays. `None` when the batch has been spilled. + #[allow(dead_code)] join_arrays: Option>, /// Number of rows in this batch (cached so we don't need the batch to know). pub num_rows: usize, @@ -64,11 +67,7 @@ impl BufferedBatch { /// Create a new in-memory buffered batch. /// /// `full_outer` controls whether per-row match tracking is allocated. - fn new_in_memory( - batch: RecordBatch, - join_arrays: Vec, - full_outer: bool, - ) -> Self { + fn new_in_memory(batch: RecordBatch, join_arrays: Vec, full_outer: bool) -> Self { let num_rows = batch.num_rows(); let mut size_estimate = get_record_batch_memory_size(&batch); for arr in &join_arrays { @@ -93,8 +92,7 @@ impl BufferedBatch { match &self.state { BatchState::InMemory(batch) => Ok(batch.clone()), BatchState::Spilled(file) => { - let reader = - spill_manager.read_spill_as_stream(file.clone(), None)?; + let reader = spill_manager.read_spill_as_stream(file.clone(), None)?; let batches = tokio::task::block_in_place(|| { let rt = tokio::runtime::Handle::current(); rt.block_on(async { @@ -111,11 +109,8 @@ impl BufferedBatch { if batches.len() == 1 { Ok(batches.into_iter().next().unwrap()) } else { - arrow::compute::concat_batches( - &spill_manager.schema().clone(), - &batches, - ) - .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) + arrow::compute::concat_batches(&Arc::clone(spill_manager.schema()), &batches) + .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) } } } @@ -123,6 +118,7 @@ impl BufferedBatch { /// Return join key arrays. If in memory, returns the cached arrays directly. /// If spilled, deserializes the batch and re-evaluates the join expressions. + #[allow(dead_code)] pub fn get_join_arrays( &self, spill_manager: &SpillManager, @@ -229,11 +225,7 @@ impl BufferedMatchGroup { } /// Get a batch by index. If the batch was spilled, it is read back from disk. - pub fn get_batch( - &self, - batch_idx: usize, - spill_manager: &SpillManager, - ) -> Result { + pub fn get_batch(&self, batch_idx: usize, spill_manager: &SpillManager) -> Result { self.batches[batch_idx].get_batch(spill_manager) } diff --git a/native/core/src/execution/joins/filter.rs b/native/core/src/execution/joins/filter.rs index e2fdeb312f..e6ef4b416f 100644 --- a/native/core/src/execution/joins/filter.rs +++ b/native/core/src/execution/joins/filter.rs @@ -97,9 +97,7 @@ pub(super) fn build_filter_candidate_batch( JoinSide::Left => (streamed_batch, streamed_indices), JoinSide::Right => (buffered_batch, buffered_indices), JoinSide::None => { - return internal_err!( - "unexpected JoinSide::None in join filter column index" - ); + return internal_err!("unexpected JoinSide::None in join filter column index"); } }; let column = batch.column(col_idx.index); @@ -107,10 +105,7 @@ pub(super) fn build_filter_candidate_batch( }) .collect::>>()?; - Ok(RecordBatch::try_new( - Arc::clone(filter.schema()), - columns, - )?) + Ok(RecordBatch::try_new(Arc::clone(filter.schema()), columns)?) } /// Returns true if the mask value at `i` is true and not null. @@ -197,7 +192,7 @@ fn apply_semi_filter(mask: &BooleanArray, indices: &[JoinIndex]) -> FilteredOutp let mut passed_indices = Vec::new(); let mut buffered_matched = Vec::new(); - for (_streamed_idx, pairs) in &groups { + for pairs in groups.values() { if let Some((_, idx)) = pairs.iter().find(|(i, _)| mask_passed(mask, *i)) { passed_indices.push(*idx); buffered_matched.push((idx.batch_idx, idx.buffered_idx)); diff --git a/native/core/src/execution/joins/output_builder.rs b/native/core/src/execution/joins/output_builder.rs index 5a857309dd..4e79925be8 100644 --- a/native/core/src/execution/joins/output_builder.rs +++ b/native/core/src/execution/joins/output_builder.rs @@ -54,6 +54,7 @@ pub(super) struct OutputBuilder { /// Schema of the output record batch. output_schema: SchemaRef, /// Schema of the streamed (left) side. + #[allow(dead_code)] streamed_schema: SchemaRef, /// Schema of the buffered (right) side. buffered_schema: SchemaRef, @@ -186,8 +187,7 @@ impl OutputBuilder { spill_manager: &SpillManager, ) -> Result { let streamed_columns = self.build_streamed_columns(streamed_batch)?; - let buffered_columns = - self.build_buffered_columns(match_group, spill_manager)?; + let buffered_columns = self.build_buffered_columns(match_group, spill_manager)?; let mut columns = streamed_columns; columns.extend(buffered_columns); @@ -265,8 +265,7 @@ impl OutputBuilder { // Part 1: Matched pairs — gather from buffered batches if !self.indices.is_empty() { - let matched_part = - self.take_buffered_matched(col_idx, match_group, spill_manager)?; + let matched_part = self.take_buffered_matched(col_idx, match_group, spill_manager)?; parts.push(matched_part); } @@ -290,8 +289,7 @@ impl OutputBuilder { return Ok(parts.into_iter().next().unwrap()); } - let part_refs: Vec<&dyn arrow::array::Array> = - parts.iter().map(|a| a.as_ref()).collect(); + let part_refs: Vec<&dyn arrow::array::Array> = parts.iter().map(|a| a.as_ref()).collect(); Ok(concat(&part_refs)?) } @@ -320,8 +318,7 @@ impl OutputBuilder { } else { // Flush previous group if let Some(batch_idx) = current_batch_idx { - let batch = - match_group.batches[batch_idx].get_batch(spill_manager)?; + let batch = match_group.batches[batch_idx].get_batch(spill_manager)?; let col = batch.column(col_idx); let row_indices = UInt32Array::from(std::mem::take(&mut current_row_indices)); parts.push(take(col.as_ref(), &row_indices, None)?); @@ -343,8 +340,7 @@ impl OutputBuilder { return Ok(parts.into_iter().next().unwrap()); } - let part_refs: Vec<&dyn arrow::array::Array> = - parts.iter().map(|a| a.as_ref()).collect(); + let part_refs: Vec<&dyn arrow::array::Array> = parts.iter().map(|a| a.as_ref()).collect(); Ok(concat(&part_refs)?) } @@ -387,8 +383,7 @@ impl OutputBuilder { return Ok(parts.into_iter().next().unwrap()); } - let part_refs: Vec<&dyn arrow::array::Array> = - parts.iter().map(|a| a.as_ref()).collect(); + let part_refs: Vec<&dyn arrow::array::Array> = parts.iter().map(|a| a.as_ref()).collect(); Ok(concat(&part_refs)?) } } diff --git a/native/core/src/execution/joins/sort_merge_join.rs b/native/core/src/execution/joins/sort_merge_join.rs index 50a8530ad3..6909f260b0 100644 --- a/native/core/src/execution/joins/sort_merge_join.rs +++ b/native/core/src/execution/joins/sort_merge_join.rs @@ -71,13 +71,14 @@ impl CometSortMergeJoinExec { check_join_is_valid(&left_schema, &right_schema, &join_on)?; - let (schema, _column_indices) = - build_join_schema(&left_schema, &right_schema, &join_type); + let (schema, _column_indices) = build_join_schema(&left_schema, &right_schema, &join_type); let schema = Arc::new(schema); let properties = PlanProperties::new( EquivalenceProperties::new(Arc::clone(&schema)), - Partitioning::UnknownPartitioning(left.properties().output_partitioning().partition_count()), + Partitioning::UnknownPartitioning( + left.properties().output_partitioning().partition_count(), + ), EmissionType::Incremental, Boundedness::Bounded, ); @@ -157,14 +158,26 @@ impl ExecutionPlan for CometSortMergeJoinExec { JoinType::Right => ( Arc::clone(&self.right), Arc::clone(&self.left), - self.join_on.iter().map(|(_, r)| Arc::clone(r)).collect::>(), - self.join_on.iter().map(|(l, _)| Arc::clone(l)).collect::>(), + self.join_on + .iter() + .map(|(_, r)| Arc::clone(r)) + .collect::>(), + self.join_on + .iter() + .map(|(l, _)| Arc::clone(l)) + .collect::>(), ), _ => ( Arc::clone(&self.left), Arc::clone(&self.right), - self.join_on.iter().map(|(l, _)| Arc::clone(l)).collect::>(), - self.join_on.iter().map(|(_, r)| Arc::clone(r)).collect::>(), + self.join_on + .iter() + .map(|(l, _)| Arc::clone(l)) + .collect::>(), + self.join_on + .iter() + .map(|(_, r)| Arc::clone(r)) + .collect::>(), ), }; diff --git a/native/core/src/execution/joins/sort_merge_join_stream.rs b/native/core/src/execution/joins/sort_merge_join_stream.rs index 257f1bcc1b..7478a27537 100644 --- a/native/core/src/execution/joins/sort_merge_join_stream.rs +++ b/native/core/src/execution/joins/sort_merge_join_stream.rs @@ -199,10 +199,7 @@ impl SortMergeJoinStream { /// Drive the state machine, returning `Poll::Ready(Some(batch))` when a /// batch is available, `Poll::Ready(None)` when done, or `Poll::Pending` /// if waiting on input. - fn poll_next_inner( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> { + fn poll_next_inner(&mut self, cx: &mut Context<'_>) -> Poll>> { loop { match self.state { JoinState::Init => { @@ -249,37 +246,34 @@ impl SortMergeJoinStream { } } - JoinState::PollBuffered => { - match self.buffered_input.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => { - if batch.num_rows() == 0 { - continue; - } - self.metrics.input_batches.add(1); - self.metrics.input_rows.add(batch.num_rows()); - let join_arrays = - evaluate_join_keys(&batch, &self.buffered_join_exprs)?; - self.buffered_pending = Some((batch, join_arrays)); - if self.streamed_batch.is_some() { - self.state = JoinState::Comparing; - } else if !self.streamed_exhausted { - self.state = JoinState::PollStreamed; - } else { - self.state = JoinState::DrainUnmatched; - } + JoinState::PollBuffered => match self.buffered_input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + continue; } - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(None) => { - self.buffered_exhausted = true; - if self.streamed_batch.is_some() { - self.state = JoinState::Comparing; - } else { - self.state = JoinState::DrainUnmatched; - } + self.metrics.input_batches.add(1); + self.metrics.input_rows.add(batch.num_rows()); + let join_arrays = evaluate_join_keys(&batch, &self.buffered_join_exprs)?; + self.buffered_pending = Some((batch, join_arrays)); + if self.streamed_batch.is_some() { + self.state = JoinState::Comparing; + } else if !self.streamed_exhausted { + self.state = JoinState::PollStreamed; + } else { + self.state = JoinState::DrainUnmatched; } - Poll::Pending => return Poll::Pending, } - } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + self.buffered_exhausted = true; + if self.streamed_batch.is_some() { + self.state = JoinState::Comparing; + } else { + self.state = JoinState::DrainUnmatched; + } + } + Poll::Pending => return Poll::Pending, + }, JoinState::Comparing => { // We have a streamed row. Compare its key against the @@ -295,9 +289,7 @@ impl SortMergeJoinStream { .any(|a| a.is_null(streamed_idx)); // For inner/semi joins, skip null keys entirely. - if streamed_has_null - && self.null_equality == NullEquality::NullEqualsNothing - { + if streamed_has_null && self.null_equality == NullEquality::NullEqualsNothing { match self.join_type { JoinType::Inner | JoinType::LeftSemi | JoinType::RightSemi => { self.advance_streamed()?; @@ -309,8 +301,7 @@ impl SortMergeJoinStream { | JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => { - self.output_builder - .add_streamed_null_join(streamed_idx); + self.output_builder.add_streamed_null_join(streamed_idx); self.advance_streamed()?; if self.output_builder.should_flush() { self.state = JoinState::OutputReady; @@ -359,8 +350,7 @@ impl SortMergeJoinStream { // Compare streamed key with buffered key. let ordering = { - let streamed_arrays = - self.streamed_join_arrays.as_ref().unwrap(); + let streamed_arrays = self.streamed_join_arrays.as_ref().unwrap(); let (_buffered_batch, buffered_arrays) = self.buffered_pending.as_ref().unwrap(); compare_join_arrays( @@ -392,10 +382,8 @@ impl SortMergeJoinStream { let (batch, arrays) = self.buffered_pending.take().unwrap(); if batch.num_rows() > 1 { let remaining = batch.slice(1, batch.num_rows() - 1); - let remaining_arrays: Vec = arrays - .iter() - .map(|a| a.slice(1, a.len() - 1)) - .collect(); + let remaining_arrays: Vec = + arrays.iter().map(|a| a.slice(1, a.len() - 1)).collect(); self.buffered_pending = Some((remaining, remaining_arrays)); // Re-compare with the next buffered row. self.state = JoinState::Comparing; @@ -579,8 +567,7 @@ impl SortMergeJoinStream { } JoinState::Exhausted => { - self.metrics - .update_peak_mem(self.reservation.size()); + self.metrics.update_peak_mem(self.reservation.size()); return Poll::Ready(None); } } @@ -789,18 +776,10 @@ impl SortMergeJoinStream { // Build candidate batch for filter evaluation. let streamed_batch = self.streamed_batch.as_ref().unwrap(); - let candidate_batch = self.build_filter_batch( - filter, - streamed_batch, - &pair_indices, - )?; + let candidate_batch = self.build_filter_batch(filter, streamed_batch, &pair_indices)?; - let filtered = apply_join_filter( - filter, - &candidate_batch, - &pair_indices, - &self.join_type, - )?; + let filtered = + apply_join_filter(filter, &candidate_batch, &pair_indices, &self.join_type)?; // Apply filtered results. for idx in &filtered.passed_indices { @@ -850,15 +829,10 @@ impl SortMergeJoinStream { } let streamed_batch = self.streamed_batch.as_ref().unwrap(); - let candidate_batch = - self.build_filter_batch(filter, streamed_batch, &pair_indices)?; - - let filtered = apply_join_filter( - filter, - &candidate_batch, - &pair_indices, - &self.join_type, - )?; + let candidate_batch = self.build_filter_batch(filter, streamed_batch, &pair_indices)?; + + let filtered = + apply_join_filter(filter, &candidate_batch, &pair_indices, &self.join_type)?; // Semi: emit the streamed row if any pair passed. if !filtered.passed_indices.is_empty() { @@ -896,15 +870,10 @@ impl SortMergeJoinStream { } let streamed_batch = self.streamed_batch.as_ref().unwrap(); - let candidate_batch = - self.build_filter_batch(filter, streamed_batch, &pair_indices)?; - - let filtered = apply_join_filter( - filter, - &candidate_batch, - &pair_indices, - &self.join_type, - )?; + let candidate_batch = self.build_filter_batch(filter, streamed_batch, &pair_indices)?; + + let filtered = + apply_join_filter(filter, &candidate_batch, &pair_indices, &self.join_type)?; // Anti: emit streamed rows that had no passing pair. for &si in &filtered.streamed_null_joins { @@ -949,10 +918,7 @@ impl SortMergeJoinStream { } /// Collect all referenced buffered rows into a single batch. - fn collect_buffered_rows( - &self, - pair_indices: &[JoinIndex], - ) -> Result { + fn collect_buffered_rows(&self, pair_indices: &[JoinIndex]) -> Result { if pair_indices.is_empty() { // Return an empty batch with the correct schema. let schema = Arc::clone(self.spill_manager.schema()); @@ -1158,10 +1124,7 @@ fn find_key_boundary( impl Stream for SortMergeJoinStream { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let join_time = self.metrics.join_time.clone(); let timer = join_time.timer(); let result = self.poll_next_inner(cx); diff --git a/native/core/src/execution/joins/tests.rs b/native/core/src/execution/joins/tests.rs index c5f6bc7de5..d916f67747 100644 --- a/native/core/src/execution/joins/tests.rs +++ b/native/core/src/execution/joins/tests.rs @@ -68,18 +68,24 @@ async fn execute_join( let l_schema = left_batches[0].schema(); let r_schema = right_batches[0].schema(); - let left = Arc::new(DataSourceExec::new(Arc::new( - MemorySourceConfig::try_new(&[left_batches], l_schema, None)?, - ))); - let right = Arc::new(DataSourceExec::new(Arc::new( - MemorySourceConfig::try_new(&[right_batches], r_schema, None)?, - ))); + let left = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + l_schema, + None, + )?))); + let right = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + r_schema, + None, + )?))); let on = vec![( - Arc::new(datafusion::physical_expr::expressions::Column::new("l_key", 0)) - as Arc, - Arc::new(datafusion::physical_expr::expressions::Column::new("r_key", 0)) - as Arc, + Arc::new(datafusion::physical_expr::expressions::Column::new( + "l_key", 0, + )) as Arc, + Arc::new(datafusion::physical_expr::expressions::Column::new( + "r_key", 0, + )) as Arc, )]; let join = CometSortMergeJoinExec::try_new( @@ -304,28 +310,34 @@ async fn test_inner_join_with_spill() -> Result<()> { let r_schema = right_schema(); let left_batches = make_sorted_batches( - l_schema.clone(), + Arc::clone(&l_schema), vec![Some(1), Some(1), Some(1), Some(2), Some(2)], vec![Some("a"), Some("b"), Some("c"), Some("d"), Some("e")], ); let right_batches = make_sorted_batches( - r_schema.clone(), + Arc::clone(&r_schema), vec![Some(1), Some(1), Some(1), Some(2)], vec![Some("w"), Some("x"), Some("y"), Some("z")], ); - let left_exec = Arc::new(DataSourceExec::new(Arc::new( - MemorySourceConfig::try_new(&[left_batches], l_schema, None)?, - ))); - let right_exec = Arc::new(DataSourceExec::new(Arc::new( - MemorySourceConfig::try_new(&[right_batches], r_schema, None)?, - ))); + let left_exec = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[left_batches], + l_schema, + None, + )?))); + let right_exec = Arc::new(DataSourceExec::new(Arc::new(MemorySourceConfig::try_new( + &[right_batches], + r_schema, + None, + )?))); let on = vec![( - Arc::new(datafusion::physical_expr::expressions::Column::new("l_key", 0)) - as Arc, - Arc::new(datafusion::physical_expr::expressions::Column::new("r_key", 0)) - as Arc, + Arc::new(datafusion::physical_expr::expressions::Column::new( + "l_key", 0, + )) as Arc, + Arc::new(datafusion::physical_expr::expressions::Column::new( + "r_key", 0, + )) as Arc, )]; let join = CometSortMergeJoinExec::try_new( From ba04a610eb6d07ce9ec7b365eebc2b9e3d18241b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 08:24:42 -0600 Subject: [PATCH 11/12] refactor: simplify output builder and encapsulate matched tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix N×M spill read issue: cache loaded batches across all columns instead of reloading per column. Extracted group_by_batch helpers. - Remove unused streamed_schema field from OutputBuilder - Encapsulate BufferedBatch.matched behind mark_matched/unmatched_indices - Remove duplicate take_buffered_matched/take_buffered_null_joins with shared take_from_groups helper --- .../src/execution/joins/buffered_batch.rs | 20 +- .../src/execution/joins/output_builder.rs | 269 ++++++------------ .../execution/joins/sort_merge_join_stream.rs | 19 +- 3 files changed, 114 insertions(+), 194 deletions(-) diff --git a/native/core/src/execution/joins/buffered_batch.rs b/native/core/src/execution/joins/buffered_batch.rs index 357d5f7536..ebd725233a 100644 --- a/native/core/src/execution/joins/buffered_batch.rs +++ b/native/core/src/execution/joins/buffered_batch.rs @@ -59,11 +59,27 @@ pub(super) struct BufferedBatch { /// Estimated memory footprint in bytes (batch + join arrays). pub size_estimate: usize, /// For full/right outer joins: tracks which rows have been matched. - /// `None` for inner/left joins where unmatched tracking is not needed. - pub matched: Option>, + matched: Option>, } impl BufferedBatch { + /// Mark a buffered row as matched (for full outer join tracking). + pub fn mark_matched(&mut self, row_idx: usize) { + if let Some(ref mut matched) = self.matched { + matched[row_idx] = true; + } + } + + /// Iterate over unmatched row indices. + pub fn unmatched_indices(&self) -> impl Iterator + '_ { + self.matched.as_ref().into_iter().flat_map(|m| { + m.iter() + .enumerate() + .filter(|(_, &matched)| !matched) + .map(|(idx, _)| idx) + }) + } + /// Create a new in-memory buffered batch. /// /// `full_outer` controls whether per-row match tracking is allocated. diff --git a/native/core/src/execution/joins/output_builder.rs b/native/core/src/execution/joins/output_builder.rs index 4e79925be8..7f04de438c 100644 --- a/native/core/src/execution/joins/output_builder.rs +++ b/native/core/src/execution/joins/output_builder.rs @@ -36,53 +36,32 @@ use super::buffered_batch::BufferedMatchGroup; /// An index pair representing a matched row from the streamed and buffered sides. #[derive(Debug, Clone, Copy)] pub(super) struct JoinIndex { - /// Row index in the current streamed batch. pub streamed_idx: usize, - /// Index of the buffered batch within the match group. pub batch_idx: usize, - /// Row index within the buffered batch. pub buffered_idx: usize, } /// Accumulates join output indices and materializes them into Arrow record batches. -/// -/// During the join process, matched pairs and null-joined rows are recorded as -/// index tuples. When enough indices have been accumulated (reaching the target -/// batch size), they are materialized into a `RecordBatch` by gathering columns -/// from the streamed and buffered sides using Arrow's `take` kernel. pub(super) struct OutputBuilder { - /// Schema of the output record batch. output_schema: SchemaRef, - /// Schema of the streamed (left) side. - #[allow(dead_code)] - streamed_schema: SchemaRef, - /// Schema of the buffered (right) side. buffered_schema: SchemaRef, - /// The type of join being performed. join_type: JoinType, - /// Target number of rows per output batch. target_batch_size: usize, - /// Matched pairs: (streamed_idx, batch_idx, buffered_idx). indices: Vec, - /// Streamed row indices that need a null buffered counterpart (left outer, left anti). streamed_null_joins: Vec, - /// Buffered row indices that need a null streamed counterpart (full outer). - /// Each entry is (batch_idx, row_idx). buffered_null_joins: Vec<(usize, usize)>, } impl OutputBuilder { - /// Create a new `OutputBuilder`. pub fn new( output_schema: SchemaRef, - streamed_schema: SchemaRef, + _streamed_schema: SchemaRef, buffered_schema: SchemaRef, join_type: JoinType, target_batch_size: usize, ) -> Self { Self { output_schema, - streamed_schema, buffered_schema, join_type, target_batch_size, @@ -92,7 +71,6 @@ impl OutputBuilder { } } - /// Record a matched pair between streamed and buffered rows. pub fn add_match(&mut self, streamed_idx: usize, batch_idx: usize, buffered_idx: usize) { self.indices.push(JoinIndex { streamed_idx, @@ -101,38 +79,28 @@ impl OutputBuilder { }); } - /// Record a streamed row that needs a null buffered counterpart - /// (used for outer joins and anti joins). pub fn add_streamed_null_join(&mut self, streamed_idx: usize) { self.streamed_null_joins.push(streamed_idx); } - /// Record a buffered row that needs a null streamed counterpart - /// (used for full outer joins). pub fn add_buffered_null_join(&mut self, batch_idx: usize, buffered_idx: usize) { self.buffered_null_joins.push((batch_idx, buffered_idx)); } - /// Return the total number of pending output rows. pub fn pending_count(&self) -> usize { self.indices.len() + self.streamed_null_joins.len() + self.buffered_null_joins.len() } - /// Returns `true` if the pending row count has reached or exceeded the target batch size. pub fn should_flush(&self) -> bool { self.pending_count() >= self.target_batch_size } - /// Returns `true` if there are any pending output rows. pub fn has_pending(&self) -> bool { self.pending_count() > 0 } /// Materialize the accumulated indices into a [`RecordBatch`]. /// - /// For `LeftSemi` and `LeftAnti` joins, only streamed columns are included. - /// For all other join types, columns from both sides are concatenated. - /// /// After building, all accumulated indices are cleared. pub fn build( &mut self, @@ -146,7 +114,6 @@ impl OutputBuilder { _ => self.build_full(streamed_batch, match_group, spill_manager), }; - // Clear all accumulated indices after building self.indices.clear(); self.streamed_null_joins.clear(); self.buffered_null_joins.clear(); @@ -154,10 +121,7 @@ impl OutputBuilder { result } - /// Build output for LeftSemi/LeftAnti joins (streamed columns only). fn build_semi_anti(&self, streamed_batch: &RecordBatch) -> Result { - // For semi/anti joins, we only output streamed rows from matched pairs - // and streamed null joins. No buffered columns. let indices: Vec = self .indices .iter() @@ -179,7 +143,6 @@ impl OutputBuilder { )?) } - /// Build output for all other join types (both streamed and buffered columns). fn build_full( &self, streamed_batch: &RecordBatch, @@ -198,14 +161,10 @@ impl OutputBuilder { )?) } - /// Build the streamed side columns by gathering rows using the `take` kernel. - /// - /// The order is: matched pairs, streamed null joins, then None for buffered null joins. fn build_streamed_columns(&self, streamed_batch: &RecordBatch) -> Result> { let total_rows = self.pending_count(); let num_buffered_nulls = self.buffered_null_joins.len(); - // Build indices: matched streamed_idx, then streamed_null_join indices, then None let indices: Vec> = self .indices .iter() @@ -225,165 +184,119 @@ impl OutputBuilder { .collect() } - /// Build the buffered side columns by gathering rows from buffered batches. - /// - /// For matched pairs: take from buffered batches (loading spilled ones as needed). - /// For streamed null joins: null arrays. - /// For buffered null joins: take from buffered batches. + /// Build all buffered columns at once, loading each batch only once across all columns. fn build_buffered_columns( &self, match_group: &BufferedMatchGroup, spill_manager: &SpillManager, ) -> Result> { - let num_buffered_cols = self.buffered_schema.fields().len(); + let num_cols = self.buffered_schema.fields().len(); let num_streamed_nulls = self.streamed_null_joins.len(); - // Build one column at a time - (0..num_buffered_cols) - .map(|col_idx| { - self.build_single_buffered_column( - col_idx, - match_group, - spill_manager, - num_streamed_nulls, - ) - }) - .collect() - } - - /// Build a single buffered column by concatenating parts from matched pairs, - /// null arrays for streamed null joins, and parts from buffered null joins. - fn build_single_buffered_column( - &self, - col_idx: usize, - match_group: &BufferedMatchGroup, - spill_manager: &SpillManager, - num_streamed_nulls: usize, - ) -> Result { - let data_type = self.buffered_schema.field(col_idx).data_type(); - let mut parts: Vec = Vec::new(); - - // Part 1: Matched pairs — gather from buffered batches - if !self.indices.is_empty() { - let matched_part = self.take_buffered_matched(col_idx, match_group, spill_manager)?; - parts.push(matched_part); - } - - // Part 2: Streamed null joins — null arrays for buffered side - if num_streamed_nulls > 0 { - parts.push(new_null_array(data_type, num_streamed_nulls)); - } - - // Part 3: Buffered null joins — gather from buffered batches - if !self.buffered_null_joins.is_empty() { - let null_join_part = - self.take_buffered_null_joins(col_idx, match_group, spill_manager)?; - parts.push(null_join_part); - } - - if parts.is_empty() { - return Ok(new_null_array(data_type, 0)); + // Pre-compute which batches we need and their grouped row indices. + // This avoids loading the same spilled batch N times (once per column). + let matched_groups = group_by_batch(&self.indices); + let null_join_groups = group_by_batch_tuple(&self.buffered_null_joins); + + // Load each referenced batch once + let mut batch_cache: std::collections::HashMap = + std::collections::HashMap::new(); + for &(batch_idx, _) in matched_groups.iter().chain(null_join_groups.iter()) { + if let std::collections::hash_map::Entry::Vacant(e) = batch_cache.entry(batch_idx) { + e.insert(match_group.batches[batch_idx].get_batch(spill_manager)?); + } } - if parts.len() == 1 { - return Ok(parts.into_iter().next().unwrap()); - } + // Build all columns using the cached batches + let mut result: Vec = Vec::with_capacity(num_cols); + for col_idx in 0..num_cols { + let data_type = self.buffered_schema.field(col_idx).data_type(); + let mut parts: Vec = Vec::with_capacity(3); - let part_refs: Vec<&dyn arrow::array::Array> = parts.iter().map(|a| a.as_ref()).collect(); - Ok(concat(&part_refs)?) - } + // Matched pairs + if !matched_groups.is_empty() { + parts.push(take_from_groups(col_idx, &matched_groups, &batch_cache)?); + } - /// Take values from buffered batches for matched pairs. - fn take_buffered_matched( - &self, - col_idx: usize, - match_group: &BufferedMatchGroup, - spill_manager: &SpillManager, - ) -> Result { - // Group indices by batch_idx to minimize batch lookups - // For simplicity, we build per-index and concatenate, but a more - // optimized version would group by batch_idx. - - // Simple approach: gather one at a time using take with single-element indices - // A more efficient approach groups by batch_idx, but this is correct and clear. - let mut parts: Vec = Vec::new(); - - // Group consecutive indices by batch_idx for efficiency - let mut current_batch_idx: Option = None; - let mut current_row_indices: Vec = Vec::new(); - - for join_idx in &self.indices { - if current_batch_idx == Some(join_idx.batch_idx) { - current_row_indices.push(join_idx.buffered_idx as u32); - } else { - // Flush previous group - if let Some(batch_idx) = current_batch_idx { - let batch = match_group.batches[batch_idx].get_batch(spill_manager)?; - let col = batch.column(col_idx); - let row_indices = UInt32Array::from(std::mem::take(&mut current_row_indices)); - parts.push(take(col.as_ref(), &row_indices, None)?); - } - current_batch_idx = Some(join_idx.batch_idx); - current_row_indices.push(join_idx.buffered_idx as u32); + // Null arrays for streamed null joins + if num_streamed_nulls > 0 { + parts.push(new_null_array(data_type, num_streamed_nulls)); } - } - // Flush last group - if let Some(batch_idx) = current_batch_idx { - let batch = match_group.batches[batch_idx].get_batch(spill_manager)?; - let col = batch.column(col_idx); - let row_indices = UInt32Array::from(current_row_indices); - parts.push(take(col.as_ref(), &row_indices, None)?); - } + // Buffered null joins + if !null_join_groups.is_empty() { + parts.push(take_from_groups(col_idx, &null_join_groups, &batch_cache)?); + } - if parts.len() == 1 { - return Ok(parts.into_iter().next().unwrap()); + result.push(concat_parts(parts, data_type)?); } - let part_refs: Vec<&dyn arrow::array::Array> = parts.iter().map(|a| a.as_ref()).collect(); - Ok(concat(&part_refs)?) + Ok(result) } +} - /// Take values from buffered batches for buffered null joins (full outer). - fn take_buffered_null_joins( - &self, - col_idx: usize, - match_group: &BufferedMatchGroup, - spill_manager: &SpillManager, - ) -> Result { - let mut parts: Vec = Vec::new(); - - // Group consecutive entries by batch_idx - let mut current_batch_idx: Option = None; - let mut current_row_indices: Vec = Vec::new(); - - for &(batch_idx, row_idx) in &self.buffered_null_joins { - if current_batch_idx == Some(batch_idx) { - current_row_indices.push(row_idx as u32); - } else { - if let Some(bi) = current_batch_idx { - let batch = match_group.batches[bi].get_batch(spill_manager)?; - let col = batch.column(col_idx); - let row_indices = UInt32Array::from(std::mem::take(&mut current_row_indices)); - parts.push(take(col.as_ref(), &row_indices, None)?); - } - current_batch_idx = Some(batch_idx); - current_row_indices.push(row_idx as u32); +/// Group JoinIndex entries by batch_idx into (batch_idx, row_indices) pairs. +fn group_by_batch(indices: &[JoinIndex]) -> Vec<(usize, Vec)> { + let mut groups: Vec<(usize, Vec)> = Vec::new(); + for idx in indices { + if let Some(last) = groups.last_mut() { + if last.0 == idx.batch_idx { + last.1.push(idx.buffered_idx as u32); + continue; } } + groups.push((idx.batch_idx, vec![idx.buffered_idx as u32])); + } + groups +} - if let Some(bi) = current_batch_idx { - let batch = match_group.batches[bi].get_batch(spill_manager)?; - let col = batch.column(col_idx); - let row_indices = UInt32Array::from(current_row_indices); - parts.push(take(col.as_ref(), &row_indices, None)?); +/// Group (batch_idx, row_idx) tuples by batch_idx into (batch_idx, row_indices) pairs. +fn group_by_batch_tuple(indices: &[(usize, usize)]) -> Vec<(usize, Vec)> { + let mut groups: Vec<(usize, Vec)> = Vec::new(); + for &(batch_idx, row_idx) in indices { + if let Some(last) = groups.last_mut() { + if last.0 == batch_idx { + last.1.push(row_idx as u32); + continue; + } } + groups.push((batch_idx, vec![row_idx as u32])); + } + groups +} - if parts.len() == 1 { - return Ok(parts.into_iter().next().unwrap()); - } +/// Take a single column from pre-loaded batches using grouped indices. +fn take_from_groups( + col_idx: usize, + groups: &[(usize, Vec)], + batch_cache: &std::collections::HashMap, +) -> Result { + let mut parts: Vec = Vec::with_capacity(groups.len()); + for (batch_idx, row_indices) in groups { + let batch = &batch_cache[batch_idx]; + let col = batch.column(col_idx); + let index_array = UInt32Array::from(row_indices.clone()); + parts.push(take(col.as_ref(), &index_array, None)?); + } + concat_parts( + parts, + batch_cache + .values() + .next() + .unwrap() + .column(col_idx) + .data_type(), + ) +} - let part_refs: Vec<&dyn arrow::array::Array> = parts.iter().map(|a| a.as_ref()).collect(); - Ok(concat(&part_refs)?) +/// Concat array parts, handling empty and single-element cases. +fn concat_parts(parts: Vec, data_type: &arrow::datatypes::DataType) -> Result { + if parts.is_empty() { + return Ok(new_null_array(data_type, 0)); + } + if parts.len() == 1 { + return Ok(parts.into_iter().next().expect("checked len == 1")); } + let refs: Vec<&dyn arrow::array::Array> = parts.iter().map(|a| a.as_ref()).collect(); + Ok(concat(&refs)?) } diff --git a/native/core/src/execution/joins/sort_merge_join_stream.rs b/native/core/src/execution/joins/sort_merge_join_stream.rs index 7478a27537..4db270df8a 100644 --- a/native/core/src/execution/joins/sort_merge_join_stream.rs +++ b/native/core/src/execution/joins/sort_merge_join_stream.rs @@ -789,11 +789,8 @@ impl SortMergeJoinStream { for &si in &filtered.streamed_null_joins { self.output_builder.add_streamed_null_join(si); } - // Mark matched buffered rows for full outer. for &(batch_idx, buffered_idx) in &filtered.buffered_matched { - if let Some(ref mut matched) = self.match_group.batches[batch_idx].matched { - matched[buffered_idx] = true; - } + self.match_group.batches[batch_idx].mark_matched(buffered_idx); } } else { // No filter: all pairs match. @@ -801,9 +798,7 @@ impl SortMergeJoinStream { for row_idx in 0..buffered_batch.num_rows { self.output_builder .add_match(streamed_idx, batch_idx, row_idx); - if let Some(ref mut matched) = buffered_batch.matched { - matched[row_idx] = true; - } + buffered_batch.mark_matched(row_idx); } } } @@ -1049,13 +1044,9 @@ impl SortMergeJoinStream { // For full outer: drain unmatched buffered rows. if matches!(self.join_type, JoinType::Full) { for (batch_idx, buffered_batch) in self.match_group.batches.iter().enumerate() { - if let Some(ref matched) = buffered_batch.matched { - for (row_idx, &was_matched) in matched.iter().enumerate() { - if !was_matched { - self.output_builder - .add_buffered_null_join(batch_idx, row_idx); - } - } + for row_idx in buffered_batch.unmatched_indices() { + self.output_builder + .add_buffered_null_join(batch_idx, row_idx); } } } From 93012303ee3a8e09eca1faa95238dbc461e95b5f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 1 Apr 2026 09:36:43 -0600 Subject: [PATCH 12/12] perf: cache row-format keys per batch instead of per-row conversion RowConverter.convert_columns() was called on the full streamed batch for every row during key-reuse checks. Now converted once when a new streamed batch arrives and cached as streamed_rows. Key-reuse and cache_streamed_key index directly into the cached Rows. --- .../execution/joins/sort_merge_join_stream.rs | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/native/core/src/execution/joins/sort_merge_join_stream.rs b/native/core/src/execution/joins/sort_merge_join_stream.rs index 4db270df8a..ace92651b0 100644 --- a/native/core/src/execution/joins/sort_merge_join_stream.rs +++ b/native/core/src/execution/joins/sort_merge_join_stream.rs @@ -30,7 +30,7 @@ use std::task::{Context, Poll}; use arrow::array::{ArrayRef, RecordBatch, UInt32Array}; use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; -use arrow::row::{OwnedRow, RowConverter, SortField}; +use arrow::row::{OwnedRow, RowConverter, Rows, SortField}; use datafusion::common::{NullEquality, Result}; use datafusion::execution::memory_pool::MemoryReservation; use datafusion::logical_expr::JoinType; @@ -109,6 +109,8 @@ pub(super) struct SortMergeJoinStream { /// Converts join keys to comparable row format for key-reuse optimization. row_converter: RowConverter, + /// Row-format conversion of current streamed batch's join keys (computed once per batch). + streamed_rows: Option, /// Cached key of the previous streamed row (for key-reuse detection). cached_streamed_key: Option, @@ -186,6 +188,7 @@ impl SortMergeJoinStream { buffered_exhausted: false, match_group: BufferedMatchGroup::new(), row_converter, + streamed_rows: None, cached_streamed_key: None, output_builder, output_schema, @@ -227,6 +230,17 @@ impl SortMergeJoinStream { self.metrics.input_rows.add(batch.num_rows()); let join_arrays = evaluate_join_keys(&batch, &self.streamed_join_exprs)?; + // Convert join keys to row format once per batch for key-reuse checks + let rows = + self.row_converter + .convert_columns(&join_arrays) + .map_err(|e| { + datafusion::common::DataFusionError::ArrowError( + Box::new(e), + None, + ) + })?; + self.streamed_rows = Some(rows); self.streamed_batch = Some(batch); self.streamed_join_arrays = Some(join_arrays); self.streamed_idx = 0; @@ -596,6 +610,7 @@ impl SortMergeJoinStream { // Safe to clear — no pending references. self.streamed_batch = None; self.streamed_join_arrays = None; + self.streamed_rows = None; self.streamed_idx = 0; } @@ -619,17 +634,13 @@ impl SortMergeJoinStream { } /// Try to reuse the existing match group if the current streamed key - /// matches the cached key. - fn try_reuse_match_group(&mut self) -> Result { + /// matches the cached key. Uses pre-computed row-format keys (once per batch). + fn try_reuse_match_group(&self) -> Result { if self.cached_streamed_key.is_none() || self.match_group.is_empty() { return Ok(false); } - let streamed_arrays = self.streamed_join_arrays.as_ref().unwrap(); - let rows = self - .row_converter - .convert_columns(streamed_arrays) - .map_err(|e| datafusion::common::DataFusionError::ArrowError(Box::new(e), None))?; + let rows = self.streamed_rows.as_ref().unwrap(); let current_key = rows.row(self.streamed_idx); if let Some(ref cached) = self.cached_streamed_key { @@ -643,11 +654,7 @@ impl SortMergeJoinStream { /// Cache the current streamed key as an OwnedRow. fn cache_streamed_key(&mut self) -> Result<()> { - let streamed_arrays = self.streamed_join_arrays.as_ref().unwrap(); - let rows = self - .row_converter - .convert_columns(streamed_arrays) - .map_err(|e| datafusion::common::DataFusionError::ArrowError(Box::new(e), None))?; + let rows = self.streamed_rows.as_ref().unwrap(); self.cached_streamed_key = Some(rows.row(self.streamed_idx).owned()); Ok(()) }