Skip to content

Commit 926ffe3

Browse files
committed
Df52 migration
1 parent b267574 commit 926ffe3

2 files changed

Lines changed: 56 additions & 15 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@ use arrow::array::{
2424
PrimitiveBuilder, StringArray, StructArray, TimestampMicrosecondBuilder,
2525
};
2626
use arrow::compute::can_cast_types;
27-
use arrow::datatypes::{
28-
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType,
29-
Schema,
30-
};
27+
use arrow::datatypes::{i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType, Schema, TimeUnit};
3128
use arrow::{
3229
array::{
3330
cast::AsArray,
@@ -964,9 +961,11 @@ fn cast_array(
964961
cast_options: &SparkCastOptions,
965962
) -> DataFusionResult<ArrayRef> {
966963
use DataType::*;
967-
let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
968964
let from_type = array.data_type().clone();
969965

966+
let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
967+
let eval_mode = cast_options.eval_mode;
968+
970969
let native_cast_options: CastOptions = CastOptions {
971970
safe: !matches!(cast_options.eval_mode, EvalMode::Ansi), // take safe mode from cast_options passed
972971
format_options: FormatOptions::new()
@@ -1015,10 +1014,8 @@ fn cast_array(
10151014
}
10161015
}
10171016
};
1018-
let from_type = array.data_type();
1019-
let eval_mode = cast_options.eval_mode;
10201017

1021-
let cast_result = match (from_type, to_type) {
1018+
let cast_result = match (&from_type, to_type) {
10221019
(Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
10231020
(LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
10241021
(Utf8, Timestamp(_, _)) => {
@@ -1044,10 +1041,10 @@ fn cast_array(
10441041
| (Int16, Int8)
10451042
if eval_mode != EvalMode::Try =>
10461043
{
1047-
spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
1044+
spark_cast_int_to_int(&array, eval_mode, &from_type, to_type)
10481045
}
10491046
(Int8 | Int16 | Int32 | Int64, Decimal128(precision, scale)) => {
1050-
cast_int_to_decimal128(&array, eval_mode, from_type, to_type, *precision, *scale)
1047+
cast_int_to_decimal128(&array, eval_mode, &from_type, to_type, *precision, *scale)
10511048
}
10521049
(Utf8, Int8 | Int16 | Int32 | Int64) => {
10531050
cast_string_to_int::<i32>(to_type, &array, eval_mode)
@@ -1079,19 +1076,19 @@ fn cast_array(
10791076
| (Decimal128(_, _), Int64)
10801077
if eval_mode != EvalMode::Try =>
10811078
{
1082-
spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
1079+
spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, &from_type, to_type)
10831080
}
10841081
(Decimal128(_p, _s), Boolean) => spark_cast_decimal_to_boolean(&array),
10851082
(Utf8View, Utf8) => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?),
10861083
(Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
10871084
(Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
10881085
array.as_struct(),
1089-
from_type,
1086+
&from_type,
10901087
to_type,
10911088
cast_options,
10921089
)?),
10931090
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
1094-
(List(_), List(_)) if can_cast_types(from_type, to_type) => {
1091+
(List(_), List(_)) if can_cast_types(&from_type, to_type) => {
10951092
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
10961093
}
10971094
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
@@ -1102,7 +1099,7 @@ fn cast_array(
11021099
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
11031100
(Date32, Timestamp(_, tz)) => Ok(cast_date_to_timestamp(&array, cast_options, tz)?),
11041101
_ if cast_options.is_adapting_schema
1105-
|| is_datafusion_spark_compatible(from_type, to_type) =>
1102+
|| is_datafusion_spark_compatible(&from_type, to_type) =>
11061103
{
11071104
// use DataFusion cast only when we know that it is compatible with Spark
11081105
Ok(cast_with_options(&array, to_type, &native_cast_options)?)
@@ -1116,7 +1113,7 @@ fn cast_array(
11161113
)))
11171114
}
11181115
};
1119-
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
1116+
Ok(spark_cast_postprocess(cast_result?, &from_type, to_type))
11201117
}
11211118

11221119
fn cast_date_to_timestamp(

native/spark-expr/src/utils.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ use arrow::{
3535
array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray},
3636
temporal_conversions::as_datetime,
3737
};
38+
use arrow::array::TimestampMicrosecondArray;
3839
use chrono::{DateTime, Offset, TimeZone};
3940

4041
/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or
@@ -71,6 +72,49 @@ pub fn array_with_timezone(
7172
to_type: Option<&DataType>,
7273
) -> Result<ArrayRef, ArrowError> {
7374
match array.data_type() {
75+
DataType::Timestamp(TimeUnit::Millisecond, None) => {
76+
assert!(!timezone.is_empty());
77+
match to_type {
78+
Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array),
79+
Some(DataType::Timestamp(_, Some(_))) => {
80+
timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
81+
}
82+
Some(DataType::Timestamp(TimeUnit::Microsecond, None)) => {
83+
// Convert from Timestamp(Millisecond, None) to Timestamp(Microsecond, None)
84+
let millis_array = as_primitive_array::<TimestampMillisecondType>(&array);
85+
let micros_array: TimestampMicrosecondArray = millis_array
86+
.iter()
87+
.map(|opt| opt.map(|v| v * 1000))
88+
.collect();
89+
Ok(Arc::new(micros_array))
90+
}
91+
_ => {
92+
// Not supported
93+
panic!(
94+
"Cannot convert from {:?} to {:?}",
95+
array.data_type(),
96+
to_type.unwrap()
97+
)
98+
}
99+
}
100+
}
101+
DataType::Timestamp(TimeUnit::Microsecond, None) => {
102+
assert!(!timezone.is_empty());
103+
match to_type {
104+
Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array),
105+
Some(DataType::Timestamp(_, Some(_))) => {
106+
timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str()))
107+
}
108+
_ => {
109+
// Not supported
110+
panic!(
111+
"Cannot convert from {:?} to {:?}",
112+
array.data_type(),
113+
to_type.unwrap()
114+
)
115+
}
116+
}
117+
}
74118
DataType::Timestamp(_, None) => {
75119
assert!(!timezone.is_empty());
76120
match to_type {

0 commit comments

Comments
 (0)