Skip to content

Commit 9a7e616

Browse files
authored
feat: Support Spark expression hours (#3804)
* feat: Add Spark V2 partition transform `Hours` to calculate hours since epoch from timestamps.
1 parent 4caad8b commit 9a7e616

9 files changed

Lines changed: 414 additions & 6 deletions

File tree

native/core/src/execution/expressions/temporal.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ use datafusion::logical_expr::ScalarUDF;
2525
use datafusion::physical_expr::{PhysicalExpr, ScalarFunctionExpr};
2626
use datafusion_comet_proto::spark_expression::Expr;
2727
use datafusion_comet_spark_expr::{
28-
SparkHour, SparkMinute, SparkSecond, SparkUnixTimestamp, TimestampTruncExpr,
28+
SparkHour, SparkHoursTransform, SparkMinute, SparkSecond, SparkUnixTimestamp,
29+
TimestampTruncExpr,
2930
};
3031

3132
use crate::execution::{
@@ -160,3 +161,29 @@ impl ExpressionBuilder for TruncTimestampBuilder {
160161
Ok(Arc::new(TimestampTruncExpr::new(child, format, timezone)))
161162
}
162163
}
164+
165+
pub struct HoursTransformBuilder;
166+
167+
impl ExpressionBuilder for HoursTransformBuilder {
168+
fn build(
169+
&self,
170+
spark_expr: &Expr,
171+
input_schema: SchemaRef,
172+
planner: &PhysicalPlanner,
173+
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
174+
let expr = extract_expr!(spark_expr, HoursTransform);
175+
let child = planner.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
176+
let args = vec![child];
177+
let comet_hours_transform = Arc::new(ScalarUDF::new_from_impl(SparkHoursTransform::new()));
178+
let field_ref = Arc::new(Field::new("hours_transform", DataType::Int32, true));
179+
let expr: ScalarFunctionExpr = ScalarFunctionExpr::new(
180+
"hours_transform",
181+
comet_hours_transform,
182+
args,
183+
field_ref,
184+
Arc::new(ConfigOptions::default()),
185+
);
186+
187+
Ok(Arc::new(expr))
188+
}
189+
}

native/core/src/execution/planner/expression_registry.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ pub enum ExpressionType {
110110
Second,
111111
TruncTimestamp,
112112
UnixTimestamp,
113+
HoursTransform,
113114
}
114115

115116
/// Registry for expression builders
@@ -310,6 +311,10 @@ impl ExpressionRegistry {
310311
ExpressionType::TruncTimestamp,
311312
Box::new(TruncTimestampBuilder),
312313
);
314+
self.builders.insert(
315+
ExpressionType::HoursTransform,
316+
Box::new(HoursTransformBuilder),
317+
);
313318
}
314319

315320
/// Extract expression type from Spark protobuf expression
@@ -382,6 +387,7 @@ impl ExpressionRegistry {
382387
Some(ExprStruct::Second(_)) => Ok(ExpressionType::Second),
383388
Some(ExprStruct::TruncTimestamp(_)) => Ok(ExpressionType::TruncTimestamp),
384389
Some(ExprStruct::UnixTimestamp(_)) => Ok(ExpressionType::UnixTimestamp),
390+
Some(ExprStruct::HoursTransform(_)) => Ok(ExpressionType::HoursTransform),
385391

386392
Some(other) => Err(ExecutionError::GeneralError(format!(
387393
"Unsupported expression type: {:?}",

native/proto/src/proto/expr.proto

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ message Expr {
8888
UnixTimestamp unix_timestamp = 65;
8989
FromJson from_json = 66;
9090
ToCsv to_csv = 67;
91+
HoursTransform hours_transform = 68;
9192
}
9293

9394
// Optional QueryContext for error reporting (contains SQL text and position)
@@ -356,6 +357,10 @@ message Hour {
356357
string timezone = 2;
357358
}
358359

360+
message HoursTransform {
361+
Expr child = 1;
362+
}
363+
359364
message Minute {
360365
Expr child = 1;
361366
string timezone = 2;
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Spark-compatible `hours` V2 partition transform.
19+
//!
20+
//! Computes the number of hours since the Unix epoch (1970-01-01 00:00:00 UTC).
21+
//!
22+
//! Both `TimestampType` and `TimestampNTZType` are computationally identical. They
23+
//! extract the absolute hours since the epoch by directly dividing the microsecond
24+
//! value by the number of microseconds in an hour, ignoring session timezone offsets.
25+
26+
use arrow::array::cast::as_primitive_array;
27+
use arrow::array::types::TimestampMicrosecondType;
28+
use arrow::array::{Array, Int32Array};
29+
use arrow::datatypes::{DataType, TimeUnit::Microsecond};
30+
use datafusion::common::{internal_datafusion_err, DataFusionError};
31+
use datafusion::logical_expr::{
32+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
33+
};
34+
use num::integer::div_floor;
35+
use std::{any::Any, fmt::Debug, sync::Arc};
36+
37+
const MICROS_PER_HOUR: i64 = 3_600_000_000;
38+
39+
#[derive(Debug, PartialEq, Eq, Hash)]
40+
pub struct SparkHoursTransform {
41+
signature: Signature,
42+
}
43+
44+
impl SparkHoursTransform {
45+
pub fn new() -> Self {
46+
Self {
47+
signature: Signature::user_defined(Volatility::Immutable),
48+
}
49+
}
50+
}
51+
52+
impl Default for SparkHoursTransform {
53+
fn default() -> Self {
54+
Self::new()
55+
}
56+
}
57+
58+
impl ScalarUDFImpl for SparkHoursTransform {
59+
fn as_any(&self) -> &dyn Any {
60+
self
61+
}
62+
63+
fn name(&self) -> &str {
64+
"hours_transform"
65+
}
66+
67+
fn signature(&self) -> &Signature {
68+
&self.signature
69+
}
70+
71+
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
72+
Ok(DataType::Int32)
73+
}
74+
75+
fn invoke_with_args(
76+
&self,
77+
args: ScalarFunctionArgs,
78+
) -> datafusion::common::Result<ColumnarValue> {
79+
let args: [ColumnarValue; 1] = args.args.try_into().map_err(|_| {
80+
internal_datafusion_err!("hours_transform expects exactly one argument")
81+
})?;
82+
83+
match args {
84+
[ColumnarValue::Array(array)] => {
85+
let result: Int32Array = match array.data_type() {
86+
DataType::Timestamp(Microsecond, _) => {
87+
let ts_array = as_primitive_array::<TimestampMicrosecondType>(&array);
88+
arrow::compute::kernels::arity::unary(ts_array, |micros| {
89+
div_floor(micros, MICROS_PER_HOUR) as i32
90+
})
91+
}
92+
other => {
93+
return Err(DataFusionError::Execution(format!(
94+
"hours_transform does not support input type: {:?}",
95+
other
96+
)));
97+
}
98+
};
99+
Ok(ColumnarValue::Array(Arc::new(result)))
100+
}
101+
_ => Err(DataFusionError::Execution(
102+
"hours_transform(scalar) should be folded on Spark JVM side.".to_string(),
103+
)),
104+
}
105+
}
106+
}
107+
108+
#[cfg(test)]
109+
mod tests {
110+
use super::*;
111+
use arrow::array::TimestampMicrosecondArray;
112+
use arrow::datatypes::Field;
113+
use datafusion::config::ConfigOptions;
114+
use std::sync::Arc;
115+
116+
#[test]
117+
fn test_hours_transform_utc() {
118+
let udf = SparkHoursTransform::new();
119+
// 2023-10-01 14:30:00 UTC = 1696171800 seconds = 1696171800000000 micros
120+
// Expected hours since epoch = 1696171800000000 / 3600000000 = 471158
121+
let ts = TimestampMicrosecondArray::from(vec![Some(1_696_171_800_000_000i64)])
122+
.with_timezone("UTC");
123+
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
124+
let args = ScalarFunctionArgs {
125+
args: vec![ColumnarValue::Array(Arc::new(ts))],
126+
number_rows: 1,
127+
return_field,
128+
config_options: Arc::new(ConfigOptions::default()),
129+
arg_fields: vec![],
130+
};
131+
let result = udf.invoke_with_args(args).unwrap();
132+
match result {
133+
ColumnarValue::Array(arr) => {
134+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
135+
assert_eq!(int_arr.value(0), 471158);
136+
}
137+
_ => panic!("Expected array"),
138+
}
139+
}
140+
141+
#[test]
142+
fn test_hours_transform_ntz() {
143+
let udf = SparkHoursTransform::new();
144+
// Same timestamp but NTZ (no timezone on array)
145+
let ts = TimestampMicrosecondArray::from(vec![Some(1_696_171_800_000_000i64)]);
146+
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
147+
let args = ScalarFunctionArgs {
148+
args: vec![ColumnarValue::Array(Arc::new(ts))],
149+
number_rows: 1,
150+
return_field,
151+
config_options: Arc::new(ConfigOptions::default()),
152+
arg_fields: vec![],
153+
};
154+
let result = udf.invoke_with_args(args).unwrap();
155+
match result {
156+
ColumnarValue::Array(arr) => {
157+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
158+
assert_eq!(int_arr.value(0), 471158);
159+
}
160+
_ => panic!("Expected array"),
161+
}
162+
}
163+
164+
#[test]
165+
fn test_hours_transform_negative_epoch() {
166+
let udf = SparkHoursTransform::new();
167+
// 1969-12-31 23:30:00 UTC = -1800 seconds = -1800000000 micros
168+
// Expected: floor_div(-1800000000, 3600000000) = -1
169+
let ts =
170+
TimestampMicrosecondArray::from(vec![Some(-1_800_000_000i64)]).with_timezone("UTC");
171+
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
172+
let args = ScalarFunctionArgs {
173+
args: vec![ColumnarValue::Array(Arc::new(ts))],
174+
number_rows: 1,
175+
return_field,
176+
config_options: Arc::new(ConfigOptions::default()),
177+
arg_fields: vec![],
178+
};
179+
let result = udf.invoke_with_args(args).unwrap();
180+
match result {
181+
ColumnarValue::Array(arr) => {
182+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
183+
assert_eq!(int_arr.value(0), -1);
184+
}
185+
_ => panic!("Expected array"),
186+
}
187+
}
188+
189+
#[test]
190+
fn test_hours_transform_null() {
191+
let udf = SparkHoursTransform::new();
192+
let ts = TimestampMicrosecondArray::from(vec![None as Option<i64>]).with_timezone("UTC");
193+
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
194+
let args = ScalarFunctionArgs {
195+
args: vec![ColumnarValue::Array(Arc::new(ts))],
196+
number_rows: 1,
197+
return_field,
198+
config_options: Arc::new(ConfigOptions::default()),
199+
arg_fields: vec![],
200+
};
201+
let result = udf.invoke_with_args(args).unwrap();
202+
match result {
203+
ColumnarValue::Array(arr) => {
204+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
205+
assert!(int_arr.is_null(0));
206+
}
207+
_ => panic!("Expected array"),
208+
}
209+
}
210+
211+
#[test]
212+
fn test_hours_transform_epoch_zero() {
213+
let udf = SparkHoursTransform::new();
214+
let ts = TimestampMicrosecondArray::from(vec![Some(0i64)]).with_timezone("UTC");
215+
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
216+
let args = ScalarFunctionArgs {
217+
args: vec![ColumnarValue::Array(Arc::new(ts))],
218+
number_rows: 1,
219+
return_field,
220+
config_options: Arc::new(ConfigOptions::default()),
221+
arg_fields: vec![],
222+
};
223+
let result = udf.invoke_with_args(args).unwrap();
224+
match result {
225+
ColumnarValue::Array(arr) => {
226+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
227+
assert_eq!(int_arr.value(0), 0);
228+
}
229+
_ => panic!("Expected array"),
230+
}
231+
}
232+
233+
#[test]
234+
fn test_hours_transform_non_utc_timezone() {
235+
// Spark's Hours partition transform evaluates absolute hours since epoch. Thus, a UTC
236+
// timestamp of 1970-01-01 00:00:00 UTC (micros=0) maps to 0 hours, even if the
237+
// timestamp array itself contains timezone metadata like Asia/Tokyo.
238+
let udf = SparkHoursTransform::new();
239+
let ts = TimestampMicrosecondArray::from(vec![Some(0i64)]).with_timezone("Asia/Tokyo");
240+
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
241+
let args = ScalarFunctionArgs {
242+
args: vec![ColumnarValue::Array(Arc::new(ts))],
243+
number_rows: 1,
244+
return_field,
245+
config_options: Arc::new(ConfigOptions::default()),
246+
arg_fields: vec![],
247+
};
248+
let result = udf.invoke_with_args(args).unwrap();
249+
match result {
250+
ColumnarValue::Array(arr) => {
251+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
252+
assert_eq!(int_arr.value(0), 0);
253+
}
254+
_ => panic!("Expected array"),
255+
}
256+
}
257+
258+
#[test]
259+
fn test_hours_transform_ntz_ignores_timezone() {
260+
// NTZ with micros=0 always returns 0 because NTZ is pure wall-clock time.
261+
// There is no timezone offset logic applied to either TimestampType or NTZ.
262+
let udf = SparkHoursTransform::new();
263+
let ts = TimestampMicrosecondArray::from(vec![Some(0i64)]); // No timezone on array
264+
let return_field = Arc::new(Field::new("hours_transform", DataType::Int32, true));
265+
let args = ScalarFunctionArgs {
266+
args: vec![ColumnarValue::Array(Arc::new(ts))],
267+
number_rows: 1,
268+
return_field,
269+
config_options: Arc::new(ConfigOptions::default()),
270+
arg_fields: vec![],
271+
};
272+
let result = udf.invoke_with_args(args).unwrap();
273+
match result {
274+
ColumnarValue::Array(arr) => {
275+
let int_arr = arr.as_any().downcast_ref::<Int32Array>().unwrap();
276+
assert_eq!(int_arr.value(0), 0); // NOT 9, because NTZ ignores timezone
277+
}
278+
_ => panic!("Expected array"),
279+
}
280+
}
281+
}

native/spark-expr/src/datetime_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
mod date_diff;
1919
mod date_trunc;
2020
mod extract_date_part;
21+
mod hours;
2122
mod make_date;
2223
mod timestamp_trunc;
2324
mod unix_timestamp;
@@ -27,6 +28,7 @@ pub use date_trunc::SparkDateTrunc;
2728
pub use extract_date_part::SparkHour;
2829
pub use extract_date_part::SparkMinute;
2930
pub use extract_date_part::SparkSecond;
31+
pub use hours::SparkHoursTransform;
3032
pub use make_date::SparkMakeDate;
3133
pub use timestamp_trunc::TimestampTruncExpr;
3234
pub use unix_timestamp::SparkUnixTimestamp;

native/spark-expr/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ pub use comet_scalar_funcs::{
7171
};
7272
pub use csv_funcs::*;
7373
pub use datetime_funcs::{
74-
SparkDateDiff, SparkDateTrunc, SparkHour, SparkMakeDate, SparkMinute, SparkSecond,
75-
SparkUnixTimestamp, TimestampTruncExpr,
74+
SparkDateDiff, SparkDateTrunc, SparkHour, SparkHoursTransform, SparkMakeDate, SparkMinute,
75+
SparkSecond, SparkUnixTimestamp, TimestampTruncExpr,
7676
};
7777
pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult};
7878
pub use hash_funcs::*;

0 commit comments

Comments
 (0)