Skip to content

Commit 9bd5b80

Browse files
committed
update plans
1 parent e31c3bc commit 9bd5b80

3 files changed

Lines changed: 52 additions & 91 deletions

File tree

native/core/src/execution/merge_as_partial.rs

Lines changed: 29 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! MergeAsPartial wrapper for implementing Spark's PartialMerge aggregate mode.
18+
//! Accumulator wrappers for implementing Spark's PartialMerge aggregate mode.
1919
//!
2020
//! Spark's PartialMerge mode merges intermediate state buffers and outputs intermediate
21-
//! state (not final values). DataFusion has no equivalent mode — `Partial` calls
22-
//! `update_batch` and outputs state, while `Final` calls `merge_batch` and outputs
23-
//! evaluated results.
21+
//! state (not final values). DataFusion's `PartialReduce` mode has the same semantics.
2422
//!
25-
//! This wrapper bridges the gap: it operates under DataFusion's `Partial` mode (which
26-
//! outputs state) but redirects `update_batch` calls to `merge_batch`, giving merge
27-
//! semantics with state output.
23+
//! For mixed-mode aggregates (some expressions PartialMerge, some Partial in the same
24+
//! operator), we use `PartialReduce` mode for the whole operator and wrap Partial
25+
//! expressions with `UpdateAsReduceUDF` to redirect `merge_batch → update_batch`.
2826
2927
use std::any::Any;
3028
use std::fmt::Debug;
@@ -42,51 +40,38 @@ use datafusion::logical_expr::{
4240
use datafusion::physical_expr::aggregate::AggregateFunctionExpr;
4341
use datafusion::scalar::ScalarValue;
4442

45-
/// An AggregateUDF wrapper that gives merge semantics in Partial mode.
43+
/// Wraps a Partial-mode aggregate to work inside a PartialReduce-mode AggregateExec.
4644
///
47-
/// When DataFusion runs an AggregateExec in Partial mode, it calls `update_batch`
48-
/// on each accumulator and outputs `state()`. This wrapper intercepts `update_batch`
49-
/// and redirects it to `merge_batch` on the inner accumulator, effectively
50-
/// implementing PartialMerge: merge inputs, output state.
51-
///
52-
/// We store the inner AggregateUDF (not the AggregateFunctionExpr) to avoid keeping
53-
/// references to UnboundColumn expressions that would panic if evaluated.
45+
/// PartialReduce calls `merge_batch` on all accumulators. Partial expressions need
46+
/// `update_batch` semantics, so this wrapper redirects `merge_batch → update_batch`.
5447
#[derive(Debug)]
55-
pub struct MergeAsPartialUDF {
56-
/// The inner aggregate UDF, cloned from the original expression.
48+
pub struct UpdateAsReduceUDF {
5749
inner_udf: AggregateUDF,
58-
/// Pre-computed return type from the original expression.
5950
return_type: DataType,
60-
/// Pre-computed state fields from the original expression.
6151
cached_state_fields: Vec<FieldRef>,
62-
/// Cached signature that accepts state field types.
6352
signature: Signature,
64-
/// Name for this wrapper.
6553
name: String,
6654
}
6755

68-
impl PartialEq for MergeAsPartialUDF {
56+
impl PartialEq for UpdateAsReduceUDF {
6957
fn eq(&self, other: &Self) -> bool {
7058
self.name == other.name
7159
}
7260
}
7361

74-
impl Eq for MergeAsPartialUDF {}
62+
impl Eq for UpdateAsReduceUDF {}
7563

76-
impl Hash for MergeAsPartialUDF {
64+
impl Hash for UpdateAsReduceUDF {
7765
fn hash<H: Hasher>(&self, state: &mut H) {
7866
self.name.hash(state);
7967
}
8068
}
8169

82-
impl MergeAsPartialUDF {
70+
impl UpdateAsReduceUDF {
8371
pub fn new(inner_expr: &AggregateFunctionExpr) -> Result<Self> {
84-
let name = format!("merge_as_partial_{}", inner_expr.name());
72+
let name = format!("update_as_reduce_{}", inner_expr.name());
8573
let return_type = inner_expr.field().data_type().clone();
8674
let cached_state_fields = inner_expr.state_fields()?;
87-
88-
// Use a permissive signature since we accept state field types which
89-
// vary per aggregate function.
9075
let signature = Signature::variadic_any(Volatility::Immutable);
9176

9277
Ok(Self {
@@ -99,7 +84,7 @@ impl MergeAsPartialUDF {
9984
}
10085
}
10186

102-
impl AggregateUDFImpl for MergeAsPartialUDF {
87+
impl AggregateUDFImpl for UpdateAsReduceUDF {
10388
fn as_any(&self) -> &dyn Any {
10489
self
10590
}
@@ -113,23 +98,16 @@ impl AggregateUDFImpl for MergeAsPartialUDF {
11398
}
11499

115100
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
116-
// In Partial mode, return_type isn't used for output schema (state_fields is).
117-
// Return the inner function's return type for consistency.
118101
Ok(self.return_type.clone())
119102
}
120103

121104
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
122-
// State fields must match the inner aggregate's state fields so that
123-
// the output of this PartialMerge stage is compatible with subsequent
124-
// Final or PartialMerge stages.
125105
Ok(self.cached_state_fields.clone())
126106
}
127107

128108
fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
129-
// Create the inner accumulator using the provided args (which have the
130-
// correct Column refs, not UnboundColumns).
131109
let inner_acc = self.inner_udf.accumulator(args)?;
132-
Ok(Box::new(MergeAsPartialAccumulator { inner: inner_acc }))
110+
Ok(Box::new(UpdateAsReduceAccumulator { inner: inner_acc }))
133111
}
134112

135113
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
@@ -141,7 +119,7 @@ impl AggregateUDFImpl for MergeAsPartialUDF {
141119
args: AccumulatorArgs,
142120
) -> Result<Box<dyn GroupsAccumulator>> {
143121
let inner_acc = self.inner_udf.create_groups_accumulator(args)?;
144-
Ok(Box::new(MergeAsPartialGroupsAccumulator {
122+
Ok(Box::new(UpdateAsReduceGroupsAccumulator {
145123
inner: inner_acc,
146124
}))
147125
}
@@ -159,25 +137,23 @@ impl AggregateUDFImpl for MergeAsPartialUDF {
159137
}
160138
}
161139

162-
/// Accumulator wrapper that redirects update_batch to merge_batch.
163-
struct MergeAsPartialAccumulator {
140+
struct UpdateAsReduceAccumulator {
164141
inner: Box<dyn Accumulator>,
165142
}
166143

167-
impl Debug for MergeAsPartialAccumulator {
144+
impl Debug for UpdateAsReduceAccumulator {
168145
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169-
f.debug_struct("MergeAsPartialAccumulator").finish()
146+
f.debug_struct("UpdateAsReduceAccumulator").finish()
170147
}
171148
}
172149

173-
impl Accumulator for MergeAsPartialAccumulator {
150+
impl Accumulator for UpdateAsReduceAccumulator {
174151
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
175-
// Redirect update to merge — this is the key trick.
176-
self.inner.merge_batch(values)
152+
self.inner.update_batch(values)
177153
}
178154

179155
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
180-
self.inner.merge_batch(states)
156+
self.inner.update_batch(states)
181157
}
182158

183159
fn evaluate(&mut self) -> Result<ScalarValue> {
@@ -193,28 +169,26 @@ impl Accumulator for MergeAsPartialAccumulator {
193169
}
194170
}
195171

196-
/// GroupsAccumulator wrapper that redirects update_batch to merge_batch.
197-
struct MergeAsPartialGroupsAccumulator {
172+
struct UpdateAsReduceGroupsAccumulator {
198173
inner: Box<dyn GroupsAccumulator>,
199174
}
200175

201-
impl Debug for MergeAsPartialGroupsAccumulator {
176+
impl Debug for UpdateAsReduceGroupsAccumulator {
202177
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203-
f.debug_struct("MergeAsPartialGroupsAccumulator").finish()
178+
f.debug_struct("UpdateAsReduceGroupsAccumulator").finish()
204179
}
205180
}
206181

207-
impl GroupsAccumulator for MergeAsPartialGroupsAccumulator {
182+
impl GroupsAccumulator for UpdateAsReduceGroupsAccumulator {
208183
fn update_batch(
209184
&mut self,
210185
values: &[ArrayRef],
211186
group_indices: &[usize],
212187
opt_filter: Option<&BooleanArray>,
213188
total_num_groups: usize,
214189
) -> Result<()> {
215-
// Redirect update to merge — this is the key trick.
216190
self.inner
217-
.merge_batch(values, group_indices, opt_filter, total_num_groups)
191+
.update_batch(values, group_indices, opt_filter, total_num_groups)
218192
}
219193

220194
fn merge_batch(
@@ -225,7 +199,7 @@ impl GroupsAccumulator for MergeAsPartialGroupsAccumulator {
225199
total_num_groups: usize,
226200
) -> Result<()> {
227201
self.inner
228-
.merge_batch(values, group_indices, opt_filter, total_num_groups)
202+
.update_batch(values, group_indices, opt_filter, total_num_groups)
229203
}
230204

231205
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {

native/core/src/execution/planner.rs

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -974,16 +974,15 @@ impl PhysicalPlanner {
974974
let mode = match agg.mode {
975975
0 => {
976976
if has_mixed_partial_merge {
977-
// Mixed {Partial, PartialMerge}: use Partial mode, wrap
978-
// PartialMerge expressions with MergeAsPartial.
979-
DFAggregateMode::Partial
977+
// Mixed {Partial, PartialMerge}: use PartialReduce so
978+
// PartialMerge expressions use native merge semantics.
979+
// Partial expressions are wrapped to redirect merge→update.
980+
DFAggregateMode::PartialReduce
980981
} else {
981982
DFAggregateMode::Partial
982983
}
983984
}
984985
1 => DFAggregateMode::Final,
985-
// Uniform PartialMerge maps directly to DataFusion's PartialReduce
986-
// which has merge input + state output semantics.
987986
2 => DFAggregateMode::PartialReduce,
988987
other => {
989988
return Err(ExecutionError::GeneralError(format!(
@@ -1003,8 +1002,10 @@ impl PhysicalPlanner {
10031002
.collect();
10041003

10051004
let aggr_expr: Vec<Arc<AggregateFunctionExpr>> = if has_partial_merge {
1006-
// Wrap PartialMerge expressions with MergeAsPartial.
1007-
// State fields in the child's output start at initial_input_buffer_offset.
1005+
// Mixed {Partial, PartialMerge} mode uses PartialReduce so
1006+
// PartialMerge expressions get native merge semantics.
1007+
// Partial expressions need UpdateAsReduce wrappers to redirect
1008+
// merge_batch → update_batch since PartialReduce calls merge_batch.
10081009
let mut state_offset = agg.initial_input_buffer_offset as usize;
10091010
let per_expr_modes: Vec<i32> = if !agg.expr_modes.is_empty() {
10101011
agg.expr_modes.clone()
@@ -1017,35 +1018,31 @@ impl PhysicalPlanner {
10171018
.enumerate()
10181019
.map(|(idx, expr)| {
10191020
if per_expr_modes[idx] == 2 {
1020-
// PartialMerge: wrap with MergeAsPartial
1021-
let state_fields = expr
1021+
// PartialMerge: advance state_offset past this
1022+
// expression's state fields (PartialReduce handles
1023+
// merge natively via merge_expressions column refs).
1024+
let num_state_fields = expr
10221025
.state_fields()
1023-
.map_err(|e| ExecutionError::GeneralError(e.to_string()))?;
1024-
let num_state_fields = state_fields.len();
1025-
1026-
let state_cols: Vec<Arc<dyn PhysicalExpr>> = (0..num_state_fields)
1027-
.map(|i| {
1028-
let col_idx = state_offset + i;
1029-
let field = schema.field(col_idx);
1030-
Arc::new(Column::new(field.name(), col_idx))
1031-
as Arc<dyn PhysicalExpr>
1032-
})
1033-
.collect();
1026+
.map_err(|e| ExecutionError::GeneralError(e.to_string()))?
1027+
.len();
10341028
state_offset += num_state_fields;
1035-
1036-
let merge_udf =
1037-
crate::execution::merge_as_partial::MergeAsPartialUDF::new(
1029+
Ok(Arc::new(expr))
1030+
} else {
1031+
// Partial: wrap with UpdateAsReduce so merge_batch
1032+
// (called by PartialReduce) redirects to update_batch.
1033+
let update_udf =
1034+
crate::execution::merge_as_partial::UpdateAsReduceUDF::new(
10381035
&expr,
10391036
)
10401037
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))?;
1041-
let merge_udf_arc = Arc::new(
1038+
let update_udf_arc = Arc::new(
10421039
datafusion::logical_expr::AggregateUDF::new_from_impl(
1043-
merge_udf,
1040+
update_udf,
10441041
),
10451042
);
10461043

10471044
let merge_expr =
1048-
AggregateExprBuilder::new(merge_udf_arc, state_cols)
1045+
AggregateExprBuilder::new(update_udf_arc, expr.expressions())
10491046
.schema(Arc::clone(&schema))
10501047
.alias(format!("col_{idx}"))
10511048
.with_ignore_nulls(expr.ignore_nulls())
@@ -1056,8 +1053,6 @@ impl PhysicalPlanner {
10561053
})?;
10571054

10581055
Ok(Arc::new(merge_expr))
1059-
} else {
1060-
Ok(Arc::new(expr))
10611056
}
10621057
})
10631058
.collect::<Result<Vec<_>, ExecutionError>>()?

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,20 +1432,12 @@ trait CometBaseAggregate {
14321432
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
14331433
Some(builder.setHashAgg(hashAggBuilder).build())
14341434
} else {
1435-
// Validate mode combinations. We support:
1436-
// - All Partial
1437-
// - All Final
1438-
// - All PartialMerge
1439-
// - Mixed {Partial, PartialMerge} (for distinct aggregate plans)
14401435
val isMixedPartialMerge = modeSet == Set(Partial, PartialMerge)
14411436
if (modes.size > 1 && !isMixedPartialMerge) {
14421437
withInfo(aggregate, s"Unsupported mixed aggregation modes: ${modes.mkString(", ")}")
14431438
return None
14441439
}
14451440

1446-
// Determine the proto mode. For uniform modes, use that mode directly.
1447-
// For mixed {Partial, PartialMerge}, use Partial as the base mode since
1448-
// PartialMerge expressions are wrapped with MergeAsPartial on the native side.
14491441
val mode = if (isMixedPartialMerge) {
14501442
CometAggregateMode.Partial
14511443
} else {

0 commit comments

Comments
 (0)