Skip to content

Commit 18645c9

Browse files
committed
feat: support PartialMerge
1 parent 336aadd commit 18645c9

2 files changed

Lines changed: 42 additions & 41 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ use datafusion_comet_proto::{
116116
},
117117
spark_operator::{
118118
self, lower_window_frame_bound::LowerFrameBoundStruct, operator::OpStruct,
119-
upper_window_frame_bound::UpperFrameBoundStruct, BuildSide,
120-
CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType,
119+
upper_window_frame_bound::UpperFrameBoundStruct, AggregateMode as ProtoAggregateMode,
120+
BuildSide, CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType,
121121
},
122122
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
123123
};
@@ -967,26 +967,24 @@ impl PhysicalPlanner {
967967
let group_by = PhysicalGroupBy::new_single(group_exprs?);
968968
let schema = child.schema();
969969

970+
let partial_merge = ProtoAggregateMode::PartialMerge as i32;
971+
970972
let mode = match agg.mode {
971973
0 => DFAggregateMode::Partial,
972974
1 => DFAggregateMode::Final,
973-
2 => DFAggregateMode::Partial, // PartialMerge uses Partial + MergeAsPartial
975+
m if m == partial_merge => DFAggregateMode::Partial,
974976
other => {
975977
return Err(ExecutionError::GeneralError(format!(
976978
"Unsupported aggregate mode: {other}"
977979
)))
978980
}
979981
};
980982

981-
// Determine per-expression modes. PartialMerge (2) expressions use
982-
// MergeAsPartial wrapper so they run merge semantics in Partial mode.
983-
let per_expr_modes: Vec<i32> = if !agg.expr_modes.is_empty() {
984-
agg.expr_modes.clone()
985-
} else {
986-
vec![agg.mode; agg.agg_exprs.len()]
987-
};
988-
989-
let has_partial_merge = per_expr_modes.contains(&2);
983+
// Check if any expression uses PartialMerge mode. When present,
984+
// those expressions are wrapped with MergeAsPartial to get merge
985+
// semantics inside a Partial-mode AggregateExec.
986+
let has_partial_merge =
987+
agg.mode == partial_merge || agg.expr_modes.contains(&partial_merge);
990988

991989
let agg_exprs: PhyAggResult = agg
992990
.agg_exprs
@@ -998,51 +996,57 @@ impl PhysicalPlanner {
998996
// Wrap PartialMerge expressions with MergeAsPartial.
999997
// State fields in the child's output start at initial_input_buffer_offset.
1000998
let mut state_offset = agg.initial_input_buffer_offset as usize;
1001-
let child_schema = child.schema();
999+
let per_expr_modes: Vec<i32> = if !agg.expr_modes.is_empty() {
1000+
agg.expr_modes.clone()
1001+
} else {
1002+
vec![agg.mode; agg.agg_exprs.len()]
1003+
};
10021004

10031005
agg_exprs?
10041006
.into_iter()
10051007
.enumerate()
10061008
.map(|(idx, expr)| {
1007-
let expr_mode = per_expr_modes[idx];
1008-
if expr_mode == 2 {
1009+
if per_expr_modes[idx] == partial_merge {
10091010
// PartialMerge: wrap with MergeAsPartial
1010-
let state_fields = expr.state_fields().map_err(|e| {
1011-
ExecutionError::GeneralError(e.to_string())
1012-
})?;
1011+
let state_fields = expr
1012+
.state_fields()
1013+
.map_err(|e| ExecutionError::GeneralError(e.to_string()))?;
10131014
let num_state_fields = state_fields.len();
10141015

1015-
// Create Column refs pointing to state field positions
10161016
let state_cols: Vec<Arc<dyn PhysicalExpr>> = (0..num_state_fields)
10171017
.map(|i| {
10181018
let col_idx = state_offset + i;
1019-
let field = child_schema.field(col_idx);
1019+
let field = schema.field(col_idx);
10201020
Arc::new(Column::new(field.name(), col_idx))
10211021
as Arc<dyn PhysicalExpr>
10221022
})
10231023
.collect();
10241024
state_offset += num_state_fields;
10251025

1026-
let merge_udf = crate::execution::merge_as_partial::MergeAsPartialUDF::new(&expr)
1026+
let merge_udf =
1027+
crate::execution::merge_as_partial::MergeAsPartialUDF::new(
1028+
&expr,
1029+
)
10271030
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))?;
10281031
let merge_udf_arc = Arc::new(
1029-
datafusion::logical_expr::AggregateUDF::new_from_impl(merge_udf),
1032+
datafusion::logical_expr::AggregateUDF::new_from_impl(
1033+
merge_udf,
1034+
),
10301035
);
10311036

1032-
let merge_expr = AggregateExprBuilder::new(
1033-
merge_udf_arc,
1034-
state_cols,
1035-
)
1036-
.schema(Arc::clone(&child_schema))
1037-
.alias(format!("col_{idx}"))
1038-
.with_ignore_nulls(expr.ignore_nulls())
1039-
.with_distinct(expr.is_distinct())
1040-
.build()
1041-
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))?;
1037+
let merge_expr =
1038+
AggregateExprBuilder::new(merge_udf_arc, state_cols)
1039+
.schema(Arc::clone(&schema))
1040+
.alias(format!("col_{idx}"))
1041+
.with_ignore_nulls(expr.ignore_nulls())
1042+
.with_distinct(expr.is_distinct())
1043+
.build()
1044+
.map_err(|e| {
1045+
ExecutionError::DataFusionError(e.to_string())
1046+
})?;
10421047

10431048
Ok(Arc::new(merge_expr))
10441049
} else {
1045-
// Partial: use as-is
10461050
Ok(Arc::new(expr))
10471051
}
10481052
})

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,9 +1356,11 @@ trait CometBaseAggregate {
13561356
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
13571357

13581358
val modes = aggregate.aggregateExpressions.map(_.mode).distinct
1359+
val modeSet = modes.toSet
1360+
val hasPartialMerge = modeSet.contains(PartialMerge)
13591361
// In distinct aggregates there can be a combination of modes.
13601362
// We support {Partial, PartialMerge} mix; other combinations are rejected.
1361-
val multiMode = modes.size > 1 && modes.toSet != Set(Partial, PartialMerge)
1363+
val multiMode = modes.size > 1 && modeSet != Set(Partial, PartialMerge)
13621364
// For a final mode HashAggregate, we only need to transform the HashAggregate
13631365
// if there is Comet partial aggregation.
13641366
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty
@@ -1430,9 +1432,6 @@ trait CometBaseAggregate {
14301432
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
14311433
Some(builder.setHashAgg(hashAggBuilder).build())
14321434
} else {
1433-
val modes = aggregateExpressions.map(_.mode).distinct
1434-
val modeSet = modes.toSet
1435-
14361435
// Validate mode combinations. We support:
14371436
// - All Partial
14381437
// - All Final
@@ -1463,8 +1462,7 @@ trait CometBaseAggregate {
14631462
// FIRST/LAST are order-dependent aggregates whose merge result depends on
14641463
// hash table processing order. In PartialMerge mode, DataFusion's hash table
14651464
// may process rows in a different order than Spark's, producing different results.
1466-
val hasPartialMergeMode = modeSet.contains(PartialMerge)
1467-
if (hasPartialMergeMode) {
1465+
if (hasPartialMerge) {
14681466
val unsupportedAggs = aggregateExpressions.filter { a =>
14691467
a.mode == PartialMerge && (a.aggregateFunction.isInstanceOf[First] ||
14701468
a.aggregateFunction.isInstanceOf[Last])
@@ -1514,7 +1512,6 @@ trait CometBaseAggregate {
15141512
hashAggBuilder.setModeValue(mode.getNumber)
15151513

15161514
// Send per-expression modes and buffer offset for PartialMerge handling
1517-
val hasPartialMerge = aggregateExpressions.exists(_.mode == PartialMerge)
15181515
if (hasPartialMerge) {
15191516
val exprModes = aggregateExpressions.map { a =>
15201517
a.mode match {
@@ -1654,7 +1651,7 @@ object CometObjectHashAggregateExec
16541651
* case branch here mapping it to the native state type.
16551652
*/
16561653
private def adjustOutputForNativeState(op: ObjectHashAggregateExec): Seq[Attribute] = {
1657-
// CometBaseAggregate.doConvert guarantees all expressions share the same mode.
1654+
// This adjustment only applies to pure-Partial aggregates (checked below).
16581655
val modes = op.aggregateExpressions.map(_.mode).distinct
16591656
if (modes != Seq(Partial)) {
16601657
return op.output

0 commit comments

Comments
 (0)