Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
378 changes: 378 additions & 0 deletions vortex-array/public-api.lock

Large diffs are not rendered by default.

166 changes: 166 additions & 0 deletions vortex-array/src/aggregate_fn/accumulator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_session::VortexSession;

use crate::AnyCanonical;
use crate::ArrayRef;
use crate::Canonical;
use crate::DynArray;
use crate::VortexSessionExecute;
use crate::aggregate_fn::AggregateFn;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::AggregateFnVTable;
use crate::aggregate_fn::session::AggregateFnSessionExt;
use crate::dtype::DType;
use crate::executor::MAX_ITERATIONS;
use crate::scalar::Scalar;

/// Reference-counted type-erased accumulator.
pub type AccumulatorRef = Box<dyn DynAccumulator>;

/// An accumulator used for computing aggregates over an entire stream of arrays.
pub struct Accumulator<V: AggregateFnVTable> {
/// The vtable of the aggregate function.
vtable: V,
/// Type-erased aggregate function used for kernel dispatch.
aggregate_fn: AggregateFnRef,
/// The DType of the input.
dtype: DType,
/// The DType of the aggregate.
return_dtype: DType,
/// The DType of the accumulator state.
state_dtype: DType,
/// The current state of the accumulator, updated after each accumulate/merge call.
current_state: V::GroupState,
/// A session used to lookup custom aggregate kernels.
session: VortexSession,
}

impl<V: AggregateFnVTable> Accumulator<V> {
pub fn try_new(
vtable: V,
options: V::Options,
dtype: DType,
session: VortexSession,
) -> VortexResult<Self> {
let return_dtype = vtable.return_dtype(&options, &dtype)?;
let state_dtype = vtable.state_dtype(&options, &dtype)?;
let current_state = vtable.state_new(&options, &dtype)?;
let aggregate_fn = AggregateFn::new(vtable.clone(), options).erased();

Ok(Self {
vtable,
aggregate_fn,
dtype,
return_dtype,
state_dtype,
current_state,
session,
})
}
}

/// A trait object for type-erased accumulators, used for dynamic dispatch when the aggregate
/// function is not known at compile time.
pub trait DynAccumulator: 'static + Send {
/// Accumulate a new array into the accumulator's state.
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()>;

/// Whether the accumulator's result is fully determined.
fn is_saturated(&self) -> bool;

/// Flush the accumulation state and return the partial aggregate result as a scalar.
///
/// Resets the accumulator state back to the initial state.
fn flush(&mut self) -> VortexResult<Scalar>;

/// Finish the accumulation and return the final aggregate result as a scalar.
///
/// Resets the accumulator state back to the initial state.
fn finish(&mut self) -> VortexResult<Scalar>;
}

impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
fn accumulate(&mut self, batch: &ArrayRef) -> VortexResult<()> {
if self.is_saturated() {
return Ok(());
}

vortex_ensure!(
batch.dtype() == &self.dtype,
"Input DType mismatch: expected {}, got {}",
self.dtype,
batch.dtype()
);

let kernels = &self.session.aggregate_fns().kernels;

let mut ctx = self.session.create_execution_ctx();
let mut batch = batch.clone();
for _ in 0..*MAX_ITERATIONS {
if batch.is::<AnyCanonical>() {
break;
}

let kernel_key = (self.vtable.id(), batch.encoding_id());
if let Some(kernel) = kernels.read().get(&kernel_key)
&& let Some(result) = kernel.aggregate(&self.aggregate_fn, &batch)?
{
vortex_ensure!(
result.dtype() == &self.state_dtype,
"Aggregate kernel returned {}, expected {}",
result.dtype(),
self.state_dtype,
);
self.vtable.state_merge(&mut self.current_state, result)?;
return Ok(());
}

// Execute one step and try again
batch = batch.execute(&mut ctx)?;
}

// Otherwise, execute the batch until it is canonical and accumulate it into the state.
let canonical = batch.execute::<Canonical>(&mut ctx)?;

self.vtable
.state_accumulate(&mut self.current_state, &canonical, &mut ctx)
}

fn is_saturated(&self) -> bool {
self.vtable.state_is_saturated(&self.current_state)
}

fn flush(&mut self) -> VortexResult<Scalar> {
let partial = self.vtable.state_flush(&mut self.current_state)?;

#[cfg(debug_assertions)]
{
vortex_ensure!(
partial.dtype() == &self.state_dtype,
"Aggregate kernel returned incorrect DType on flush: expected {}, got {}",
self.state_dtype,
partial.dtype(),
);
}

Ok(partial)
}

fn finish(&mut self) -> VortexResult<Scalar> {
let partial = self.flush()?;
let result = self.vtable.finalize_scalar(partial)?;

vortex_ensure!(
result.dtype() == &self.return_dtype,
"Aggregate kernel returned incorrect DType on finalize: expected {}, got {}",
self.return_dtype,
result.dtype(),
);

Ok(result)
}
}
Loading
Loading