@@ -116,8 +116,8 @@ use datafusion_comet_proto::{
116116 } ,
117117 spark_operator:: {
118118 self , lower_window_frame_bound:: LowerFrameBoundStruct , operator:: OpStruct ,
119- upper_window_frame_bound:: UpperFrameBoundStruct , BuildSide ,
120- CompressionCodec as SparkCompressionCodec , JoinType , Operator , WindowFrameType ,
119+ upper_window_frame_bound:: UpperFrameBoundStruct , AggregateMode as ProtoAggregateMode ,
120+ BuildSide , CompressionCodec as SparkCompressionCodec , JoinType , Operator , WindowFrameType ,
121121 } ,
122122 spark_partitioning:: { partitioning:: PartitioningStruct , Partitioning as SparkPartitioning } ,
123123} ;
@@ -967,26 +967,24 @@ impl PhysicalPlanner {
967967 let group_by = PhysicalGroupBy :: new_single ( group_exprs?) ;
968968 let schema = child. schema ( ) ;
969969
970+ let partial_merge = ProtoAggregateMode :: PartialMerge as i32 ;
971+
970972 let mode = match agg. mode {
971973 0 => DFAggregateMode :: Partial ,
972974 1 => DFAggregateMode :: Final ,
973- 2 = > DFAggregateMode :: Partial , // PartialMerge uses Partial + MergeAsPartial
975+ m if m == partial_merge = > DFAggregateMode :: Partial ,
974976 other => {
975977 return Err ( ExecutionError :: GeneralError ( format ! (
976978 "Unsupported aggregate mode: {other}"
977979 ) ) )
978980 }
979981 } ;
980982
981- // Determine per-expression modes. PartialMerge (2) expressions use
982- // MergeAsPartial wrapper so they run merge semantics in Partial mode.
983- let per_expr_modes: Vec < i32 > = if !agg. expr_modes . is_empty ( ) {
984- agg. expr_modes . clone ( )
985- } else {
986- vec ! [ agg. mode; agg. agg_exprs. len( ) ]
987- } ;
988-
989- let has_partial_merge = per_expr_modes. contains ( & 2 ) ;
983+ // Check if any expression uses PartialMerge mode. When present,
984+ // those expressions are wrapped with MergeAsPartial to get merge
985+ // semantics inside a Partial-mode AggregateExec.
986+ let has_partial_merge =
987+ agg. mode == partial_merge || agg. expr_modes . contains ( & partial_merge) ;
990988
991989 let agg_exprs: PhyAggResult = agg
992990 . agg_exprs
@@ -998,51 +996,57 @@ impl PhysicalPlanner {
998996 // Wrap PartialMerge expressions with MergeAsPartial.
999997 // State fields in the child's output start at initial_input_buffer_offset.
1000998 let mut state_offset = agg. initial_input_buffer_offset as usize ;
1001- let child_schema = child. schema ( ) ;
999+ let per_expr_modes: Vec < i32 > = if !agg. expr_modes . is_empty ( ) {
1000+ agg. expr_modes . clone ( )
1001+ } else {
1002+ vec ! [ agg. mode; agg. agg_exprs. len( ) ]
1003+ } ;
10021004
10031005 agg_exprs?
10041006 . into_iter ( )
10051007 . enumerate ( )
10061008 . map ( |( idx, expr) | {
1007- let expr_mode = per_expr_modes[ idx] ;
1008- if expr_mode == 2 {
1009+ if per_expr_modes[ idx] == partial_merge {
10091010 // PartialMerge: wrap with MergeAsPartial
1010- let state_fields = expr. state_fields ( ) . map_err ( |e| {
1011- ExecutionError :: GeneralError ( e . to_string ( ) )
1012- } ) ?;
1011+ let state_fields = expr
1012+ . state_fields ( )
1013+ . map_err ( |e| ExecutionError :: GeneralError ( e . to_string ( ) ) ) ?;
10131014 let num_state_fields = state_fields. len ( ) ;
10141015
1015- // Create Column refs pointing to state field positions
10161016 let state_cols: Vec < Arc < dyn PhysicalExpr > > = ( 0 ..num_state_fields)
10171017 . map ( |i| {
10181018 let col_idx = state_offset + i;
1019- let field = child_schema . field ( col_idx) ;
1019+ let field = schema . field ( col_idx) ;
10201020 Arc :: new ( Column :: new ( field. name ( ) , col_idx) )
10211021 as Arc < dyn PhysicalExpr >
10221022 } )
10231023 . collect ( ) ;
10241024 state_offset += num_state_fields;
10251025
1026- let merge_udf = crate :: execution:: merge_as_partial:: MergeAsPartialUDF :: new ( & expr)
1026+ let merge_udf =
1027+ crate :: execution:: merge_as_partial:: MergeAsPartialUDF :: new (
1028+ & expr,
1029+ )
10271030 . map_err ( |e| ExecutionError :: DataFusionError ( e. to_string ( ) ) ) ?;
10281031 let merge_udf_arc = Arc :: new (
1029- datafusion:: logical_expr:: AggregateUDF :: new_from_impl ( merge_udf) ,
1032+ datafusion:: logical_expr:: AggregateUDF :: new_from_impl (
1033+ merge_udf,
1034+ ) ,
10301035 ) ;
10311036
1032- let merge_expr = AggregateExprBuilder :: new (
1033- merge_udf_arc,
1034- state_cols ,
1035- )
1036- . schema ( Arc :: clone ( & child_schema ) )
1037- . alias ( format ! ( "col_{idx}" ) )
1038- . with_ignore_nulls ( expr . ignore_nulls ( ) )
1039- . with_distinct ( expr . is_distinct ( ) )
1040- . build ( )
1041- . map_err ( |e| ExecutionError :: DataFusionError ( e . to_string ( ) ) ) ?;
1037+ let merge_expr =
1038+ AggregateExprBuilder :: new ( merge_udf_arc, state_cols )
1039+ . schema ( Arc :: clone ( & schema ) )
1040+ . alias ( format ! ( "col_{idx}" ) )
1041+ . with_ignore_nulls ( expr . ignore_nulls ( ) )
1042+ . with_distinct ( expr . is_distinct ( ) )
1043+ . build ( )
1044+ . map_err ( |e| {
1045+ ExecutionError :: DataFusionError ( e . to_string ( ) )
1046+ } ) ?;
10421047
10431048 Ok ( Arc :: new ( merge_expr) )
10441049 } else {
1045- // Partial: use as-is
10461050 Ok ( Arc :: new ( expr) )
10471051 }
10481052 } )
0 commit comments