Skip to content

Commit d9f16cc

Browse files
committed
DataFusion 52 migration
1 parent 921a7d0 commit d9f16cc

5 files changed

Lines changed: 379 additions & 26 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,6 @@ impl PhysicalPlanner {
11131113
))
11141114
}
11151115
OpStruct::Scan(scan) => {
1116-
// dbg!(&scan);
11171116
let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec();
11181117

11191118
// If it is not test execution context for unit test, we should have at least one
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
use arrow::{
18+
array::{ArrayRef, TimestampMicrosecondArray, TimestampMillisecondArray},
19+
compute::CastOptions,
20+
datatypes::{DataType, FieldRef, Schema, TimeUnit},
21+
record_batch::RecordBatch,
22+
};
23+
24+
use datafusion::common::format::DEFAULT_CAST_OPTIONS;
25+
use datafusion::common::Result as DataFusionResult;
26+
use datafusion::common::{cast_column, ScalarValue};
27+
use datafusion::logical_expr::ColumnarValue;
28+
use datafusion::physical_expr::PhysicalExpr;
29+
use std::{
30+
any::Any,
31+
fmt::{self, Display},
32+
hash::Hash,
33+
sync::Arc,
34+
};
35+
36+
/// Casts a Timestamp(Microsecond) array to Timestamp(Millisecond) by dividing values by 1000.
37+
/// Preserves the timezone from the target type.
38+
fn cast_timestamp_micros_to_millis_array(
39+
array: &ArrayRef,
40+
target_tz: Option<Arc<str>>,
41+
) -> ArrayRef {
42+
let micros_array = array
43+
.as_any()
44+
.downcast_ref::<TimestampMicrosecondArray>()
45+
.expect("Expected TimestampMicrosecondArray");
46+
47+
let millis_values: TimestampMillisecondArray = micros_array
48+
.iter()
49+
.map(|opt| opt.map(|v| v / 1000))
50+
.collect();
51+
52+
// Apply timezone if present
53+
let result = if let Some(tz) = target_tz {
54+
millis_values.with_timezone(tz)
55+
} else {
56+
millis_values
57+
};
58+
59+
Arc::new(result)
60+
}
61+
62+
/// Casts a Timestamp(Microsecond) scalar to Timestamp(Millisecond) by dividing the value by 1000.
63+
/// Preserves the timezone from the target type.
64+
fn cast_timestamp_micros_to_millis_scalar(
65+
opt_val: Option<i64>,
66+
target_tz: Option<Arc<str>>,
67+
) -> ScalarValue {
68+
let new_val = opt_val.map(|v| v / 1000);
69+
ScalarValue::TimestampMillisecond(new_val, target_tz)
70+
}
71+
72+
#[derive(Debug, Clone, Eq)]
73+
pub struct
74+
CometCastColumnExpr {
75+
/// The physical expression producing the value to cast.
76+
expr: Arc<dyn PhysicalExpr>,
77+
/// The physical field of the input column.
78+
input_physical_field: FieldRef,
79+
/// The field type required by query
80+
target_field: FieldRef,
81+
/// Options forwarded to [`cast_column`].
82+
cast_options: CastOptions<'static>,
83+
}
84+
85+
// Manually derive `PartialEq`/`Hash` as `Arc<dyn PhysicalExpr>` does not
86+
// implement these traits by default for the trait object.
87+
impl PartialEq for CometCastColumnExpr {
88+
fn eq(&self, other: &Self) -> bool {
89+
self.expr.eq(&other.expr)
90+
&& self.input_physical_field.eq(&other.input_physical_field)
91+
&& self.target_field.eq(&other.target_field)
92+
&& self.cast_options.eq(&other.cast_options)
93+
}
94+
}
95+
96+
impl Hash for CometCastColumnExpr {
97+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
98+
self.expr.hash(state);
99+
self.input_physical_field.hash(state);
100+
self.target_field.hash(state);
101+
self.cast_options.hash(state);
102+
}
103+
}
104+
105+
impl CometCastColumnExpr {
106+
/// Create a new [`CometCastColumnExpr`].
107+
pub fn new(
108+
expr: Arc<dyn PhysicalExpr>,
109+
physical_field: FieldRef,
110+
target_field: FieldRef,
111+
cast_options: Option<CastOptions<'static>>,
112+
) -> Self {
113+
Self {
114+
expr,
115+
input_physical_field: physical_field,
116+
target_field,
117+
cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS),
118+
}
119+
}
120+
}
121+
122+
impl Display for CometCastColumnExpr {
123+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124+
write!(
125+
f,
126+
"COMET_CAST_COLUMN({} AS {})",
127+
self.expr,
128+
self.target_field.data_type()
129+
)
130+
}
131+
}
132+
133+
impl PhysicalExpr for CometCastColumnExpr {
134+
fn as_any(&self) -> &dyn Any {
135+
self
136+
}
137+
138+
fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
139+
Ok(self.target_field.data_type().clone())
140+
}
141+
142+
fn nullable(&self, _input_schema: &Schema) -> DataFusionResult<bool> {
143+
Ok(self.target_field.is_nullable())
144+
}
145+
146+
fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
147+
let value = self.expr.evaluate(batch)?;
148+
149+
if value.data_type().equals_datatype(self.target_field.data_type()) {
150+
return Ok(value)
151+
}
152+
153+
let input_physical_field = self.input_physical_field.data_type();
154+
let target_field = self.target_field.data_type();
155+
156+
// dbg!(&input_physical_field, &target_field, &value);
157+
158+
// Handle specific type conversions with custom casts
159+
match (input_physical_field, target_field) {
160+
// Timestamp(Microsecond) -> Timestamp(Millisecond)
161+
(
162+
DataType::Timestamp(TimeUnit::Microsecond, _),
163+
DataType::Timestamp(TimeUnit::Millisecond, target_tz),
164+
) => match value {
165+
ColumnarValue::Array(array) => {
166+
let casted = cast_timestamp_micros_to_millis_array(&array, target_tz.clone());
167+
Ok(ColumnarValue::Array(casted))
168+
}
169+
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(opt_val, _)) => {
170+
let casted = cast_timestamp_micros_to_millis_scalar(opt_val, target_tz.clone());
171+
Ok(ColumnarValue::Scalar(casted))
172+
}
173+
_ => Ok(value),
174+
},
175+
_ => Ok(value),
176+
}
177+
}
178+
179+
fn return_field(&self, _input_schema: &Schema) -> DataFusionResult<FieldRef> {
180+
Ok(Arc::clone(&self.target_field))
181+
}
182+
183+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
184+
vec![&self.expr]
185+
}
186+
187+
fn with_new_children(
188+
self: Arc<Self>,
189+
mut children: Vec<Arc<dyn PhysicalExpr>>,
190+
) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
191+
assert_eq!(children.len(), 1);
192+
let child = children.pop().expect("CastColumnExpr child");
193+
Ok(Arc::new(Self::new(
194+
child,
195+
Arc::clone(&self.input_physical_field),
196+
Arc::clone(&self.target_field),
197+
Some(self.cast_options.clone()),
198+
)))
199+
}
200+
201+
fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202+
Display::fmt(self, f)
203+
}
204+
}
205+
206+
#[cfg(test)]
207+
mod tests {
208+
use super::*;
209+
use arrow::array::Array;
210+
use arrow::datatypes::Field;
211+
use datafusion::physical_expr::expressions::Column;
212+
213+
#[test]
214+
fn test_cast_timestamp_micros_to_millis_array() {
215+
// Create a TimestampMicrosecond array with some values
216+
let micros_array: TimestampMicrosecondArray = vec![
217+
Some(1_000_000), // 1 second in micros
218+
Some(2_500_000), // 2.5 seconds in micros
219+
None, // null value
220+
Some(0), // zero
221+
Some(-1_000_000), // negative value (before epoch)
222+
]
223+
.into();
224+
let array_ref: ArrayRef = Arc::new(micros_array);
225+
226+
// Cast without timezone
227+
let result = cast_timestamp_micros_to_millis_array(&array_ref, None);
228+
let millis_array = result
229+
.as_any()
230+
.downcast_ref::<TimestampMillisecondArray>()
231+
.expect("Expected TimestampMillisecondArray");
232+
233+
assert_eq!(millis_array.len(), 5);
234+
assert_eq!(millis_array.value(0), 1000); // 1_000_000 / 1000
235+
assert_eq!(millis_array.value(1), 2500); // 2_500_000 / 1000
236+
assert!(millis_array.is_null(2));
237+
assert_eq!(millis_array.value(3), 0);
238+
assert_eq!(millis_array.value(4), -1000); // -1_000_000 / 1000
239+
}
240+
241+
#[test]
242+
fn test_cast_timestamp_micros_to_millis_array_with_timezone() {
243+
let micros_array: TimestampMicrosecondArray = vec![Some(1_000_000), Some(2_000_000)].into();
244+
let array_ref: ArrayRef = Arc::new(micros_array);
245+
246+
let target_tz: Option<Arc<str>> = Some(Arc::from("UTC"));
247+
let result = cast_timestamp_micros_to_millis_array(&array_ref, target_tz);
248+
let millis_array = result
249+
.as_any()
250+
.downcast_ref::<TimestampMillisecondArray>()
251+
.expect("Expected TimestampMillisecondArray");
252+
253+
assert_eq!(millis_array.value(0), 1000);
254+
assert_eq!(millis_array.value(1), 2000);
255+
// Verify timezone is preserved
256+
assert_eq!(
257+
result.data_type(),
258+
&DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC")))
259+
);
260+
}
261+
262+
#[test]
263+
fn test_cast_timestamp_micros_to_millis_scalar() {
264+
// Test with a value
265+
let result = cast_timestamp_micros_to_millis_scalar(Some(1_500_000), None);
266+
assert_eq!(
267+
result,
268+
ScalarValue::TimestampMillisecond(Some(1500), None)
269+
);
270+
271+
// Test with null
272+
let null_result = cast_timestamp_micros_to_millis_scalar(None, None);
273+
assert_eq!(null_result, ScalarValue::TimestampMillisecond(None, None));
274+
275+
// Test with timezone
276+
let target_tz: Option<Arc<str>> = Some(Arc::from("UTC"));
277+
let tz_result = cast_timestamp_micros_to_millis_scalar(Some(2_000_000), target_tz.clone());
278+
assert_eq!(
279+
tz_result,
280+
ScalarValue::TimestampMillisecond(Some(2000), target_tz)
281+
);
282+
}
283+
284+
#[test]
285+
fn test_comet_cast_column_expr_evaluate_micros_to_millis_array() {
286+
// Create input schema with TimestampMicrosecond column
287+
let input_field = Arc::new(Field::new(
288+
"ts",
289+
DataType::Timestamp(TimeUnit::Microsecond, None),
290+
true,
291+
));
292+
let schema = Schema::new(vec![Arc::clone(&input_field)]);
293+
294+
// Create target field with TimestampMillisecond
295+
let target_field = Arc::new(Field::new(
296+
"ts",
297+
DataType::Timestamp(TimeUnit::Millisecond, None),
298+
true,
299+
));
300+
301+
// Create a column expression
302+
let col_expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("ts", 0));
303+
304+
// Create the CometCastColumnExpr
305+
let cast_expr = CometCastColumnExpr::new(col_expr, input_field, target_field, None);
306+
307+
// Create a record batch with TimestampMicrosecond data
308+
let micros_array: TimestampMicrosecondArray = vec![Some(1_000_000), Some(2_000_000), None].into();
309+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(micros_array)]).unwrap();
310+
311+
// Evaluate
312+
let result = cast_expr.evaluate(&batch).unwrap();
313+
314+
match result {
315+
ColumnarValue::Array(arr) => {
316+
let millis_array = arr
317+
.as_any()
318+
.downcast_ref::<TimestampMillisecondArray>()
319+
.expect("Expected TimestampMillisecondArray");
320+
assert_eq!(millis_array.value(0), 1000);
321+
assert_eq!(millis_array.value(1), 2000);
322+
assert!(millis_array.is_null(2));
323+
}
324+
_ => panic!("Expected Array result"),
325+
}
326+
}
327+
328+
#[test]
329+
fn test_comet_cast_column_expr_evaluate_micros_to_millis_scalar() {
330+
// Create input schema with TimestampMicrosecond column
331+
let input_field = Arc::new(Field::new(
332+
"ts",
333+
DataType::Timestamp(TimeUnit::Microsecond, None),
334+
true,
335+
));
336+
let schema = Schema::new(vec![Arc::clone(&input_field)]);
337+
338+
// Create target field with TimestampMillisecond
339+
let target_field = Arc::new(Field::new(
340+
"ts",
341+
DataType::Timestamp(TimeUnit::Millisecond, None),
342+
true,
343+
));
344+
345+
// Create a literal expression that returns a scalar
346+
let scalar = ScalarValue::TimestampMicrosecond(Some(1_500_000), None);
347+
let literal_expr: Arc<dyn PhysicalExpr> =
348+
Arc::new(datafusion::physical_expr::expressions::Literal::new(scalar));
349+
350+
// Create the CometCastColumnExpr
351+
let cast_expr = CometCastColumnExpr::new(literal_expr, input_field, target_field, None);
352+
353+
// Create an empty batch (scalar doesn't need data)
354+
let batch = RecordBatch::new_empty(Arc::new(schema));
355+
356+
// Evaluate
357+
let result = cast_expr.evaluate(&batch).unwrap();
358+
359+
match result {
360+
ColumnarValue::Scalar(s) => {
361+
assert_eq!(s, ScalarValue::TimestampMillisecond(Some(1500), None));
362+
}
363+
_ => panic!("Expected Scalar result"),
364+
}
365+
}
366+
}

native/core/src/parquet/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub mod parquet_support;
2727
pub mod read;
2828
pub mod schema_adapter;
2929

30+
mod cast_column;
3031
mod objectstore;
3132

3233
use std::collections::HashMap;

0 commit comments

Comments
 (0)