44use vortex_buffer:: Buffer ;
55use vortex_buffer:: BufferMut ;
66use vortex_error:: VortexResult ;
7+ use vortex_error:: vortex_bail;
78use vortex_error:: vortex_err;
89use vortex_mask:: AllOr ;
910use vortex_mask:: Mask ;
@@ -13,8 +14,11 @@ use crate::ExecutionCtx;
1314use crate :: IntoArray ;
1415use crate :: arrays:: Primitive ;
1516use crate :: arrays:: PrimitiveArray ;
17+ use crate :: compute;
1618use crate :: dtype:: DType ;
1719use crate :: dtype:: NativePType ;
20+ use crate :: dtype:: Nullability ;
21+ use crate :: dtype:: PType ;
1822use crate :: match_each_native_ptype;
1923use crate :: scalar_fn:: fns:: cast:: CastKernel ;
2024use crate :: vtable:: ValidityHelper ;
@@ -36,7 +40,7 @@ impl CastKernel for Primitive {
3640 . clone ( )
3741 . cast_nullability ( new_nullability, array. len ( ) ) ?;
3842
39- // If the bit width is the same, we can short-circuit and simply update the validity
43+ // Same ptype: zero-copy, just update validity.
4044 if array. ptype ( ) == new_ptype {
4145 // SAFETY: validity and data buffer still have same length
4246 return Ok ( Some ( unsafe {
@@ -49,9 +53,35 @@ impl CastKernel for Primitive {
4953 } ) ) ;
5054 }
5155
56+ // Same-width integers have identical bit representations due to 2's
57+ // complement. If all values fit in the target range, reinterpret with
58+ // no allocation.
59+ if array. ptype ( ) . is_int ( )
60+ && new_ptype. is_int ( )
61+ && array. ptype ( ) . byte_width ( ) == new_ptype. byte_width ( )
62+ {
63+ if !values_fit_in ( array, new_ptype) {
64+ vortex_bail ! (
65+ Compute : "Cannot cast {} to {} — values exceed target range" ,
66+ array. ptype( ) ,
67+ new_ptype,
68+ ) ;
69+ }
70+ // SAFETY: both types are integers with the same size and alignment, and
71+ // min/max confirm all valid values are representable in the target type.
72+ return Ok ( Some ( unsafe {
73+ PrimitiveArray :: new_unchecked_from_handle (
74+ array. buffer_handle ( ) . clone ( ) ,
75+ new_ptype,
76+ new_validity,
77+ )
78+ . into_array ( )
79+ } ) ) ;
80+ }
81+
5282 let mask = array. validity_mask ( ) ?;
5383
54- // Otherwise, we need to cast the values one-by-one
84+ // Otherwise, we need to cast the values one-by-one.
5585 Ok ( Some ( match_each_native_ptype ! ( new_ptype, |T | {
5686 match_each_native_ptype!( array. ptype( ) , |F | {
5787 PrimitiveArray :: new( cast:: <F , T >( array. as_slice( ) , mask) ?, new_validity)
@@ -61,34 +91,35 @@ impl CastKernel for Primitive {
6191 }
6292}
6393
94+ /// Returns `true` if all valid values in `array` are representable as `target_ptype`.
95+ fn values_fit_in ( array : & PrimitiveArray , target_ptype : PType ) -> bool {
96+ let target_dtype = DType :: Primitive ( target_ptype, Nullability :: NonNullable ) ;
97+ compute:: min_max ( & array. clone ( ) . into_array ( ) )
98+ . ok ( )
99+ . flatten ( )
100+ . is_none_or ( |mm| mm. min . cast ( & target_dtype) . is_ok ( ) && mm. max . cast ( & target_dtype) . is_ok ( ) )
101+ }
102+
64103fn cast < F : NativePType , T : NativePType > ( array : & [ F ] , mask : Mask ) -> VortexResult < Buffer < T > > {
104+ let try_cast = |src : F | -> VortexResult < T > {
105+ T :: from ( src) . ok_or_else ( || vortex_err ! ( Compute : "Failed to cast {} to {:?}" , src, T :: PTYPE ) )
106+ } ;
65107 match mask. bit_buffer ( ) {
108+ AllOr :: None => Ok ( Buffer :: zeroed ( array. len ( ) ) ) ,
66109 AllOr :: All => {
67110 let mut buffer = BufferMut :: with_capacity ( array. len ( ) ) ;
68- for item in array {
69- let item = T :: from ( * item) . ok_or_else (
70- || vortex_err ! ( Compute : "Failed to cast {} to {:?}" , item, T :: PTYPE ) ,
71- ) ?;
111+ for & src in array {
72112 // SAFETY: we've pre-allocated the required capacity
73- unsafe { buffer. push_unchecked ( item ) }
113+ unsafe { buffer. push_unchecked ( try_cast ( src ) ? ) }
74114 }
75115 Ok ( buffer. freeze ( ) )
76116 }
77- AllOr :: None => Ok ( Buffer :: zeroed ( array. len ( ) ) ) ,
78117 AllOr :: Some ( b) => {
79- // TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values
80118 let mut buffer = BufferMut :: with_capacity ( array. len ( ) ) ;
81- for ( item, valid) in array. iter ( ) . zip ( b. iter ( ) ) {
82- if valid {
83- let item = T :: from ( * item) . ok_or_else (
84- || vortex_err ! ( Compute : "Failed to cast {} to {:?}" , item, T :: PTYPE ) ,
85- ) ?;
86- // SAFETY: we've pre-allocated the required capacity
87- unsafe { buffer. push_unchecked ( item) }
88- } else {
89- // SAFETY: we've pre-allocated the required capacity
90- unsafe { buffer. push_unchecked ( T :: default ( ) ) }
91- }
119+ for ( & src, valid) in array. iter ( ) . zip ( b. iter ( ) ) {
120+ let dst = if valid { try_cast ( src) ? } else { T :: default ( ) } ;
121+ // SAFETY: we've pre-allocated the required capacity
122+ unsafe { buffer. push_unchecked ( dst) }
92123 }
93124 Ok ( buffer. freeze ( ) )
94125 }
@@ -183,7 +214,7 @@ mod test {
183214 . and_then ( |a| a. to_canonical ( ) . map ( |c| c. into_array ( ) ) )
184215 . unwrap_err ( ) ;
185216 assert ! ( matches!( error, VortexError :: Compute ( ..) ) ) ;
186- assert ! ( error. to_string( ) . contains( "Failed to cast -1 to U32 " ) ) ;
217+ assert ! ( error. to_string( ) . contains( "values exceed target range " ) ) ;
187218 }
188219
189220 #[ test]
@@ -223,6 +254,69 @@ mod test {
223254 ) ;
224255 }
225256
257+ /// Same-width integer cast where all values fit: should reinterpret the
258+ /// buffer without allocation (pointer identity).
259+ #[ test]
260+ fn cast_same_width_int_reinterprets_buffer ( ) -> vortex_error:: VortexResult < ( ) > {
261+ let src = PrimitiveArray :: from_iter ( [ 0u32 , 10 , 100 ] ) ;
262+ let src_ptr = src. as_slice :: < u32 > ( ) . as_ptr ( ) ;
263+
264+ let dst = src. into_array ( ) . cast ( PType :: I32 . into ( ) ) ?. to_primitive ( ) ;
265+ let dst_ptr = dst. as_slice :: < i32 > ( ) . as_ptr ( ) ;
266+
267+ // Zero-copy: the data pointer should be identical.
268+ assert_eq ! ( src_ptr as usize , dst_ptr as usize ) ;
269+ assert_arrays_eq ! ( dst, PrimitiveArray :: from_iter( [ 0i32 , 10 , 100 ] ) ) ;
270+ Ok ( ( ) )
271+ }
272+
273+ /// Same-width integer cast where values don't fit: should fall through
274+ /// to the allocating path and produce an error.
275+ #[ test]
276+ fn cast_same_width_int_out_of_range_errors ( ) {
277+ let arr = buffer ! [ u32 :: MAX ] . into_array ( ) ;
278+ let err = arr
279+ . cast ( PType :: I32 . into ( ) )
280+ . and_then ( |a| a. to_canonical ( ) . map ( |c| c. into_array ( ) ) )
281+ . unwrap_err ( ) ;
282+ assert ! ( matches!( err, VortexError :: Compute ( ..) ) ) ;
283+ }
284+
285+ /// All-null array cast between same-width types should succeed without
286+ /// touching the buffer contents.
287+ #[ test]
288+ fn cast_same_width_all_null ( ) -> vortex_error:: VortexResult < ( ) > {
289+ let arr = PrimitiveArray :: new ( buffer ! [ 0xFFu8 , 0xFF ] , Validity :: AllInvalid ) ;
290+ let casted = arr
291+ . into_array ( )
292+ . cast ( DType :: Primitive ( PType :: I8 , Nullability :: Nullable ) ) ?
293+ . to_primitive ( ) ;
294+ assert_eq ! ( casted. len( ) , 2 ) ;
295+ assert ! ( matches!( casted. validity( ) , Validity :: AllInvalid ) ) ;
296+ Ok ( ( ) )
297+ }
298+
299+ /// Same-width integer cast with nullable values: out-of-range nulls should
300+ /// not prevent the cast from succeeding.
301+ #[ test]
302+ fn cast_same_width_int_nullable_with_out_of_range_nulls ( ) -> vortex_error:: VortexResult < ( ) > {
303+ // The null position holds u32::MAX which doesn't fit in i32, but it's
304+ // masked as invalid so the cast should still succeed via reinterpret.
305+ let arr = PrimitiveArray :: new (
306+ buffer ! [ u32 :: MAX , 0u32 , 42u32 ] ,
307+ Validity :: from_iter ( [ false , true , true ] ) ,
308+ ) ;
309+ let casted = arr
310+ . into_array ( )
311+ . cast ( DType :: Primitive ( PType :: I32 , Nullability :: Nullable ) ) ?
312+ . to_primitive ( ) ;
313+ assert_arrays_eq ! (
314+ casted,
315+ PrimitiveArray :: from_option_iter( [ None , Some ( 0i32 ) , Some ( 42 ) ] )
316+ ) ;
317+ Ok ( ( ) )
318+ }
319+
226320 #[ rstest]
227321 #[ case( buffer![ 0u8 , 1 , 2 , 3 , 255 ] . into_array( ) ) ]
228322 #[ case( buffer![ 0u16 , 100 , 1000 , 65535 ] . into_array( ) ) ]
0 commit comments