Skip to content

Commit 67d24b6

Browse files
committed
feat: support PartialMerge
1 parent 750f6c9 commit 67d24b6

5 files changed

Lines changed: 129 additions & 32 deletions

File tree

native/core/src/execution/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
pub mod columnar_to_row;
2020
pub mod expressions;
2121
pub mod jni_api;
22+
pub(crate) mod merge_as_partial;
2223
pub(crate) mod metrics;
2324
pub mod operators;
2425
pub(crate) mod planner;

native/core/src/execution/planner.rs

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -968,24 +968,87 @@ impl PhysicalPlanner {
968968

969969
let mode = match agg.mode {
970970
0 => DFAggregateMode::Partial,
971-
// Both Final and PartialMerge use merge semantics in DataFusion.
972-
// The output difference (final values vs intermediate buffers) is
973-
// handled by the presence/absence of result_exprs.
974-
1 | 2 => DFAggregateMode::Final,
971+
1 => DFAggregateMode::Final,
972+
2 => DFAggregateMode::Partial, // PartialMerge uses Partial + MergeAsPartial
975973
other => {
976974
return Err(ExecutionError::GeneralError(format!(
977975
"Unsupported aggregate mode: {other}"
978976
)))
979977
}
980978
};
981979

980+
// Determine per-expression modes. PartialMerge (2) expressions use
981+
// MergeAsPartial wrapper so they run merge semantics in Partial mode.
982+
let per_expr_modes: Vec<i32> = if !agg.expr_modes.is_empty() {
983+
agg.expr_modes.clone()
984+
} else {
985+
vec![agg.mode; agg.agg_exprs.len()]
986+
};
987+
988+
let has_partial_merge = per_expr_modes.contains(&2);
989+
982990
let agg_exprs: PhyAggResult = agg
983991
.agg_exprs
984992
.iter()
985993
.map(|expr| self.create_agg_expr(expr, Arc::clone(&schema)))
986994
.collect();
987995

988-
let aggr_expr = agg_exprs?.into_iter().map(Arc::new).collect();
996+
let aggr_expr: Vec<Arc<AggregateFunctionExpr>> = if has_partial_merge {
997+
// Wrap PartialMerge expressions with MergeAsPartial.
998+
// State fields in the child's output start at initial_input_buffer_offset.
999+
let mut state_offset = agg.initial_input_buffer_offset as usize;
1000+
let child_schema = child.schema();
1001+
1002+
agg_exprs?
1003+
.into_iter()
1004+
.enumerate()
1005+
.map(|(idx, expr)| {
1006+
let expr_mode = per_expr_modes[idx];
1007+
if expr_mode == 2 {
1008+
// PartialMerge: wrap with MergeAsPartial
1009+
let state_fields = expr.state_fields().map_err(|e| {
1010+
ExecutionError::GeneralError(e.to_string())
1011+
})?;
1012+
let num_state_fields = state_fields.len();
1013+
1014+
// Create Column refs pointing to state field positions
1015+
let state_cols: Vec<Arc<dyn PhysicalExpr>> = (0..num_state_fields)
1016+
.map(|i| {
1017+
let col_idx = state_offset + i;
1018+
let field = child_schema.field(col_idx);
1019+
Arc::new(Column::new(field.name(), col_idx))
1020+
as Arc<dyn PhysicalExpr>
1021+
})
1022+
.collect();
1023+
state_offset += num_state_fields;
1024+
1025+
let merge_udf = crate::execution::merge_as_partial::MergeAsPartialUDF::new(&expr)
1026+
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))?;
1027+
let merge_udf_arc = Arc::new(
1028+
datafusion::logical_expr::AggregateUDF::new_from_impl(merge_udf),
1029+
);
1030+
1031+
let merge_expr = AggregateExprBuilder::new(
1032+
merge_udf_arc,
1033+
state_cols,
1034+
)
1035+
.schema(Arc::clone(&child_schema))
1036+
.alias(format!("col_{idx}"))
1037+
.with_ignore_nulls(false)
1038+
.with_distinct(false)
1039+
.build()
1040+
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))?;
1041+
1042+
Ok(Arc::new(merge_expr))
1043+
} else {
1044+
// Partial: use as-is
1045+
Ok(Arc::new(expr))
1046+
}
1047+
})
1048+
.collect::<Result<Vec<_>, ExecutionError>>()?
1049+
} else {
1050+
agg_exprs?.into_iter().map(Arc::new).collect()
1051+
};
9891052

9901053
// Build per-aggregate filter expressions from the FILTER (WHERE ...) clause.
9911054
// Filters are only present in Partial mode; Final/PartialMerge always get None.

native/proto/src/proto/operator.proto

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,13 @@ message HashAggregate {
270270
repeated spark.spark_expression.AggExpr agg_exprs = 2;
271271
repeated spark.spark_expression.Expr result_exprs = 3;
272272
AggregateMode mode = 5;
273+
// Per-expression modes for mixed-mode aggregates (e.g., PartialMerge + Partial).
274+
// When set, each entry corresponds to agg_exprs at the same index.
275+
// When empty, all expressions use the `mode` field.
276+
repeated AggregateMode expr_modes = 6;
277+
// Offset in the child's output where aggregate buffer attributes start.
278+
// Used by PartialMerge to locate state fields in the input.
279+
int32 initial_input_buffer_offset = 7;
273280
}
274281

275282
message Limit {

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,8 +1356,9 @@ trait CometBaseAggregate {
13561356
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {
13571357

13581358
val modes = aggregate.aggregateExpressions.map(_.mode).distinct
1359-
// In distinct aggregates there can be a combination of modes
1360-
val multiMode = modes.size > 1
1359+
// In distinct aggregates there can be a combination of modes.
1360+
// We support {Partial, PartialMerge} mix; other combinations are rejected.
1361+
val multiMode = modes.size > 1 && modes.toSet != Set(Partial, PartialMerge)
13611362
// For a final mode HashAggregate, we only need to transform the HashAggregate
13621363
// if there is Comet partial aggregation.
13631364
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty
@@ -1430,34 +1431,42 @@ trait CometBaseAggregate {
14301431
Some(builder.setHashAgg(hashAggBuilder).build())
14311432
} else {
14321433
val modes = aggregateExpressions.map(_.mode).distinct
1433-
1434-
if (modes.size != 1) {
1435-
// This shouldn't happen as all aggregation expressions should share the same mode.
1436-
// Fallback to Spark nevertheless here.
1437-
withInfo(aggregate, "All aggregate expressions do not have the same mode")
1434+
val modeSet = modes.toSet
1435+
1436+
// Validate mode combinations. We support:
1437+
// - All Partial
1438+
// - All Final
1439+
// - All PartialMerge
1440+
// - Mixed {Partial, PartialMerge} (for distinct aggregate plans)
1441+
val isMixedPartialMerge = modeSet == Set(Partial, PartialMerge)
1442+
if (modes.size > 1 && !isMixedPartialMerge) {
1443+
withInfo(aggregate, s"Unsupported mixed aggregation modes: ${modes.mkString(", ")}")
14381444
return None
14391445
}
14401446

1441-
val mode = modes.head match {
1442-
case Partial => CometAggregateMode.Partial
1443-
case Final => CometAggregateMode.Final
1444-
case PartialMerge => CometAggregateMode.PartialMerge
1445-
case _ =>
1446-
withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}")
1447-
return None
1447+
// Determine the proto mode. For uniform modes, use that mode directly.
1448+
// For mixed {Partial, PartialMerge}, use Partial as the base mode since
1449+
// PartialMerge expressions are wrapped with MergeAsPartial on the native side.
1450+
val mode = if (isMixedPartialMerge) {
1451+
CometAggregateMode.Partial
1452+
} else {
1453+
modes.head match {
1454+
case Partial => CometAggregateMode.Partial
1455+
case Final => CometAggregateMode.Final
1456+
case PartialMerge => CometAggregateMode.PartialMerge
1457+
case _ =>
1458+
withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}")
1459+
return None
1460+
}
14481461
}
14491462

1450-
// In final/partial-merge mode, the aggregate expressions are bound to the
1451-
// output of the child and partial aggregate expressions buffer attributes
1452-
// produced by partial aggregation. This is done in Spark `HashAggregateExec`
1453-
// internally. In Comet, we don't have to do this because we don't use the
1454-
// merging expression.
1455-
val binding = mode == CometAggregateMode.Partial
1456-
// `output` is only used when `binding` is true (i.e., non-Final)
1463+
// Per-expression binding: Partial expressions bind to child output,
1464+
// PartialMerge/Final expressions do not (native planner handles their input).
14571465
val output = child.output
1458-
1459-
val aggExprs =
1460-
aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf))
1466+
val aggExprs = aggregateExpressions.map { a =>
1467+
val exprBinding = a.mode != PartialMerge && a.mode != Final
1468+
aggExprToProto(a, output, exprBinding, aggregate.conf)
1469+
}
14611470

14621471
if (aggExprs.exists(_.isEmpty)) {
14631472
withInfo(
@@ -1485,6 +1494,24 @@ trait CometBaseAggregate {
14851494
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
14861495
}
14871496
hashAggBuilder.setModeValue(mode.getNumber)
1497+
1498+
// Send per-expression modes and buffer offset for PartialMerge handling
1499+
val hasPartialMerge = aggregateExpressions.exists(_.mode == PartialMerge)
1500+
if (hasPartialMerge) {
1501+
val exprModes = aggregateExpressions.map { a =>
1502+
a.mode match {
1503+
case Partial => CometAggregateMode.Partial
1504+
case PartialMerge => CometAggregateMode.PartialMerge
1505+
case Final => CometAggregateMode.Final
1506+
case other =>
1507+
withInfo(aggregate, s"Unsupported aggregation mode $other")
1508+
return None
1509+
}
1510+
}
1511+
hashAggBuilder.addAllExprModes(exprModes.asJava)
1512+
hashAggBuilder.setInitialInputBufferOffset(aggregate.initialInputBufferOffset)
1513+
}
1514+
14881515
Some(builder.setHashAgg(hashAggBuilder).build())
14891516
} else {
14901517
val allChildren: Seq[Expression] =

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,10 +638,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
638638
dictionaryEnabled) {
639639
withView("v") {
640640
sql("CREATE TEMP VIEW v AS SELECT _1, _2 FROM tbl ORDER BY _1")
641-
checkSparkAnswerAndFallbackReason(
641+
checkSparkAnswerAndOperator(
642642
"SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1)," +
643-
" COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2",
644-
"All aggregate expressions do not have the same mode")
643+
" COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2")
645644
}
646645
}
647646
}

0 commit comments

Comments
 (0)