Skip to content

Commit 219859b

Browse files
authored
feat: Support int to timestamp casts (apache#3541)
1 parent 2829ce8 commit 219859b

5 files changed

Lines changed: 328 additions & 20 deletions

File tree

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ harness = false
9292
name = "to_csv"
9393
harness = false
9494

95+
[[bench]]
96+
name = "cast_int_to_timestamp"
97+
harness = false
98+
9599
[[test]]
96100
name = "test_udf_registration"
97101
path = "tests/spark_expr_reg.rs"
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::builder::{Int16Builder, Int32Builder, Int64Builder, Int8Builder};
19+
use arrow::array::RecordBatch;
20+
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
21+
use criterion::{criterion_group, criterion_main, Criterion};
22+
use datafusion::physical_expr::{expressions::Column, PhysicalExpr};
23+
use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
24+
use std::sync::Arc;
25+
26+
const BATCH_SIZE: usize = 8192;
27+
28+
fn criterion_benchmark(c: &mut Criterion) {
29+
// Test with UTC timezone
30+
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
31+
let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()));
32+
33+
let mut group = c.benchmark_group("cast_int_to_timestamp");
34+
35+
// Int8 -> Timestamp
36+
let batch_i8 = create_int8_batch();
37+
let expr_i8 = Arc::new(Column::new("a", 0));
38+
let cast_i8_to_ts = Cast::new(expr_i8, timestamp_type.clone(), spark_cast_options.clone());
39+
group.bench_function("cast_i8_to_timestamp", |b| {
40+
b.iter(|| cast_i8_to_ts.evaluate(&batch_i8).unwrap());
41+
});
42+
43+
// Int16 -> Timestamp
44+
let batch_i16 = create_int16_batch();
45+
let expr_i16 = Arc::new(Column::new("a", 0));
46+
let cast_i16_to_ts = Cast::new(expr_i16, timestamp_type.clone(), spark_cast_options.clone());
47+
group.bench_function("cast_i16_to_timestamp", |b| {
48+
b.iter(|| cast_i16_to_ts.evaluate(&batch_i16).unwrap());
49+
});
50+
51+
// Int32 -> Timestamp
52+
let batch_i32 = create_int32_batch();
53+
let expr_i32 = Arc::new(Column::new("a", 0));
54+
let cast_i32_to_ts = Cast::new(expr_i32, timestamp_type.clone(), spark_cast_options.clone());
55+
group.bench_function("cast_i32_to_timestamp", |b| {
56+
b.iter(|| cast_i32_to_ts.evaluate(&batch_i32).unwrap());
57+
});
58+
59+
// Int64 -> Timestamp
60+
let batch_i64 = create_int64_batch();
61+
let expr_i64 = Arc::new(Column::new("a", 0));
62+
let cast_i64_to_ts = Cast::new(expr_i64, timestamp_type.clone(), spark_cast_options.clone());
63+
group.bench_function("cast_i64_to_timestamp", |b| {
64+
b.iter(|| cast_i64_to_ts.evaluate(&batch_i64).unwrap());
65+
});
66+
67+
group.finish();
68+
}
69+
70+
fn create_int8_batch() -> RecordBatch {
71+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int8, true)]));
72+
let mut b = Int8Builder::with_capacity(BATCH_SIZE);
73+
for i in 0..BATCH_SIZE {
74+
if i % 10 == 0 {
75+
b.append_null();
76+
} else {
77+
b.append_value(rand::random::<i8>());
78+
}
79+
}
80+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
81+
}
82+
83+
fn create_int16_batch() -> RecordBatch {
84+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int16, true)]));
85+
let mut b = Int16Builder::with_capacity(BATCH_SIZE);
86+
for i in 0..BATCH_SIZE {
87+
if i % 10 == 0 {
88+
b.append_null();
89+
} else {
90+
b.append_value(rand::random::<i16>());
91+
}
92+
}
93+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
94+
}
95+
96+
fn create_int32_batch() -> RecordBatch {
97+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
98+
let mut b = Int32Builder::with_capacity(BATCH_SIZE);
99+
for i in 0..BATCH_SIZE {
100+
if i % 10 == 0 {
101+
b.append_null();
102+
} else {
103+
b.append_value(rand::random::<i32>());
104+
}
105+
}
106+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
107+
}
108+
109+
fn create_int64_batch() -> RecordBatch {
110+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
111+
let mut b = Int64Builder::with_capacity(BATCH_SIZE);
112+
for i in 0..BATCH_SIZE {
113+
if i % 10 == 0 {
114+
b.append_null();
115+
} else {
116+
b.append_value(rand::random::<i64>());
117+
}
118+
}
119+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
120+
}
121+
122+
fn config() -> Criterion {
123+
Criterion::default()
124+
}
125+
126+
criterion_group! {
127+
name = benches;
128+
config = config();
129+
targets = criterion_benchmark
130+
}
131+
criterion_main!(benches);

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,23 @@ macro_rules! cast_decimal_to_int32_up {
613613
}};
614614
}
615615

616+
macro_rules! cast_int_to_timestamp_impl {
617+
($array:expr, $builder:expr, $primitive_type:ty) => {{
618+
let arr = $array.as_primitive::<$primitive_type>();
619+
for i in 0..arr.len() {
620+
if arr.is_null(i) {
621+
$builder.append_null();
622+
} else {
623+
// saturating_mul limits to i64::MIN/MAX on overflow instead of panicking,
624+
// which could occur when converting extreme values (e.g., Long.MIN_VALUE)
625+
// matching spark behavior (irrespective of EvalMode)
626+
let micros = (arr.value(i) as i64).saturating_mul(MICROS_PER_SECOND);
627+
$builder.append_value(micros);
628+
}
629+
}
630+
}};
631+
}
632+
616633
// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
617634
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
618635
let (sign, rest) = match value_str.strip_prefix('-') {
@@ -915,6 +932,7 @@ fn cast_array(
915932
(Boolean, Decimal128(precision, scale)) => {
916933
cast_boolean_to_decimal(&array, *precision, *scale)
917934
}
935+
(Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) => cast_int_to_timestamp(&array, tz),
918936
_ if cast_options.is_adapting_schema
919937
|| is_datafusion_spark_compatible(from_type, to_type) =>
920938
{
@@ -933,6 +951,29 @@ fn cast_array(
933951
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
934952
}
935953

954+
fn cast_int_to_timestamp(
955+
array_ref: &ArrayRef,
956+
target_tz: &Option<Arc<str>>,
957+
) -> SparkResult<ArrayRef> {
958+
// Input is seconds since epoch, multiply by MICROS_PER_SECOND to get microseconds.
959+
let mut builder = TimestampMicrosecondBuilder::with_capacity(array_ref.len());
960+
961+
match array_ref.data_type() {
962+
DataType::Int8 => cast_int_to_timestamp_impl!(array_ref, builder, Int8Type),
963+
DataType::Int16 => cast_int_to_timestamp_impl!(array_ref, builder, Int16Type),
964+
DataType::Int32 => cast_int_to_timestamp_impl!(array_ref, builder, Int32Type),
965+
DataType::Int64 => cast_int_to_timestamp_impl!(array_ref, builder, Int64Type),
966+
dt => {
967+
return Err(SparkError::Internal(format!(
968+
"Unsupported type for cast_int_to_timestamp: {:?}",
969+
dt
970+
)))
971+
}
972+
}
973+
974+
Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef)
975+
}
976+
936977
fn cast_date_to_timestamp(
937978
array_ref: &ArrayRef,
938979
cast_options: &SparkCastOptions,
@@ -3519,4 +3560,94 @@ mod tests {
35193560
assert_eq!(r#"[null]"#, string_array.value(2));
35203561
assert_eq!(r#"[]"#, string_array.value(3));
35213562
}
3563+
3564+
#[test]
3565+
fn test_cast_int_to_timestamp() {
3566+
let timezones: [Option<Arc<str>>; 6] = [
3567+
Some(Arc::from("UTC")),
3568+
Some(Arc::from("America/New_York")),
3569+
Some(Arc::from("America/Los_Angeles")),
3570+
Some(Arc::from("Europe/London")),
3571+
Some(Arc::from("Asia/Tokyo")),
3572+
Some(Arc::from("Australia/Sydney")),
3573+
];
3574+
3575+
for tz in &timezones {
3576+
let int8_array: ArrayRef = Arc::new(Int8Array::from(vec![
3577+
Some(0),
3578+
Some(1),
3579+
Some(-1),
3580+
Some(127),
3581+
Some(-128),
3582+
None,
3583+
]));
3584+
3585+
let result = cast_int_to_timestamp(&int8_array, tz).unwrap();
3586+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
3587+
3588+
assert_eq!(ts_array.value(0), 0);
3589+
assert_eq!(ts_array.value(1), 1_000_000);
3590+
assert_eq!(ts_array.value(2), -1_000_000);
3591+
assert_eq!(ts_array.value(3), 127_000_000);
3592+
assert_eq!(ts_array.value(4), -128_000_000);
3593+
assert!(ts_array.is_null(5));
3594+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
3595+
3596+
let int16_array: ArrayRef = Arc::new(Int16Array::from(vec![
3597+
Some(0),
3598+
Some(1),
3599+
Some(-1),
3600+
Some(32767),
3601+
Some(-32768),
3602+
None,
3603+
]));
3604+
3605+
let result = cast_int_to_timestamp(&int16_array, tz).unwrap();
3606+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
3607+
3608+
assert_eq!(ts_array.value(0), 0);
3609+
assert_eq!(ts_array.value(1), 1_000_000);
3610+
assert_eq!(ts_array.value(2), -1_000_000);
3611+
assert_eq!(ts_array.value(3), 32_767_000_000_i64);
3612+
assert_eq!(ts_array.value(4), -32_768_000_000_i64);
3613+
assert!(ts_array.is_null(5));
3614+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
3615+
3616+
let int32_array: ArrayRef = Arc::new(Int32Array::from(vec![
3617+
Some(0),
3618+
Some(1),
3619+
Some(-1),
3620+
Some(1704067200),
3621+
None,
3622+
]));
3623+
3624+
let result = cast_int_to_timestamp(&int32_array, tz).unwrap();
3625+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
3626+
3627+
assert_eq!(ts_array.value(0), 0);
3628+
assert_eq!(ts_array.value(1), 1_000_000);
3629+
assert_eq!(ts_array.value(2), -1_000_000);
3630+
assert_eq!(ts_array.value(3), 1_704_067_200_000_000_i64);
3631+
assert!(ts_array.is_null(4));
3632+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
3633+
3634+
let int64_array: ArrayRef = Arc::new(Int64Array::from(vec![
3635+
Some(0),
3636+
Some(1),
3637+
Some(-1),
3638+
Some(i64::MAX),
3639+
Some(i64::MIN),
3640+
]));
3641+
3642+
let result = cast_int_to_timestamp(&int64_array, tz).unwrap();
3643+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
3644+
3645+
assert_eq!(ts_array.value(0), 0);
3646+
assert_eq!(ts_array.value(1), 1_000_000_i64);
3647+
assert_eq!(ts_array.value(2), -1_000_000_i64);
3648+
assert_eq!(ts_array.value(3), i64::MAX);
3649+
assert_eq!(ts_array.value(4), i64::MIN);
3650+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
3651+
}
3652+
}
35223653
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
299299
Compatible()
300300
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
301301
Compatible()
302+
case DataTypes.TimestampType =>
303+
Compatible()
302304
case _ =>
303305
unsupported(DataTypes.ByteType, toType)
304306
}
@@ -313,6 +315,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
313315
Compatible()
314316
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
315317
Compatible()
318+
case DataTypes.TimestampType =>
319+
Compatible()
316320
case _ =>
317321
unsupported(DataTypes.ShortType, toType)
318322
}
@@ -328,6 +332,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
328332
case _: DecimalType =>
329333
Compatible()
330334
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
335+
case DataTypes.TimestampType =>
336+
Compatible()
331337
case _ =>
332338
unsupported(DataTypes.IntegerType, toType)
333339
}
@@ -343,6 +349,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
343349
case _: DecimalType =>
344350
Compatible()
345351
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
352+
case DataTypes.TimestampType =>
353+
Compatible()
346354
case _ =>
347355
unsupported(DataTypes.LongType, toType)
348356
}

0 commit comments

Comments
 (0)