Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::reset(&self, pa

pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

Expand Down Expand Up @@ -120,7 +120,7 @@ pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::coerce_args(&self,

pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>

pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>

Expand Down Expand Up @@ -228,7 +228,7 @@ pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::reset(&self, partial: &

pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

Expand Down Expand Up @@ -306,7 +306,7 @@ pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::reset(&self, partia

pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

Expand Down Expand Up @@ -368,7 +368,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::reset(&self, partial: &mut Sel

pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

Expand All @@ -388,6 +388,8 @@ pub fn vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel::grouped_a

pub fn vortex_array::aggregate_fn::kernels::DynGroupedAggregateKernel::grouped_aggregate_fixed_size(&self, aggregate_fn: &vortex_array::aggregate_fn::AggregateFnRef, groups: &vortex_array::arrays::FixedSizeListArray) -> vortex_error::VortexResult<core::option::Option<vortex_array::ArrayRef>>

pub mod vortex_array::aggregate_fn::proto

pub mod vortex_array::aggregate_fn::session

pub struct vortex_array::aggregate_fn::session::AggregateFnSession
Expand Down Expand Up @@ -500,6 +502,12 @@ pub fn vortex_array::aggregate_fn::AggregateFnRef::state_dtype(&self, input_dtyp

pub fn vortex_array::aggregate_fn::AggregateFnRef::vtable_ref<V: vortex_array::aggregate_fn::AggregateFnVTable>(&self) -> core::option::Option<&V>

impl vortex_array::aggregate_fn::AggregateFnRef

pub fn vortex_array::aggregate_fn::AggregateFnRef::from_proto(proto: &vortex_proto::expr::AggregateFn, session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self>

pub fn vortex_array::aggregate_fn::AggregateFnRef::serialize_proto(&self) -> vortex_error::VortexResult<vortex_proto::expr::AggregateFn>

impl core::clone::Clone for vortex_array::aggregate_fn::AggregateFnRef

pub fn vortex_array::aggregate_fn::AggregateFnRef::clone(&self) -> vortex_array::aggregate_fn::AggregateFnRef
Expand Down Expand Up @@ -638,7 +646,7 @@ pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::reset(&self, pa

pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::fns::is_constant::IsConstant::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

Expand All @@ -654,7 +662,7 @@ pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::coerce_args(&self,

pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>

pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::deserialize(&self, metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>

pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::empty_partial(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>

Expand Down Expand Up @@ -706,7 +714,7 @@ pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::reset(&self, partial: &

pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::fns::min_max::MinMax::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

Expand Down Expand Up @@ -740,7 +748,7 @@ pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::reset(&self, partia

pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

Expand Down Expand Up @@ -774,7 +782,7 @@ pub fn vortex_array::aggregate_fn::fns::sum::Sum::reset(&self, partial: &mut Sel

pub fn vortex_array::aggregate_fn::fns::sum::Sum::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
pub fn vortex_array::aggregate_fn::fns::sum::Sum::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>

pub fn vortex_array::aggregate_fn::fns::sum::Sum::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

Expand Down
12 changes: 12 additions & 0 deletions vortex-array/src/aggregate_fn/fns/is_constant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,18 @@ impl AggregateFnVTable for IsConstant {
AggregateFnId::new_ref("vortex.is_constant")
}

fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(vec![]))
}

fn deserialize(
&self,
_metadata: &[u8],
_session: &vortex_session::VortexSession,
) -> VortexResult<Self::Options> {
Ok(EmptyOptions)
}

fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
match input_dtype {
DType::Null => None,
Expand Down
24 changes: 24 additions & 0 deletions vortex-array/src/aggregate_fn/fns/is_sorted/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use std::fmt::Formatter;

use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;

use self::bool::check_bool_sorted;
use self::decimal::check_decimal_sorted;
Expand Down Expand Up @@ -231,6 +232,29 @@ impl AggregateFnVTable for IsSorted {
AggregateFnId::new_ref("vortex.is_sorted")
}

fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(vec![u8::from(options.strict)]))
}

fn deserialize(
&self,
metadata: &[u8],
_session: &vortex_session::VortexSession,
) -> VortexResult<Self::Options> {
let &[strict_byte] = metadata else {
vortex_bail!(
"IsSorted: expected 1 byte of metadata, got {}",
metadata.len()
);
};
let strict = match strict_byte {
0 => false,
1 => true,
_ => vortex_bail!("IsSorted: expected 0 or 1 for strict, got {}", strict_byte),
};
Ok(IsSortedOptions { strict })
}

fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
match input_dtype {
DType::Null | DType::Struct(..) | DType::List(..) | DType::FixedSizeList(..) => None,
Expand Down
12 changes: 12 additions & 0 deletions vortex-array/src/aggregate_fn/fns/min_max/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,18 @@ impl AggregateFnVTable for MinMax {
AggregateFnId::new_ref("vortex.min_max")
}

fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(vec![]))
}

fn deserialize(
&self,
_metadata: &[u8],
_session: &vortex_session::VortexSession,
) -> VortexResult<Self::Options> {
Ok(EmptyOptions)
}

fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
match input_dtype {
DType::Bool(_)
Expand Down
12 changes: 12 additions & 0 deletions vortex-array/src/aggregate_fn/fns/nan_count/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ impl AggregateFnVTable for NanCount {
AggregateFnId::new_ref("vortex.nan_count")
}

fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(vec![]))
}

fn deserialize(
&self,
_metadata: &[u8],
_session: &vortex_session::VortexSession,
) -> VortexResult<Self::Options> {
Ok(EmptyOptions)
}

fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
if let DType::Primitive(ptype, ..) = input_dtype
&& ptype.is_float()
Expand Down
12 changes: 12 additions & 0 deletions vortex-array/src/aggregate_fn/fns/sum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ impl AggregateFnVTable for Sum {
AggregateFnId::new_ref("vortex.sum")
}

fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(vec![]))
}

fn deserialize(
&self,
_metadata: &[u8],
_session: &vortex_session::VortexSession,
) -> VortexResult<Self::Options> {
Ok(EmptyOptions)
}

fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
// When a sum overflows, we return a sum _value_ of null. Therefore, we all return dtypes
// are nullable.
Expand Down
1 change: 1 addition & 0 deletions vortex-array/src/aggregate_fn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub use options::*;

pub mod fns;
pub mod kernels;
pub mod proto;
pub mod session;

/// A unique identifier for an aggregate function.
Expand Down
87 changes: 87 additions & 0 deletions vortex-array/src/aggregate_fn/proto.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::sync::Arc;

use arcref::ArcRef;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_proto::expr as pb;
use vortex_session::VortexSession;

use crate::aggregate_fn::AggregateFnId;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::session::AggregateFnSessionExt;

impl AggregateFnRef {
/// Serialize this aggregate function to its protobuf representation.
///
/// Note: the serialization format is not stable and may change between versions.
pub fn serialize_proto(&self) -> VortexResult<pb::AggregateFn> {
let metadata = self
.options()
.serialize()?
.ok_or_else(|| vortex_err!("Aggregate function '{}' is not serializable", self.id()))?;

Ok(pb::AggregateFn {
id: self.id().to_string(),
metadata: Some(metadata),
})
}

/// Deserialize an aggregate function from its protobuf representation.
///
/// Looks up the aggregate function plugin by ID in the session's registry
/// and delegates deserialization to it.
///
/// Note: the serialization format is not stable and may change between versions.
pub fn from_proto(proto: &pb::AggregateFn, session: &VortexSession) -> VortexResult<Self> {
let agg_fn_id: AggregateFnId = ArcRef::new_arc(Arc::from(proto.id.as_str()));
let plugin = session
.aggregate_fns()
.registry()
.find(&agg_fn_id)
.ok_or_else(|| vortex_err!("unknown aggregate function id: {}", proto.id))?;
let agg_fn = plugin.deserialize(proto.metadata(), session)?;

if agg_fn.id() != agg_fn_id {
vortex_bail!(
"Aggregate function ID mismatch: expected {}, got {}",
agg_fn_id,
agg_fn.id()
);
}

Ok(agg_fn)
}
}

#[cfg(test)]
mod tests {
use prost::Message;
use vortex_proto::expr as pb;
use vortex_session::VortexSession;

use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::AggregateFnVTableExt;
use crate::aggregate_fn::EmptyOptions;
use crate::aggregate_fn::fns::sum::Sum;
use crate::aggregate_fn::session::AggregateFnSession;
use crate::aggregate_fn::session::AggregateFnSessionExt;

#[test]
fn aggregate_fn_serde() {
let session = VortexSession::empty().with::<AggregateFnSession>();
session.aggregate_fns().register(Sum);

let agg_fn = Sum.bind(EmptyOptions);

let serialized = agg_fn.serialize_proto().unwrap();
let buf = serialized.encode_to_vec();
let deserialized_proto = pb::AggregateFn::decode(buf.as_slice()).unwrap();
let deserialized = AggregateFnRef::from_proto(&deserialized_proto, &session).unwrap();

assert_eq!(deserialized, agg_fn);
}
}
9 changes: 9 additions & 0 deletions vortex-array/src/aggregate_fn/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use crate::aggregate_fn::AggregateFnVTable;
use crate::aggregate_fn::fns::is_constant::IsConstant;
use crate::aggregate_fn::fns::is_sorted::IsSorted;
use crate::aggregate_fn::fns::min_max::MinMax;
use crate::aggregate_fn::fns::nan_count::NanCount;
use crate::aggregate_fn::fns::sum::Sum;
use crate::aggregate_fn::kernels::DynAggregateKernel;
use crate::aggregate_fn::kernels::DynGroupedAggregateKernel;
use crate::arrays::Chunked;
Expand Down Expand Up @@ -47,6 +49,13 @@ impl Default for AggregateFnSession {
grouped_kernels: RwLock::new(HashMap::default()),
};

// Register the built-in aggregate functions
this.register(IsConstant);
this.register(IsSorted);
this.register(MinMax);
this.register(NanCount);
this.register(Sum);

// Register the built-in aggregate kernels.
this.register_aggregate_kernel(Chunked::ID, None, &ChunkedArrayAggregate);
this.register_aggregate_kernel(Dict::ID, Some(MinMax.id()), &DictMinMaxKernel);
Expand Down
6 changes: 6 additions & 0 deletions vortex-proto/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ message Expr {
optional bytes metadata = 3;
}

// Captures a serialized aggregate function with its ID and options metadata.
message AggregateFn {
string id = 1;
optional bytes metadata = 2;
}

// Options for `vortex.literal`
message LiteralOpts {
vortex.scalar.Scalar value = 1;
Expand Down
Loading
Loading