diff --git a/native/spark-expr/src/utils.rs b/native/spark-expr/src/utils.rs index 613f55cf77..1a4ab70241 100644 --- a/native/spark-expr/src/utils.rs +++ b/native/spark-expr/src/utils.rs @@ -36,7 +36,7 @@ use arrow::{ array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray}, temporal_conversions::as_datetime, }; -use chrono::{DateTime, Offset, TimeZone}; +use chrono::{DateTime, LocalResult, NaiveDateTime, Offset, TimeDelta, TimeZone}; /// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or /// to apply timezone offset. @@ -174,6 +174,23 @@ fn datetime_cast_err(value: i64) -> ArrowError { )) } +fn resolve_local_datetime(tz: &Tz, local_datetime: NaiveDateTime) -> DateTime { + match tz.from_local_datetime(&local_datetime) { + LocalResult::Single(dt) => dt, + LocalResult::Ambiguous(dt, _) => dt, + LocalResult::None => { + // Interpret nonexistent local time by shifting from one hour earlier. + // This handles the common case of DST transitions with 1-hour gaps. + // NOTE: Some timezones have non-standard DST transitions (e.g., Australia/Lord_Howe + // has a 30-minute shift). This function assumes a 1-hour gap and may not correctly + // handle those cases. + let shift = TimeDelta::hours(1); + let before = tz.from_local_datetime(&(local_datetime - shift)).unwrap(); + before + shift + } + } +} + /// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns /// a Timestamp(Microsecond, Some<_>) array. /// The understanding is that the input array has time in the timezone specified in the second @@ -196,8 +213,8 @@ fn timestamp_ntz_to_timestamp( as_datetime::(value) .ok_or_else(|| datetime_cast_err(value)) .map(|local_datetime| { - let datetime: DateTime = - tz.from_local_datetime(&local_datetime).unwrap(); + let datetime = resolve_local_datetime(&tz, local_datetime); + datetime.timestamp_micros() }) })?; @@ -215,8 +232,8 @@ fn timestamp_ntz_to_timestamp( as_datetime::(value) .ok_or_else(|| datetime_cast_err(value)) .map(|local_datetime| { - let datetime: DateTime = - tz.from_local_datetime(&local_datetime).unwrap(); + let datetime = resolve_local_datetime(&tz, local_datetime); + datetime.timestamp_millis() }) })?; @@ -312,6 +329,19 @@ pub fn unlikely(b: bool) -> bool { mod tests { use super::*; + fn array_containing(local_datetime: &str) -> ArrayRef { + let dt = NaiveDateTime::parse_from_str(local_datetime, "%Y-%m-%d %H:%M:%S").unwrap(); + let ts = dt.and_utc().timestamp_micros(); + Arc::new(TimestampMicrosecondArray::from(vec![ts])) + } + + fn micros_for(datetime: &str) -> i64 { + NaiveDateTime::parse_from_str(datetime, "%Y-%m-%d %H:%M:%S") + .unwrap() + .and_utc() + .timestamp_micros() + } + #[test] fn test_build_bool_state() { let mut builder = BooleanBufferBuilder::new(0); @@ -330,4 +360,34 @@ mod tests { ); assert_eq!(last, build_bool_state(&mut builder, &EmitTo::All)); } + + #[test] + fn test_timestamp_ntz_to_timestamp_handles_non_existent_time() { + let output = timestamp_ntz_to_timestamp( + array_containing("2024-03-31 01:30:00"), + "Europe/London", + None, + ) + .unwrap(); + + assert_eq!( + as_primitive_array::(&output).value(0), + micros_for("2024-03-31 01:30:00") + ); + } + + #[test] + fn test_timestamp_ntz_to_timestamp_handles_ambiguous_time() { + let output = timestamp_ntz_to_timestamp( + array_containing("2024-10-27 01:30:00"), + "Europe/London", + None, + ) + .unwrap(); + + assert_eq!( + as_primitive_array::(&output).value(0), + micros_for("2024-10-27 00:30:00") + ); + } } diff --git a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala index 9f5413933a..dcac002c8f 100644 --- a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala @@ -489,4 +489,22 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH dummyDF.selectExpr("unix_date(cast(NULL as date))")) } } + + test("cast TimestampNTZ to Timestamp - DST edge cases") { + val data = Seq( + Row(java.time.LocalDateTime.parse("2024-03-31T01:30:00")), // Spring forward (Europe/London) + Row(java.time.LocalDateTime.parse("2024-10-27T01:30:00")) // Fall back (Europe/London) + ) + val schema = StructType(Seq(StructField("ts_ntz", DataTypes.TimestampNTZType, true))) + spark + .createDataFrame(spark.sparkContext.parallelize(data), schema) + .createOrReplaceTempView("dst_tbl") + + withSQLConf( + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Europe/London", + "spark.comet.expression.Cast.allowIncompatible" -> "true") { + checkSparkAnswerAndOperator( + "SELECT ts_ntz, CAST(ts_ntz AS TIMESTAMP) FROM dst_tbl ORDER BY ts_ntz") + } + } }