@@ -667,31 +667,65 @@ impl PhysicalPlanner {
667667 ) -> Result < Arc < dyn PhysicalExpr > , ExecutionError > {
668668 let left = self . create_expr ( left, Arc :: clone ( & input_schema) ) ?;
669669 let right = self . create_expr ( right, Arc :: clone ( & input_schema) ) ?;
670- match (
671- & op,
672- left. data_type ( & input_schema) ,
673- right. data_type ( & input_schema) ,
674- ) {
670+ let left_type = left. data_type ( & input_schema) ;
671+ let right_type = right. data_type ( & input_schema) ;
672+ match ( & op, & left_type, & right_type) {
673+ // Handle date arithmetic with Int8/Int16/Int32 by:
674+ // 1. Casting Date32 to Int32 (days since epoch)
675+ // 2. Performing the arithmetic as Int32 +/- Int32
676+ // 3. Casting the result back to Date32 using DataFusion's CastExpr
677+ // Arrow's date arithmetic kernel only supports Date32 +/- Interval types
678+ // Note: We use DataFusion's CastExpr for the final cast because Spark's Cast
679+ // doesn't support Int32 -> Date32 conversion
680+ (
681+ DataFusionOperator :: Plus | DataFusionOperator :: Minus ,
682+ Ok ( DataType :: Date32 ) ,
683+ Ok ( DataType :: Int8 ) | Ok ( DataType :: Int16 ) | Ok ( DataType :: Int32 ) ,
684+ ) => {
685+ // Cast Date32 to Int32 (days since epoch)
686+ let left_as_int = Arc :: new ( Cast :: new (
687+ left,
688+ DataType :: Int32 ,
689+ SparkCastOptions :: new_without_timezone ( EvalMode :: Legacy , false ) ,
690+ ) ) ;
691+ // Cast Int8/Int16 to Int32 if needed
692+ let right_as_int: Arc < dyn PhysicalExpr > =
693+ if matches ! ( right_type, Ok ( DataType :: Int32 ) ) {
694+ right
695+ } else {
696+ Arc :: new ( Cast :: new (
697+ right,
698+ DataType :: Int32 ,
699+ SparkCastOptions :: new_without_timezone ( EvalMode :: Legacy , false ) ,
700+ ) )
701+ } ;
702+ // Perform the arithmetic as Int32 +/- Int32
703+ let result_int = Arc :: new ( BinaryExpr :: new ( left_as_int, op, right_as_int) ) ;
704+ // Cast the result back to Date32 using DataFusion's CastExpr
705+ // (Spark's Cast doesn't support Int32 -> Date32)
706+ Ok ( Arc :: new ( CastExpr :: new ( result_int, DataType :: Date32 , None ) ) )
707+ }
675708 (
676709 DataFusionOperator :: Plus | DataFusionOperator :: Minus | DataFusionOperator :: Multiply ,
677710 Ok ( DataType :: Decimal128 ( p1, s1) ) ,
678711 Ok ( DataType :: Decimal128 ( p2, s2) ) ,
679712 ) if ( ( op == DataFusionOperator :: Plus || op == DataFusionOperator :: Minus )
680- && max ( s1, s2) as u8 + max ( p1 - s1 as u8 , p2 - s2 as u8 )
713+ && max ( * s1, * s2) as u8 + max ( * p1 - * s1 as u8 , * p2 - * s2 as u8 )
681714 >= DECIMAL128_MAX_PRECISION )
682- || ( op == DataFusionOperator :: Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION ) =>
715+ || ( op == DataFusionOperator :: Multiply
716+ && * p1 + * p2 >= DECIMAL128_MAX_PRECISION ) =>
683717 {
684718 let data_type = return_type. map ( to_arrow_datatype) . unwrap ( ) ;
685719 // For some Decimal128 operations, we need wider internal digits.
686720 // Cast left and right to Decimal256 and cast the result back to Decimal128
687721 let left = Arc :: new ( Cast :: new (
688722 left,
689- DataType :: Decimal256 ( p1, s1) ,
723+ DataType :: Decimal256 ( * p1, * s1) ,
690724 SparkCastOptions :: new_without_timezone ( EvalMode :: Legacy , false ) ,
691725 ) ) ;
692726 let right = Arc :: new ( Cast :: new (
693727 right,
694- DataType :: Decimal256 ( p2, s2) ,
728+ DataType :: Decimal256 ( * p2, * s2) ,
695729 SparkCastOptions :: new_without_timezone ( EvalMode :: Legacy , false ) ,
696730 ) ) ;
697731 let child = Arc :: new ( BinaryExpr :: new ( left, op, right) ) ;
@@ -3432,7 +3466,9 @@ mod tests {
34323466 use futures:: { poll, StreamExt } ;
34333467 use std:: { sync:: Arc , task:: Poll } ;
34343468
3435- use arrow:: array:: { Array , DictionaryArray , Int32Array , ListArray , RecordBatch , StringArray } ;
3469+ use arrow:: array:: {
3470+ Array , DictionaryArray , Int32Array , Int8Array , ListArray , RecordBatch , StringArray ,
3471+ } ;
34363472 use arrow:: datatypes:: { DataType , Field , FieldRef , Fields , Schema } ;
34373473 use datafusion:: catalog:: memory:: DataSourceExec ;
34383474 use datafusion:: config:: TableParquetOptions ;
@@ -4364,4 +4400,151 @@ mod tests {
43644400
43654401 Ok ( ( ) )
43664402 }
4403+
4404+ /// Test that reproduces the "Cast error: Casting from Int8 to Date32 not supported" error
4405+ /// that occurs when performing date subtraction with Int8 (TINYINT) values.
4406+ /// This corresponds to the Scala test "date_sub with int arrays" in CometExpressionSuite.
4407+ ///
4408+ /// The error occurs because DataFusion's BinaryExpr tries to cast Int8 to Date32
4409+ /// when evaluating date - int8, but this cast is not supported.
4410+ #[ test]
4411+ fn test_date_sub_with_int8_cast_error ( ) {
4412+ use arrow:: array:: Date32Array ;
4413+
4414+ let session_ctx = SessionContext :: new ( ) ;
4415+ let task_ctx = session_ctx. task_ctx ( ) ;
4416+ let planner = PhysicalPlanner :: new ( Arc :: from ( session_ctx) , 0 ) ;
4417+
4418+ // Create a scan operator with Date32 (DATE) and Int8 (TINYINT) columns
4419+ // This simulates the schema from the Scala test where _20 is DATE and _2 is TINYINT
4420+ let op_scan = Operator {
4421+ plan_id : 0 ,
4422+ children : vec ! [ ] ,
4423+ op_struct : Some ( OpStruct :: Scan ( spark_operator:: Scan {
4424+ fields : vec ! [
4425+ spark_expression:: DataType {
4426+ type_id: 12 , // DATE (Date32)
4427+ type_info: None ,
4428+ } ,
4429+ spark_expression:: DataType {
4430+ type_id: 1 , // INT8 (TINYINT)
4431+ type_info: None ,
4432+ } ,
4433+ ] ,
4434+ source : "test" . to_string ( ) ,
4435+ arrow_ffi_safe : false ,
4436+ } ) ) ,
4437+ } ;
4438+
4439+ // Create bound reference for the DATE column (index 0)
4440+ let date_col = spark_expression:: Expr {
4441+ expr_struct : Some ( Bound ( spark_expression:: BoundReference {
4442+ index : 0 ,
4443+ datatype : Some ( spark_expression:: DataType {
4444+ type_id : 12 , // DATE
4445+ type_info : None ,
4446+ } ) ,
4447+ } ) ) ,
4448+ } ;
4449+
4450+ // Create bound reference for the INT8 column (index 1)
4451+ let int8_col = spark_expression:: Expr {
4452+ expr_struct : Some ( Bound ( spark_expression:: BoundReference {
4453+ index : 1 ,
4454+ datatype : Some ( spark_expression:: DataType {
4455+ type_id : 1 , // INT8
4456+ type_info : None ,
4457+ } ) ,
4458+ } ) ) ,
4459+ } ;
4460+
4461+ // Create a Subtract expression: date_col - int8_col
4462+ // This is equivalent to the SQL: SELECT _20 - _2 FROM tbl (date_sub operation)
4463+ // In the protobuf, subtract uses MathExpr type
4464+ let subtract_expr = spark_expression:: Expr {
4465+ expr_struct : Some ( ExprStruct :: Subtract ( Box :: new ( spark_expression:: MathExpr {
4466+ left : Some ( Box :: new ( date_col) ) ,
4467+ right : Some ( Box :: new ( int8_col) ) ,
4468+ return_type : Some ( spark_expression:: DataType {
4469+ type_id : 12 , // DATE - result should be DATE
4470+ type_info : None ,
4471+ } ) ,
4472+ eval_mode : 0 , // Legacy mode
4473+ } ) ) ) ,
4474+ } ;
4475+
4476+ // Create a projection operator with the subtract expression
4477+ let projection = Operator {
4478+ children : vec ! [ op_scan] ,
4479+ plan_id : 1 ,
4480+ op_struct : Some ( OpStruct :: Projection ( spark_operator:: Projection {
4481+ project_list : vec ! [ subtract_expr] ,
4482+ } ) ) ,
4483+ } ;
4484+
4485+ // Create the physical plan
4486+ let ( mut scans, datafusion_plan) =
4487+ planner. create_plan ( & projection, & mut vec ! [ ] , 1 ) . unwrap ( ) ;
4488+
4489+ // Execute the plan with test data
4490+ let mut stream = datafusion_plan. native_plan . execute ( 0 , task_ctx) . unwrap ( ) ;
4491+
4492+ let runtime = tokio:: runtime:: Runtime :: new ( ) . unwrap ( ) ;
4493+ let ( tx, mut rx) = mpsc:: channel ( 1 ) ;
4494+
4495+ // Send test data: Date32 values and Int8 values
4496+ runtime. spawn ( async move {
4497+ // Create Date32 array (days since epoch)
4498+ // 19000 days = approximately 2022-01-01
4499+ let date_array = Date32Array :: from ( vec ! [ Some ( 19000 ) , Some ( 19001 ) , Some ( 19002 ) ] ) ;
4500+ // Create Int8 array
4501+ let int8_array = Int8Array :: from ( vec ! [ Some ( 1i8 ) , Some ( 2i8 ) , Some ( 3i8 ) ] ) ;
4502+
4503+ let input_batch1 =
4504+ InputBatch :: Batch ( vec ! [ Arc :: new( date_array) , Arc :: new( int8_array) ] , 3 ) ;
4505+ let input_batch2 = InputBatch :: EOF ;
4506+
4507+ let batches = vec ! [ input_batch1, input_batch2] ;
4508+
4509+ for batch in batches. into_iter ( ) {
4510+ tx. send ( batch) . await . unwrap ( ) ;
4511+ }
4512+ } ) ;
4513+
4514+ // Execute and expect success - the Int8 should be cast to Int32 for date arithmetic
4515+ runtime. block_on ( async move {
4516+ loop {
4517+ let batch = rx. recv ( ) . await . unwrap ( ) ;
4518+ scans[ 0 ] . set_input_batch ( batch) ;
4519+ match poll ! ( stream. next( ) ) {
4520+ Poll :: Ready ( Some ( result) ) => {
4521+ // We expect success - the Int8 should be automatically cast to Int32
4522+ assert ! (
4523+ result. is_ok( ) ,
4524+ "Expected success for date - int8 operation but got error: {:?}" ,
4525+ result. unwrap_err( )
4526+ ) ;
4527+ let batch = result. unwrap ( ) ;
4528+ assert_eq ! ( batch. num_rows( ) , 3 ) ;
4529+ // The result should be Date32 type
4530+ assert_eq ! ( batch. column( 0 ) . data_type( ) , & DataType :: Date32 ) ;
4531+ // Verify the values: 19000-1=18999, 19001-2=18999, 19002-3=18999
4532+ let date_array = batch
4533+ . column ( 0 )
4534+ . as_any ( )
4535+ . downcast_ref :: < Date32Array > ( )
4536+ . unwrap ( ) ;
4537+ assert_eq ! ( date_array. value( 0 ) , 18999 ) ; // 19000 - 1
4538+ assert_eq ! ( date_array. value( 1 ) , 18999 ) ; // 19001 - 2
4539+ assert_eq ! ( date_array. value( 2 ) , 18999 ) ; // 19002 - 3
4540+ break ;
4541+ }
4542+ Poll :: Ready ( None ) => {
4543+ break ;
4544+ }
4545+ _ => { }
4546+ }
4547+ }
4548+ } ) ;
4549+ }
43674550}
0 commit comments