Skip to content

Commit 5efd972

Browse files
authored
chore: Improve shuffle fallback logic (#3989)
1 parent d6d5f09 commit 5efd972

File tree

7 files changed

+210
-295
lines changed

7 files changed

+210
-295
lines changed

docs/source/contributor-guide/adding_a_new_operator.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,14 @@ For operators that run in the JVM:
553553
Example pattern from `CometExecRule.scala`:
554554

555555
```scala
556-
case s: ShuffleExchangeExec if nativeShuffleSupported(s) =>
557-
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
556+
case s: ShuffleExchangeExec =>
557+
CometShuffleExchangeExec.shuffleSupported(s) match {
558+
case Some(CometNativeShuffle) =>
559+
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
560+
case Some(CometColumnarShuffle) =>
561+
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
562+
case None => s
563+
}
558564
```
559565

560566
## Common Patterns and Helpers

spark/src/main/scala/org/apache/comet/CometFallback.scala

Lines changed: 0 additions & 67 deletions
This file was deleted.

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -200,22 +200,29 @@ object CometSparkSessionExtensions extends Logging {
200200
}
201201

202202
/**
203-
* Attaches explain information to a TreeNode, rolling up the corresponding information tags
204-
* from any child nodes. For now, we are using this to attach the reasons why certain Spark
205-
* operators or expressions are disabled.
203+
* Record a fallback reason on a `TreeNode` (a Spark operator or expression) explaining why
204+
* Comet cannot accelerate it. Reasons recorded here are surfaced in extended explain output
205+
* (see `ExtendedExplainInfo`) and, when `COMET_LOG_FALLBACK_REASONS` is enabled, logged as
206+
* warnings. The reasons are also rolled up from child nodes so that the operator that remains
207+
* in the Spark plan carries the reasons from its converted-away subtree.
208+
*
209+
* Call this in any code path where Comet decides not to convert a given node - serde `convert`
210+
* methods returning `None`, unsupported data types, disabled configs, etc. Do not use this for
211+
* informational messages that are not fallback reasons: anything tagged here is treated by the
212+
* rules as a signal that the node falls back to Spark.
206213
*
207214
* @param node
208-
* The node to attach the explain information to. Typically a SparkPlan
215+
* The Spark operator or expression that is falling back to Spark.
209216
* @param info
210-
* Information text. Optional, may be null or empty. If not provided, then only information
211-
* from child nodes will be included.
217+
* The fallback reason. Optional, may be null or empty - pass empty only when the call is used
218+
* purely to roll up reasons from `exprs`.
212219
* @param exprs
213-
* Child nodes. Information attached in these nodes will be be included in the information
214-
* attached to @node
220+
* Child nodes whose own fallback reasons should be rolled up into `node`. Pass the
221+
* sub-expressions or child operators whose failure caused `node` to fall back.
215222
* @tparam T
216-
* The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression
223+
* The type of the TreeNode. Typically `SparkPlan`, `AggregateExpression`, or `Expression`.
217224
* @return
218-
* The node with information (if any) attached
225+
* `node` with fallback reasons attached (as a side effect on its tag map).
219226
*/
220227
def withInfo[T <: TreeNode[_]](node: T, info: String, exprs: T*): T = {
221228
// support existing approach of passing in multiple infos in a newline-delimited string
@@ -228,22 +235,24 @@ object CometSparkSessionExtensions extends Logging {
228235
}
229236

230237
/**
231-
* Attaches explain information to a TreeNode, rolling up the corresponding information tags
232-
* from any child nodes. For now, we are using this to attach the reasons why certain Spark
233-
* operators or expressions are disabled.
238+
* Record one or more fallback reasons on a `TreeNode` and roll up reasons from any child nodes.
239+
* This is the set-valued form of [[withInfo]]; see that overload for the full contract.
240+
*
241+
* Reasons are accumulated (never overwritten) on the node's `EXTENSION_INFO` tag and are
242+
* surfaced in extended explain output. When `COMET_LOG_FALLBACK_REASONS` is enabled, each new
243+
* reason is also emitted as a warning.
234244
*
235245
* @param node
236-
* The node to attach the explain information to. Typically a SparkPlan
246+
* The Spark operator or expression that is falling back to Spark.
237247
* @param info
238-
* Information text. May contain zero or more strings. If not provided, then only information
239-
* from child nodes will be included.
248+
* The fallback reasons for this node. May be empty when the call is used purely to roll up
249+
* child reasons.
240250
* @param exprs
241-
* Child nodes. Information attached in these nodes will be be included in the information
242-
* attached to @node
251+
* Child nodes whose own fallback reasons should be rolled up into `node`.
243252
* @tparam T
244-
* The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression
253+
* The type of the TreeNode. Typically `SparkPlan`, `AggregateExpression`, or `Expression`.
245254
* @return
246-
* The node with information (if any) attached
255+
* `node` with fallback reasons attached (as a side effect on its tag map).
247256
*/
248257
def withInfos[T <: TreeNode[_]](node: T, info: Set[String], exprs: T*): T = {
249258
if (CometConf.COMET_LOG_FALLBACK_REASONS.get()) {
@@ -259,25 +268,27 @@ object CometSparkSessionExtensions extends Logging {
259268
}
260269

261270
/**
262-
* Attaches explain information to a TreeNode, rolling up the corresponding information tags
263-
* from any child nodes
271+
* Roll up fallback reasons from `exprs` onto `node` without adding a new reason of its own. Use
272+
* this when a parent operator is itself falling back and wants to preserve the reasons recorded
273+
* on its child expressions/operators so they appear together in explain output.
264274
*
265275
* @param node
266-
* The node to attach the explain information to. Typically a SparkPlan
276+
* The parent operator or expression falling back to Spark.
267277
* @param exprs
268-
* Child nodes. Information attached in these nodes will be be included in the information
269-
* attached to @node
278+
* Child nodes whose fallback reasons should be aggregated onto `node`.
270279
* @tparam T
271-
* The type of the TreeNode. Typically SparkPlan, AggregateExpression, or Expression
280+
* The type of the TreeNode. Typically `SparkPlan`, `AggregateExpression`, or `Expression`.
272281
* @return
273-
* The node with information (if any) attached
282+
* `node` with the rolled-up reasons attached (as a side effect on its tag map).
274283
*/
275284
def withInfo[T <: TreeNode[_]](node: T, exprs: T*): T = {
276285
withInfos(node, Set.empty, exprs: _*)
277286
}
278287

279288
/**
280-
* Checks whether a TreeNode has any explain information attached
289+
* True if any fallback reason has been recorded on `node` (via [[withInfo]] / [[withInfos]]).
290+
* Callers that need to short-circuit when a prior rule pass has already decided a node falls
291+
* back can use this as the sticky signal.
281292
*/
282293
def hasExplainInfo(node: TreeNode[_]): Boolean = {
283294
node.getTagValue(CometExplainInfo.EXTENSION_INFO).exists(_.nonEmpty)

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,18 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
9898
private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get()
9999

100100
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
101-
plan.transformUp {
102-
case s: ShuffleExchangeExec if CometShuffleExchangeExec.nativeShuffleSupported(s) =>
103-
// Switch to use Decimal128 regardless of precision, since Arrow native execution
104-
// doesn't support Decimal32 and Decimal64 yet.
105-
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
106-
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
107-
108-
case s: ShuffleExchangeExec if CometShuffleExchangeExec.columnarShuffleSupported(s) =>
109-
// Columnar shuffle for regular Spark operators (not Comet) and Comet operators
110-
// (if configured)
111-
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
101+
plan.transformUp { case s: ShuffleExchangeExec =>
102+
CometShuffleExchangeExec.shuffleSupported(s) match {
103+
case Some(CometNativeShuffle) =>
104+
// Switch to use Decimal128 regardless of precision, since Arrow native execution
105+
// doesn't support Decimal32 and Decimal64 yet.
106+
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
107+
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
108+
case Some(CometColumnarShuffle) =>
109+
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
110+
case None =>
111+
s
112+
}
112113
}
113114
}
114115

0 commit comments

Comments
 (0)