Skip to content

Commit f7fa33c

Browse files
andygroveclaude
andcommitted
fix: allow safe mixed Spark/Comet partial/final aggregate execution
Previously, when one aggregate stage (Partial or Final) couldn't be converted to Comet, the other was also blocked to avoid crashes from incompatible intermediate buffer formats (issues #1389, #1267). This change introduces per-aggregate `supportsMixedPartialFinal` declarations so that aggregates with simple, compatible buffers (MIN, MAX, COUNT, bitwise) can safely run in mixed mode while unsafe aggregates (SUM, AVG, Variance, CollectSet) continue to be blocked. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5efd972 commit f7fa33c

6 files changed

Lines changed: 211 additions & 7 deletions

File tree

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@ import scala.collection.mutable.ListBuffer
2323

2424
import org.apache.spark.sql.SparkSession
2525
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
26+
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
2627
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
2728
import org.apache.spark.sql.catalyst.rules.Rule
29+
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
2830
import org.apache.spark.sql.catalyst.util.sideBySide
2931
import org.apache.spark.sql.comet._
3032
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
3133
import org.apache.spark.sql.comet.util.Utils
3234
import org.apache.spark.sql.execution._
3335
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
34-
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
36+
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
3537
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
3638
import org.apache.spark.sql.execution.datasources.WriteFilesExec
3739
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
@@ -56,6 +58,14 @@ import org.apache.comet.serde.operator._
5658

5759
object CometExecRule {
5860

61+
/**
62+
* Tag applied to Partial-mode aggregate operators that must NOT be converted to Comet because
63+
* the corresponding Final-mode aggregate cannot be converted, and the aggregate functions have
64+
* incompatible intermediate buffer formats between Spark and Comet.
65+
*/
66+
val COMET_UNSAFE_PARTIAL: TreeNodeTag[String] =
67+
TreeNodeTag[String]("comet.unsafePartialAgg")
68+
5969
/**
6070
* Fully native operators.
6171
*/
@@ -388,6 +398,12 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
388398
normalizedPlan
389399
}
390400

401+
// Tag Partial aggregates that must not be converted to Comet because the
402+
// corresponding Final aggregate cannot be converted and the intermediate buffer
403+
// formats are incompatible. This runs before transform() so the tags are checked
404+
// during the bottom-up conversion. Tags persist through AQE stage creation.
405+
tagUnsafePartialAggregates(planWithJoinRewritten)
406+
391407
var newPlan = transform(planWithJoinRewritten)
392408

393409
// if the plan cannot be run fully natively then explain why (when appropriate
@@ -601,4 +617,88 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
601617
}
602618
}
603619

620+
/**
621+
* Walk the plan to find Final-mode aggregates that cannot be converted to Comet. For each such
622+
* Final, if the aggregate functions have incompatible intermediate buffer formats, tag the
623+
* corresponding Partial-mode aggregate so it will also be skipped during conversion.
624+
*
625+
* This prevents the crash described in issue #1389 where a Comet Partial produces intermediate
626+
* data in a format that the Spark Final cannot interpret.
627+
*/
628+
private def tagUnsafePartialAggregates(plan: SparkPlan): Unit = {
629+
plan.foreach {
630+
case agg: BaseAggregateExec if agg.aggregateExpressions.exists(_.mode == Final) =>
631+
if (!QueryPlanSerde.allAggsSupportMixedExecution(agg.aggregateExpressions)) {
632+
if (!canFinalAggregateBeConverted(agg)) {
633+
findPartialAggInPlan(agg.child).foreach { partial =>
634+
partial.setTagValue(
635+
CometExecRule.COMET_UNSAFE_PARTIAL,
636+
"Partial aggregate disabled: corresponding final aggregate " +
637+
"cannot be converted to Comet and intermediate buffer formats are incompatible")
638+
}
639+
}
640+
}
641+
case _ =>
642+
}
643+
}
644+
645+
/**
646+
* Conservative check for whether a Final-mode aggregate could be converted to Comet. Checks
647+
* operator enablement, grouping expressions, aggregate expressions, and result expressions.
648+
* Intentionally skips the sparkFinalMode / child-native checks since those depend on
649+
* transformation state.
650+
*/
651+
private def canFinalAggregateBeConverted(agg: BaseAggregateExec): Boolean = {
652+
val handler = allExecs.get(agg.getClass)
653+
if (handler.isEmpty) return false
654+
val serde = handler.get.asInstanceOf[CometOperatorSerde[SparkPlan]]
655+
if (!isOperatorEnabled(serde, agg.asInstanceOf[SparkPlan])) return false
656+
657+
// ObjectHashAggregate has an extra shuffle-enabled guard in its convert method
658+
agg match {
659+
case _: ObjectHashAggregateExec if !isCometShuffleEnabled(agg.conf) => return false
660+
case _ =>
661+
}
662+
663+
val aggregateExpressions = agg.aggregateExpressions
664+
val groupingExpressions = agg.groupingExpressions
665+
666+
if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) return false
667+
668+
if (groupingExpressions.exists(_.dataType.isInstanceOf[MapType])) return false
669+
670+
if (!groupingExpressions.forall(e =>
671+
QueryPlanSerde.exprToProto(e, agg.child.output).isDefined)) {
672+
return false
673+
}
674+
675+
if (aggregateExpressions.nonEmpty) {
676+
val modes = aggregateExpressions.map(_.mode).distinct
677+
if (modes.size != 1 || !modes.contains(Final)) return false
678+
679+
val binding = false
680+
if (!aggregateExpressions.forall(e =>
681+
QueryPlanSerde.aggExprToProto(e, agg.child.output, binding, agg.conf).isDefined)) {
682+
return false
683+
}
684+
}
685+
686+
val attributes =
687+
groupingExpressions.map(_.toAttribute) ++ agg.aggregateAttributes
688+
agg.resultExpressions.forall(e => QueryPlanSerde.exprToProto(e, attributes).isDefined)
689+
}
690+
691+
/**
692+
* Search the child subtree for the first Partial-mode aggregate, traversing through exchanges
693+
* and AQE stages.
694+
*/
695+
private def findPartialAggInPlan(plan: SparkPlan): Option[BaseAggregateExec] = {
696+
plan.collectFirst {
697+
case agg: BaseAggregateExec if agg.aggregateExpressions.forall(e => e.mode == Partial) =>
698+
Some(agg)
699+
case a: AQEShuffleReadExec => findPartialAggInPlan(a.child)
700+
case s: ShuffleQueryStageExec => findPartialAggInPlan(s.plan)
701+
}.flatten
702+
}
703+
604704
}

spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
6868
* case it is expected that the input expression will have been tagged with reasons why it
6969
* could not be converted.
7070
*/
71+
/**
72+
* Whether this aggregate's intermediate buffer format is compatible between Spark and Comet,
73+
* making it safe to run the Partial in one engine and the Final in the other. Aggregates with
74+
* simple single-value buffers (MIN, MAX, COUNT, bitwise) are safe; those with complex or
75+
* differently-encoded buffers (AVG, SUM with decimals, CollectSet, Variance) are not.
76+
*/
77+
def supportsMixedPartialFinal: Boolean = false
78+
7179
def convert(
7280
aggExpr: AggregateExpression,
7381
expr: T,

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,24 @@ object QueryPlanSerde extends Logging with CometExprShim {
277277
classOf[VariancePop] -> CometVariancePop,
278278
classOf[VarianceSamp] -> CometVarianceSamp)
279279

280+
/**
281+
* Returns true if all aggregate expressions in the list have intermediate buffer formats that
282+
* are compatible between Spark and Comet, making it safe to run Partial in one engine and Final
283+
* in the other.
284+
*/
285+
def allAggsSupportMixedExecution(aggExprs: Seq[AggregateExpression]): Boolean = {
286+
aggExprs.forall { aggExpr =>
287+
val fn = aggExpr.aggregateFunction
288+
aggrSerdeMap.get(fn.getClass) match {
289+
case Some(handler) =>
290+
handler
291+
.asInstanceOf[CometAggregateExpressionSerde[AggregateFunction]]
292+
.supportsMixedPartialFinal
293+
case None => false
294+
}
295+
}
296+
}
297+
280298
// A unique id for each expression. ~used to look up QueryContext during error creation.
281299
private val exprIdCounter = new AtomicLong(0)
282300

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ import org.apache.comet.shims.CometEvalModeUtil
3434

3535
object CometMin extends CometAggregateExpressionSerde[Min] {
3636

37+
override def supportsMixedPartialFinal: Boolean = true
38+
3739
override def convert(
3840
aggExpr: AggregateExpression,
3941
expr: Min,
@@ -81,6 +83,8 @@ object CometMin extends CometAggregateExpressionSerde[Min] {
8183

8284
object CometMax extends CometAggregateExpressionSerde[Max] {
8385

86+
override def supportsMixedPartialFinal: Boolean = true
87+
8488
override def convert(
8589
aggExpr: AggregateExpression,
8690
expr: Max,
@@ -127,6 +131,8 @@ object CometMax extends CometAggregateExpressionSerde[Max] {
127131
}
128132

129133
object CometCount extends CometAggregateExpressionSerde[Count] {
134+
override def supportsMixedPartialFinal: Boolean = true
135+
130136
override def convert(
131137
aggExpr: AggregateExpression,
132138
expr: Count,
@@ -306,6 +312,8 @@ object CometLast extends CometAggregateExpressionSerde[Last] {
306312
}
307313

308314
object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
315+
override def supportsMixedPartialFinal: Boolean = true
316+
309317
override def convert(
310318
aggExpr: AggregateExpression,
311319
bitAnd: BitAndAgg,
@@ -340,6 +348,8 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
340348
}
341349

342350
object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
351+
override def supportsMixedPartialFinal: Boolean = true
352+
343353
override def convert(
344354
aggExpr: AggregateExpression,
345355
bitOr: BitOrAgg,
@@ -374,6 +384,8 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
374384
}
375385

376386
object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] {
387+
override def supportsMixedPartialFinal: Boolean = true
388+
377389
override def convert(
378390
aggExpr: AggregateExpression,
379391
bitXor: BitXorAgg,

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ import com.google.protobuf.CodedOutputStream
5454
import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, ConfigEntry}
5555
import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo}
5656
import org.apache.comet.parquet.CometParquetUtils
57+
import org.apache.comet.rules.CometExecRule
5758
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported}
5859
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator}
60+
import org.apache.comet.serde.QueryPlanSerde
5961
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, supportedSortType}
6062
import org.apache.comet.serde.operator.CometSink
6163

@@ -1359,10 +1361,24 @@ trait CometBaseAggregate {
13591361
// In distinct aggregates there can be a combination of modes
13601362
val multiMode = modes.size > 1
13611363
// For a final mode HashAggregate, we only need to transform the HashAggregate
1362-
// if there is Comet partial aggregation.
1364+
// if there is Comet partial aggregation, unless all aggregates have compatible
1365+
// intermediate buffer formats (safe for mixed Spark/Comet execution).
13631366
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty
13641367

1365-
if (multiMode || sparkFinalMode) {
1368+
if (multiMode) {
1369+
return None
1370+
}
1371+
1372+
if (sparkFinalMode &&
1373+
!QueryPlanSerde.allAggsSupportMixedExecution(aggregate.aggregateExpressions)) {
1374+
return None
1375+
}
1376+
1377+
// Check if this aggregate has been tagged as unsafe for mixed execution
1378+
// (Comet partial + Spark final with incompatible intermediate buffers)
1379+
val unsafeReason = aggregate.getTagValue(CometExecRule.COMET_UNSAFE_PARTIAL)
1380+
if (unsafeReason.isDefined) {
1381+
withInfo(aggregate, unsafeReason.get)
13661382
return None
13671383
}
13681384

spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,8 @@ class CometExecRuleSuite extends CometTestBase {
131131
}
132132
}
133133

134-
// TODO this test exposes the bug described in
135-
// https://github.com/apache/datafusion-comet/issues/1389
136-
ignore("CometExecRule should not allow Comet partial and Spark final hash aggregate") {
134+
// Regression test for https://github.com/apache/datafusion-comet/issues/1389
135+
test("CometExecRule should not allow Comet partial and Spark final hash aggregate") {
137136
withTempView("test_data") {
138137
createTestDataFrame.createOrReplaceTempView("test_data")
139138

@@ -149,7 +148,8 @@ class CometExecRuleSuite extends CometTestBase {
149148
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
150149
val transformedPlan = applyCometExecRule(sparkPlan)
151150

152-
// if the final aggregate cannot be converted to Comet, then neither should be
151+
// SUM has incompatible intermediate buffers, so if the final aggregate cannot
152+
// be converted to Comet, neither should be
153153
assert(
154154
countOperators(transformedPlan, classOf[HashAggregateExec]) == originalHashAggCount)
155155
assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0)
@@ -181,6 +181,56 @@ class CometExecRuleSuite extends CometTestBase {
181181
}
182182
}
183183

184+
test("CometExecRule should allow safe Comet partial and Spark final hash aggregate") {
185+
withTempView("test_data") {
186+
createTestDataFrame.createOrReplaceTempView("test_data")
187+
188+
// Query uses only safe aggregates (MIN, MAX, COUNT) with compatible intermediate buffers
189+
val sparkPlan =
190+
createSparkPlan(
191+
spark,
192+
"SELECT COUNT(*), MIN(id), MAX(id) FROM test_data GROUP BY (id % 3)")
193+
194+
val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec])
195+
assert(originalHashAggCount == 2)
196+
197+
withSQLConf(
198+
CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false",
199+
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
200+
val transformedPlan = applyCometExecRule(sparkPlan)
201+
202+
// Safe aggregates allow mixed execution: partial can be Comet, final stays Spark
203+
assert(countOperators(transformedPlan, classOf[HashAggregateExec]) == 1) // final only
204+
assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1) // partial
205+
}
206+
}
207+
}
208+
209+
test("CometExecRule should allow safe Spark partial and Comet final hash aggregate") {
210+
withTempView("test_data") {
211+
createTestDataFrame.createOrReplaceTempView("test_data")
212+
213+
// Query uses only safe aggregates (MIN, MAX, COUNT) with compatible intermediate buffers
214+
val sparkPlan =
215+
createSparkPlan(
216+
spark,
217+
"SELECT COUNT(*), MIN(id), MAX(id) FROM test_data GROUP BY (id % 3)")
218+
219+
val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec])
220+
assert(originalHashAggCount == 2)
221+
222+
withSQLConf(
223+
CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false",
224+
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
225+
val transformedPlan = applyCometExecRule(sparkPlan)
226+
227+
// Safe aggregates allow mixed execution: partial stays Spark, final can be Comet
228+
assert(countOperators(transformedPlan, classOf[HashAggregateExec]) == 1) // partial only
229+
assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1) // final
230+
}
231+
}
232+
}
233+
184234
test("CometExecRule should apply broadcast exchange transformations") {
185235
withTempView("test_data") {
186236
createTestDataFrame.createOrReplaceTempView("test_data")

0 commit comments

Comments
 (0)