Skip to content

Commit d6b7b10

Browse files
committed
feat: support PartialMerge
1 parent a2a3dd3 commit d6b7b10

6 files changed

Lines changed: 98 additions & 31 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -966,10 +966,17 @@ impl PhysicalPlanner {
966966
let group_by = PhysicalGroupBy::new_single(group_exprs?);
967967
let schema = child.schema();
968968

969-
let mode = if agg.mode == 0 {
970-
DFAggregateMode::Partial
971-
} else {
972-
DFAggregateMode::Final
969+
let mode = match agg.mode {
970+
0 => DFAggregateMode::Partial,
971+
// Both Final and PartialMerge use merge semantics in DataFusion.
972+
// The output difference (final values vs intermediate buffers) is
973+
// handled by the presence/absence of result_exprs.
974+
1 | 2 => DFAggregateMode::Final,
975+
other => {
976+
return Err(ExecutionError::GeneralError(format!(
977+
"Unsupported aggregate mode: {other}"
978+
)))
979+
}
973980
};
974981

975982
let agg_exprs: PhyAggResult = agg

native/proto/src/proto/operator.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ message ParquetWriter {
319319
enum AggregateMode {
320320
Partial = 0;
321321
Final = 1;
322+
PartialMerge = 2;
322323
}
323324

324325
message Expand {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -460,15 +460,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
460460
binding: Boolean,
461461
conf: SQLConf): Option[AggExpr] = {
462462

463-
// Support Count(distinct single_value)
464-
// COUNT(DISTINCT x) - supported
465-
// COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x)
466-
// COUNT(DISTINCT x, y) - not supported
463+
// Distinct aggregates with a single column are supported (e.g., COUNT(DISTINCT x),
464+
// SUM(DISTINCT x), AVG(DISTINCT x)). The multi-stage plan generated by Spark
465+
// guarantees distinct semantics through grouping — the native side does not need
466+
// to handle deduplication.
467+
// Multi-column distinct is only supported for COUNT (e.g., COUNT(DISTINCT x, y)).
467468
if (aggExpr.isDistinct
468-
&&
469-
!(aggExpr.aggregateFunction.prettyName == "count" &&
470-
aggExpr.aggregateFunction.children.length == 1)) {
471-
withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr")
469+
&& aggExpr.aggregateFunction.children.length > 1
470+
&& aggExpr.aggregateFunction.prettyName != "count") {
471+
withInfo(aggExpr, s"Multi-column distinct aggregate not supported for: $aggExpr")
472472
return None
473473
}
474474

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,16 +1441,18 @@ trait CometBaseAggregate {
14411441
val mode = modes.head match {
14421442
case Partial => CometAggregateMode.Partial
14431443
case Final => CometAggregateMode.Final
1444+
case PartialMerge => CometAggregateMode.PartialMerge
14441445
case _ =>
14451446
withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}")
14461447
return None
14471448
}
14481449

1449-
// In final mode, the aggregate expressions are bound to the output of the
1450-
// child and partial aggregate expressions buffer attributes produced by partial
1451-
// aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
1452-
// we don't have to do this because we don't use the merging expression.
1453-
val binding = mode != CometAggregateMode.Final
1450+
// In final/partial-merge mode, the aggregate expressions are bound to the
1451+
// output of the child and partial aggregate expressions buffer attributes
1452+
// produced by partial aggregation. This is done in Spark `HashAggregateExec`
1453+
// internally. In Comet, we don't have to do this because we don't use the
1454+
// merging expression.
1455+
val binding = mode == CometAggregateMode.Partial
14541456
// `output` is only used when `binding` is true (i.e., non-Final)
14551457
val output = child.output
14561458

@@ -1496,14 +1498,21 @@ trait CometBaseAggregate {
14961498

14971499
/**
14981500
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
1499-
* partial mode, it will return None.
1501+
* partial or partial-merge mode, it will return None.
15001502
*/
15011503
private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
1504+
def isPartialOrMerge(mode: AggregateMode): Boolean =
1505+
mode == Partial || mode == PartialMerge
1506+
15021507
plan.collectFirst {
1503-
case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
1508+
case agg: CometHashAggregateExec
1509+
if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
15041510
Some(agg)
1505-
case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None
1506-
case agg: ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
1511+
case agg: HashAggregateExec
1512+
if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
1513+
None
1514+
case agg: ObjectHashAggregateExec
1515+
if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
15071516
None
15081517
case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
15091518
case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
@@ -1605,11 +1614,8 @@ case class CometHashAggregateExec(
16051614

16061615
// The aggExprs could be empty. For example, if the aggregate functions only have
16071616
// distinct aggregate functions or only have group by, the aggExprs is empty and
1608-
// modes is empty too. If aggExprs is not empty, we need to verify all the
1609-
// aggregates have the same mode.
1617+
// modes is empty too.
16101618
val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct
1611-
assert(modes.length == 1 || modes.isEmpty)
1612-
val mode = modes.headOption
16131619

16141620
override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(resultExpressions)
16151621

@@ -1626,7 +1632,7 @@ case class CometHashAggregateExec(
16261632
}
16271633

16281634
override def stringArgs: Iterator[Any] =
1629-
Iterator(input, mode, groupingExpressions, aggregateExpressions, child)
1635+
Iterator(input, modes, groupingExpressions, aggregateExpressions, child)
16301636

16311637
override def equals(obj: Any): Boolean = {
16321638
obj match {
@@ -1635,7 +1641,7 @@ case class CometHashAggregateExec(
16351641
this.groupingExpressions == other.groupingExpressions &&
16361642
this.aggregateExpressions == other.aggregateExpressions &&
16371643
this.input == other.input &&
1638-
this.mode == other.mode &&
1644+
this.modes == other.modes &&
16391645
this.child == other.child &&
16401646
this.serializedPlanOpt == other.serializedPlanOpt
16411647
case _ =>
@@ -1644,7 +1650,7 @@ case class CometHashAggregateExec(
16441650
}
16451651

16461652
override def hashCode(): Int =
1647-
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, mode, child)
1653+
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, modes, child)
16481654

16491655
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
16501656
}

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
641641
checkSparkAnswerAndFallbackReason(
642642
"SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1)," +
643643
" COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2",
644-
"Unsupported aggregation mode PartialMerge")
644+
"All aggregate expressions do not have the same mode")
645645
}
646646
}
647647
}
@@ -650,6 +650,56 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
650650
}
651651
}
652652

653+
test("partialMerge - cnt distinct + sum") {
654+
withTempDir(dir => {
655+
withSQLConf("spark.comet.enabled" -> "false") {
656+
sql("""
657+
CREATE OR REPLACE TEMP VIEW t (v, v1, i) AS
658+
VALUES
659+
('c', 'a', 1),
660+
('c1', 'a1', 1),
661+
('c2', 'a2', 2),
662+
('c3', 'a3', 2),
663+
('c4', 'a4', 2),
664+
('c', 'a', 1),
665+
('c1', 'a1', 1),
666+
('c2', 'a2', 2),
667+
('c3', 'a3', 2),
668+
('c4', 'a4', 2),
669+
('c', 'a', 1),
670+
('c1', 'a1', 1),
671+
('c2', 'a2', 2),
672+
('c3', 'a3', 2),
673+
('c4', 'a4', 2)
674+
""")
675+
sql("select * from t")
676+
.repartition(3)
677+
.write
678+
.mode("overwrite")
679+
.parquet(dir.getAbsolutePath)
680+
}
681+
682+
withSQLConf(
683+
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
684+
"spark.comet.exec.shuffle.fallbackToColumnar" -> "false",
685+
"spark.comet.cast.allowIncompatible" -> "true",
686+
"spark.sql.adaptive.enabled" -> "false",
687+
"spark.comet.explain.native.enabled" -> "true",
688+
"spark.comet.enabled" -> "true",
689+
"spark.comet.expression.Cast.allowIncompatible" -> "true",
690+
"spark.comet.exec.shuffle.enableFastEncoding" -> "true",
691+
"spark.comet.exec.shuffle.enabled" -> "true",
692+
"spark.comet.explainFallback.enabled" -> "true",
693+
CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_iceberg_compat",
694+
"spark.shuffle.manager" ->
695+
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager",
696+
"spark.comet.logFallbackReasons.enabled" -> "true") {
697+
spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("t2")
698+
checkSparkAnswerAndOperator("SELECT i, sum(v1), count(distinct v) FROM t2 group by i")
699+
}
700+
})
701+
}
702+
653703
test("multiple group-by columns + single aggregate column (first/last), with nulls") {
654704
val numValues = 10000
655705

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,12 @@ class CometExecSuite extends CometTestBase {
484484
case s: CometHashAggregateExec => s
485485
}.get
486486

487-
assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode])
487+
assert(
488+
agg.modes.nonEmpty && agg.modes.headOption.get.isInstanceOf[AggregateMode])
488489
val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec]
489-
assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode])
490+
assert(
491+
newAgg.modes.nonEmpty &&
492+
newAgg.modes.headOption.get.isInstanceOf[AggregateMode])
490493
}
491494
}
492495

0 commit comments

Comments
 (0)