@@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer
2323
2424import org .apache .spark .sql .SparkSession
2525import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
26- import org .apache .spark .sql .catalyst .expressions .aggregate .{Final , Partial }
26+ import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateMode , Final , Partial }
2727import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
2828import org .apache .spark .sql .catalyst .rules .Rule
2929import org .apache .spark .sql .catalyst .trees .TreeNodeTag
@@ -629,12 +629,18 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
629629 plan.foreach {
630630 case agg : BaseAggregateExec if agg.aggregateExpressions.exists(_.mode == Final ) =>
631631 if (! QueryPlanSerde .allAggsSupportMixedExecution(agg.aggregateExpressions)) {
632- if (! canFinalAggregateBeConverted (agg)) {
632+ if (! canAggregateBeConverted (agg, Final )) {
633633 findPartialAggInPlan(agg.child).foreach { partial =>
634- partial.setTagValue(
635- CometExecRule .COMET_UNSAFE_PARTIAL ,
636- " Partial aggregate disabled: corresponding final aggregate " +
637- " cannot be converted to Comet and intermediate buffer formats are incompatible" )
634+ // Only tag if the Partial would otherwise have been converted. If the Partial
635+ // itself cannot be converted (e.g. the aggregate function is incompatible for the
636+ // input type), there is no buffer-format mismatch to guard against, and tagging
637+ // would mask the natural, more specific fallback reason.
638+ if (canAggregateBeConverted(partial, Partial )) {
639+ partial.setTagValue(
640+ CometExecRule .COMET_UNSAFE_PARTIAL ,
641+ " Partial aggregate disabled: corresponding final aggregate " +
642+ " cannot be converted to Comet and intermediate buffer formats are incompatible" )
643+ }
638644 }
639645 }
640646 }
@@ -643,8 +649,8 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
643649 }
644650
645651 /**
646- * Conservative check for whether a Final-mode aggregate could be converted to Comet. Checks
647- * operator enablement, grouping expressions, aggregate expressions, and result expressions.
652+ * Conservative check for whether an aggregate could be converted to Comet. Checks operator
653+ * enablement, grouping expressions, aggregate expressions, and result expressions.
648654 * Intentionally skips the sparkFinalMode / child-native checks since those depend on
649655 * transformation state.
650656 *
@@ -653,7 +659,9 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
653659 * this tagging pass will drift and either crash (missed tag) or over-disable (spurious tag). A
654660 * shared predicate helper would be preferable.
655661 */
656- private def canFinalAggregateBeConverted (agg : BaseAggregateExec ): Boolean = {
662+ private def canAggregateBeConverted (
663+ agg : BaseAggregateExec ,
664+ expectedMode : AggregateMode ): Boolean = {
657665 val handler = allExecs.get(agg.getClass)
658666 if (handler.isEmpty) return false
659667 val serde = handler.get.asInstanceOf [CometOperatorSerde [SparkPlan ]]
@@ -677,20 +685,35 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
677685 return false
678686 }
679687
680- if (aggregateExpressions.nonEmpty) {
681- val modes = aggregateExpressions.map(_.mode).distinct
682- if (modes.size != 1 || ! modes.contains(Final )) return false
688+ if (aggregateExpressions.isEmpty) {
689+ // Result expressions always checked when there are no aggregate expressions
690+ val attributes =
691+ groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
692+ return agg.resultExpressions.forall(e =>
693+ QueryPlanSerde .exprToProto(e, attributes).isDefined)
694+ }
683695
684- val binding = false
685- if (! aggregateExpressions.forall(e =>
686- QueryPlanSerde .aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) {
687- return false
688- }
696+ val modes = aggregateExpressions.map(_.mode).distinct
697+ if (modes.size != 1 || modes.head != expectedMode) return false
698+
699+ // In Final mode, exprToProto resolves against the child's output; in Partial/non-Final mode
700+ // it must bind to input attributes. This mirrors the `binding` calculation in
701+ // `CometBaseAggregate.doConvert`.
702+ val binding = expectedMode != Final
703+ if (! aggregateExpressions.forall(e =>
704+ QueryPlanSerde .aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) {
705+ return false
689706 }
690707
691- val attributes =
692- groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
693- agg.resultExpressions.forall(e => QueryPlanSerde .exprToProto(e, attributes).isDefined)
708+ // doConvert only checks resultExpressions in Final mode when aggregate expressions exist
709+ // (Partial emits the buffer directly). Mirror that here to avoid false negatives.
710+ if (expectedMode == Final ) {
711+ val attributes =
712+ groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
713+ agg.resultExpressions.forall(e => QueryPlanSerde .exprToProto(e, attributes).isDefined)
714+ } else {
715+ true
716+ }
694717 }
695718
696719 /**
0 commit comments