1515// specific language governing permissions and limitations
1616// under the License.
1717
18- //! MergeAsPartial wrapper for implementing Spark's PartialMerge aggregate mode.
18+ //! Accumulator wrappers for implementing Spark's PartialMerge aggregate mode.
1919//!
2020//! Spark's PartialMerge mode merges intermediate state buffers and outputs intermediate
21- //! state (not final values). DataFusion has no equivalent mode — `Partial` calls
22- //! `update_batch` and outputs state, while `Final` calls `merge_batch` and outputs
23- //! evaluated results.
21+ //! state (not final values). DataFusion's `PartialReduce` mode has the same semantics.
2422//!
25- //! This wrapper bridges the gap: it operates under DataFusion's ` Partial` mode (which
26- //! outputs state) but redirects `update_batch` calls to `merge_batch`, giving merge
27- //! semantics with state output .
23+ //! For mixed-mode aggregates (some expressions PartialMerge, some Partial in the same
24+ //! operator), we use `PartialReduce` mode for the whole operator and wrap Partial
25+ //! expressions with `UpdateAsReduceUDF` to redirect `merge_batch → update_batch` .
2826
2927use std:: any:: Any ;
3028use std:: fmt:: Debug ;
@@ -42,51 +40,38 @@ use datafusion::logical_expr::{
4240use datafusion:: physical_expr:: aggregate:: AggregateFunctionExpr ;
4341use datafusion:: scalar:: ScalarValue ;
4442
45- /// An AggregateUDF wrapper that gives merge semantics in Partial mode.
43+ /// Wraps a Partial-mode aggregate to work inside a PartialReduce- mode AggregateExec .
4644///
47- /// When DataFusion runs an AggregateExec in Partial mode, it calls `update_batch`
48- /// on each accumulator and outputs `state()`. This wrapper intercepts `update_batch`
49- /// and redirects it to `merge_batch` on the inner accumulator, effectively
50- /// implementing PartialMerge: merge inputs, output state.
51- ///
52- /// We store the inner AggregateUDF (not the AggregateFunctionExpr) to avoid keeping
53- /// references to UnboundColumn expressions that would panic if evaluated.
45+ /// PartialReduce calls `merge_batch` on all accumulators. Partial expressions need
46+ /// `update_batch` semantics, so this wrapper redirects `merge_batch → update_batch`.
5447#[ derive( Debug ) ]
55- pub struct MergeAsPartialUDF {
56- /// The inner aggregate UDF, cloned from the original expression.
48+ pub struct UpdateAsReduceUDF {
5749 inner_udf : AggregateUDF ,
58- /// Pre-computed return type from the original expression.
5950 return_type : DataType ,
60- /// Pre-computed state fields from the original expression.
6151 cached_state_fields : Vec < FieldRef > ,
62- /// Cached signature that accepts state field types.
6352 signature : Signature ,
64- /// Name for this wrapper.
6553 name : String ,
6654}
6755
68- impl PartialEq for MergeAsPartialUDF {
56+ impl PartialEq for UpdateAsReduceUDF {
6957 fn eq ( & self , other : & Self ) -> bool {
7058 self . name == other. name
7159 }
7260}
7361
74- impl Eq for MergeAsPartialUDF { }
62+ impl Eq for UpdateAsReduceUDF { }
7563
76- impl Hash for MergeAsPartialUDF {
64+ impl Hash for UpdateAsReduceUDF {
7765 fn hash < H : Hasher > ( & self , state : & mut H ) {
7866 self . name . hash ( state) ;
7967 }
8068}
8169
82- impl MergeAsPartialUDF {
70+ impl UpdateAsReduceUDF {
8371 pub fn new ( inner_expr : & AggregateFunctionExpr ) -> Result < Self > {
84- let name = format ! ( "merge_as_partial_ {}" , inner_expr. name( ) ) ;
72+ let name = format ! ( "update_as_reduce_ {}" , inner_expr. name( ) ) ;
8573 let return_type = inner_expr. field ( ) . data_type ( ) . clone ( ) ;
8674 let cached_state_fields = inner_expr. state_fields ( ) ?;
87-
88- // Use a permissive signature since we accept state field types which
89- // vary per aggregate function.
9075 let signature = Signature :: variadic_any ( Volatility :: Immutable ) ;
9176
9277 Ok ( Self {
@@ -99,7 +84,7 @@ impl MergeAsPartialUDF {
9984 }
10085}
10186
102- impl AggregateUDFImpl for MergeAsPartialUDF {
87+ impl AggregateUDFImpl for UpdateAsReduceUDF {
10388 fn as_any ( & self ) -> & dyn Any {
10489 self
10590 }
@@ -113,23 +98,16 @@ impl AggregateUDFImpl for MergeAsPartialUDF {
11398 }
11499
115100 fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
116- // In Partial mode, return_type isn't used for output schema (state_fields is).
117- // Return the inner function's return type for consistency.
118101 Ok ( self . return_type . clone ( ) )
119102 }
120103
121104 fn state_fields ( & self , _args : StateFieldsArgs ) -> Result < Vec < FieldRef > > {
122- // State fields must match the inner aggregate's state fields so that
123- // the output of this PartialMerge stage is compatible with subsequent
124- // Final or PartialMerge stages.
125105 Ok ( self . cached_state_fields . clone ( ) )
126106 }
127107
128108 fn accumulator ( & self , args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
129- // Create the inner accumulator using the provided args (which have the
130- // correct Column refs, not UnboundColumns).
131109 let inner_acc = self . inner_udf . accumulator ( args) ?;
132- Ok ( Box :: new ( MergeAsPartialAccumulator { inner : inner_acc } ) )
110+ Ok ( Box :: new ( UpdateAsReduceAccumulator { inner : inner_acc } ) )
133111 }
134112
135113 fn groups_accumulator_supported ( & self , args : AccumulatorArgs ) -> bool {
@@ -141,7 +119,7 @@ impl AggregateUDFImpl for MergeAsPartialUDF {
141119 args : AccumulatorArgs ,
142120 ) -> Result < Box < dyn GroupsAccumulator > > {
143121 let inner_acc = self . inner_udf . create_groups_accumulator ( args) ?;
144- Ok ( Box :: new ( MergeAsPartialGroupsAccumulator {
122+ Ok ( Box :: new ( UpdateAsReduceGroupsAccumulator {
145123 inner : inner_acc,
146124 } ) )
147125 }
@@ -159,25 +137,23 @@ impl AggregateUDFImpl for MergeAsPartialUDF {
159137 }
160138}
161139
162- /// Accumulator wrapper that redirects update_batch to merge_batch.
163- struct MergeAsPartialAccumulator {
140+ struct UpdateAsReduceAccumulator {
164141 inner : Box < dyn Accumulator > ,
165142}
166143
167- impl Debug for MergeAsPartialAccumulator {
144+ impl Debug for UpdateAsReduceAccumulator {
168145 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
169- f. debug_struct ( "MergeAsPartialAccumulator " ) . finish ( )
146+ f. debug_struct ( "UpdateAsReduceAccumulator " ) . finish ( )
170147 }
171148}
172149
173- impl Accumulator for MergeAsPartialAccumulator {
150+ impl Accumulator for UpdateAsReduceAccumulator {
174151 fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
175- // Redirect update to merge — this is the key trick.
176- self . inner . merge_batch ( values)
152+ self . inner . update_batch ( values)
177153 }
178154
179155 fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
180- self . inner . merge_batch ( states)
156+ self . inner . update_batch ( states)
181157 }
182158
183159 fn evaluate ( & mut self ) -> Result < ScalarValue > {
@@ -193,28 +169,26 @@ impl Accumulator for MergeAsPartialAccumulator {
193169 }
194170}
195171
196- /// GroupsAccumulator wrapper that redirects update_batch to merge_batch.
197- struct MergeAsPartialGroupsAccumulator {
172+ struct UpdateAsReduceGroupsAccumulator {
198173 inner : Box < dyn GroupsAccumulator > ,
199174}
200175
201- impl Debug for MergeAsPartialGroupsAccumulator {
176+ impl Debug for UpdateAsReduceGroupsAccumulator {
202177 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
203- f. debug_struct ( "MergeAsPartialGroupsAccumulator " ) . finish ( )
178+ f. debug_struct ( "UpdateAsReduceGroupsAccumulator " ) . finish ( )
204179 }
205180}
206181
207- impl GroupsAccumulator for MergeAsPartialGroupsAccumulator {
182+ impl GroupsAccumulator for UpdateAsReduceGroupsAccumulator {
208183 fn update_batch (
209184 & mut self ,
210185 values : & [ ArrayRef ] ,
211186 group_indices : & [ usize ] ,
212187 opt_filter : Option < & BooleanArray > ,
213188 total_num_groups : usize ,
214189 ) -> Result < ( ) > {
215- // Redirect update to merge — this is the key trick.
216190 self . inner
217- . merge_batch ( values, group_indices, opt_filter, total_num_groups)
191+ . update_batch ( values, group_indices, opt_filter, total_num_groups)
218192 }
219193
220194 fn merge_batch (
@@ -225,7 +199,7 @@ impl GroupsAccumulator for MergeAsPartialGroupsAccumulator {
225199 total_num_groups : usize ,
226200 ) -> Result < ( ) > {
227201 self . inner
228- . merge_batch ( values, group_indices, opt_filter, total_num_groups)
202+ . update_batch ( values, group_indices, opt_filter, total_num_groups)
229203 }
230204
231205 fn evaluate ( & mut self , emit_to : EmitTo ) -> Result < ArrayRef > {
0 commit comments