Skip to content

Commit f7aad61

Browse files
committed
Df52 migration
1 parent 48032aa commit f7aad61

6 files changed

Lines changed: 93 additions & 97 deletions

File tree

native/core/src/execution/operators/scan.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ impl ScanExec {
9494

9595
// Build schema directly from data types since get_next now always unpacks dictionaries
9696
let schema = schema_from_data_types(&data_types);
97+
dbg!(&schema);
9798

9899
let cache = PlanProperties::new(
99100
EquivalenceProperties::new(Arc::clone(&schema)),
@@ -209,6 +210,8 @@ impl ScanExec {
209210

210211
let array = make_array(array_data);
211212

213+
dbg!(&array, &selection_indices_arrays);
214+
212215
// Apply selection if selection vectors exist (applies to all columns)
213216
let array = if let Some(ref selection_arrays) = selection_indices_arrays {
214217
let indices = &selection_arrays[i];
@@ -487,7 +490,7 @@ impl ScanStream<'_> {
487490
) -> DataFusionResult<RecordBatch, DataFusionError> {
488491
let schema_fields = self.schema.fields();
489492
assert_eq!(columns.len(), schema_fields.len());
490-
493+
dbg!(&columns, &self.schema);
491494
// Cast dictionary-encoded primitive arrays to regular arrays and cast
492495
// Utf8/LargeUtf8/Binary arrays to dictionary-encoded if the schema is
493496
// defined as dictionary-encoded and the data in this batch is not
@@ -507,6 +510,7 @@ impl ScanStream<'_> {
507510
})
508511
.collect::<Result<Vec<_>, _>>()?;
509512
let options = RecordBatchOptions::new().with_row_count(Some(num_rows));
513+
dbg!(&new_columns, &self.schema);
510514
RecordBatch::try_new_with_options(Arc::clone(&self.schema), new_columns, &options)
511515
.map_err(|e| arrow_datafusion_err!(e))
512516
}
@@ -517,6 +521,7 @@ impl Stream for ScanStream<'_> {
517521

518522
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
519523
let mut timer = self.baseline_metrics.elapsed_compute().timer();
524+
dbg!(&self.scan);
520525
let mut scan_batch = self.scan.batch.try_lock().unwrap();
521526

522527
let input_batch = &*scan_batch;

native/core/src/execution/planner.rs

Lines changed: 34 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -667,65 +667,31 @@ impl PhysicalPlanner {
667667
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
668668
let left = self.create_expr(left, Arc::clone(&input_schema))?;
669669
let right = self.create_expr(right, Arc::clone(&input_schema))?;
670-
let left_type = left.data_type(&input_schema);
671-
let right_type = right.data_type(&input_schema);
672-
match (&op, &left_type, &right_type) {
673-
// Handle date arithmetic with Int8/Int16/Int32 by:
674-
// 1. Casting Date32 to Int32 (days since epoch)
675-
// 2. Performing the arithmetic as Int32 +/- Int32
676-
// 3. Casting the result back to Date32 using DataFusion's CastExpr
677-
// Arrow's date arithmetic kernel only supports Date32 +/- Interval types
678-
// Note: We use DataFusion's CastExpr for the final cast because Spark's Cast
679-
// doesn't support Int32 -> Date32 conversion
680-
(
681-
DataFusionOperator::Plus | DataFusionOperator::Minus,
682-
Ok(DataType::Date32),
683-
Ok(DataType::Int8) | Ok(DataType::Int16) | Ok(DataType::Int32),
684-
) => {
685-
// Cast Date32 to Int32 (days since epoch)
686-
let left_as_int = Arc::new(Cast::new(
687-
left,
688-
DataType::Int32,
689-
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
690-
));
691-
// Cast Int8/Int16 to Int32 if needed
692-
let right_as_int: Arc<dyn PhysicalExpr> =
693-
if matches!(right_type, Ok(DataType::Int32)) {
694-
right
695-
} else {
696-
Arc::new(Cast::new(
697-
right,
698-
DataType::Int32,
699-
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
700-
))
701-
};
702-
// Perform the arithmetic as Int32 +/- Int32
703-
let result_int = Arc::new(BinaryExpr::new(left_as_int, op, right_as_int));
704-
// Cast the result back to Date32 using DataFusion's CastExpr
705-
// (Spark's Cast doesn't support Int32 -> Date32)
706-
Ok(Arc::new(CastExpr::new(result_int, DataType::Date32, None)))
707-
}
670+
match (
671+
&op,
672+
left.data_type(&input_schema),
673+
right.data_type(&input_schema),
674+
) {
708675
(
709676
DataFusionOperator::Plus | DataFusionOperator::Minus | DataFusionOperator::Multiply,
710677
Ok(DataType::Decimal128(p1, s1)),
711678
Ok(DataType::Decimal128(p2, s2)),
712679
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
713-
&& max(*s1, *s2) as u8 + max(*p1 - *s1 as u8, *p2 - *s2 as u8)
680+
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
714681
>= DECIMAL128_MAX_PRECISION)
715-
|| (op == DataFusionOperator::Multiply
716-
&& *p1 + *p2 >= DECIMAL128_MAX_PRECISION) =>
682+
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) =>
717683
{
718684
let data_type = return_type.map(to_arrow_datatype).unwrap();
719685
// For some Decimal128 operations, we need wider internal digits.
720686
// Cast left and right to Decimal256 and cast the result back to Decimal128
721687
let left = Arc::new(Cast::new(
722688
left,
723-
DataType::Decimal256(*p1, *s1),
689+
DataType::Decimal256(p1, s1),
724690
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
725691
));
726692
let right = Arc::new(Cast::new(
727693
right,
728-
DataType::Decimal256(*p2, *s2),
694+
DataType::Decimal256(p2, s2),
729695
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
730696
));
731697
let child = Arc::new(BinaryExpr::new(left, op, right));
@@ -999,6 +965,7 @@ impl PhysicalPlanner {
999965
))
1000966
}
1001967
OpStruct::NativeScan(scan) => {
968+
dbg!(&scan);
1002969
let data_schema = convert_spark_types_to_arrow_schema(scan.data_schema.as_slice());
1003970
let required_schema: SchemaRef =
1004971
convert_spark_types_to_arrow_schema(scan.required_schema.as_slice());
@@ -1146,6 +1113,7 @@ impl PhysicalPlanner {
11461113
))
11471114
}
11481115
OpStruct::Scan(scan) => {
1116+
dbg!(&scan);
11491117
let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec();
11501118

11511119
// If it is not test execution context for unit test, we should have at least one
@@ -1172,6 +1140,8 @@ impl PhysicalPlanner {
11721140
scan.arrow_ffi_safe,
11731141
)?;
11741142

1143+
dbg!(&scan);
1144+
11751145
Ok((
11761146
vec![scan.clone()],
11771147
Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])),
@@ -4411,12 +4381,10 @@ mod tests {
44114381
fn test_date_sub_with_int8_cast_error() {
44124382
use arrow::array::Date32Array;
44134383

4414-
let session_ctx = SessionContext::new();
4415-
let task_ctx = session_ctx.task_ctx();
4416-
let planner = PhysicalPlanner::new(Arc::from(session_ctx), 0);
4384+
let planner = PhysicalPlanner::default();
4385+
let row_count = 3;
44174386

4418-
// Create a scan operator with Date32 (DATE) and Int8 (TINYINT) columns
4419-
// This simulates the schema from the Scala test where _20 is DATE and _2 is TINYINT
4387+
// Create a Scan operator with Date32 (DATE) and Int8 (TINYINT) columns
44204388
let op_scan = Operator {
44214389
plan_id: 0,
44224390
children: vec![],
@@ -4431,7 +4399,7 @@ mod tests {
44314399
type_info: None,
44324400
},
44334401
],
4434-
source: "test".to_string(),
4402+
source: "".to_string(),
44354403
arrow_ffi_safe: false,
44364404
})),
44374405
};
@@ -4486,22 +4454,27 @@ mod tests {
44864454
let (mut scans, datafusion_plan) =
44874455
planner.create_plan(&projection, &mut vec![], 1).unwrap();
44884456

4489-
// Execute the plan with test data
4457+
// Create test data: Date32 and Int8 columns
4458+
let date_array = Date32Array::from(vec![Some(19000), Some(19001), Some(19002)]);
4459+
let int8_array = Int8Array::from(vec![Some(1i8), Some(2i8), Some(3i8)]);
4460+
4461+
// Set input batch for the scan
4462+
let input_batch = InputBatch::Batch(vec![Arc::new(date_array), Arc::new(int8_array)], row_count);
4463+
scans[0].set_input_batch(input_batch);
4464+
4465+
let session_ctx = SessionContext::new();
4466+
let task_ctx = session_ctx.task_ctx();
44904467
let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap();
44914468

44924469
let runtime = tokio::runtime::Runtime::new().unwrap();
44934470
let (tx, mut rx) = mpsc::channel(1);
44944471

4495-
// Send test data: Date32 values and Int8 values
4472+
// Separate thread to send the EOF signal once we've processed the only input batch
44964473
runtime.spawn(async move {
4497-
// Create Date32 array (days since epoch)
4498-
// 19000 days = approximately 2022-01-01
4474+
// Create test data again for the second batch
44994475
let date_array = Date32Array::from(vec![Some(19000), Some(19001), Some(19002)]);
4500-
// Create Int8 array
45014476
let int8_array = Int8Array::from(vec![Some(1i8), Some(2i8), Some(3i8)]);
4502-
4503-
let input_batch1 =
4504-
InputBatch::Batch(vec![Arc::new(date_array), Arc::new(int8_array)], 3);
4477+
let input_batch1 = InputBatch::Batch(vec![Arc::new(date_array), Arc::new(int8_array)], row_count);
45054478
let input_batch2 = InputBatch::EOF;
45064479

45074480
let batches = vec![input_batch1, input_batch2];
@@ -4511,7 +4484,6 @@ mod tests {
45114484
}
45124485
});
45134486

4514-
// Execute and expect success - the Int8 should be cast to Int32 for date arithmetic
45154487
runtime.block_on(async move {
45164488
loop {
45174489
let batch = rx.recv().await.unwrap();
@@ -4524,10 +4496,13 @@ mod tests {
45244496
"Expected success for date - int8 operation but got error: {:?}",
45254497
result.unwrap_err()
45264498
);
4499+
45274500
let batch = result.unwrap();
4528-
assert_eq!(batch.num_rows(), 3);
4501+
assert_eq!(batch.num_rows(), row_count);
4502+
45294503
// The result should be Date32 type
45304504
assert_eq!(batch.column(0).data_type(), &DataType::Date32);
4505+
45314506
// Verify the values: 19000-1=18999, 19001-2=18999, 19002-3=18999
45324507
let date_array = batch
45334508
.column(0)
@@ -4537,7 +4512,6 @@ mod tests {
45374512
assert_eq!(date_array.value(0), 18999); // 19000 - 1
45384513
assert_eq!(date_array.value(1), 18999); // 19001 - 2
45394514
assert_eq!(date_array.value(2), 18999); // 19002 - 3
4540-
break;
45414515
}
45424516
Poll::Ready(None) => {
45434517
break;

native/core/src/parquet/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
703703
key_unwrapper_obj: JObject,
704704
metrics_node: JObject,
705705
) -> jlong {
706+
dbg!("Java_org_apache_comet_parquet_Native_initRecordBatchReader");
706707
try_unwrap_or_throw(&e, |mut env| unsafe {
707708
JVMClasses::init(&mut env);
708709
let session_config = SessionConfig::new().with_batch_size(batch_size as usize);
@@ -776,6 +777,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
776777
encryption_enabled,
777778
)?;
778779

780+
dbg!(&scan);
781+
779782
let partition_index: usize = 0;
780783
let batch_stream = Some(scan.execute(partition_index, session_ctx.task_ctx())?);
781784

@@ -787,6 +790,9 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
787790
reader_state: ParquetReaderState::Init,
788791
};
789792
let res = Box::new(ctx);
793+
794+
dbg!("end Java_org_apache_comet_parquet_Native_initRecordBatchReader");
795+
790796
Ok(Box::into_raw(res) as i64)
791797
})
792798
}

native/core/src/parquet/parquet_exec.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ pub(crate) fn init_datasource_exec(
9696
TableSchema::from_file_schema(Arc::clone(&required_schema))
9797
};
9898

99+
dbg!(&table_schema);
100+
99101
let mut parquet_source =
100102
ParquetSource::new(table_schema).with_table_parquet_options(table_parquet_options);
101103

@@ -135,12 +137,11 @@ pub(crate) fn init_datasource_exec(
135137
.collect();
136138

137139
let mut file_scan_config_builder =
138-
FileScanConfigBuilder::new(object_store_url, file_source).with_file_groups(file_groups);
140+
FileScanConfigBuilder::new(object_store_url, file_source).with_file_groups(file_groups).with_expr_adapter(Some(expr_adapter_factory));
139141

140142
if let Some(projection_vector) = projection_vector {
141143
file_scan_config_builder = file_scan_config_builder
142-
.with_projection_indices(Some(projection_vector))?
143-
.with_expr_adapter(Some(expr_adapter_factory));
144+
.with_projection_indices(Some(projection_vector))?;
144145
}
145146

146147
let file_scan_config = file_scan_config_builder.build();

native/core/src/parquet/schema_adapter.rs

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
use crate::parquet::parquet_support::{spark_parquet_convert, SparkParquetOptions};
2727
use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions};
28-
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
28+
use arrow::datatypes::{Field, Schema, SchemaRef};
2929
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
3030
use datafusion::common::{ColumnStatistics, Result as DataFusionResult};
3131
use datafusion::datasource::schema_adapter::{SchemaAdapter, SchemaAdapterFactory, SchemaMapper};
@@ -102,7 +102,6 @@ impl PhysicalExprAdapterFactory for SparkPhysicalExprAdapterFactory {
102102
struct SparkPhysicalExprAdapter {
103103
/// The logical schema expected by the query
104104
logical_file_schema: SchemaRef,
105-
#[allow(dead_code)]
106105
/// The physical schema of the actual file being read
107106
physical_file_schema: SchemaRef,
108107
/// Spark-specific options for type conversions
@@ -161,44 +160,56 @@ impl SparkPhysicalExprAdapter {
161160
Ok(Transformed::no(expr))
162161
}
163162

164-
// Cast expressions that currently not supported in DF
165-
// For example, Arrow's date arithmetic kernel only supports Date32 +/- Int32 (days)
166-
// but Spark may send Int8/Int16 values. We need to cast them to Int32.
163+
/// Cast Column expressions where the physical and logical datatypes differ.
164+
///
165+
/// This function traverses the expression tree and for each Column expression,
166+
/// checks if the physical file schema datatype differs from the logical file schema
167+
/// datatype. If they differ, it wraps the Column with a CastColumnExpr to perform
168+
/// the necessary type conversion.
167169
fn cast_datafusion_unsupported_expr(
168170
&self,
169171
expr: Arc<dyn PhysicalExpr>,
170172
) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
171-
use datafusion::logical_expr::Operator;
172-
use datafusion::physical_expr::expressions::{BinaryExpr, CastColumnExpr};
173+
use datafusion::physical_expr::expressions::CastColumnExpr;
173174

174175
expr.transform(|e| {
175-
// Check if this is a BinaryExpr with date arithmetic
176-
if let Some(binary) = e.as_any().downcast_ref::<BinaryExpr>() {
177-
let op = binary.op();
178-
// Only handle Plus and Minus for date arithmetic
179-
if matches!(op, &Operator::Plus | &Operator::Minus) {
180-
let left = binary.left();
181-
let right = binary.right();
182-
183-
let left_type = left.data_type(&self.logical_file_schema);
184-
let right_type = right.data_type(&self.logical_file_schema);
185-
186-
// Check for Date32 +/- Int8 or Date32 +/- Int16
187-
if let (Ok(DataType::Date32), Ok(ref rt @ (DataType::Int8 | DataType::Int16))) =
188-
(&left_type, &right_type)
189-
{
190-
// Cast the right operand (Int8/Int16) to Int32
191-
let input_field = Arc::new(Field::new("input", rt.clone(), true));
192-
let target_field = Arc::new(Field::new("cast", DataType::Int32, true));
193-
let casted_right: Arc<dyn PhysicalExpr> = Arc::new(CastColumnExpr::new(
194-
Arc::clone(right),
176+
// Check if this is a Column expression
177+
if let Some(column) = e.as_any().downcast_ref::<Column>() {
178+
let col_idx = column.index();
179+
180+
// Get the logical datatype (expected by the query)
181+
let logical_field = self.logical_file_schema.fields().get(col_idx);
182+
// Get the physical datatype (actual file schema)
183+
let physical_field = self.physical_file_schema.fields().get(col_idx);
184+
185+
dbg!(&logical_field, &physical_field);
186+
187+
if let (Some(logical_field), Some(physical_field)) = (logical_field, physical_field)
188+
{
189+
let logical_type = logical_field.data_type();
190+
let physical_type = physical_field.data_type();
191+
192+
// If datatypes differ, insert a CastColumnExpr
193+
if logical_type != physical_type || 1==1 {
194+
let input_field = Arc::new(Field::new(
195+
physical_field.name(),
196+
physical_type.clone(),
197+
physical_field.is_nullable(),
198+
));
199+
let target_field = Arc::new(Field::new(
200+
logical_field.name(),
201+
logical_type.clone(),
202+
logical_field.is_nullable(),
203+
));
204+
205+
let cast_expr: Arc<dyn PhysicalExpr> = Arc::new(CastColumnExpr::new(
206+
e.clone(),
195207
input_field,
196208
target_field,
197209
None,
198210
));
199-
let new_binary: Arc<dyn PhysicalExpr> =
200-
Arc::new(BinaryExpr::new(Arc::clone(left), *op, casted_right));
201-
return Ok(Transformed::yes(new_binary));
211+
dbg!(&cast_expr);
212+
return Ok(Transformed::yes(cast_expr));
202213
}
203214
}
204215
}
@@ -459,7 +470,6 @@ impl SchemaMapper for SchemaMapping {
459470
/// columns, so if one needs a RecordBatch with a schema that references columns which are not
460471
/// in the projected, it would be better to use `map_partial_batch`
461472
fn map_batch(&self, batch: RecordBatch) -> datafusion::common::Result<RecordBatch> {
462-
dbg!("map_batch");
463473
let batch_rows = batch.num_rows();
464474
let batch_cols = batch.columns().to_vec();
465475

0 commit comments

Comments
 (0)