Skip to content

Commit fcbc002

Browse files
committed
fix: array to array cast
1 parent 99aed8e commit fcbc002

4 files changed

Lines changed: 235 additions & 30 deletions

File tree

native/core/src/parquet/parquet_support.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ fn parquet_convert_array(
200200
.with_timezone(Arc::clone(tz)),
201201
))
202202
}
203+
// Keep scan-time nested parquet conversion aligned with Spark's legacy
204+
// array<Date> -> array<Int> behavior without affecting scalar Date -> Int casts.
205+
(Date32, Int32) => Ok(new_null_array(to_type, array.len())),
203206
(Map(_, ordered_from), Map(_, ordered_to)) if ordered_from == ordered_to =>
204207
parquet_convert_map_to_map(array.as_map(), to_type, parquet_options, *ordered_to)
205208
,

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ use crate::{cast_whole_num_to_binary, BinaryOutputStyle};
4141
use crate::{EvalMode, SparkError};
4242
use arrow::array::builder::StringBuilder;
4343
use arrow::array::{
44-
BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray,
44+
new_null_array, BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray,
45+
StringArray, StructArray,
4546
};
46-
use arrow::compute::can_cast_types;
4747
use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
4848
use arrow::datatypes::{Field, Fields, GenericBinaryType};
4949
use arrow::error::ArrowError;
@@ -294,6 +294,9 @@ pub(crate) fn cast_array(
294294
};
295295

296296
let cast_result = match (&from_type, to_type) {
297+
// Null arrays carry no concrete values, so Arrow's native cast can change only the
298+
// logical type while preserving length and nullness.
299+
(Null, _) => Ok(cast_with_options(&array, to_type, &native_cast_options)?),
297300
(Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
298301
(LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
299302
(Utf8, Timestamp(_, _)) => {
@@ -366,8 +369,23 @@ pub(crate) fn cast_array(
366369
cast_options,
367370
)?),
368371
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
369-
(List(_), List(_)) if can_cast_types(&from_type, to_type) => {
370-
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
372+
(List(_), List(to)) => {
373+
let list_array = array.as_list::<i32>();
374+
let casted_values = match (list_array.values().data_type(), to.data_type()) {
375+
// Spark legacy array casts produce null elements for array<Date> -> array<Int>.
376+
(Date32, Int32) => new_null_array(to.data_type(), list_array.values().len()),
377+
_ => cast_array(
378+
Arc::clone(list_array.values()),
379+
to.data_type(),
380+
cast_options,
381+
)?,
382+
};
383+
Ok(Arc::new(ListArray::new(
384+
Arc::clone(to),
385+
list_array.offsets().clone(),
386+
casted_values,
387+
list_array.nulls().cloned(),
388+
)) as ArrayRef)
371389
}
372390
(Map(_, _), Map(_, _)) => Ok(cast_map_to_map(&array, &from_type, to_type, cast_options)?),
373391
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
@@ -803,7 +821,8 @@ fn cast_binary_formatter(value: &[u8]) -> String {
803821
#[cfg(test)]
804822
mod tests {
805823
use super::*;
806-
use arrow::array::StringArray;
824+
use arrow::array::{ListArray, NullArray, StringArray};
825+
use arrow::buffer::OffsetBuffer;
807826
use arrow::datatypes::TimestampMicrosecondType;
808827
use arrow::datatypes::{Field, Fields};
809828
#[test]
@@ -929,8 +948,6 @@ mod tests {
929948

930949
#[test]
931950
fn test_cast_string_array_to_string() {
932-
use arrow::array::ListArray;
933-
use arrow::buffer::OffsetBuffer;
934951
let values_array =
935952
StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("a"), None, None]);
936953
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
@@ -955,8 +972,6 @@ mod tests {
955972

956973
#[test]
957974
fn test_cast_i32_array_to_string() {
958-
use arrow::array::ListArray;
959-
use arrow::buffer::OffsetBuffer;
960975
let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]);
961976
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
962977
let item_field = Arc::new(Field::new("item", DataType::Int32, true));
@@ -977,4 +992,33 @@ mod tests {
977992
assert_eq!(r#"[null]"#, string_array.value(2));
978993
assert_eq!(r#"[]"#, string_array.value(3));
979994
}
995+
996+
#[test]
997+
fn test_cast_array_of_nulls_to_array() {
998+
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 2, 3, 3].into());
999+
let from_item_field = Arc::new(Field::new("item", DataType::Null, true));
1000+
let from_array: ArrayRef = Arc::new(ListArray::new(
1001+
from_item_field,
1002+
offsets_buffer,
1003+
Arc::new(NullArray::new(3)),
1004+
None,
1005+
));
1006+
1007+
let to_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1008+
let to_array = cast_array(
1009+
from_array,
1010+
&to_type,
1011+
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
1012+
)
1013+
.unwrap();
1014+
1015+
let result = to_array.as_list::<i32>();
1016+
assert_eq!(3, result.len());
1017+
assert_eq!(result.value_offsets(), &[0, 2, 3, 3]);
1018+
1019+
let values = result.values().as_primitive::<Int32Type>();
1020+
assert_eq!(3, values.len());
1021+
assert_eq!(3, values.null_count());
1022+
assert!(values.iter().all(|value| value.is_none()));
1023+
}
9801024
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
141141

142142
(fromType, toType) match {
143143
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
144+
case (ArrayType(DataTypes.DateType, _), ArrayType(toElementType, _))
145+
if toElementType != DataTypes.IntegerType && toElementType != DataTypes.StringType =>
146+
unsupported(fromType, toType)
144147
case (dt: ArrayType, DataTypes.StringType) if dt.elementType == DataTypes.BinaryType =>
145148
Incompatible()
146149
case (dt: ArrayType, DataTypes.StringType) =>

0 commit comments

Comments
 (0)