Skip to content

Commit 0837d9e

Browse files
committed
Passes tests with reuse.
1 parent 6bcc2c3 commit 0837d9e

1 file changed

Lines changed: 163 additions & 0 deletions

File tree

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet
21+
22+
import java.util.concurrent.{Future => JFuture}
23+
24+
import scala.concurrent.ExecutionContext
25+
import scala.concurrent.duration.Duration
26+
27+
import org.apache.spark.rdd.RDD
28+
import org.apache.spark.sql.catalyst.InternalRow
29+
import org.apache.spark.sql.catalyst.expressions._
30+
import org.apache.spark.sql.catalyst.plans.QueryPlan
31+
import org.apache.spark.sql.comet.util.Utils
32+
import org.apache.spark.sql.errors.QueryExecutionErrors
33+
import org.apache.spark.sql.execution._
34+
import org.apache.spark.sql.execution.metric.SQLMetrics
35+
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
36+
import org.apache.spark.util.ThreadUtils
37+
import org.apache.spark.util.io.ChunkedByteBuffer
38+
39+
/**
40+
* Comet replacement for SubqueryBroadcastExec that consumes Arrow broadcast data from a
41+
* CometBroadcastExchangeExec instead of HashedRelation from BroadcastExchangeExec.
42+
*
43+
* This enables broadcast exchange reuse between DPP subqueries and broadcast hash joins
44+
* when CometExecRule converts BroadcastExchangeExec to CometBroadcastExchangeExec.
45+
* Without this, the two exchanges have different types and canonical forms, so Spark's
46+
* ReuseExchangeAndSubquery (which runs after Comet rules) cannot match them.
47+
*
48+
* @param indices the indices of the join keys in the list of keys from the build side
49+
* @param buildKeys the join keys from the build side of the join
50+
* @param child the CometBroadcastExchangeExec (or ReusedExchangeExec after reuse)
51+
*/
52+
case class CometSubqueryBroadcastExec(
53+
name: String,
54+
indices: Seq[Int],
55+
buildKeys: Seq[Expression],
56+
child: SparkPlan)
57+
extends BaseSubqueryExec
58+
with UnaryExecNode {
59+
60+
override def output: Seq[Attribute] = {
61+
indices.map { idx =>
62+
val key = buildKeys(idx)
63+
val attrName = key match {
64+
case n: NamedExpression => n.name
65+
case Cast(n: NamedExpression, _, _, _) => n.name
66+
case _ => s"key_$idx"
67+
}
68+
AttributeReference(attrName, key.dataType, key.nullable)()
69+
}
70+
}
71+
72+
override lazy val metrics = Map(
73+
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
74+
"dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
75+
"collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"))
76+
77+
override def doCanonicalize(): SparkPlan = {
78+
val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output))
79+
CometSubqueryBroadcastExec("dpp", indices, keys, child.canonicalized)
80+
}
81+
82+
@transient
83+
private lazy val relationFuture: JFuture[Array[InternalRow]] = {
84+
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
85+
SQLExecution.withThreadLocalCaptured[Array[InternalRow]](
86+
session,
87+
CometSubqueryBroadcastExec.executionContext) {
88+
SQLExecution.withExecutionId(session, executionId) {
89+
val beforeCollect = System.nanoTime()
90+
91+
// Get the Arrow broadcast from CometBroadcastExchangeExec
92+
val broadcasted = child.executeBroadcast[Array[ChunkedByteBuffer]]()
93+
val arrowBatches = broadcasted.value
94+
95+
// Decode Arrow batches and extract key column values
96+
val keyIndices = indices.map { idx =>
97+
val key = buildKeys(idx)
98+
// Find the column index in the broadcast output that matches the build key
99+
key match {
100+
case attr: Attribute =>
101+
child.output.indexWhere(_.exprId == attr.exprId)
102+
case Cast(attr: Attribute, _, _, _) =>
103+
child.output.indexWhere(_.exprId == attr.exprId)
104+
case _ => idx
105+
}
106+
}
107+
108+
val rows = arrowBatches.iterator
109+
.flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName))
110+
.flatMap { batch =>
111+
val numRows = batch.numRows()
112+
(0 until numRows).iterator.map { rowIdx =>
113+
val row = batch.getRow(rowIdx)
114+
val projected = new GenericInternalRow(keyIndices.length)
115+
keyIndices.zipWithIndex.foreach { case (colIdx, outIdx) =>
116+
projected.update(outIdx, row.get(colIdx, output(outIdx).dataType))
117+
}
118+
projected.asInstanceOf[InternalRow].copy()
119+
}
120+
}
121+
.toArray
122+
.distinct
123+
124+
val beforeBuild = System.nanoTime()
125+
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
126+
longMetric("numOutputRows") += rows.length
127+
// Convert to UnsafeRow for consistent size metric and to match SubqueryBroadcastExec
128+
val unsafeProj = UnsafeProjection.create(output.map(_.dataType).toArray)
129+
val unsafeRows = rows.map(r => unsafeProj(r).copy())
130+
val dataSize = unsafeRows.map(_.getSizeInBytes.toLong).sum
131+
longMetric("dataSize") += dataSize
132+
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
133+
134+
unsafeRows.asInstanceOf[Array[InternalRow]]
135+
}
136+
}
137+
}
138+
139+
protected override def doPrepare(): Unit = {
140+
relationFuture
141+
}
142+
143+
protected override def doExecute(): RDD[InternalRow] = {
144+
throw QueryExecutionErrors.executeCodePathUnsupportedError("CometSubqueryBroadcastExec")
145+
}
146+
147+
override def executeCollect(): Array[InternalRow] = {
148+
ThreadUtils.awaitResult(relationFuture, Duration.Inf)
149+
}
150+
151+
override def stringArgs: Iterator[Any] = super.stringArgs ++ Iterator(s"[id=#$id]")
152+
153+
override protected def withNewChildInternal(
154+
newChild: SparkPlan): CometSubqueryBroadcastExec =
155+
copy(child = newChild)
156+
}
157+
158+
object CometSubqueryBroadcastExec {
159+
private[comet] val executionContext = ExecutionContext.fromExecutorService(
160+
ThreadUtils.newDaemonCachedThreadPool(
161+
"comet-dynamicpruning",
162+
SQLConf.get.getConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD)))
163+
}

0 commit comments

Comments
 (0)