diff --git a/native/core/src/parquet/parquet_support.rs b/native/core/src/parquet/parquet_support.rs index 3418a17c43..7369a510a2 100644 --- a/native/core/src/parquet/parquet_support.rs +++ b/native/core/src/parquet/parquet_support.rs @@ -200,6 +200,9 @@ fn parquet_convert_array( .with_timezone(Arc::clone(tz)), )) } + // Keep scan-time nested parquet conversion aligned with Spark's legacy + // array -> array behavior without affecting scalar Date -> Int casts. + (Date32, Int32) => Ok(new_null_array(to_type, array.len())), (Map(_, ordered_from), Map(_, ordered_to)) if ordered_from == ordered_to => parquet_convert_map_to_map(array.as_map(), to_type, parquet_options, *ordered_to) , diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 6e5a80c84a..d1da9ba923 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -41,9 +41,9 @@ use crate::{cast_whole_num_to_binary, BinaryOutputStyle}; use crate::{EvalMode, SparkError}; use arrow::array::builder::StringBuilder; use arrow::array::{ - BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray, + new_null_array, BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, + StringArray, StructArray, }; -use arrow::compute::can_cast_types; use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema}; use arrow::datatypes::{Field, Fields, GenericBinaryType}; use arrow::error::ArrowError; @@ -294,6 +294,9 @@ pub(crate) fn cast_array( }; let cast_result = match (&from_type, to_type) { + // Null arrays carry no concrete values, so Arrow's native cast can change only the + // logical type while preserving length and nullness. + (Null, _) => Ok(cast_with_options(&array, to_type, &native_cast_options)?), (Utf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), (Utf8, Timestamp(_, _)) => { @@ -366,8 +369,23 @@ pub(crate) fn cast_array( cast_options, )?), (List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?), - (List(_), List(_)) if can_cast_types(&from_type, to_type) => { - Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + (List(_), List(to)) => { + let list_array = array.as_list::(); + let casted_values = match (list_array.values().data_type(), to.data_type()) { + // Spark legacy array casts produce null elements for array -> array. + (Date32, Int32) => new_null_array(to.data_type(), list_array.values().len()), + _ => cast_array( + Arc::clone(list_array.values()), + to.data_type(), + cast_options, + )?, + }; + Ok(Arc::new(ListArray::new( + Arc::clone(to), + list_array.offsets().clone(), + casted_values, + list_array.nulls().cloned(), + )) as ArrayRef) } (Map(_, _), Map(_, _)) => Ok(cast_map_to_map(&array, &from_type, to_type, cast_options)?), (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) @@ -803,7 +821,8 @@ fn cast_binary_formatter(value: &[u8]) -> String { #[cfg(test)] mod tests { use super::*; - use arrow::array::StringArray; + use arrow::array::{ListArray, NullArray, StringArray}; + use arrow::buffer::OffsetBuffer; use arrow::datatypes::TimestampMicrosecondType; use arrow::datatypes::{Field, Fields}; #[test] @@ -929,8 +948,6 @@ mod tests { #[test] fn test_cast_string_array_to_string() { - use arrow::array::ListArray; - use arrow::buffer::OffsetBuffer; let values_array = StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("a"), None, None]); let offsets_buffer = OffsetBuffer::::new(vec![0, 3, 5, 6, 6].into()); @@ -955,8 +972,6 @@ mod tests { #[test] fn test_cast_i32_array_to_string() { - use arrow::array::ListArray; - use arrow::buffer::OffsetBuffer; let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]); let offsets_buffer = OffsetBuffer::::new(vec![0, 3, 5, 6, 6].into()); let item_field = Arc::new(Field::new("item", DataType::Int32, true)); @@ -977,4 +992,33 @@ mod tests { assert_eq!(r#"[null]"#, string_array.value(2)); assert_eq!(r#"[]"#, string_array.value(3)); } + + #[test] + fn test_cast_array_of_nulls_to_array() { + let offsets_buffer = OffsetBuffer::::new(vec![0, 2, 3, 3].into()); + let from_item_field = Arc::new(Field::new("item", DataType::Null, true)); + let from_array: ArrayRef = Arc::new(ListArray::new( + from_item_field, + offsets_buffer, + Arc::new(NullArray::new(3)), + None, + )); + + let to_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + let to_array = cast_array( + from_array, + &to_type, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), + ) + .unwrap(); + + let result = to_array.as_list::(); + assert_eq!(3, result.len()); + assert_eq!(result.value_offsets(), &[0, 2, 3, 3]); + + let values = result.values().as_primitive::(); + assert_eq!(3, values.len()); + assert_eq!(3, values.null_count()); + assert!(values.iter().all(|value| value.is_none())); + } } diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 2188f8e9af..cb14718ca8 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -141,6 +141,9 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { (fromType, toType) match { case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible() + case (ArrayType(DataTypes.DateType, _), ArrayType(toElementType, _)) + if toElementType != DataTypes.IntegerType && toElementType != DataTypes.StringType => + unsupported(fromType, toType) case (dt: ArrayType, DataTypes.StringType) if dt.elementType == DataTypes.BinaryType => Incompatible() case (dt: ArrayType, DataTypes.StringType) => diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index f3eb9033b5..c191229e8b 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.{CometTestBase, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, monotonically_increasing_id} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DataTypes, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType, TimestampType} import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.rules.CometScanTypeChecker @@ -1328,14 +1328,99 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("cast ArrayType to ArrayType") { + val types = Seq( + BooleanType, + StringType, + ByteType, + IntegerType, + LongType, + ShortType, + FloatType, + DoubleType, + DecimalType(10, 2), + DecimalType(38, 18), + DateType, + TimestampType, + BinaryType) + testArrayCastMatrix(types, ArrayType(_), generateArrays(100, _)) + } + + // TODO: This test failed with java.lang.OutOfMemoryError: Java heap space + ignore("cast nested ArrayType to nested ArrayType") { + val types = Seq( + BooleanType, + StringType, + ByteType, + IntegerType, + LongType, + ShortType, + FloatType, + DoubleType, + DecimalType(10, 2), + DecimalType(38, 18), + DateType, + TimestampType, + BinaryType) + testArrayCastMatrix( + types, + dt => ArrayType(ArrayType(dt)), + dt => generateArrays(100, ArrayType(dt))) + } + + private def testArrayCastMatrix( + elementTypes: Seq[DataType], + wrapType: DataType => DataType, + generateInput: DataType => DataFrame): Unit = { + for (fromType <- elementTypes) { + val input = generateInput(fromType) + for (toType <- elementTypes) { + val name = s"cast $fromType to $toType" + val fromWrappedType = wrapType(fromType) + val toWrappedType = wrapType(toType) + if (fromType != toType && + testNames.contains(name) && + !tags + .get(name) + .exists(s => s.contains("org.scalatest.Ignore")) && + Cast.canCast(fromWrappedType, toWrappedType) && + CometCast.isSupported(fromWrappedType, toWrappedType, None, CometEvalMode.LEGACY) == + Compatible()) { + val legacyOnly = + fromType == DateType || (fromType == BooleanType && toType == TimestampType) + val ansiSupported = + CometCast.isSupported(fromWrappedType, toWrappedType, None, CometEvalMode.ANSI) == + Compatible() + val trySupported = + CometCast.isSupported(fromWrappedType, toWrappedType, None, CometEvalMode.TRY) == + Compatible() + castTest( + input, + toWrappedType, + hasIncompatibleType = legacyOnly, + testAnsi = !legacyOnly && ansiSupported, + testTry = !legacyOnly && trySupported) + } + } + } + } + private def generateFloats(): DataFrame = { withNulls(gen.generateFloats(dataSize)).toDF("a") } + private def generateSafeFloatValues(): Seq[Float] = { + Seq(-123456.75f, -1.0f, 0.0f, 1.0f, 123456.75f) + } + private def generateDoubles(): DataFrame = { withNulls(gen.generateDoubles(dataSize)).toDF("a") } + private def generateSafeDoubleValues(): Seq[Double] = { + Seq(-123456.75d, -1.0d, 0.0d, 1.0d, 123456.75d) + } + private def generateBools(): DataFrame = { withNulls(Seq(true, false)).toDF("a") } @@ -1356,10 +1441,63 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withNulls(gen.generateLongs(dataSize)).toDF("a") } - private def generateArrays(rowSize: Int, elementType: DataType): DataFrame = { + private def generateSafeLongValues(): Seq[Long] = { + Seq(-123456789L, -1L, 0L, 1L, 123456789L) + } + + private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = { import scala.collection.JavaConverters._ val schema = StructType(Seq(StructField("a", ArrayType(elementType), true))) - spark.createDataFrame(gen.generateRows(rowSize, schema).asJava, schema) + def buildRows(values: Seq[Any]): Seq[Row] = { + Range(0, rowNum).map { i => + Row( + Seq[Any]( + values(i % values.length), + if (i % 3 == 0) null else values((i + 1) % values.length), + values((i + 2) % values.length))) + } + } + + def withEdgeCaseRows(generatedRows: Seq[Row]): Seq[Row] = { + val sampleValue = + generatedRows.find(_.get(0) != null).flatMap(_.getSeq[Any](0).headOption).orNull + Seq(Row(Seq(sampleValue, null, sampleValue)), Row(Seq.empty[Any]), Row(null)) ++ + generatedRows + } + + elementType match { + case DateType => + val stringSchema = StructType(Seq(StructField("a", ArrayType(StringType), true))) + spark + .createDataFrame( + withEdgeCaseRows(buildRows(generateDateLiterals())).asJava, + stringSchema) + .select(col("a").cast(ArrayType(DateType)).as("a")) + case TimestampType => + val stringSchema = StructType(Seq(StructField("a", ArrayType(StringType), true))) + spark + .createDataFrame( + withEdgeCaseRows(buildRows(generateTimestampLiterals())).asJava, + stringSchema) + .select(col("a").cast(ArrayType(TimestampType)).as("a")) + case FloatType => + spark.createDataFrame( + withEdgeCaseRows(buildRows(generateSafeFloatValues())).asJava, + schema) + case DoubleType => + spark.createDataFrame( + withEdgeCaseRows(buildRows(generateSafeDoubleValues())).asJava, + schema) + case LongType => + spark.createDataFrame( + withEdgeCaseRows(buildRows(generateSafeLongValues())).asJava, + schema) + case BinaryType => + val values = generateBinary().collect().map(_.getAs[Array[Byte]]("a")).toSeq + spark.createDataFrame(withEdgeCaseRows(buildRows(values)).asJava, schema) + case _ => + spark.createDataFrame(withEdgeCaseRows(gen.generateRows(rowNum, schema)).asJava, schema) + } } // https://github.com/apache/datafusion-comet/issues/2038 @@ -1424,7 +1562,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withNulls(values).toDF("a") } - private def generateDates(): DataFrame = { + private def generateDateLiterals(): Seq[String] = { // add 1st, 10th, 20th of each month from epoch to 2027 val sampledDates = (1970 to 2027).flatMap { year => (1 to 12).flatMap { month => @@ -1481,7 +1619,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Edge cases val edgeCases = Seq("1969-12-31", "2000-02-29", "999-01-01", "12345-01-01") - val values = (sampledDates ++ dstDates ++ edgeCases).distinct + (sampledDates ++ dstDates ++ edgeCases).distinct + } + + private def generateDates(): DataFrame = { + val values = generateDateLiterals() withNulls(values).toDF("b").withColumn("a", col("b").cast(DataTypes.DateType)).drop("b") } @@ -1493,13 +1635,15 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { values.toDF("str").select(col("str").cast(DataTypes.TimestampType).as("a"))) } + private def generateTimestampLiterals(): Seq[String] = + Seq( + "2024-01-01T12:34:56.123456", + "2024-01-01T01:00:00Z", + "9999-12-31T01:00:00-02:00", + "2024-12-31T01:00:00+02:00") + private def generateTimestamps(): DataFrame = { - val values = - Seq( - "2024-01-01T12:34:56.123456", - "2024-01-01T01:00:00Z", - "9999-12-31T01:00:00-02:00", - "2024-12-31T01:00:00+02:00") + val values = generateTimestampLiterals() withNulls(values) .toDF("str") .withColumn("a", col("str").cast(DataTypes.TimestampType)) @@ -1603,10 +1747,14 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withTempPath { dir => val data = roundtripParquet(input, dir).coalesce(1) + val dataWithRowId = data.withColumn("__row_id", monotonically_increasing_id()) withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { // cast() should return null for invalid inputs when ansi mode is disabled - val df = data.select(col("a"), col("a").cast(toType)).orderBy(col("a")) + val df = dataWithRowId + .select(col("__row_id"), col("a"), col("a").cast(toType).as("casted")) + .orderBy(col("__row_id")) + .drop("__row_id") if (useDataFrameDiff) { assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType) } else { @@ -1618,10 +1766,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } if (testTry) { - data.createOrReplaceTempView("t") - // try_cast() should always return null for invalid inputs - // not using spark DSL since it `try_cast` is only available from Spark 4x - val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + val df2 = tryCastWithRowId(dataWithRowId, toType) if (hasIncompatibleType) { checkSparkAnswer(df2) } else { @@ -1641,7 +1786,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { // cast() should throw exception on invalid inputs when ansi mode is enabled - val df = data.withColumn("converted", col("a").cast(toType)) + val df = dataWithRowId + .select(col("__row_id"), col("a"), col("a").cast(toType).as("converted")) + .orderBy(col("__row_id")) + .drop("__row_id") val res = if (useDataFrameDiff) { assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType) } else { @@ -1686,10 +1834,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // try_cast() should always return null for invalid inputs if (testTry) { - data.createOrReplaceTempView("t") - val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") + val df2 = tryCastWithRowId(dataWithRowId, toType) if (useDataFrameDiff) { assertDataFrameEqualsWithExceptions(df2, assertCometNative = !hasIncompatibleType) } else { @@ -1704,6 +1850,15 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + private def tryCastWithRowId(dataWithRowId: DataFrame, toType: DataType): DataFrame = { + dataWithRowId.createOrReplaceTempView("t") + // try_cast() should always return null for invalid inputs + // not using spark DSL since `try_cast` is only available from Spark 4.x + spark + .sql(s"select __row_id, a, try_cast(a as ${toType.sql}) as casted from t order by __row_id") + .drop("__row_id") + } + private def roundtripParquet(df: DataFrame, tempDir: File): DataFrame = { val filename = new File(tempDir, s"castTest_${System.currentTimeMillis()}.parquet").toString df.write.mode(SaveMode.Overwrite).parquet(filename)