Skip to content

Commit 9826403

Browse files
committed
fix: skip partial aggregate tag when partial itself cannot be converted
If the corresponding partial aggregate would also fail conversion to Comet (for example, collect_set on float is incompatible), tagging it early hijacks the more specific natural fallback reason. Only tag the partial when it would otherwise have been converted, so the tag guards genuine buffer-format mismatches rather than masking unrelated fallbacks. Generalize the convertibility predicate to accept an expected mode and mirror the mode-specific result-expression handling in doConvert.
1 parent f2a8207 commit 9826403

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer
2323

2424
import org.apache.spark.sql.SparkSession
2525
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
26-
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
26+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Final, Partial}
2727
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -629,12 +629,18 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
629629
plan.foreach {
630630
case agg: BaseAggregateExec if agg.aggregateExpressions.exists(_.mode == Final) =>
631631
if (!QueryPlanSerde.allAggsSupportMixedExecution(agg.aggregateExpressions)) {
632-
if (!canFinalAggregateBeConverted(agg)) {
632+
if (!canAggregateBeConverted(agg, Final)) {
633633
findPartialAggInPlan(agg.child).foreach { partial =>
634-
partial.setTagValue(
635-
CometExecRule.COMET_UNSAFE_PARTIAL,
636-
"Partial aggregate disabled: corresponding final aggregate " +
637-
"cannot be converted to Comet and intermediate buffer formats are incompatible")
634+
// Only tag if the Partial would otherwise have been converted. If the Partial
635+
// itself cannot be converted (e.g. the aggregate function is incompatible for the
636+
// input type), there is no buffer-format mismatch to guard against, and tagging
637+
// would mask the natural, more specific fallback reason.
638+
if (canAggregateBeConverted(partial, Partial)) {
639+
partial.setTagValue(
640+
CometExecRule.COMET_UNSAFE_PARTIAL,
641+
"Partial aggregate disabled: corresponding final aggregate " +
642+
"cannot be converted to Comet and intermediate buffer formats are incompatible")
643+
}
638644
}
639645
}
640646
}
@@ -643,8 +649,8 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
643649
}
644650

645651
/**
646-
* Conservative check for whether a Final-mode aggregate could be converted to Comet. Checks
647-
* operator enablement, grouping expressions, aggregate expressions, and result expressions.
652+
* Conservative check for whether an aggregate could be converted to Comet. Checks operator
653+
* enablement, grouping expressions, aggregate expressions, and result expressions.
648654
* Intentionally skips the sparkFinalMode / child-native checks since those depend on
649655
* transformation state.
650656
*
@@ -653,7 +659,9 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
653659
* this tagging pass will drift and either crash (missed tag) or over-disable (spurious tag). A
654660
* shared predicate helper would be preferable.
655661
*/
656-
private def canFinalAggregateBeConverted(agg: BaseAggregateExec): Boolean = {
662+
private def canAggregateBeConverted(
663+
agg: BaseAggregateExec,
664+
expectedMode: AggregateMode): Boolean = {
657665
val handler = allExecs.get(agg.getClass)
658666
if (handler.isEmpty) return false
659667
val serde = handler.get.asInstanceOf[CometOperatorSerde[SparkPlan]]
@@ -677,20 +685,35 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
677685
return false
678686
}
679687

680-
if (aggregateExpressions.nonEmpty) {
681-
val modes = aggregateExpressions.map(_.mode).distinct
682-
if (modes.size != 1 || !modes.contains(Final)) return false
688+
if (aggregateExpressions.isEmpty) {
689+
// Result expressions always checked when there are no aggregate expressions
690+
val attributes =
691+
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
692+
return agg.resultExpressions.forall(e =>
693+
QueryPlanSerde.exprToProto(e, attributes).isDefined)
694+
}
683695

684-
val binding = false
685-
if (!aggregateExpressions.forall(e =>
686-
QueryPlanSerde.aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) {
687-
return false
688-
}
696+
val modes = aggregateExpressions.map(_.mode).distinct
697+
if (modes.size != 1 || modes.head != expectedMode) return false
698+
699+
// In Final mode, exprToProto resolves against the child's output; in Partial/non-Final mode
700+
// it must bind to input attributes. This mirrors the `binding` calculation in
701+
// `CometBaseAggregate.doConvert`.
702+
val binding = expectedMode != Final
703+
if (!aggregateExpressions.forall(e =>
704+
QueryPlanSerde.aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) {
705+
return false
689706
}
690707

691-
val attributes =
692-
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
693-
agg.resultExpressions.forall(e => QueryPlanSerde.exprToProto(e, attributes).isDefined)
708+
// doConvert only checks resultExpressions in Final mode when aggregate expressions exist
709+
// (Partial emits the buffer directly). Mirror that here to avoid false negatives.
710+
if (expectedMode == Final) {
711+
val attributes =
712+
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
713+
agg.resultExpressions.forall(e => QueryPlanSerde.exprToProto(e, attributes).isDefined)
714+
} else {
715+
true
716+
}
694717
}
695718

696719
/**

0 commit comments

Comments
 (0)