Skip to content

Commit f12ce3e

Browse files
committed
feat: support PartialMerge
1 parent d6d5f09 commit f12ce3e

6 files changed

Lines changed: 98 additions & 31 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -967,10 +967,17 @@ impl PhysicalPlanner {
967967
let group_by = PhysicalGroupBy::new_single(group_exprs?);
968968
let schema = child.schema();
969969

970-
let mode = if agg.mode == 0 {
971-
DFAggregateMode::Partial
972-
} else {
973-
DFAggregateMode::Final
970+
let mode = match agg.mode {
971+
0 => DFAggregateMode::Partial,
972+
// Both Final and PartialMerge use merge semantics in DataFusion.
973+
// The output difference (final values vs intermediate buffers) is
974+
// handled by the presence/absence of result_exprs.
975+
1 | 2 => DFAggregateMode::Final,
976+
other => {
977+
return Err(ExecutionError::GeneralError(format!(
978+
"Unsupported aggregate mode: {other}"
979+
)))
980+
}
974981
};
975982

976983
let agg_exprs: PhyAggResult = agg

native/proto/src/proto/operator.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ message ParquetWriter {
319319
enum AggregateMode {
320320
Partial = 0;
321321
Final = 1;
322+
PartialMerge = 2;
322323
}
323324

324325
message Expand {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -461,15 +461,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
461461
binding: Boolean,
462462
conf: SQLConf): Option[AggExpr] = {
463463

464-
// Support Count(distinct single_value)
465-
// COUNT(DISTINCT x) - supported
466-
// COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x)
467-
// COUNT(DISTINCT x, y) - not supported
464+
// Distinct aggregates with a single column are supported (e.g., COUNT(DISTINCT x),
465+
// SUM(DISTINCT x), AVG(DISTINCT x)). The multi-stage plan generated by Spark
466+
// guarantees distinct semantics through grouping — the native side does not need
467+
// to handle deduplication.
468+
// Multi-column distinct is only supported for COUNT (e.g., COUNT(DISTINCT x, y)).
468469
if (aggExpr.isDistinct
469-
&&
470-
!(aggExpr.aggregateFunction.prettyName == "count" &&
471-
aggExpr.aggregateFunction.children.length == 1)) {
472-
withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr")
470+
&& aggExpr.aggregateFunction.children.length > 1
471+
&& aggExpr.aggregateFunction.prettyName != "count") {
472+
withInfo(aggExpr, s"Multi-column distinct aggregate not supported for: $aggExpr")
473473
return None
474474
}
475475

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,16 +1441,18 @@ trait CometBaseAggregate {
14411441
val mode = modes.head match {
14421442
case Partial => CometAggregateMode.Partial
14431443
case Final => CometAggregateMode.Final
1444+
case PartialMerge => CometAggregateMode.PartialMerge
14441445
case _ =>
14451446
withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}")
14461447
return None
14471448
}
14481449

1449-
// In final mode, the aggregate expressions are bound to the output of the
1450-
// child and partial aggregate expressions buffer attributes produced by partial
1451-
// aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
1452-
// we don't have to do this because we don't use the merging expression.
1453-
val binding = mode != CometAggregateMode.Final
1450+
// In final/partial-merge mode, the aggregate expressions are bound to the
1451+
// output of the child and partial aggregate expressions buffer attributes
1452+
// produced by partial aggregation. This is done in Spark `HashAggregateExec`
1453+
// internally. In Comet, we don't have to do this because we don't use the
1454+
// merging expression.
1455+
val binding = mode == CometAggregateMode.Partial
14541456
// `output` is only used when `binding` is true (i.e., non-Final)
14551457
val output = child.output
14561458

@@ -1496,14 +1498,21 @@ trait CometBaseAggregate {
14961498

14971499
/**
14981500
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
1499-
* partial mode, it will return None.
1501+
* partial or partial-merge mode, it will return None.
15001502
*/
15011503
private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
1504+
def isPartialOrMerge(mode: AggregateMode): Boolean =
1505+
mode == Partial || mode == PartialMerge
1506+
15021507
plan.collectFirst {
1503-
case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
1508+
case agg: CometHashAggregateExec
1509+
if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
15041510
Some(agg)
1505-
case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None
1506-
case agg: ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
1511+
case agg: HashAggregateExec
1512+
if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
1513+
None
1514+
case agg: ObjectHashAggregateExec
1515+
if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
15071516
None
15081517
case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
15091518
case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
@@ -1642,11 +1651,8 @@ case class CometHashAggregateExec(
16421651

16431652
// The aggExprs could be empty. For example, if the aggregate functions only have
16441653
// distinct aggregate functions or only have group by, the aggExprs is empty and
1645-
// modes is empty too. If aggExprs is not empty, we need to verify all the
1646-
// aggregates have the same mode.
1654+
// modes is empty too.
16471655
val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct
1648-
assert(modes.length == 1 || modes.isEmpty)
1649-
val mode = modes.headOption
16501656

16511657
override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(resultExpressions)
16521658

@@ -1663,7 +1669,7 @@ case class CometHashAggregateExec(
16631669
}
16641670

16651671
override def stringArgs: Iterator[Any] =
1666-
Iterator(input, mode, groupingExpressions, aggregateExpressions, child)
1672+
Iterator(input, modes, groupingExpressions, aggregateExpressions, child)
16671673

16681674
override def equals(obj: Any): Boolean = {
16691675
obj match {
@@ -1672,7 +1678,7 @@ case class CometHashAggregateExec(
16721678
this.groupingExpressions == other.groupingExpressions &&
16731679
this.aggregateExpressions == other.aggregateExpressions &&
16741680
this.input == other.input &&
1675-
this.mode == other.mode &&
1681+
this.modes == other.modes &&
16761682
this.child == other.child &&
16771683
this.serializedPlanOpt == other.serializedPlanOpt
16781684
case _ =>
@@ -1681,7 +1687,7 @@ case class CometHashAggregateExec(
16811687
}
16821688

16831689
override def hashCode(): Int =
1684-
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, mode, child)
1690+
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, modes, child)
16851691

16861692
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
16871693
}

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
641641
checkSparkAnswerAndFallbackReason(
642642
"SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1)," +
643643
" COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2",
644-
"Unsupported aggregation mode PartialMerge")
644+
"All aggregate expressions do not have the same mode")
645645
}
646646
}
647647
}
@@ -650,6 +650,56 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
650650
}
651651
}
652652

653+
test("partialMerge - cnt distinct + sum") {
654+
withTempDir(dir => {
655+
withSQLConf("spark.comet.enabled" -> "false") {
656+
sql("""
657+
CREATE OR REPLACE TEMP VIEW t (v, v1, i) AS
658+
VALUES
659+
('c', 'a', 1),
660+
('c1', 'a1', 1),
661+
('c2', 'a2', 2),
662+
('c3', 'a3', 2),
663+
('c4', 'a4', 2),
664+
('c', 'a', 1),
665+
('c1', 'a1', 1),
666+
('c2', 'a2', 2),
667+
('c3', 'a3', 2),
668+
('c4', 'a4', 2),
669+
('c', 'a', 1),
670+
('c1', 'a1', 1),
671+
('c2', 'a2', 2),
672+
('c3', 'a3', 2),
673+
('c4', 'a4', 2)
674+
""")
675+
sql("select * from t")
676+
.repartition(3)
677+
.write
678+
.mode("overwrite")
679+
.parquet(dir.getAbsolutePath)
680+
}
681+
682+
withSQLConf(
683+
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
684+
"spark.comet.exec.shuffle.fallbackToColumnar" -> "false",
685+
"spark.comet.cast.allowIncompatible" -> "true",
686+
"spark.sql.adaptive.enabled" -> "false",
687+
"spark.comet.explain.native.enabled" -> "true",
688+
"spark.comet.enabled" -> "true",
689+
"spark.comet.expression.Cast.allowIncompatible" -> "true",
690+
"spark.comet.exec.shuffle.enableFastEncoding" -> "true",
691+
"spark.comet.exec.shuffle.enabled" -> "true",
692+
"spark.comet.explainFallback.enabled" -> "true",
693+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_iceberg_compat",
694+
"spark.shuffle.manager" ->
695+
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager",
696+
"spark.comet.logFallbackReasons.enabled" -> "true") {
697+
spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("t2")
698+
checkSparkAnswerAndOperator("SELECT i, sum(v1), count(distinct v) FROM t2 group by i")
699+
}
700+
})
701+
}
702+
653703
test("multiple group-by columns + single aggregate column (first/last), with nulls") {
654704
val numValues = 10000
655705

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,12 @@ class CometExecSuite extends CometTestBase {
484484
case s: CometHashAggregateExec => s
485485
}.get
486486

487-
assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode])
487+
assert(
488+
agg.modes.nonEmpty && agg.modes.headOption.get.isInstanceOf[AggregateMode])
488489
val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec]
489-
assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode])
490+
assert(
491+
newAgg.modes.nonEmpty &&
492+
newAgg.modes.headOption.get.isInstanceOf[AggregateMode])
490493
}
491494
}
492495

0 commit comments

Comments
 (0)