Skip to content

Commit 6bcc2c3

Browse files
committed
Passes tests with reuse.
1 parent 03530e9 commit 6bcc2c3

6 files changed

Lines changed: 229 additions & 7 deletions

File tree

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ import org.apache.comet.CometSparkSessionExtensions._
5353
import org.apache.comet.rules.CometExecRule.allExecs
5454
import org.apache.comet.serde._
5555
import org.apache.comet.serde.operator._
56+
import org.apache.comet.shims.ShimSubqueryBroadcast
5657

5758
object CometExecRule {
5859

@@ -93,7 +94,9 @@ object CometExecRule {
9394
/**
9495
* Spark physical optimizer rule for replacing Spark operators with Comet operators.
9596
*/
96-
case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
97+
case class CometExecRule(session: SparkSession)
98+
extends Rule[SparkPlan]
99+
with ShimSubqueryBroadcast {
97100

98101
private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get()
99102

@@ -298,8 +301,100 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
298301
}
299302
}
300303

304+
// scalastyle:off println
301305
plan.transformUp { case op =>
302-
convertNode(op)
306+
val hasSubqueryExpr = op.expressions.exists(_.exists {
307+
case _: InSubqueryExec => true
308+
case _ => false
309+
})
310+
if (hasSubqueryExpr) {
311+
println(s"[RULE-DEBUG] convertNode on ${op.getClass.getSimpleName} " +
312+
s"which HAS InSubqueryExec expressions")
313+
op.expressions.foreach { expr =>
314+
expr.foreach {
315+
case sub: InSubqueryExec =>
316+
println(s"[RULE-DEBUG] InSubqueryExec.plan: ${sub.plan.getClass.getSimpleName}")
317+
sub.plan match {
318+
case sb: SubqueryBroadcastExec =>
319+
println(s"[RULE-DEBUG] SubqueryBroadcast.child: " +
320+
s"${sb.child.getClass.getSimpleName}")
321+
sb.child match {
322+
case b: BroadcastExchangeExec =>
323+
println(s"[RULE-DEBUG] BroadcastExchange.child: " +
324+
s"${b.child.getClass.getSimpleName}")
325+
println(s"[RULE-DEBUG] BroadcastExchange.child is CometNative? " +
326+
s"${b.child.isInstanceOf[CometNativeExec]}")
327+
println(s"[RULE-DEBUG] BroadcastExchange.children all CometNative? " +
328+
s"${b.children.forall(_.isInstanceOf[CometNativeExec])}")
329+
case other =>
330+
println(s"[RULE-DEBUG] SubqueryBroadcast.child is: " +
331+
s"${other.getClass.getSimpleName}")
332+
}
333+
case other =>
334+
println(s"[RULE-DEBUG] sub.plan is: ${other.getClass.getSimpleName}")
335+
}
336+
case _ =>
337+
}
338+
}
339+
}
340+
val converted = convertNode(op)
341+
// Replace SubqueryBroadcastExec with CometSubqueryBroadcastExec in DPP expressions
342+
// when the broadcast child has a Comet plan underneath. This enables exchange reuse
343+
// between the DPP subquery and the join's CometBroadcastExchangeExec because both
344+
// will have the same CometBroadcastExchangeExec type and canonical form.
345+
convertSubqueryBroadcasts(converted)
346+
}
347+
// scalastyle:on println
348+
}
349+
350+
/**
351+
* Replace SubqueryBroadcastExec with CometSubqueryBroadcastExec in a node's expressions.
352+
*
353+
* When CometExecRule converts BroadcastExchangeExec to CometBroadcastExchangeExec on the
354+
* join side, the DPP subquery still references the original BroadcastExchangeExec.
355+
* ReuseExchangeAndSubquery (which runs after Comet rules) can't match them because they
356+
* have different types. By replacing SubqueryBroadcastExec with CometSubqueryBroadcastExec
357+
* (which wraps a CometBroadcastExchangeExec), both sides have the same exchange type and
358+
* reuse works.
359+
*
360+
* The BroadcastExchangeExec in the subquery has a CometNativeColumnarToRowExec child
361+
* (inserted by ApplyColumnarRulesAndInsertTransitions because BroadcastExchangeExec expects
362+
* row input). We strip this transition and create CometBroadcastExchangeExec with the
363+
* underlying Comet plan directly.
364+
*/
365+
private def convertSubqueryBroadcasts(plan: SparkPlan): SparkPlan = {
366+
plan.transformExpressionsUp {
367+
case inSub: InSubqueryExec =>
368+
inSub.plan match {
369+
case sub: SubqueryBroadcastExec =>
370+
sub.child match {
371+
case b: BroadcastExchangeExec =>
372+
// The BroadcastExchangeExec child is CometNativeColumnarToRowExec wrapping
373+
// a Comet plan. Strip the row transition to get the columnar Comet plan.
374+
val cometChild = b.child match {
375+
case c2r: CometNativeColumnarToRowExec => c2r.child
376+
case other => other
377+
}
378+
if (cometChild.isInstanceOf[CometNativeExec]) {
379+
// scalastyle:off println
380+
println(s"[RULE-DEBUG] Converting SubqueryBroadcastExec to " +
381+
s"CometSubqueryBroadcastExec, cometChild=${cometChild.getClass.getSimpleName}")
382+
// scalastyle:on println
383+
val cometBroadcast = CometBroadcastExchangeExec(
384+
b, b.output, b.mode, cometChild)
385+
val cometSub = CometSubqueryBroadcastExec(
386+
sub.name,
387+
getSubqueryBroadcastExecIndices(sub),
388+
sub.buildKeys,
389+
cometBroadcast)
390+
inSub.withNewPlan(cometSub)
391+
} else {
392+
inSub
393+
}
394+
case _ => inSub
395+
}
396+
case _ => inSub
397+
}
303398
}
304399
}
305400

spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,54 @@ case class CometNativeScanExec(
176176
* partition's files (lazily, as tasks are scheduled).
177177
*/
178178
@transient private lazy val serializedPartitionData: (Array[Byte], Array[Array[Byte]]) = {
179+
// scalastyle:off println
180+
println(s"[DPP-DEBUG] serializedPartitionData: checking partitionFilters")
181+
partitionFilters.foreach {
182+
case DynamicPruningExpression(e: InSubqueryExec) =>
183+
println(s"[DPP-DEBUG] InSubqueryExec plan=${e.plan.getClass.getSimpleName} " +
184+
s"values empty=${e.values().isEmpty}")
185+
case other =>
186+
println(s"[DPP-DEBUG] filter: ${other.getClass.getSimpleName}")
187+
}
188+
// scalastyle:on println
179189
// Ensure DPP subqueries are resolved before accessing file partitions.
180190
// serializedPartitionData can be triggered from findAllPlanData (via commonData) on a
181191
// different execution path than the standard prepare() -> executeSubqueries() flow
182192
// (e.g., from a BroadcastExchangeExec thread). We must resolve DPP here explicitly.
183193
partitionFilters.foreach {
184194
case DynamicPruningExpression(e: InSubqueryExec) if e.values().isEmpty =>
185-
e.updateResult()
195+
// scalastyle:off println
196+
println(s"[DPP-DEBUG] calling updateResult on InSubqueryExec " +
197+
s"plan=${e.plan.getClass.getSimpleName} id=${System.identityHashCode(e)}")
198+
// scalastyle:on println
199+
try {
200+
e.updateResult()
201+
// scalastyle:off println
202+
println(s"[DPP-DEBUG] updateResult succeeded, values empty=${e.values().isEmpty}")
203+
// scalastyle:on println
204+
} catch {
205+
// scalastyle:off println
206+
case ex: Exception =>
207+
println(s"[DPP-DEBUG] updateResult FAILED: ${ex.getMessage}")
208+
throw ex
209+
// scalastyle:on println
210+
}
186211
case _ =>
187212
}
213+
// Also resolve DPP in CometScanExec's partitionFilters, which may reference
214+
// a different InSubqueryExec instance (with the original SubqueryBroadcastExec).
215+
// CometScanExec.dynamicallySelectedPartitions evaluates these filters.
216+
if (scan != null) {
217+
scan.partitionFilters.foreach {
218+
case DynamicPruningExpression(e: InSubqueryExec) if e.values().isEmpty =>
219+
// scalastyle:off println
220+
println(s"[DPP-DEBUG] also resolving scan's InSubqueryExec " +
221+
s"plan=${e.plan.getClass.getSimpleName} id=${System.identityHashCode(e)}")
222+
// scalastyle:on println
223+
e.updateResult()
224+
case _ =>
225+
}
226+
}
188227
// Extract common data from nativeOp
189228
val commonBytes = nativeOp.getNativeScan.getCommon.toByteArray
190229

spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.comet.shims
2121

22-
import org.apache.spark.sql.execution.SubqueryAdaptiveBroadcastExec
22+
import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec}
2323

2424
trait ShimSubqueryBroadcast {
2525

@@ -30,4 +30,9 @@ trait ShimSubqueryBroadcast {
3030
def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = {
3131
Seq(sab.index)
3232
}
33+
34+
/** Same version shim for SubqueryBroadcastExec. */
35+
def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = {
36+
Seq(sub.index)
37+
}
3338
}

spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.comet.shims
2121

22-
import org.apache.spark.sql.execution.SubqueryAdaptiveBroadcastExec
22+
import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec}
2323

2424
trait ShimSubqueryBroadcast {
2525

@@ -30,4 +30,9 @@ trait ShimSubqueryBroadcast {
3030
def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = {
3131
Seq(sab.index)
3232
}
33+
34+
/** Same version shim for SubqueryBroadcastExec. */
35+
def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = {
36+
Seq(sub.index)
37+
}
3338
}

spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.comet.shims
2121

22-
import org.apache.spark.sql.execution.SubqueryAdaptiveBroadcastExec
22+
import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec}
2323

2424
trait ShimSubqueryBroadcast {
2525

@@ -30,4 +30,9 @@ trait ShimSubqueryBroadcast {
3030
def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = {
3131
sab.indices
3232
}
33+
34+
/** Same version shim for SubqueryBroadcastExec. */
35+
def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = {
36+
sub.indices
37+
}
3338
}

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ class CometExecSuite extends CometTestBase {
204204
spark.read.parquet(dimPath).createOrReplaceTempView("dpp_dim_bhj")
205205
val df = spark.sql(
206206
"select * from dpp_fact_bhj join dpp_dim_bhj on fact_date = dim_date where dim_id > 7")
207-
val (_, cometPlan) = checkSparkAnswerAndOperator(df)
207+
// Exclude ReusedExchangeExec — it appears inside the DPP subquery after exchange reuse
208+
val (_, cometPlan) = checkSparkAnswerAndOperator(
209+
df, classOf[ReusedExchangeExec])
208210

209211
val nativeScans = cometPlan.collect { case s: CometNativeScanExec => s }
210212
assert(nativeScans.nonEmpty, "Expected CometNativeScanExec in plan")
@@ -263,6 +265,77 @@ class CometExecSuite extends CometTestBase {
263265
}
264266
}
265267

268+
test("DPP broadcast exchange reuse investigation") {
269+
withTempDir { dir =>
270+
val path = s"${dir.getAbsolutePath}/data"
271+
withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") {
272+
spark.range(100).selectExpr(
273+
"id % 10 as store_id", "cast(id * 2 as int) as date_id",
274+
"cast(id * 3 as int) as product_id", "cast(id as int) as units_sold")
275+
.write.partitionBy("store_id").parquet(s"$path/fact")
276+
spark.range(10).selectExpr(
277+
"cast(id as int) as store_id", "cast(id as string) as country")
278+
.write.parquet(s"$path/dim")
279+
}
280+
281+
withSQLConf(
282+
SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
283+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
284+
SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
285+
spark.read.parquet(s"$path/fact").createOrReplaceTempView("fact_reuse")
286+
spark.read.parquet(s"$path/dim").createOrReplaceTempView("dim_reuse")
287+
288+
val df = spark.sql(
289+
"""SELECT f.date_id, f.store_id
290+
|FROM fact_reuse f JOIN dim_reuse d
291+
|ON f.store_id = d.store_id
292+
|WHERE d.country = 'DE'""".stripMargin)
293+
df.collect()
294+
val plan = df.queryExecution.executedPlan
295+
// scalastyle:off println
296+
println(s"[REUSE-DEBUG] Plan:\n${plan.treeString}")
297+
298+
// Walk into subquery expressions to see what's inside
299+
plan.foreach { node =>
300+
node.expressions.foreach { expr =>
301+
expr.foreach {
302+
case sub: InSubqueryExec =>
303+
println(s"[REUSE-DEBUG] Found InSubqueryExec in ${node.getClass.getSimpleName}")
304+
println(s"[REUSE-DEBUG] sub.plan class: ${sub.plan.getClass.getSimpleName}")
305+
println(s"[REUSE-DEBUG] sub.plan tree:\n${sub.plan.treeString}")
306+
sub.plan match {
307+
case sb: SubqueryBroadcastExec =>
308+
println(s"[REUSE-DEBUG] SubqueryBroadcast child: " +
309+
s"${sb.child.getClass.getSimpleName}")
310+
println(s"[REUSE-DEBUG] SubqueryBroadcast child tree:\n" +
311+
s"${sb.child.treeString}")
312+
case other =>
313+
println(s"[REUSE-DEBUG] sub.plan is: ${other.getClass.getSimpleName}")
314+
}
315+
case _ =>
316+
}
317+
}
318+
}
319+
320+
val reused = collectWithSubqueries(plan) {
321+
case e: ReusedExchangeExec => e
322+
}
323+
println(s"[REUSE-DEBUG] ReusedExchangeExec count: ${reused.size}")
324+
325+
val broadcasts = collectWithSubqueries(plan) {
326+
case e: BroadcastExchangeExec => ("BroadcastExchangeExec", e: SparkPlan)
327+
case e: CometBroadcastExchangeExec => ("CometBroadcastExchangeExec", e: SparkPlan)
328+
}
329+
println(s"[REUSE-DEBUG] Broadcast exchange count: ${broadcasts.size}")
330+
broadcasts.foreach { case (typ, e) =>
331+
println(s"[REUSE-DEBUG] $typ hash=${e.canonicalized.hashCode()}")
332+
println(s"[REUSE-DEBUG] $typ child: ${e.children.map(_.getClass.getSimpleName)}")
333+
}
334+
// scalastyle:on println
335+
}
336+
}
337+
}
338+
266339
test("ShuffleQueryStageExec could be direct child node of CometBroadcastExchangeExec") {
267340
withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") {
268341
val table = "src"

0 commit comments

Comments
 (0)