Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ case class CometBroadcastExchangeExec(
@transient
private lazy val maxBroadcastRows = 512000000

private def getByteArrayRdd(plan: SparkPlan): RDD[(Long, ChunkedByteBuffer)] = {
plan.executeColumnar().mapPartitionsInternal { iter =>
Utils.serializeBatches(iter)
}
}

def getNumPartitions(): Int = {
child.executeColumnar().getNumPartitions
}
Expand All @@ -125,16 +131,18 @@ case class CometBroadcastExchangeExec(

val countsAndBytes = child match {
case c: CometPlan => CometExec.getByteArrayRdd(c).collect()
case AQEShuffleReadExec(s: ShuffleQueryStageExec, _)
// Execute through AQEShuffleReadExec to respect AQE partition coalescing
case aqe @ AQEShuffleReadExec(s: ShuffleQueryStageExec, _)
if s.plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect()
getByteArrayRdd(aqe).collect()
case s: ShuffleQueryStageExec if s.plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect()
case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _)
if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
case aqe @ AQEShuffleReadExec(
ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _),
_) if plan.isInstanceOf[CometPlan] =>
getByteArrayRdd(aqe).collect()
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _)
if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
Expand Down
47 changes: 47 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec}
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.CometConf
Expand Down Expand Up @@ -517,4 +518,50 @@ class CometJoinSuite extends CometTestBase {
}
}
}

test("Broadcast exchange respects AQE shuffle partition coalescing") {
// When a shuffle feeds into a broadcast exchange, AQE may coalesce the shuffle
// partitions. The broadcast collect should execute through the AQEShuffleReadExec
// to use coalesced partitions rather than bypassing it.
val numPartitions = 200
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString,
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") {
withParquetTable((0 until 100).map(i => (i, i % 5)), "small_tbl") {
withParquetTable((0 until 10000).map(i => (i, i + 2)), "large_tbl") {
val query =
"""SELECT /*+ BROADCAST(a) */ *
|FROM (SELECT /*+ REBALANCE(_1) */ * FROM small_tbl) a
|JOIN large_tbl b ON a._1 = b._1""".stripMargin

val (_, cometPlan) = checkSparkAnswerAndOperator(
sql(query),
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec]))

// The shuffle partitions feeding the broadcast should be coalesced by
// AQE. AQEShuffleReadExec.executeColumnar() lazily builds its shuffleRDD
// and, as a side effect, sets the "numPartitions" driver metric to
// partitionSpecs.length. If the broadcast collect bypasses the wrapper
// (the bug this test guards against), executeColumnar is never called
// and the metric stays at its initial 0.
val readExecs = collect(cometPlan) { case r: AQEShuffleReadExec => r }
assert(readExecs.nonEmpty, "Expected AQEShuffleReadExec in plan")
readExecs.foreach { r =>
val coalesced = r.metrics("numPartitions").value
assert(
coalesced > 0,
"AQEShuffleReadExec.numPartitions metric was never updated; the " +
"broadcast collect likely bypassed AQEShuffleReadExec")
assert(
coalesced < numPartitions,
s"Expected AQE to coalesce shuffle partitions below $numPartitions, " +
s"got $coalesced")
}
}
}
}
}
}
Loading