Skip to content

Commit 30c57ce

Browse files
committed
DataFusion 52 migration
1 parent 077005c commit 30c57ce

7 files changed

Lines changed: 80 additions & 77 deletions

File tree

native/Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ arrow = { version = "57.2.0", features = ["prettyprint", "ffi", "chrono-tz"] }
3838
async-trait = { version = "0.1" }
3939
bytes = { version = "1.10.0" }
4040
parquet = { version = "57.2.0", default-features = false, features = ["experimental"] }
41-
datafusion = { version = "51.0.0", default-features = false, features = ["unicode_expressions", "crypto_expressions", "nested_expressions", "parquet"] }
42-
datafusion-datasource = { version = "51.0.0" }
43-
datafusion-spark = { version = "51.0.0" }
41+
datafusion = { git = "https://github.com/apache/datafusion", branch = "branch-52", default-features = false, features = ["unicode_expressions", "crypto_expressions", "nested_expressions", "parquet"] }
42+
datafusion-datasource = { git = "https://github.com/apache/datafusion", branch = "branch-52" }
43+
datafusion-spark = { git = "https://github.com/apache/datafusion", branch = "branch-52" }
4444
datafusion-comet-spark-expr = { path = "spark-expr" }
4545
datafusion-comet-proto = { path = "proto" }
4646
chrono = { version = "0.4", default-features = false, features = ["clock"] }

native/core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ jni = { version = "0.21", features = ["invocation"] }
9595
lazy_static = "1.4"
9696
assertables = "9"
9797
hex = "0.4.3"
98-
datafusion-functions-nested = { version = "51.0.0" }
98+
datafusion-functions-nested = { git = "https://github.com/apache/datafusion", branch = "branch-52" }
9999

100100
[features]
101101
backtrace = ["datafusion/backtrace"]

native/core/src/execution/planner.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3420,6 +3420,7 @@ mod tests {
34203420
use arrow::array::{Array, DictionaryArray, Int32Array, ListArray, RecordBatch, StringArray};
34213421
use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema};
34223422
use datafusion::catalog::memory::DataSourceExec;
3423+
use datafusion::config::TableParquetOptions;
34233424
use datafusion::datasource::listing::PartitionedFile;
34243425
use datafusion::datasource::object_store::ObjectStoreUrl;
34253426
use datafusion::datasource::physical_plan::{
@@ -4039,16 +4040,14 @@ mod tests {
40394040
}
40404041
}
40414042

4042-
let source = ParquetSource::default().with_schema_adapter_factory(Arc::new(
4043-
SparkSchemaAdapterFactory::new(
4044-
SparkParquetOptions::new(EvalMode::Ansi, "", false),
4045-
None,
4046-
),
4047-
))?;
4043+
let source = Arc::new(
4044+
ParquetSource::new(Arc::new(read_schema.clone()))
4045+
.with_table_parquet_options(TableParquetOptions::new())
4046+
) as Arc<dyn FileSource>;
40484047

40494048
let object_store_url = ObjectStoreUrl::local_filesystem();
40504049
let file_scan_config =
4051-
FileScanConfigBuilder::new(object_store_url, read_schema.into(), source)
4050+
FileScanConfigBuilder::new(object_store_url, source)
40524051
.with_file_groups(file_groups)
40534052
.build();
40544053

native/core/src/parquet/parquet_exec.rs

Lines changed: 28 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion::physical_expr::PhysicalExpr;
3232
use datafusion::prelude::SessionContext;
3333
use datafusion::scalar::ScalarValue;
3434
use datafusion_comet_spark_expr::EvalMode;
35+
use datafusion_datasource::TableSchema;
3536
use itertools::Itertools;
3637
use std::collections::HashMap;
3738
use std::sync::Arc;
@@ -78,7 +79,24 @@ pub(crate) fn init_datasource_exec(
7879
encryption_enabled,
7980
);
8081

81-
let mut parquet_source = ParquetSource::new(table_parquet_options);
82+
// Determine the schema to use for ParquetSource
83+
let table_schema = if let Some(ref data_schema) = data_schema {
84+
if let Some(ref partition_schema) = partition_schema {
85+
let partition_fields: Vec<_> = partition_schema
86+
.fields()
87+
.iter()
88+
.map(|f| Arc::new(Field::new(f.name(), f.data_type().clone(), f.is_nullable())) as _)
89+
.collect();
90+
TableSchema::new(Arc::clone(data_schema), partition_fields)
91+
} else {
92+
TableSchema::from_file_schema(Arc::clone(data_schema))
93+
}
94+
} else {
95+
TableSchema::from_file_schema(Arc::clone(&required_schema))
96+
};
97+
98+
let mut parquet_source = ParquetSource::new(table_schema)
99+
.with_table_parquet_options(table_parquet_options);
82100

83101
// Create a conjunctive form of the vector because ParquetExecBuilder takes
84102
// a single expression
@@ -104,37 +122,21 @@ pub(crate) fn init_datasource_exec(
104122
);
105123
}
106124

107-
let file_source = parquet_source.with_schema_adapter_factory(Arc::new(
108-
SparkSchemaAdapterFactory::new(spark_parquet_options, default_values),
109-
))?;
125+
let file_source = Arc::new(parquet_source) as Arc<dyn FileSource>;
110126

111127
let file_groups = file_groups
112128
.iter()
113129
.map(|files| FileGroup::new(files.clone()))
114130
.collect();
115131

116-
let file_scan_config = match (data_schema, projection_vector, partition_fields) {
117-
(Some(data_schema), Some(projection_vector), Some(partition_fields)) => {
118-
get_file_config_builder(
119-
data_schema,
120-
partition_schema,
121-
file_groups,
122-
object_store_url,
123-
file_source,
124-
)
125-
.with_projection_indices(Some(projection_vector))
126-
.with_table_partition_cols(partition_fields)
127-
.build()
128-
}
129-
_ => get_file_config_builder(
130-
required_schema,
131-
partition_schema,
132-
file_groups,
133-
object_store_url,
134-
file_source,
135-
)
136-
.build(),
137-
};
132+
let mut file_scan_config_builder = FileScanConfigBuilder::new(object_store_url, file_source)
133+
.with_file_groups(file_groups);
134+
135+
if let Some(projection_vector) = projection_vector {
136+
file_scan_config_builder = file_scan_config_builder.with_projection_indices(Some(projection_vector))?;
137+
}
138+
139+
let file_scan_config = file_scan_config_builder.build();
138140

139141
Ok(Arc::new(DataSourceExec::new(Arc::new(file_scan_config))))
140142
}
@@ -165,28 +167,3 @@ fn get_options(
165167

166168
(table_parquet_options, spark_parquet_options)
167169
}
168-
169-
fn get_file_config_builder(
170-
schema: SchemaRef,
171-
partition_schema: Option<SchemaRef>,
172-
file_groups: Vec<FileGroup>,
173-
object_store_url: ObjectStoreUrl,
174-
file_source: Arc<dyn FileSource>,
175-
) -> FileScanConfigBuilder {
176-
match partition_schema {
177-
Some(partition_schema) => {
178-
let partition_fields: Vec<Field> = partition_schema
179-
.fields()
180-
.iter()
181-
.map(|field| {
182-
Field::new(field.name(), field.data_type().clone(), field.is_nullable())
183-
})
184-
.collect_vec();
185-
FileScanConfigBuilder::new(object_store_url, Arc::clone(&schema), file_source)
186-
.with_file_groups(file_groups)
187-
.with_table_partition_cols(partition_fields)
188-
}
189-
_ => FileScanConfigBuilder::new(object_store_url, Arc::clone(&schema), file_source)
190-
.with_file_groups(file_groups),
191-
}
192-
}

native/core/src/parquet/schema_adapter.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,14 +344,14 @@ mod test {
344344
let mut spark_parquet_options = SparkParquetOptions::new(EvalMode::Legacy, "UTC", false);
345345
spark_parquet_options.allow_cast_unsigned_ints = true;
346346

347-
let parquet_source =
348-
ParquetSource::new(TableParquetOptions::new()).with_schema_adapter_factory(
349-
Arc::new(SparkSchemaAdapterFactory::new(spark_parquet_options, None)),
350-
)?;
347+
let parquet_source = Arc::new(
348+
ParquetSource::new(Arc::clone(&required_schema))
349+
.with_table_parquet_options(TableParquetOptions::new())
350+
) as Arc<dyn FileSource>;
351351

352352
let files = FileGroup::new(vec![PartitionedFile::from_path(filename.to_string())?]);
353353
let file_scan_config =
354-
FileScanConfigBuilder::new(object_store_url, required_schema, parquet_source)
354+
FileScanConfigBuilder::new(object_store_url, parquet_source)
355355
.with_file_groups(vec![files])
356356
.build();
357357

native/spark-expr/src/agg_funcs/covariance.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use arrow::{
2424
datatypes::{DataType, Field},
2525
};
2626
use datafusion::common::{
27-
downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue,
27+
downcast_value, unwrap_or_internal_err, Result, ScalarValue,
2828
};
2929
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
3030
use datafusion::logical_expr::type_coercion::aggregates::NUMERICS;

native/spark-expr/src/math_funcs/round.rs

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@ use crate::arithmetic_overflow_error;
1919
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
2020
use arrow::array::{Array, ArrowNativeTypeOp};
2121
use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
22-
use arrow::datatypes::DataType;
22+
use arrow::datatypes::{DataType, Field};
2323
use arrow::error::ArrowError;
2424
use datafusion::common::{exec_err, internal_err, DataFusionError, ScalarValue};
25-
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
25+
use datafusion::common::config::ConfigOptions;
26+
use datafusion::physical_plan::ColumnarValue;
27+
use datafusion::functions::math::round::RoundFunc;
28+
use datafusion::logical_expr::{ScalarUDFImpl, ScalarFunctionArgs};
2629
use std::{cmp::min, sync::Arc};
2730

2831
macro_rules! integer_round {
@@ -126,10 +129,21 @@ pub fn spark_round(
126129
let (precision, scale) = get_precision_scale(data_type);
127130
make_decimal_array(array, precision, scale, &f)
128131
}
129-
DataType::Float32 | DataType::Float64 => Ok(ColumnarValue::Array(round(&[
130-
Arc::clone(array),
131-
args[1].to_array(array.len())?,
132-
])?)),
132+
DataType::Float32 | DataType::Float64 => {
133+
let round_udf = RoundFunc::new();
134+
let return_field = Arc::new(Field::new("round", array.data_type().clone(), true));
135+
let args_for_round = ScalarFunctionArgs {
136+
args: vec![
137+
ColumnarValue::Array(Arc::clone(array)),
138+
args[1].clone(),
139+
],
140+
number_rows: array.len(),
141+
return_field,
142+
arg_fields: vec![],
143+
config_options: Arc::new(ConfigOptions::default()),
144+
};
145+
round_udf.invoke_with_args(args_for_round)
146+
}
133147
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
134148
},
135149
ColumnarValue::Scalar(a) => match a {
@@ -150,9 +164,22 @@ pub fn spark_round(
150164
let (precision, scale) = get_precision_scale(data_type);
151165
make_decimal_scalar(a, precision, scale, &f)
152166
}
153-
ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar(
154-
ScalarValue::try_from_array(&round(&[a.to_array()?, args[1].to_array(1)?])?, 0)?,
155-
)),
167+
ScalarValue::Float32(_) | ScalarValue::Float64(_) => {
168+
let round_udf = RoundFunc::new();
169+
let data_type = a.data_type();
170+
let return_field = Arc::new(Field::new("round", data_type, true));
171+
let args_for_round = ScalarFunctionArgs {
172+
args: vec![
173+
ColumnarValue::Scalar(a.clone()),
174+
args[1].clone(),
175+
],
176+
number_rows: 1,
177+
return_field,
178+
arg_fields: vec![],
179+
config_options: Arc::new(ConfigOptions::default()),
180+
};
181+
round_udf.invoke_with_args(args_for_round)
182+
}
156183
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
157184
},
158185
}

0 commit comments

Comments
 (0)