Skip to content

Commit 12fe78e

Browse files
authored
perf: skip allocation for prim cast if possible (#6997)
Reinterpret int primitive arrays in-place if byte widths match rather than allocating a new buffer. --------- Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 704f560 commit 12fe78e

1 file changed

Lines changed: 115 additions & 21 deletions

File tree

  • vortex-array/src/arrays/primitive/compute

vortex-array/src/arrays/primitive/compute/cast.rs

Lines changed: 115 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use vortex_buffer::Buffer;
55
use vortex_buffer::BufferMut;
66
use vortex_error::VortexResult;
7+
use vortex_error::vortex_bail;
78
use vortex_error::vortex_err;
89
use vortex_mask::AllOr;
910
use vortex_mask::Mask;
@@ -13,8 +14,11 @@ use crate::ExecutionCtx;
1314
use crate::IntoArray;
1415
use crate::arrays::Primitive;
1516
use crate::arrays::PrimitiveArray;
17+
use crate::compute;
1618
use crate::dtype::DType;
1719
use crate::dtype::NativePType;
20+
use crate::dtype::Nullability;
21+
use crate::dtype::PType;
1822
use crate::match_each_native_ptype;
1923
use crate::scalar_fn::fns::cast::CastKernel;
2024
use crate::vtable::ValidityHelper;
@@ -36,7 +40,7 @@ impl CastKernel for Primitive {
3640
.clone()
3741
.cast_nullability(new_nullability, array.len())?;
3842

39-
// If the bit width is the same, we can short-circuit and simply update the validity
43+
// Same ptype: zero-copy, just update validity.
4044
if array.ptype() == new_ptype {
4145
// SAFETY: validity and data buffer still have same length
4246
return Ok(Some(unsafe {
@@ -49,9 +53,35 @@ impl CastKernel for Primitive {
4953
}));
5054
}
5155

56+
// Same-width integers have identical bit representations due to 2's
57+
// complement. If all values fit in the target range, reinterpret with
58+
// no allocation.
59+
if array.ptype().is_int()
60+
&& new_ptype.is_int()
61+
&& array.ptype().byte_width() == new_ptype.byte_width()
62+
{
63+
if !values_fit_in(array, new_ptype) {
64+
vortex_bail!(
65+
Compute: "Cannot cast {} to {} — values exceed target range",
66+
array.ptype(),
67+
new_ptype,
68+
);
69+
}
70+
// SAFETY: both types are integers with the same size and alignment, and
71+
// min/max confirm all valid values are representable in the target type.
72+
return Ok(Some(unsafe {
73+
PrimitiveArray::new_unchecked_from_handle(
74+
array.buffer_handle().clone(),
75+
new_ptype,
76+
new_validity,
77+
)
78+
.into_array()
79+
}));
80+
}
81+
5282
let mask = array.validity_mask()?;
5383

54-
// Otherwise, we need to cast the values one-by-one
84+
// Otherwise, we need to cast the values one-by-one.
5585
Ok(Some(match_each_native_ptype!(new_ptype, |T| {
5686
match_each_native_ptype!(array.ptype(), |F| {
5787
PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
@@ -61,34 +91,35 @@ impl CastKernel for Primitive {
6191
}
6292
}
6393

94+
/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
95+
fn values_fit_in(array: &PrimitiveArray, target_ptype: PType) -> bool {
96+
let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
97+
compute::min_max(&array.clone().into_array())
98+
.ok()
99+
.flatten()
100+
.is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
101+
}
102+
64103
fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
104+
let try_cast = |src: F| -> VortexResult<T> {
105+
T::from(src).ok_or_else(|| vortex_err!(Compute: "Failed to cast {} to {:?}", src, T::PTYPE))
106+
};
65107
match mask.bit_buffer() {
108+
AllOr::None => Ok(Buffer::zeroed(array.len())),
66109
AllOr::All => {
67110
let mut buffer = BufferMut::with_capacity(array.len());
68-
for item in array {
69-
let item = T::from(*item).ok_or_else(
70-
|| vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
71-
)?;
111+
for &src in array {
72112
// SAFETY: we've pre-allocated the required capacity
73-
unsafe { buffer.push_unchecked(item) }
113+
unsafe { buffer.push_unchecked(try_cast(src)?) }
74114
}
75115
Ok(buffer.freeze())
76116
}
77-
AllOr::None => Ok(Buffer::zeroed(array.len())),
78117
AllOr::Some(b) => {
79-
// TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values
80118
let mut buffer = BufferMut::with_capacity(array.len());
81-
for (item, valid) in array.iter().zip(b.iter()) {
82-
if valid {
83-
let item = T::from(*item).ok_or_else(
84-
|| vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
85-
)?;
86-
// SAFETY: we've pre-allocated the required capacity
87-
unsafe { buffer.push_unchecked(item) }
88-
} else {
89-
// SAFETY: we've pre-allocated the required capacity
90-
unsafe { buffer.push_unchecked(T::default()) }
91-
}
119+
for (&src, valid) in array.iter().zip(b.iter()) {
120+
let dst = if valid { try_cast(src)? } else { T::default() };
121+
// SAFETY: we've pre-allocated the required capacity
122+
unsafe { buffer.push_unchecked(dst) }
92123
}
93124
Ok(buffer.freeze())
94125
}
@@ -183,7 +214,7 @@ mod test {
183214
.and_then(|a| a.to_canonical().map(|c| c.into_array()))
184215
.unwrap_err();
185216
assert!(matches!(error, VortexError::Compute(..)));
186-
assert!(error.to_string().contains("Failed to cast -1 to U32"));
217+
assert!(error.to_string().contains("values exceed target range"));
187218
}
188219

189220
#[test]
@@ -223,6 +254,69 @@ mod test {
223254
);
224255
}
225256

257+
/// Same-width integer cast where all values fit: should reinterpret the
258+
/// buffer without allocation (pointer identity).
259+
#[test]
260+
fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
261+
let src = PrimitiveArray::from_iter([0u32, 10, 100]);
262+
let src_ptr = src.as_slice::<u32>().as_ptr();
263+
264+
let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
265+
let dst_ptr = dst.as_slice::<i32>().as_ptr();
266+
267+
// Zero-copy: the data pointer should be identical.
268+
assert_eq!(src_ptr as usize, dst_ptr as usize);
269+
assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
270+
Ok(())
271+
}
272+
273+
/// Same-width integer cast where values don't fit: should fall through
274+
/// to the allocating path and produce an error.
275+
#[test]
276+
fn cast_same_width_int_out_of_range_errors() {
277+
let arr = buffer![u32::MAX].into_array();
278+
let err = arr
279+
.cast(PType::I32.into())
280+
.and_then(|a| a.to_canonical().map(|c| c.into_array()))
281+
.unwrap_err();
282+
assert!(matches!(err, VortexError::Compute(..)));
283+
}
284+
285+
/// All-null array cast between same-width types should succeed without
286+
/// touching the buffer contents.
287+
#[test]
288+
fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
289+
let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
290+
let casted = arr
291+
.into_array()
292+
.cast(DType::Primitive(PType::I8, Nullability::Nullable))?
293+
.to_primitive();
294+
assert_eq!(casted.len(), 2);
295+
assert!(matches!(casted.validity(), Validity::AllInvalid));
296+
Ok(())
297+
}
298+
299+
/// Same-width integer cast with nullable values: out-of-range nulls should
300+
/// not prevent the cast from succeeding.
301+
#[test]
302+
fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
303+
// The null position holds u32::MAX which doesn't fit in i32, but it's
304+
// masked as invalid so the cast should still succeed via reinterpret.
305+
let arr = PrimitiveArray::new(
306+
buffer![u32::MAX, 0u32, 42u32],
307+
Validity::from_iter([false, true, true]),
308+
);
309+
let casted = arr
310+
.into_array()
311+
.cast(DType::Primitive(PType::I32, Nullability::Nullable))?
312+
.to_primitive();
313+
assert_arrays_eq!(
314+
casted,
315+
PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
316+
);
317+
Ok(())
318+
}
319+
226320
#[rstest]
227321
#[case(buffer![0u8, 1, 2, 3, 255].into_array())]
228322
#[case(buffer![0u16, 100, 1000, 65535].into_array())]

0 commit comments

Comments
 (0)