Skip to content

Commit 2eeeb2a

Browse files
committed
feat: support PartialMerge
1 parent 18645c9 commit 2eeeb2a

2 files changed

Lines changed: 247 additions & 9 deletions

File tree

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! MergeAsPartial wrapper for implementing Spark's PartialMerge aggregate mode.
19+
//!
20+
//! Spark's PartialMerge mode merges intermediate state buffers and outputs intermediate
21+
//! state (not final values). DataFusion has no equivalent mode — `Partial` calls
22+
//! `update_batch` and outputs state, while `Final` calls `merge_batch` and outputs
23+
//! evaluated results.
24+
//!
25+
//! This wrapper bridges the gap: it operates under DataFusion's `Partial` mode (which
26+
//! outputs state) but redirects `update_batch` calls to `merge_batch`, giving merge
27+
//! semantics with state output.
28+
29+
use std::any::Any;
30+
use std::fmt::Debug;
31+
use std::hash::{Hash, Hasher};
32+
33+
use arrow::array::{ArrayRef, BooleanArray};
34+
use arrow::datatypes::{DataType, FieldRef};
35+
use datafusion::common::Result;
36+
use datafusion::logical_expr::function::AccumulatorArgs;
37+
use datafusion::logical_expr::function::StateFieldsArgs;
38+
use datafusion::logical_expr::{
39+
Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF,
40+
Signature, Volatility,
41+
};
42+
use datafusion::physical_expr::aggregate::AggregateFunctionExpr;
43+
use datafusion::scalar::ScalarValue;
44+
45+
/// An AggregateUDF wrapper that gives merge semantics in Partial mode.
46+
///
47+
/// When DataFusion runs an AggregateExec in Partial mode, it calls `update_batch`
48+
/// on each accumulator and outputs `state()`. This wrapper intercepts `update_batch`
49+
/// and redirects it to `merge_batch` on the inner accumulator, effectively
50+
/// implementing PartialMerge: merge inputs, output state.
51+
///
52+
/// We store the inner AggregateUDF (not the AggregateFunctionExpr) to avoid keeping
53+
/// references to UnboundColumn expressions that would panic if evaluated.
54+
#[derive(Debug)]
55+
pub struct MergeAsPartialUDF {
56+
/// The inner aggregate UDF, cloned from the original expression.
57+
inner_udf: AggregateUDF,
58+
/// Pre-computed return type from the original expression.
59+
return_type: DataType,
60+
/// Pre-computed state fields from the original expression.
61+
cached_state_fields: Vec<FieldRef>,
62+
/// Cached signature that accepts state field types.
63+
signature: Signature,
64+
/// Name for this wrapper.
65+
name: String,
66+
}
67+
68+
impl PartialEq for MergeAsPartialUDF {
69+
fn eq(&self, other: &Self) -> bool {
70+
self.name == other.name
71+
}
72+
}
73+
74+
impl Eq for MergeAsPartialUDF {}
75+
76+
impl Hash for MergeAsPartialUDF {
77+
fn hash<H: Hasher>(&self, state: &mut H) {
78+
self.name.hash(state);
79+
}
80+
}
81+
82+
impl MergeAsPartialUDF {
83+
pub fn new(inner_expr: &AggregateFunctionExpr) -> Result<Self> {
84+
let name = format!("merge_as_partial_{}", inner_expr.name());
85+
let return_type = inner_expr.field().data_type().clone();
86+
let cached_state_fields = inner_expr.state_fields()?;
87+
88+
// Use a permissive signature since we accept state field types which
89+
// vary per aggregate function.
90+
let signature = Signature::variadic_any(Volatility::Immutable);
91+
92+
Ok(Self {
93+
inner_udf: inner_expr.fun().clone(),
94+
return_type,
95+
cached_state_fields,
96+
signature,
97+
name,
98+
})
99+
}
100+
}
101+
102+
impl AggregateUDFImpl for MergeAsPartialUDF {
103+
fn as_any(&self) -> &dyn Any {
104+
self
105+
}
106+
107+
fn name(&self) -> &str {
108+
&self.name
109+
}
110+
111+
fn signature(&self) -> &Signature {
112+
&self.signature
113+
}
114+
115+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
116+
// In Partial mode, return_type isn't used for output schema (state_fields is).
117+
// Return the inner function's return type for consistency.
118+
Ok(self.return_type.clone())
119+
}
120+
121+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
122+
// State fields must match the inner aggregate's state fields so that
123+
// the output of this PartialMerge stage is compatible with subsequent
124+
// Final or PartialMerge stages.
125+
Ok(self.cached_state_fields.clone())
126+
}
127+
128+
fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
129+
// Create the inner accumulator using the provided args (which have the
130+
// correct Column refs, not UnboundColumns).
131+
let inner_acc = self.inner_udf.accumulator(args)?;
132+
Ok(Box::new(MergeAsPartialAccumulator { inner: inner_acc }))
133+
}
134+
135+
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
136+
self.inner_udf.groups_accumulator_supported(args)
137+
}
138+
139+
fn create_groups_accumulator(
140+
&self,
141+
args: AccumulatorArgs,
142+
) -> Result<Box<dyn GroupsAccumulator>> {
143+
let inner_acc = self.inner_udf.create_groups_accumulator(args)?;
144+
Ok(Box::new(MergeAsPartialGroupsAccumulator {
145+
inner: inner_acc,
146+
}))
147+
}
148+
149+
fn reverse_expr(&self) -> ReversedUDAF {
150+
ReversedUDAF::NotSupported
151+
}
152+
153+
fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
154+
ScalarValue::try_from(data_type)
155+
}
156+
157+
fn is_descending(&self) -> Option<bool> {
158+
None
159+
}
160+
}
161+
162+
/// Accumulator wrapper that redirects update_batch to merge_batch.
163+
struct MergeAsPartialAccumulator {
164+
inner: Box<dyn Accumulator>,
165+
}
166+
167+
impl Debug for MergeAsPartialAccumulator {
168+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169+
f.debug_struct("MergeAsPartialAccumulator").finish()
170+
}
171+
}
172+
173+
impl Accumulator for MergeAsPartialAccumulator {
174+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
175+
// Redirect update to merge — this is the key trick.
176+
self.inner.merge_batch(values)
177+
}
178+
179+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
180+
self.inner.merge_batch(states)
181+
}
182+
183+
fn evaluate(&mut self) -> Result<ScalarValue> {
184+
self.inner.evaluate()
185+
}
186+
187+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
188+
self.inner.state()
189+
}
190+
191+
fn size(&self) -> usize {
192+
self.inner.size()
193+
}
194+
}
195+
196+
/// GroupsAccumulator wrapper that redirects update_batch to merge_batch.
197+
struct MergeAsPartialGroupsAccumulator {
198+
inner: Box<dyn GroupsAccumulator>,
199+
}
200+
201+
impl Debug for MergeAsPartialGroupsAccumulator {
202+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203+
f.debug_struct("MergeAsPartialGroupsAccumulator").finish()
204+
}
205+
}
206+
207+
impl GroupsAccumulator for MergeAsPartialGroupsAccumulator {
208+
fn update_batch(
209+
&mut self,
210+
values: &[ArrayRef],
211+
group_indices: &[usize],
212+
opt_filter: Option<&BooleanArray>,
213+
total_num_groups: usize,
214+
) -> Result<()> {
215+
// Redirect update to merge — this is the key trick.
216+
self.inner
217+
.merge_batch(values, group_indices, opt_filter, total_num_groups)
218+
}
219+
220+
fn merge_batch(
221+
&mut self,
222+
values: &[ArrayRef],
223+
group_indices: &[usize],
224+
opt_filter: Option<&BooleanArray>,
225+
total_num_groups: usize,
226+
) -> Result<()> {
227+
self.inner
228+
.merge_batch(values, group_indices, opt_filter, total_num_groups)
229+
}
230+
231+
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
232+
self.inner.evaluate(emit_to)
233+
}
234+
235+
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
236+
self.inner.state(emit_to)
237+
}
238+
239+
fn size(&self) -> usize {
240+
self.inner.size()
241+
}
242+
}

native/core/src/execution/planner.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ use datafusion_comet_proto::{
116116
},
117117
spark_operator::{
118118
self, lower_window_frame_bound::LowerFrameBoundStruct, operator::OpStruct,
119-
upper_window_frame_bound::UpperFrameBoundStruct, AggregateMode as ProtoAggregateMode,
120-
BuildSide, CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType,
119+
upper_window_frame_bound::UpperFrameBoundStruct, BuildSide, CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType,
121120
},
122121
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
123122
};
@@ -967,24 +966,21 @@ impl PhysicalPlanner {
967966
let group_by = PhysicalGroupBy::new_single(group_exprs?);
968967
let schema = child.schema();
969968

970-
let partial_merge = ProtoAggregateMode::PartialMerge as i32;
971-
972969
let mode = match agg.mode {
973970
0 => DFAggregateMode::Partial,
974971
1 => DFAggregateMode::Final,
975-
m if m == partial_merge => DFAggregateMode::Partial,
972+
2 => DFAggregateMode::Partial, // PartialMerge: Partial + MergeAsPartial
976973
other => {
977974
return Err(ExecutionError::GeneralError(format!(
978975
"Unsupported aggregate mode: {other}"
979976
)))
980977
}
981978
};
982979

983-
// Check if any expression uses PartialMerge mode. When present,
980+
// Check if any expression uses PartialMerge mode (2). When present,
984981
// those expressions are wrapped with MergeAsPartial to get merge
985982
// semantics inside a Partial-mode AggregateExec.
986-
let has_partial_merge =
987-
agg.mode == partial_merge || agg.expr_modes.contains(&partial_merge);
983+
let has_partial_merge = agg.mode == 2 || agg.expr_modes.contains(&2);
988984

989985
let agg_exprs: PhyAggResult = agg
990986
.agg_exprs
@@ -1006,7 +1002,7 @@ impl PhysicalPlanner {
10061002
.into_iter()
10071003
.enumerate()
10081004
.map(|(idx, expr)| {
1009-
if per_expr_modes[idx] == partial_merge {
1005+
if per_expr_modes[idx] == 2 {
10101006
// PartialMerge: wrap with MergeAsPartial
10111007
let state_fields = expr
10121008
.state_fields()

0 commit comments

Comments
 (0)