@@ -1441,16 +1441,18 @@ trait CometBaseAggregate {
14411441 val mode = modes.head match {
14421442 case Partial => CometAggregateMode .Partial
14431443 case Final => CometAggregateMode .Final
1444+ case PartialMerge => CometAggregateMode .PartialMerge
14441445 case _ =>
14451446 withInfo(aggregate, s " Unsupported aggregation mode ${modes.head}" )
14461447 return None
14471448 }
14481449
1449- // In final mode, the aggregate expressions are bound to the output of the
1450- // child and partial aggregate expressions buffer attributes produced by partial
1451- // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
1452- // we don't have to do this because we don't use the merging expression.
1453- val binding = mode != CometAggregateMode .Final
1450+ // In final/partial-merge mode, the aggregate expressions are bound to the
1451+ // output of the child and partial aggregate expressions buffer attributes
1452+ // produced by partial aggregation. This is done in Spark `HashAggregateExec`
1453+ // internally. In Comet, we don't have to do this because we don't use the
1454+ // merging expression.
1455+ val binding = mode == CometAggregateMode .Partial
14541456 // `output` is only used when `binding` is true (i.e., non-Final)
14551457 val output = child.output
14561458
@@ -1496,14 +1498,21 @@ trait CometBaseAggregate {
14961498
14971499 /**
14981500 * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
1499- * partial mode, it will return None.
1501+ * partial or partial-merge mode, it will return None.
15001502 */
15011503 private def findCometPartialAgg (plan : SparkPlan ): Option [CometHashAggregateExec ] = {
1504+ def isPartialOrMerge (mode : AggregateMode ): Boolean =
1505+ mode == Partial || mode == PartialMerge
1506+
15021507 plan.collectFirst {
1503- case agg : CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) =>
1508+ case agg : CometHashAggregateExec
1509+ if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
15041510 Some (agg)
1505- case agg : HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) => None
1506- case agg : ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) =>
1511+ case agg : HashAggregateExec
1512+ if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
1513+ None
1514+ case agg : ObjectHashAggregateExec
1515+ if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
15071516 None
15081517 case a : AQEShuffleReadExec => findCometPartialAgg(a.child)
15091518 case s : ShuffleQueryStageExec => findCometPartialAgg(s.plan)
@@ -1605,11 +1614,8 @@ case class CometHashAggregateExec(
16051614
16061615 // The aggExprs could be empty. For example, if the aggregate functions only have
16071616 // distinct aggregate functions or only have group by, the aggExprs is empty and
1608- // modes is empty too. If aggExprs is not empty, we need to verify all the
1609- // aggregates have the same mode.
1617+ // modes is empty too.
16101618 val modes : Seq [AggregateMode ] = aggregateExpressions.map(_.mode).distinct
1611- assert(modes.length == 1 || modes.isEmpty)
1612- val mode = modes.headOption
16131619
16141620 override def producedAttributes : AttributeSet = outputSet ++ AttributeSet (resultExpressions)
16151621
@@ -1626,7 +1632,7 @@ case class CometHashAggregateExec(
16261632 }
16271633
16281634 override def stringArgs : Iterator [Any ] =
1629- Iterator (input, mode , groupingExpressions, aggregateExpressions, child)
1635+ Iterator (input, modes , groupingExpressions, aggregateExpressions, child)
16301636
16311637 override def equals (obj : Any ): Boolean = {
16321638 obj match {
@@ -1635,7 +1641,7 @@ case class CometHashAggregateExec(
16351641 this .groupingExpressions == other.groupingExpressions &&
16361642 this .aggregateExpressions == other.aggregateExpressions &&
16371643 this .input == other.input &&
1638- this .mode == other.mode &&
1644+ this .modes == other.modes &&
16391645 this .child == other.child &&
16401646 this .serializedPlanOpt == other.serializedPlanOpt
16411647 case _ =>
@@ -1644,7 +1650,7 @@ case class CometHashAggregateExec(
16441650 }
16451651
16461652 override def hashCode (): Int =
1647- Objects .hashCode(output, groupingExpressions, aggregateExpressions, input, mode , child)
1653+ Objects .hashCode(output, groupingExpressions, aggregateExpressions, input, modes , child)
16481654
16491655 override protected def outputExpressions : Seq [NamedExpression ] = resultExpressions
16501656}
0 commit comments