@@ -1356,8 +1356,9 @@ trait CometBaseAggregate {
13561356 childOp : OperatorOuterClass .Operator * ): Option [OperatorOuterClass .Operator ] = {
13571357
13581358 val modes = aggregate.aggregateExpressions.map(_.mode).distinct
1359- // In distinct aggregates there can be a combination of modes
1360- val multiMode = modes.size > 1
1359+ // In distinct aggregates there can be a combination of modes.
1360+ // We support {Partial, PartialMerge} mix; other combinations are rejected.
1361+ val multiMode = modes.size > 1 && modes.toSet != Set (Partial , PartialMerge )
13611362 // For a final mode HashAggregate, we only need to transform the HashAggregate
13621363 // if there is Comet partial aggregation.
13631364 val sparkFinalMode = modes.contains(Final ) && findCometPartialAgg(aggregate.child).isEmpty
@@ -1430,34 +1431,42 @@ trait CometBaseAggregate {
14301431 Some (builder.setHashAgg(hashAggBuilder).build())
14311432 } else {
14321433 val modes = aggregateExpressions.map(_.mode).distinct
1433-
1434- if (modes.size != 1 ) {
1435- // This shouldn't happen as all aggregation expressions should share the same mode.
1436- // Fallback to Spark nevertheless here.
1437- withInfo(aggregate, " All aggregate expressions do not have the same mode" )
1434+ val modeSet = modes.toSet
1435+
1436+ // Validate mode combinations. We support:
1437+ // - All Partial
1438+ // - All Final
1439+ // - All PartialMerge
1440+ // - Mixed {Partial, PartialMerge} (for distinct aggregate plans)
1441+ val isMixedPartialMerge = modeSet == Set (Partial , PartialMerge )
1442+ if (modes.size > 1 && ! isMixedPartialMerge) {
1443+ withInfo(aggregate, s " Unsupported mixed aggregation modes: ${modes.mkString(" , " )}" )
14381444 return None
14391445 }
14401446
1441- val mode = modes.head match {
1442- case Partial => CometAggregateMode .Partial
1443- case Final => CometAggregateMode .Final
1444- case PartialMerge => CometAggregateMode .PartialMerge
1445- case _ =>
1446- withInfo(aggregate, s " Unsupported aggregation mode ${modes.head}" )
1447- return None
1447+ // Determine the proto mode. For uniform modes, use that mode directly.
1448+ // For mixed {Partial, PartialMerge}, use Partial as the base mode since
1449+ // PartialMerge expressions are wrapped with MergeAsPartial on the native side.
1450+ val mode = if (isMixedPartialMerge) {
1451+ CometAggregateMode .Partial
1452+ } else {
1453+ modes.head match {
1454+ case Partial => CometAggregateMode .Partial
1455+ case Final => CometAggregateMode .Final
1456+ case PartialMerge => CometAggregateMode .PartialMerge
1457+ case _ =>
1458+ withInfo(aggregate, s " Unsupported aggregation mode ${modes.head}" )
1459+ return None
1460+ }
14481461 }
14491462
1450- // In final/partial-merge mode, the aggregate expressions are bound to the
1451- // output of the child and partial aggregate expressions buffer attributes
1452- // produced by partial aggregation. This is done in Spark `HashAggregateExec`
1453- // internally. In Comet, we don't have to do this because we don't use the
1454- // merging expression.
1455- val binding = mode == CometAggregateMode .Partial
1456- // `output` is only used when `binding` is true (i.e., non-Final)
1463+ // Per-expression binding: Partial expressions bind to child output,
1464+ // PartialMerge/Final expressions do not (native planner handles their input).
14571465 val output = child.output
1458-
1459- val aggExprs =
1460- aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf))
1466+ val aggExprs = aggregateExpressions.map { a =>
1467+ val exprBinding = a.mode != PartialMerge && a.mode != Final
1468+ aggExprToProto(a, output, exprBinding, aggregate.conf)
1469+ }
14611470
14621471 if (aggExprs.exists(_.isEmpty)) {
14631472 withInfo(
@@ -1485,6 +1494,24 @@ trait CometBaseAggregate {
14851494 hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
14861495 }
14871496 hashAggBuilder.setModeValue(mode.getNumber)
1497+
1498+ // Send per-expression modes and buffer offset for PartialMerge handling
1499+ val hasPartialMerge = aggregateExpressions.exists(_.mode == PartialMerge )
1500+ if (hasPartialMerge) {
1501+ val exprModes = aggregateExpressions.map { a =>
1502+ a.mode match {
1503+ case Partial => CometAggregateMode .Partial
1504+ case PartialMerge => CometAggregateMode .PartialMerge
1505+ case Final => CometAggregateMode .Final
1506+ case other =>
1507+ withInfo(aggregate, s " Unsupported aggregation mode $other" )
1508+ return None
1509+ }
1510+ }
1511+ hashAggBuilder.addAllExprModes(exprModes.asJava)
1512+ hashAggBuilder.setInitialInputBufferOffset(aggregate.initialInputBufferOffset)
1513+ }
1514+
14881515 Some (builder.setHashAgg(hashAggBuilder).build())
14891516 } else {
14901517 val allChildren : Seq [Expression ] =
0 commit comments