Skip to content

Commit f2a8207

Browse files
committed
fix: address review feedback on mixed partial/final aggregate guard
- Restore `convert` scaladoc in `CometAggregateExpressionSerde` that was displaced when `supportsMixedPartialFinal` was added - Require `aggregateExpressions.nonEmpty` in `findPartialAggInPlan` so intermediate distinct-elimination stages (empty agg, group-by only) are not incorrectly tagged as the Partial to disable - Document that `canFinalAggregateBeConverted` mirrors the predicate checks in `CometBaseAggregate.doConvert` and must be kept in sync
1 parent f7fa33c commit f2a8207

2 files changed

Lines changed: 19 additions & 10 deletions

File tree

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,11 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
647647
* operator enablement, grouping expressions, aggregate expressions, and result expressions.
648648
* Intentionally skips the sparkFinalMode / child-native checks since those depend on
649649
* transformation state.
650+
*
651+
* WARNING: this intentionally mirrors the predicate checks in `CometBaseAggregate.doConvert`
652+
* (operators.scala). Any change to the convertibility rules there must be reflected here or
653+
* this tagging pass will drift and either crash (missed tag) or over-disable (spurious tag). A
654+
* shared predicate helper would be preferable.
650655
*/
651656
private def canFinalAggregateBeConverted(agg: BaseAggregateExec): Boolean = {
652657
val handler = allExecs.get(agg.getClass)
@@ -690,11 +695,15 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
690695

691696
/**
692697
* Search the child subtree for the first Partial-mode aggregate, traversing through exchanges
693-
* and AQE stages.
698+
* and AQE stages. Requires `aggregateExpressions.nonEmpty` so that intermediate distinct stages
699+
* (group-by-only aggregates with empty aggregateExpressions, where `.forall` vacuously matches)
700+
* are not mistaken for the partial we want to tag.
694701
*/
695702
private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = {
696703
plan.collectFirst {
697-
case agg: BaseAggregateExec if agg.aggregateExpressions.forall(e => e.mode == Partial) =>
704+
case agg: BaseAggregateExec
705+
if agg.aggregateExpressions.nonEmpty &&
706+
agg.aggregateExpressions.forall(e => e.mode == Partial) =>
698707
Some(agg)
699708
case a: AQEShuffleReadExec => findPartialAggInPlan(a.child)
700709
case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan)

spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
4949
*/
5050
def getSupportLevel(expr: T): SupportLevel = Compatible(None)
5151

52+
/**
53+
* Whether this aggregate's intermediate buffer format is compatible between Spark and Comet,
54+
* making it safe to run the Partial in one engine and the Final in the other. Aggregates with
55+
* simple single-value buffers (MIN, MAX, COUNT, bitwise) are safe; those with complex or
56+
* differently-encoded buffers (AVG, SUM with decimals, CollectSet, Variance) are not.
57+
*/
58+
def supportsMixedPartialFinal: Boolean = false
59+
5260
/**
5361
* Convert a Spark expression into a protocol buffer representation that can be passed into
5462
* native code.
@@ -68,14 +76,6 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
6876
* case it is expected that the input expression will have been tagged with reasons why it
6977
* could not be converted.
7078
*/
71-
/**
72-
* Whether this aggregate's intermediate buffer format is compatible between Spark and Comet,
73-
* making it safe to run the Partial in one engine and the Final in the other. Aggregates with
74-
* simple single-value buffers (MIN, MAX, COUNT, bitwise) are safe; those with complex or
75-
* differently-encoded buffers (AVG, SUM with decimals, CollectSet, Variance) are not.
76-
*/
77-
def supportsMixedPartialFinal: Boolean = false
78-
7979
def convert(
8080
aggExpr: AggregateExpression,
8181
expr: T,

0 commit comments

Comments
 (0)