Skip to content

Commit 6565de9

Browse files
matthewalex4Matthew Alex
andauthored
fix: handle ambiguous and non-existent local times (#3865)
* fix: choose earliest time when ambiguous * fix: handle non-existent times due to time change * test: add unit tests for ambiguous and non-existent local times * refactor: remove unnecessary catch_unwind in tests * test: add spark test for timestampntz dst casting * refactor: remove unwrap from resolve_local_datetime --------- Co-authored-by: Matthew Alex <malex@palantir.com>
1 parent 735adfb commit 6565de9

2 files changed

Lines changed: 96 additions & 5 deletions

File tree

native/spark-expr/src/utils.rs

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use arrow::{
3636
array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
3737
temporal_conversions::as_datetime,
3838
};
39-
use chrono::{DateTime, Offset, TimeZone};
39+
use chrono::{DateTime, LocalResult, NaiveDateTime, Offset, TimeZone};
4040

4141
/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or
4242
/// to apply timezone offset.
@@ -174,6 +174,34 @@ fn datetime_cast_err(value: i64) -> ArrowError {
174174
))
175175
}
176176

177+
/// Resolves a local datetime in the given timezone to an absolute DateTime,
178+
/// handling DST ambiguity and spring-forward gaps.
179+
/// Parameters:
180+
/// tz - timezone used to interpret local_datetime
181+
/// local_datetime - a naive local datetime to resolve
182+
fn resolve_local_datetime(tz: &Tz, local_datetime: NaiveDateTime) -> DateTime<Tz> {
183+
match tz.from_local_datetime(&local_datetime) {
184+
LocalResult::Single(dt) => dt,
185+
LocalResult::Ambiguous(dt, _) => dt,
186+
LocalResult::None => {
187+
// Determine offset before time-change
188+
let probe = local_datetime - chrono::Duration::hours(3);
189+
let pre_offset = match tz.from_local_datetime(&probe) {
190+
LocalResult::Single(dt) => dt.offset().fix(),
191+
LocalResult::Ambiguous(dt, _) => dt.offset().fix(),
192+
LocalResult::None => {
193+
// Cannot determine offset; fall back to UTC interpretation
194+
return local_datetime.and_utc().with_timezone(tz);
195+
}
196+
};
197+
let offset_secs = pre_offset.local_minus_utc() as i64;
198+
199+
let utc_naive = local_datetime - chrono::Duration::seconds(offset_secs);
200+
utc_naive.and_utc().with_timezone(tz)
201+
}
202+
}
203+
}
204+
177205
/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns
178206
/// a Timestamp(Microsecond, Some<_>) array.
179207
/// The understanding is that the input array has time in the timezone specified in the second
@@ -196,8 +224,8 @@ fn timestamp_ntz_to_timestamp(
196224
as_datetime::<TimestampMicrosecondType>(value)
197225
.ok_or_else(|| datetime_cast_err(value))
198226
.map(|local_datetime| {
199-
let datetime: DateTime<Tz> =
200-
tz.from_local_datetime(&local_datetime).unwrap();
227+
let datetime = resolve_local_datetime(&tz, local_datetime);
228+
201229
datetime.timestamp_micros()
202230
})
203231
})?;
@@ -215,8 +243,8 @@ fn timestamp_ntz_to_timestamp(
215243
as_datetime::<TimestampMillisecondType>(value)
216244
.ok_or_else(|| datetime_cast_err(value))
217245
.map(|local_datetime| {
218-
let datetime: DateTime<Tz> =
219-
tz.from_local_datetime(&local_datetime).unwrap();
246+
let datetime = resolve_local_datetime(&tz, local_datetime);
247+
220248
datetime.timestamp_millis()
221249
})
222250
})?;
@@ -312,6 +340,19 @@ pub fn unlikely(b: bool) -> bool {
312340
mod tests {
313341
use super::*;
314342

343+
fn array_containing(local_datetime: &str) -> ArrayRef {
344+
let dt = NaiveDateTime::parse_from_str(local_datetime, "%Y-%m-%d %H:%M:%S").unwrap();
345+
let ts = dt.and_utc().timestamp_micros();
346+
Arc::new(TimestampMicrosecondArray::from(vec![ts]))
347+
}
348+
349+
fn micros_for(datetime: &str) -> i64 {
350+
NaiveDateTime::parse_from_str(datetime, "%Y-%m-%d %H:%M:%S")
351+
.unwrap()
352+
.and_utc()
353+
.timestamp_micros()
354+
}
355+
315356
#[test]
316357
fn test_build_bool_state() {
317358
let mut builder = BooleanBufferBuilder::new(0);
@@ -330,4 +371,34 @@ mod tests {
330371
);
331372
assert_eq!(last, build_bool_state(&mut builder, &EmitTo::All));
332373
}
374+
375+
#[test]
376+
fn test_timestamp_ntz_to_timestamp_handles_non_existent_time() {
377+
let output = timestamp_ntz_to_timestamp(
378+
array_containing("2024-03-31 01:30:00"),
379+
"Europe/London",
380+
None,
381+
)
382+
.unwrap();
383+
384+
assert_eq!(
385+
as_primitive_array::<TimestampMicrosecondType>(&output).value(0),
386+
micros_for("2024-03-31 01:30:00")
387+
);
388+
}
389+
390+
#[test]
391+
fn test_timestamp_ntz_to_timestamp_handles_ambiguous_time() {
392+
let output = timestamp_ntz_to_timestamp(
393+
array_containing("2024-10-27 01:30:00"),
394+
"Europe/London",
395+
None,
396+
)
397+
.unwrap();
398+
399+
assert_eq!(
400+
as_primitive_array::<TimestampMicrosecondType>(&output).value(0),
401+
micros_for("2024-10-27 00:30:00")
402+
);
403+
}
333404
}

spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,4 +489,24 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
489489
dummyDF.selectExpr("unix_date(cast(NULL as date))"))
490490
}
491491
}
492+
493+
test("cast TimestampNTZ to Timestamp - DST edge cases") {
494+
val data = Seq(
495+
Row(java.time.LocalDateTime.parse("2024-03-31T01:30:00")), // Spring forward (Europe/London)
496+
Row(java.time.LocalDateTime.parse("2024-10-27T01:30:00")) // Fall back (Europe/London)
497+
)
498+
val schema = StructType(Seq(StructField("ts_ntz", DataTypes.TimestampNTZType, true)))
499+
spark
500+
.createDataFrame(spark.sparkContext.parallelize(data), schema)
501+
.createOrReplaceTempView("dst_tbl")
502+
503+
// We `allowIncompatible` here because casts involving TimestampNTZ are marked
504+
// as Incompatible (due to incorrect behaviour when casting from a string)
505+
withSQLConf(
506+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Europe/London",
507+
"spark.comet.expression.Cast.allowIncompatible" -> "true") {
508+
checkSparkAnswerAndOperator(
509+
"SELECT ts_ntz, CAST(ts_ntz AS TIMESTAMP) FROM dst_tbl ORDER BY ts_ntz")
510+
}
511+
}
492512
}

0 commit comments

Comments
 (0)