@@ -1441,16 +1441,18 @@ trait CometBaseAggregate {
14411441 val mode = modes.head match {
14421442 case Partial => CometAggregateMode .Partial
14431443 case Final => CometAggregateMode .Final
1444+ case PartialMerge => CometAggregateMode .PartialMerge
14441445 case _ =>
14451446 withInfo(aggregate, s " Unsupported aggregation mode ${modes.head}" )
14461447 return None
14471448 }
14481449
1449- // In final mode, the aggregate expressions are bound to the output of the
1450- // child and partial aggregate expressions buffer attributes produced by partial
1451- // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
1452- // we don't have to do this because we don't use the merging expression.
1453- val binding = mode != CometAggregateMode .Final
1450+ // In final/partial-merge mode, the aggregate expressions are bound to the
1451+ // output of the child and partial aggregate expressions buffer attributes
1452+ // produced by partial aggregation. This is done in Spark `HashAggregateExec`
1453+ // internally. In Comet, we don't have to do this because we don't use the
1454+ // merging expression.
1455+ val binding = mode == CometAggregateMode .Partial
14541456 // `output` is only used when `binding` is true (i.e., non-Final)
14551457 val output = child.output
14561458
@@ -1496,14 +1498,21 @@ trait CometBaseAggregate {
14961498
14971499 /**
14981500 * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
1499- * partial mode, it will return None.
1501+ * partial or partial-merge mode, it will return None.
15001502 */
15011503 private def findCometPartialAgg (plan : SparkPlan ): Option [CometHashAggregateExec ] = {
1504+ def isPartialOrMerge (mode : AggregateMode ): Boolean =
1505+ mode == Partial || mode == PartialMerge
1506+
15021507 plan.collectFirst {
1503- case agg : CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) =>
1508+ case agg : CometHashAggregateExec
1509+ if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
15041510 Some (agg)
1505- case agg : HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) => None
1506- case agg : ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) =>
1511+ case agg : HashAggregateExec
1512+ if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
1513+ None
1514+ case agg : ObjectHashAggregateExec
1515+ if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
15071516 None
15081517 case a : AQEShuffleReadExec => findCometPartialAgg(a.child)
15091518 case s : ShuffleQueryStageExec => findCometPartialAgg(s.plan)
@@ -1642,11 +1651,8 @@ case class CometHashAggregateExec(
16421651
16431652 // The aggExprs could be empty. For example, if the aggregate functions only have
16441653 // distinct aggregate functions or only have group by, the aggExprs is empty and
1645- // modes is empty too. If aggExprs is not empty, we need to verify all the
1646- // aggregates have the same mode.
1654+ // modes is empty too.
16471655 val modes : Seq [AggregateMode ] = aggregateExpressions.map(_.mode).distinct
1648- assert(modes.length == 1 || modes.isEmpty)
1649- val mode = modes.headOption
16501656
16511657 override def producedAttributes : AttributeSet = outputSet ++ AttributeSet (resultExpressions)
16521658
@@ -1663,7 +1669,7 @@ case class CometHashAggregateExec(
16631669 }
16641670
16651671 override def stringArgs : Iterator [Any ] =
1666- Iterator (input, mode , groupingExpressions, aggregateExpressions, child)
1672+ Iterator (input, modes , groupingExpressions, aggregateExpressions, child)
16671673
16681674 override def equals (obj : Any ): Boolean = {
16691675 obj match {
@@ -1672,7 +1678,7 @@ case class CometHashAggregateExec(
16721678 this .groupingExpressions == other.groupingExpressions &&
16731679 this .aggregateExpressions == other.aggregateExpressions &&
16741680 this .input == other.input &&
1675- this .mode == other.mode &&
1681+ this .modes == other.modes &&
16761682 this .child == other.child &&
16771683 this .serializedPlanOpt == other.serializedPlanOpt
16781684 case _ =>
@@ -1681,7 +1687,7 @@ case class CometHashAggregateExec(
16811687 }
16821688
16831689 override def hashCode (): Int =
1684- Objects .hashCode(output, groupingExpressions, aggregateExpressions, input, mode , child)
1690+ Objects .hashCode(output, groupingExpressions, aggregateExpressions, input, modes , child)
16851691
16861692 override protected def outputExpressions : Seq [NamedExpression ] = resultExpressions
16871693}
0 commit comments