From ba05a8ede1e02f629e88cc3ed81d77fd1ae1fc4e Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 22 Apr 2026 12:26:47 -0400 Subject: [PATCH 1/5] AQE DPP broadcast reuse for Iceberg native scans. This replaces the reflection hack with a columnar rule that wires CometSubqueryBroadcastExec to reuse the join's already-materialized broadcast exchange, eliminating the double execution of the dim table. --- .../comet/CometSparkSessionExtensions.scala | 8 +- ...metPlanAdaptiveDynamicPruningFilters.scala | 212 +++++++++ .../comet/CometIcebergNativeScanExec.scala | 167 +++---- .../ShimCometSparkSessionExtensions.scala | 0 .../comet/shims/ShimSubqueryBroadcast.scala | 71 ++- .../ShimCometSparkSessionExtensions.scala | 46 ++ .../comet/shims/ShimSubqueryBroadcast.scala | 30 +- .../comet/shims/ShimSubqueryBroadcast.scala | 29 +- .../comet/CometIcebergNativeSuite.scala | 448 +++++++++++++++++- 9 files changed, 880 insertions(+), 131 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/rules/CometPlanAdaptiveDynamicPruningFilters.scala rename spark/src/main/{spark-3.x => spark-3.4}/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala (100%) create mode 100644 spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index b89e57422b..aff4a1a473 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf._ -import org.apache.comet.rules.{CometExecRule, CometScanRule, EliminateRedundantTransitions} +import org.apache.comet.rules.{CometExecRule, CometPlanAdaptiveDynamicPruningFilters, CometScanRule, EliminateRedundantTransitions} import org.apache.comet.shims.ShimCometSparkSessionExtensions /** @@ -47,6 +47,7 @@ class CometSparkSessionExtensions override def apply(extensions: SparkSessionExtensions): Unit = { extensions.injectColumnar { session => CometScanColumnar(session) } extensions.injectColumnar { session => CometExecColumnar(session) } + extensions.injectColumnar { session => CometDPPColumnar(session) } extensions.injectQueryStagePrepRule { session => CometScanRule(session) } extensions.injectQueryStagePrepRule { session => CometExecRule(session) } } @@ -61,6 +62,11 @@ class CometSparkSessionExtensions override def postColumnarTransitions: Rule[SparkPlan] = EliminateRedundantTransitions(session) } + + case class CometDPPColumnar(session: SparkSession) extends ColumnarRule { + override def postColumnarTransitions: Rule[SparkPlan] = + CometPlanAdaptiveDynamicPruningFilters(session) + } } object CometSparkSessionExtensions extends Logging { diff --git a/spark/src/main/scala/org/apache/comet/rules/CometPlanAdaptiveDynamicPruningFilters.scala b/spark/src/main/scala/org/apache/comet/rules/CometPlanAdaptiveDynamicPruningFilters.scala new file mode 100644 index 0000000000..7047987d93 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/rules/CometPlanAdaptiveDynamicPruningFilters.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression, Literal} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometIcebergNativeScanExec, CometSubqueryBroadcastExec} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, BroadcastQueryStageExec} +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec + +import org.apache.comet.shims.ShimSubqueryBroadcast + +/** + * Converts SubqueryAdaptiveBroadcastExec (AQE DPP) to CometSubqueryBroadcastExec or + * SubqueryBroadcastExec inside CometIcebergNativeScanExec's runtimeFilters. + * + * Spark's PlanAdaptiveDynamicPruningFilters performs this conversion for BatchScanExec, but + * CometIcebergNativeScanExec wraps BatchScanExec and hides its runtimeFilters from the plan's + * expression tree. This rule accesses them directly via originalPlan.runtimeFilters. + * + * Registered as postColumnarTransitions (not queryStageOptimizerRule) because CometExecRule runs + * in preColumnarTransitions and recreates CometIcebergNativeScanExec instances, which would + * discard earlier modifications. Running after ensures we see the final scan instances. + * + * @see + * PlanAdaptiveDynamicPruningFilters (Spark's equivalent for visible DPP expressions) + * @see + * CometExecRule.convertSubqueryBroadcasts (non-AQE DPP conversion, PR #4011) + */ +case class CometPlanAdaptiveDynamicPruningFilters(session: SparkSession) + extends Rule[SparkPlan] + with AdaptiveSparkPlanHelper + with ShimSubqueryBroadcast + with Logging { + + override def apply(plan: SparkPlan): SparkPlan = { + if (!conf.dynamicPartitionPruningEnabled) { + return plan + } + + // Short-circuit: only process plans containing Iceberg scans with DPP runtime filters. + val hasIcebergDPP = plan.find { + case scan: CometIcebergNativeScanExec => + scan.originalPlan != null && scan.originalPlan.runtimeFilters.exists { + case DynamicPruningExpression(_: InSubqueryExec) => true + case _ => false + } + case _ => false + }.isDefined + + if (!hasIcebergDPP) return plan + + logDebug("Processing plan with Iceberg DPP runtime filters") + + plan.transformUp { + case scan: CometIcebergNativeScanExec if scan.originalPlan != null => + val runtimeFilters = scan.originalPlan.runtimeFilters + val newFilters = runtimeFilters.map(transformFilter(_, plan)) + if (newFilters != runtimeFilters) { + val newBatchScan = scan.originalPlan.copy(runtimeFilters = newFilters) + scan.originalPlan.logicalLink.foreach(newBatchScan.setLogicalLink) + scan.copy(originalPlan = newBatchScan) + } else { + scan + } + } + } + + private def transformFilter(filter: Expression, fullPlan: SparkPlan): Expression = { + filter.transformUp { case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) => + inSub.plan match { + case sab: SubqueryAdaptiveBroadcastExec => + logDebug(s"Converting SubqueryAdaptiveBroadcastExec '${sab.name}'") + convertSAB(inSub, sab, fullPlan).getOrElse { + // No matching broadcast join found (e.g., SortMergeJoin, or join optimized away). + // Spark's PlanAdaptiveDynamicPruningFilters handles onlyInBroadcast=false with an + // aggregate subquery fallback. We use TrueLiteral for both cases: correct results + // (scans all partitions), avoids replicating Spark's aggregate planning internals. + logInfo(s"No matching broadcast join for DPP subquery '${sab.name}', disabling DPP") + DynamicPruningExpression(Literal.TrueLiteral) + } + case _ => dpe + } + } + } + + /** + * Converts a SubqueryAdaptiveBroadcastExec by finding the matching broadcast join and wiring + * the subquery to reuse its already-materialized broadcast exchange. + * + * The subquery type depends on the actual broadcast exchange type (not the join type): + * - CometBroadcastExchangeExec -> CometSubqueryBroadcastExec (decodes Arrow broadcast data) + * - BroadcastExchangeExec -> SubqueryBroadcastExec (decodes HashedRelation) + * + * CometExecRule converts joins and their broadcast exchanges together, so the join type and + * broadcast type should always agree. The assert in extractBroadcastChild enforces this. + */ + private def convertSAB( + inSub: InSubqueryExec, + sab: SubqueryAdaptiveBroadcastExec, + fullPlan: SparkPlan): Option[DynamicPruningExpression] = { + val buildKeys = sab.buildKeys + val indices = getSubqueryBroadcastIndices(sab) + val sabKeyIds: Set[Any] = sab.buildKeys.flatMap(_.references.map(_.exprId)).toSet + + findMatchingBroadcastJoin(sabKeyIds, fullPlan).map { + case (broadcastChild: SparkPlan, isComet: Boolean) => + logDebug( + s"Matched DPP subquery '${sab.name}' to " + + s"${if (isComet) "Comet" else "Spark"} broadcast exchange") + val subquery = if (isComet) { + CometSubqueryBroadcastExec(sab.name, indices, buildKeys, broadcastChild) + } else { + createSubqueryBroadcastExec(sab.name, indices, buildKeys, broadcastChild) + } + DynamicPruningExpression(inSub.withNewPlan(subquery)) + } + } + + /** + * Finds a broadcast hash join whose build-side keys match the given exprIds. Searches both + * CometBroadcastHashJoinExec and BroadcastHashJoinExec to handle cases where the join fell back + * to Spark (e.g., unsupported expression, disabled config). + */ + private def findMatchingBroadcastJoin( + sabKeyIds: Set[Any], + plan: SparkPlan): Option[(SparkPlan, Boolean)] = { + + def extractBroadcastChild( + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + isCometJoin: Boolean): Option[(SparkPlan, Boolean)] = { + val joinBuildKeys = buildSide match { + case BuildLeft => leftKeys + case BuildRight => rightKeys + } + val joinKeyIds = joinBuildKeys.flatMap(_.references.map(_.exprId)).toSet + if (sabKeyIds.nonEmpty && sabKeyIds == joinKeyIds) { + val bc = buildSide match { + case BuildLeft => left + case BuildRight => right + } + val isCometBroadcast = isCometBroadcastExchange(bc) + + // CometExecRule converts joins and their broadcast exchanges together. + // A mismatch would cause a ClassCastException (Arrow vs HashedRelation). + assert( + isCometJoin == isCometBroadcast, + s"Join/broadcast type mismatch: join isComet=$isCometJoin, broadcast isComet=" + + s"$isCometBroadcast. CometExecRule should convert both or neither.") + + Some((bc, isCometBroadcast)) + } else { + None + } + } + + var result: Option[(SparkPlan, Boolean)] = None + find(plan) { + case join: CometBroadcastHashJoinExec if result.isEmpty => + result = extractBroadcastChild( + join.buildSide, + join.left, + join.right, + join.leftKeys, + join.rightKeys, + isCometJoin = true) + result.isDefined + case join: BroadcastHashJoinExec if result.isEmpty => + result = extractBroadcastChild( + join.buildSide, + join.left, + join.right, + join.leftKeys, + join.rightKeys, + isCometJoin = false) + result.isDefined + case _ => false + } + result + } + + private def isCometBroadcastExchange(plan: SparkPlan): Boolean = plan match { + case _: CometBroadcastExchangeExec => true + case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true + case _ => false + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala index 36085b6329..4dd644bf87 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala @@ -23,7 +23,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, DynamicPruningExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.{InSubqueryExec, SubqueryAdaptiveBroadcastExec} @@ -86,8 +86,8 @@ case class CometIcebergNativeScanExec( * Lazy partition serialization - deferred until execution time for DPP support. * * Entry points: This lazy val may be triggered from either doExecuteColumnar() (via - * commonData/perPartitionData) or capturedMetricValues (for Iceberg metrics). Lazy val - * semantics ensure single evaluation regardless of entry point. + * commonData/perPartitionData). Lazy val semantics ensure single evaluation regardless of entry + * point. * * DPP (Dynamic Partition Pruning) Flow: * @@ -103,50 +103,24 @@ case class CometIcebergNativeScanExec( * - Subquery plans are set up (but not yet executed) * * 2. Spark calls doExecuteColumnar() (or metrics are accessed) - * - Accesses perPartitionData (or capturedMetricValues) + * - Accesses perPartitionData * - Forces serializedPartitionData evaluation (here) - * - Waits for DPP values (updateResult or reflection) + * - Waits for DPP values (updateResult) * - Calls serializePartitions with DPP-filtered inputRDD * - Only matching partitions are serialized * }}} */ @transient private lazy val serializedPartitionData: (Array[Byte], Array[Array[Byte]]) = { - // Ensure DPP subqueries are resolved before accessing inputRDD. originalPlan.runtimeFilters.foreach { case DynamicPruningExpression(e: InSubqueryExec) if e.values().isEmpty => e.plan match { case sab: SubqueryAdaptiveBroadcastExec => - // SubqueryAdaptiveBroadcastExec.executeCollect() throws, so we call - // child.executeCollect() directly. We use the index from SAB to find the - // right buildKey, then locate that key's column in child.output. - val rows = sab.child.executeCollect() - val indices = getSubqueryBroadcastIndices(sab) - - // SPARK-46946 changed index: Int to indices: Seq[Int] as a preparatory refactor - // for future features (Null Safe Equality DPP, multiple equality predicates). - // Currently indices always has one element. CometScanRule checks for multi-index - // DPP and falls back, so this assertion should never fail. - assert( - indices.length == 1, - s"Multi-index DPP not supported: indices=$indices. See SPARK-46946.") - val buildKeyIndex = indices.head - val buildKey = sab.buildKeys(buildKeyIndex) - - // Find column index in child.output by matching buildKey's exprId - val colIndex = buildKey match { - case attr: Attribute => - sab.child.output.indexWhere(_.exprId == attr.exprId) - // DPP may cast partition column to match join key type - case Cast(attr: Attribute, _, _, _) => - sab.child.output.indexWhere(_.exprId == attr.exprId) - case _ => buildKeyIndex - } - if (colIndex < 0) { - throw new IllegalStateException( - s"DPP build key '$buildKey' not found in ${sab.child.output.map(_.name)}") - } - - setInSubqueryResult(e, rows.map(_.get(colIndex, e.child.dataType))) + // On 3.5+, CometPlanAdaptiveDynamicPruningFilters should have converted + // all SABs before execution. This path is a fallback for Spark 3.4. + logWarning( + "SubqueryAdaptiveBroadcastExec found at execution time " + + "(expected conversion by CometPlanAdaptiveDynamicPruningFilters)") + resolveSubqueryAdaptiveBroadcast(e, sab) case _ => e.updateResult() } @@ -156,29 +130,6 @@ case class CometIcebergNativeScanExec( CometIcebergNativeScan.serializePartitions(originalPlan, output, nativeIcebergScanMetadata) } - /** - * Sets InSubqueryExec's private result field via reflection. - * - * Reflection is required because: - * - SubqueryAdaptiveBroadcastExec.executeCollect() throws UnsupportedOperationException - * - InSubqueryExec has no public setter for result, only updateResult() which calls - * executeCollect() - * - We can't replace e.plan since it's a val - */ - private def setInSubqueryResult(e: InSubqueryExec, result: Array[_]): Unit = { - val fields = e.getClass.getDeclaredFields - // Field name is mangled by Scala compiler, e.g. "org$apache$...$InSubqueryExec$$result" - val resultField = fields - .find(f => f.getName.endsWith("$result") && !f.getName.contains("Broadcast")) - .getOrElse { - throw new IllegalStateException( - s"Cannot find 'result' field in ${e.getClass.getName}. " + - "Spark version may be incompatible with Comet's DPP implementation.") - } - resultField.setAccessible(true) - resultField.set(e, result) - } - def commonData: Array[Byte] = serializedPartitionData._1 def perPartitionData: Array[Array[Byte]] = serializedPartitionData._2 @@ -190,10 +141,6 @@ case class CometIcebergNativeScanExec( override lazy val outputOrdering: Seq[SortOrder] = Nil - // Capture metric VALUES and TYPES (not objects!) in a serializable case class - // This survives serialization while SQLMetric objects get reset to 0 - private case class MetricValue(name: String, value: Long, metricType: String) - /** * Maps Iceberg V2 custom metric types to standard Spark metric types for better UI formatting. * @@ -225,71 +172,60 @@ case class CometIcebergNativeScanExec( } /** - * Captures Iceberg planning metrics for display in Spark UI. + * Immutable SQLMetric for Iceberg planning metrics. + * + * Reading `value` lazily triggers serializedPartitionData to ensure Iceberg planning has run + * and the metric is populated. This avoids the side effect during metrics MAP construction + * (which SparkPlanInfo accesses before AQE runs), while still producing correct values when the + * metric VALUE is actually read (e.g., by tests or Spark UI after execution). * - * This lazy val intentionally triggers serializedPartitionData evaluation because Iceberg - * populates metrics during planning (when inputRDD is accessed). Both this and - * doExecuteColumnar() may trigger serializedPartitionData, but lazy val semantics ensure it's - * evaluated only once. + * Overrides merge/reset to prevent accumulator merges from executor (which carry 0) from + * overwriting the driver-side planning values. */ - @transient private lazy val capturedMetricValues: Seq[MetricValue] = { - // Guard against null originalPlan (from doCanonicalize) - if (originalPlan == null) { - Seq.empty - } else { - // Trigger serializedPartitionData to ensure Iceberg planning has run and - // metrics are populated + private class LazyIcebergMetric(metricType: String, metricName: String) + extends SQLMetric(metricType, 0) { + + override def value: Long = { val _ = serializedPartitionData + originalPlan.metrics.get(metricName).map(_.value).getOrElse(0L) + } + + override def merge(other: AccumulatorV2[Long, Long]): Unit = {} + override def reset(): Unit = {} + } + + /** + * Iceberg planning metrics, created eagerly from originalPlan.metrics names and types. + * + * SparkPlanInfo reads the metrics MAP (names, types, ids) before AQE runs. This is safe because + * constructing the map has no side effects. Metric VALUES are lazily resolved via + * LazyIcebergMetric when actually read (after execution). + */ + @transient private lazy val icebergPlanningMetrics: Map[String, LazyIcebergMetric] = { + if (originalPlan == null) { + Map.empty + } else { originalPlan.metrics .filterNot { case (name, _) => - // Filter out metrics that are now runtime metrics incremented on the native side name == "numOutputRows" || name == "numDeletes" || name == "numSplits" } .map { case (name, metric) => val mappedType = mapMetricType(name, metric.metricType) - MetricValue(name, metric.value, mappedType) + val lazyMetric = new LazyIcebergMetric(mappedType, name) + sparkContext.register(lazyMetric, name) + name -> lazyMetric } - .toSeq } } - /** - * Immutable SQLMetric for planning metrics that don't change during execution. - * - * Regular SQLMetric extends AccumulatorV2, which means when execution completes, accumulator - * updates from executors (which are 0 since they don't update planning metrics) get merged back - * to the driver, overwriting the driver's values with 0. - * - * This class overrides the accumulator methods to make the metric truly immutable once set. - */ - private class ImmutableSQLMetric(metricType: String) extends SQLMetric(metricType, 0) { - - override def merge(other: AccumulatorV2[Long, Long]): Unit = {} - - override def reset(): Unit = {} - } - override lazy val metrics: Map[String, SQLMetric] = { val baseMetrics = Map( "output_rows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - // Create IMMUTABLE metrics with captured values AND types - // these won't be affected by accumulator merges - val icebergMetrics = capturedMetricValues.map { mv => - // Create the immutable metric with initValue = 0 (Spark 4 requires initValue <= 0) - val metric = new ImmutableSQLMetric(mv.metricType) - // Set the actual value after creation - metric.set(mv.value) - // Register it with SparkContext to assign metadata (name, etc.) - sparkContext.register(metric, mv.name) - mv.name -> metric - }.toMap - - // Add num_splits as a runtime metric (incremented on the native side during execution) val numSplitsMetric = SQLMetrics.createMetric(sparkContext, "number of file splits processed") - baseMetrics ++ icebergMetrics + ("num_splits" -> numSplitsMetric) + baseMetrics ++ icebergPlanningMetrics + ("num_splits" -> numSplitsMetric) } /** Executes using CometExecRDD - planning data is computed lazily on first access. */ @@ -356,17 +292,28 @@ case class CometIcebergNativeScanExec( Iterator(output, s"$metadataLocation, $scanDesc$runtimeFiltersStr", taskCount) } + // runtimeFilters must be included in equals so that transformUp detects changes when + // CometPlanAdaptiveDynamicPruningFilters replaces SubqueryAdaptiveBroadcastExec with + // CometSubqueryBroadcastExec. Without this, the new scan would be "equal" to the old + // one and transformUp would return the original (unconverted) plan tree. override def equals(obj: Any): Boolean = { obj match { case other: CometIcebergNativeScanExec => this.metadataLocation == other.metadataLocation && this.output == other.output && - this.serializedPlanOpt == other.serializedPlanOpt + this.serializedPlanOpt == other.serializedPlanOpt && + this.runtimeFiltersEqual(other) case _ => false } } + private def runtimeFiltersEqual(other: CometIcebergNativeScanExec): Boolean = { + if (this.originalPlan eq other.originalPlan) return true + if (this.originalPlan == null || other.originalPlan == null) return false + this.originalPlan.runtimeFilters == other.originalPlan.runtimeFilters + } + override def hashCode(): Int = Objects.hashCode(metadataLocation, output.asJava, serializedPlanOpt) } diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala similarity index 100% rename from spark/src/main/spark-3.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala rename to spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala index 292ed2cb18..c0011afa4b 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala @@ -19,14 +19,11 @@ package org.apache.comet.shims -import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression} +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan, SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} trait ShimSubqueryBroadcast { - /** - * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`, - * Spark 4.x has `indices: Seq[Int]`. - */ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = { Seq(sab.index) } @@ -35,4 +32,68 @@ trait ShimSubqueryBroadcast { def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = { Seq(sub.index) } + + /** Creates a SubqueryBroadcastExec with version-appropriate index parameter. */ + def createSubqueryBroadcastExec( + name: String, + indices: Seq[Int], + buildKeys: Seq[Expression], + child: SparkPlan): SubqueryBroadcastExec = { + assert(indices.length == 1, s"Multi-index DPP not supported: indices=$indices") + SubqueryBroadcastExec(name, indices.head, buildKeys, child) + } + + /** + * Resolves a SubqueryAdaptiveBroadcastExec DPP filter via reflection. Required on Spark 3.4 + * where injectQueryStageOptimizerRule is unavailable, so CometPlanAdaptiveDynamicPruningFilters + * cannot convert the SAB before execution. + * + * On Spark 3.5+ this is dead code — the rule converts SAB to CometSubqueryBroadcastExec before + * serializedPartitionData runs, so the SAB case never matches. + */ + def resolveSubqueryAdaptiveBroadcast( + e: InSubqueryExec, + sab: SubqueryAdaptiveBroadcastExec): Unit = { + val rows = sab.child.executeCollect() + val indices = getSubqueryBroadcastIndices(sab) + + assert( + indices.length == 1, + s"Multi-index DPP not supported: indices=$indices. See SPARK-46946.") + val buildKeyIndex = indices.head + val buildKey = sab.buildKeys(buildKeyIndex) + + val colIndex = buildKey match { + case attr: Attribute => + sab.child.output.indexWhere(_.exprId == attr.exprId) + case Cast(attr: Attribute, _, _, _) => + sab.child.output.indexWhere(_.exprId == attr.exprId) + case _ => buildKeyIndex + } + if (colIndex < 0) { + throw new IllegalStateException( + s"DPP build key '$buildKey' not found in ${sab.child.output.map(_.name)}") + } + + val result = rows.map(_.get(colIndex, e.child.dataType)) + setInSubqueryResult(e, result) + } + + /** + * Sets InSubqueryExec's private result field via reflection. Required because + * SubqueryAdaptiveBroadcastExec.executeCollect() throws UnsupportedOperationException and + * InSubqueryExec has no public setter for result. + */ + private def setInSubqueryResult(e: InSubqueryExec, result: Array[_]): Unit = { + val fields = e.getClass.getDeclaredFields + val resultField = fields + .find(f => f.getName.endsWith("$result") && !f.getName.contains("Broadcast")) + .getOrElse { + throw new IllegalStateException( + s"Cannot find 'result' field in ${e.getClass.getName}. " + + "Spark version may be incompatible with Comet's DPP implementation.") + } + resultField.setAccessible(true) + resultField.set(e, result) + } } diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala new file mode 100644 index 0000000000..0dd783201a --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} + +trait ShimCometSparkSessionExtensions { + + /** + * TODO: delete after dropping Spark 3.x support and directly call + * SQLConf.EXTENDED_EXPLAIN_PROVIDERS.key + */ + protected val EXTENDED_EXPLAIN_PROVIDERS_KEY = "spark.sql.extendedExplainProviders" + + // Extended info is available only since Spark 4.0.0 + // (https://issues.apache.org/jira/browse/SPARK-47289) + def supportsExtendedExplainInfo(qe: QueryExecution): Boolean = { + try { + // Look for QueryExecution.extendedExplainInfo(scala.Function1[String, Unit], SparkPlan) + qe.getClass.getDeclaredMethod( + "extendedExplainInfo", + classOf[String => Unit], + classOf[SparkPlan]) + } catch { + case _: NoSuchMethodException | _: SecurityException => return false + } + true + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala index 292ed2cb18..49e63e0339 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala @@ -19,20 +19,38 @@ package org.apache.comet.shims -import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan, SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} trait ShimSubqueryBroadcast { - /** - * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`, - * Spark 4.x has `indices: Seq[Int]`. - */ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = { Seq(sab.index) } - /** Same version shim for SubqueryBroadcastExec. */ def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = { Seq(sub.index) } + + /** Creates a SubqueryBroadcastExec with version-appropriate index parameter. */ + def createSubqueryBroadcastExec( + name: String, + indices: Seq[Int], + buildKeys: Seq[Expression], + child: SparkPlan): SubqueryBroadcastExec = { + assert(indices.length == 1, s"Multi-index DPP not supported: indices=$indices") + SubqueryBroadcastExec(name, indices.head, buildKeys, child) + } + + // CometPlanAdaptiveDynamicPruningFilters converts all SABs before execution -- to + // CometSubqueryBroadcastExec (broadcast reuse) or Literal.TrueLiteral (no broadcast). + // Reaching this method means the rule didn't run, which is a configuration error. + def resolveSubqueryAdaptiveBroadcast( + e: InSubqueryExec, + sab: SubqueryAdaptiveBroadcastExec): Unit = { + throw new IllegalStateException( + "SubqueryAdaptiveBroadcastExec should have been converted by " + + "CometPlanAdaptiveDynamicPruningFilters. This indicates the AQE optimizer rule " + + "was not registered.") + } } diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala index 73d9e53c4a..1d8fab83b7 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala @@ -19,20 +19,37 @@ package org.apache.comet.shims -import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan, SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} trait ShimSubqueryBroadcast { - /** - * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`, - * Spark 4.x has `indices: Seq[Int]`. - */ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = { sab.indices } - /** Same version shim for SubqueryBroadcastExec. */ def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = { sub.indices } + + /** Creates a SubqueryBroadcastExec with version-appropriate index parameter. */ + def createSubqueryBroadcastExec( + name: String, + indices: Seq[Int], + buildKeys: Seq[Expression], + child: SparkPlan): SubqueryBroadcastExec = { + SubqueryBroadcastExec(name, indices, buildKeys, child) + } + + // CometPlanAdaptiveDynamicPruningFilters converts all SABs before execution -- to + // CometSubqueryBroadcastExec (broadcast reuse) or Literal.TrueLiteral (no broadcast). + // Reaching this method means the rule didn't run, which is a configuration error. + def resolveSubqueryAdaptiveBroadcast( + e: InSubqueryExec, + sab: SubqueryAdaptiveBroadcastExec): Unit = { + throw new IllegalStateException( + "SubqueryAdaptiveBroadcastExec should have been converted by " + + "CometPlanAdaptiveDynamicPruningFilters. This indicates the AQE optimizer rule " + + "was not registered.") + } } diff --git a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala index 62c8844f72..19a7ca448e 100644 --- a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala @@ -25,8 +25,10 @@ import java.nio.file.Files import scala.jdk.CollectionConverters._ import org.apache.spark.sql.CometTestBase -import org.apache.spark.sql.comet.CometIcebergNativeScanExec -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometIcebergNativeScanExec, CometSubqueryBroadcastExec} +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, BroadcastQueryStageExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StringType, TimestampType} @@ -38,7 +40,10 @@ import org.apache.comet.testing.{FuzzDataGenerator, SchemaGenOptions} * * Note: Requires Iceberg dependencies to be added to pom.xml */ -class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { +class CometIcebergNativeSuite + extends CometTestBase + with RESTCatalogHelper + with AdaptiveSparkPlanHelper { // Skip these tests if Iceberg is not available in classpath private def icebergAvailable: Boolean = { @@ -2530,6 +2535,15 @@ class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { val numPartitions = icebergScans.head.numPartitions assert(numPartitions == 1, s"Expected DPP to prune to 1 partition but got $numPartitions") + // Verify AQE DPP used CometSubqueryBroadcastExec with broadcast reuse + val subqueries = collectIcebergDPPSubqueries(cometPlan) + assert(subqueries.size == 2, s"Expected 2 DPP subqueries but got ${subqueries.size}") + subqueries.foreach { sub => + assert( + sub.isInstanceOf[CometSubqueryBroadcastExec], + s"Expected CometSubqueryBroadcastExec but got ${sub.getClass.getSimpleName}") + } + spark.sql("DROP TABLE runtime_cat.db.multi_dpp_fact") } } @@ -2608,6 +2622,20 @@ class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { val numPartitions = icebergScans.head.numPartitions assert(numPartitions == 1, s"Expected DPP to prune to 1 partition but got $numPartitions") + // Verify AQE DPP used CometSubqueryBroadcastExec with broadcast reuse + val subqueries = collectIcebergDPPSubqueries(cometPlan) + assert(subqueries.nonEmpty, s"Expected DPP subqueries in plan:\n$cometPlan") + subqueries.foreach { sub => + assert( + sub.isInstanceOf[CometSubqueryBroadcastExec], + s"Expected CometSubqueryBroadcastExec but got ${sub.getClass.getSimpleName}") + val csb = sub.asInstanceOf[CometSubqueryBroadcastExec] + assert( + csb.child.isInstanceOf[BroadcastQueryStageExec], + "Expected BroadcastQueryStageExec child (broadcast reuse) but got " + + s"${csb.child.getClass.getSimpleName}") + } + spark.sql("DROP TABLE runtime_cat.db.fact_table") } } @@ -2764,4 +2792,418 @@ class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { } } } + + // ---- AQE DPP broadcast reuse tests ---- + + /** + * Collects DPP subquery plans from CometIcebergNativeScanExec's originalPlan.runtimeFilters. + * These are hidden from the normal plan expression tree (unlike BatchScanExec). + */ + private def collectIcebergDPPSubqueries(plan: SparkPlan): Seq[SparkPlan] = { + collect(plan) { case scan: CometIcebergNativeScanExec => scan } + .filter(_.originalPlan != null) + .flatMap(_.originalPlan.runtimeFilters) + .collect { case DynamicPruningExpression(e: InSubqueryExec) => + e.plan + } + } + + test("AQE DPP - CometSubqueryBroadcastExec replaces SubqueryAdaptiveBroadcastExec") { + assume(icebergAvailable, "Iceberg not available") + withTempIcebergDir { warehouseDir => + val dimDir = new File(warehouseDir, "dim_parquet") + withSQLConf( + "spark.sql.catalog.aqe_cat" -> "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.aqe_cat.type" -> "hadoop", + "spark.sql.catalog.aqe_cat.warehouse" -> warehouseDir.getAbsolutePath, + "spark.sql.autoBroadcastJoinThreshold" -> "1KB", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") { + + spark.sql(""" + CREATE TABLE aqe_cat.db.dpp_reuse_fact ( + id BIGINT, data STRING, date DATE + ) USING iceberg PARTITIONED BY (date) + """) + spark.sql(""" + INSERT INTO aqe_cat.db.dpp_reuse_fact VALUES + (1, 'a', DATE '1970-01-01'), (2, 'b', DATE '1970-01-02'), + (3, 'c', DATE '1970-01-02'), (4, 'd', DATE '1970-01-03') + """) + + spark + .createDataFrame(Seq((1L, java.sql.Date.valueOf("1970-01-02")))) + .toDF("id", "date") + .write + .parquet(dimDir.getAbsolutePath) + spark.read.parquet(dimDir.getAbsolutePath).createOrReplaceTempView("aqe_dim") + + val query = + """SELECT /*+ BROADCAST(d) */ f.* FROM aqe_cat.db.dpp_reuse_fact f + |JOIN aqe_dim d ON f.date = d.date AND d.id = 1""".stripMargin + val (_, cometPlan) = checkSparkAnswer(query) + + // Verify CometSubqueryBroadcastExec replaced SubqueryAdaptiveBroadcastExec + val subqueries = collectIcebergDPPSubqueries(cometPlan) + assert(subqueries.nonEmpty, s"Expected DPP subqueries in plan:\n$cometPlan") + subqueries.foreach { sub => + assert( + sub.isInstanceOf[CometSubqueryBroadcastExec], + s"Expected CometSubqueryBroadcastExec but got ${sub.getClass.getSimpleName}") + } + + // Verify broadcast reuse: child should be BroadcastQueryStageExec (the join's + // already-materialized broadcast), not a standalone exchange + subqueries.foreach { + case csb: CometSubqueryBroadcastExec => + assert( + csb.child.isInstanceOf[BroadcastQueryStageExec], + "Expected BroadcastQueryStageExec child (broadcast reuse) but got " + + s"${csb.child.getClass.getSimpleName}") + case _ => + } + + // Verify only 1 CometBroadcastExchangeExec (shared between join and DPP) + val broadcasts = collectWithSubqueries(cometPlan) { case e: CometBroadcastExchangeExec => + e + } + assert( + broadcasts.size == 1, + s"Expected 1 CometBroadcastExchangeExec (reused) but got ${broadcasts.size}") + + // Verify correct results and partition pruning + val icebergScans = collectIcebergNativeScans(cometPlan) + assert(icebergScans.nonEmpty, "Expected CometIcebergNativeScanExec in plan") + assert( + icebergScans.head.numPartitions == 1, + s"Expected DPP to prune to 1 partition but got ${icebergScans.head.numPartitions}") + + spark.sql("DROP TABLE aqe_cat.db.dpp_reuse_fact") + } + } + } + + test("AQE DPP - multiple DPP filters reuse same broadcast") { + assume(icebergAvailable, "Iceberg not available") + withTempIcebergDir { warehouseDir => + val dimDir = new File(warehouseDir, "dim_parquet") + withSQLConf( + "spark.sql.catalog.aqe_cat" -> "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.aqe_cat.type" -> "hadoop", + "spark.sql.catalog.aqe_cat.warehouse" -> warehouseDir.getAbsolutePath, + "spark.sql.autoBroadcastJoinThreshold" -> "1KB", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") { + + spark.sql(""" + CREATE TABLE aqe_cat.db.multi_dpp_reuse ( + id BIGINT, data STRING, date DATE, ts TIMESTAMP + ) USING iceberg PARTITIONED BY (data, bucket(8, id)) + """) + val df = spark + .range(1, 100) + .selectExpr( + "id", + "CAST(DATE_ADD(DATE '1970-01-01', CAST(id % 4 AS INT)) AS STRING) as data", + "DATE_ADD(DATE '1970-01-01', CAST(id % 4 AS INT)) as date", + "CAST(DATE_ADD(DATE '1970-01-01', CAST(id % 4 AS INT)) AS TIMESTAMP) as ts") + df.coalesce(1) + .write + .format("iceberg") + .option("fanout-enabled", "true") + .mode("append") + .saveAsTable("aqe_cat.db.multi_dpp_reuse") + + spark + .createDataFrame(Seq((1L, java.sql.Date.valueOf("1970-01-02"), "1970-01-02"))) + .toDF("id", "date", "data") + .write + .parquet(dimDir.getAbsolutePath) + spark.read.parquet(dimDir.getAbsolutePath).createOrReplaceTempView("aqe_multi_dim") + + val query = + """SELECT /*+ BROADCAST(d) */ f.* + |FROM aqe_cat.db.multi_dpp_reuse f + |JOIN aqe_multi_dim d ON f.id = d.id AND f.data = d.data + |WHERE d.date = DATE '1970-01-02'""".stripMargin + val (_, cometPlan) = checkSparkAnswer(query) + + // Both DPP filters should use CometSubqueryBroadcastExec + val subqueries = collectIcebergDPPSubqueries(cometPlan) + assert(subqueries.size == 2, s"Expected 2 DPP subqueries but got ${subqueries.size}") + subqueries.foreach { sub => + assert( + sub.isInstanceOf[CometSubqueryBroadcastExec], + s"Expected CometSubqueryBroadcastExec but got ${sub.getClass.getSimpleName}") + } + + // Both should reuse the same BroadcastQueryStageExec + val stages = subqueries.collect { case csb: CometSubqueryBroadcastExec => + csb.child + } + assert( + stages.forall(_.isInstanceOf[BroadcastQueryStageExec]), + "All DPP subqueries should reuse the join's BroadcastQueryStageExec") + + // Only 1 broadcast exchange in the plan + val broadcasts = collectWithSubqueries(cometPlan) { case e: CometBroadcastExchangeExec => + e + } + assert( + broadcasts.size == 1, + s"Expected 1 CometBroadcastExchangeExec but got ${broadcasts.size}") + + spark.sql("DROP TABLE aqe_cat.db.multi_dpp_reuse") + } + } + } + + test("AQE DPP - two separate broadcast joins disambiguated by buildKeys") { + assume(icebergAvailable, "Iceberg not available") + withTempIcebergDir { warehouseDir => + val dim1Dir = new File(warehouseDir, "dim1_parquet") + val dim2Dir = new File(warehouseDir, "dim2_parquet") + withSQLConf( + "spark.sql.catalog.aqe_cat" -> "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.aqe_cat.type" -> "hadoop", + "spark.sql.catalog.aqe_cat.warehouse" -> warehouseDir.getAbsolutePath, + "spark.sql.autoBroadcastJoinThreshold" -> "1KB", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") { + + // Fact table partitioned by TWO columns: date and category + spark.sql(""" + CREATE TABLE aqe_cat.db.two_join_fact ( + id BIGINT, date DATE, category STRING, value INT + ) USING iceberg PARTITIONED BY (date, category) + """) + spark.sql(""" + INSERT INTO aqe_cat.db.two_join_fact VALUES + (1, DATE '2024-01-01', 'A', 10), + (2, DATE '2024-01-01', 'B', 20), + (3, DATE '2024-01-02', 'A', 30), + (4, DATE '2024-01-02', 'B', 40), + (5, DATE '2024-01-03', 'A', 50), + (6, DATE '2024-01-03', 'C', 60) + """) + + // Dim1: filters on date + spark + .createDataFrame(Seq((java.sql.Date.valueOf("2024-01-02"), "keep"))) + .toDF("date", "label") + .write + .parquet(dim1Dir.getAbsolutePath) + spark.read.parquet(dim1Dir.getAbsolutePath).createOrReplaceTempView("date_dim") + + // Dim2: filters on category + spark + .createDataFrame(Seq(("A", "keep"))) + .toDF("category", "label") + .write + .parquet(dim2Dir.getAbsolutePath) + spark.read.parquet(dim2Dir.getAbsolutePath).createOrReplaceTempView("cat_dim") + + // Two separate broadcast joins — each creates its own DPP filter + val query = + """SELECT /*+ BROADCAST(d1), BROADCAST(d2) */ f.* + |FROM aqe_cat.db.two_join_fact f + |JOIN date_dim d1 ON f.date = d1.date + |JOIN cat_dim d2 ON f.category = d2.category + |WHERE d1.label = 'keep' AND d2.label = 'keep'""".stripMargin + + val (_, cometPlan) = checkSparkAnswer(query) + + // Should have DPP subqueries for both joins + val subqueries = collectIcebergDPPSubqueries(cometPlan) + assert(subqueries.nonEmpty, s"Expected DPP subqueries in plan:\n$cometPlan") + + // Each should be CometSubqueryBroadcastExec with BroadcastQueryStageExec child + subqueries.foreach { sub => + assert( + sub.isInstanceOf[CometSubqueryBroadcastExec], + s"Expected CometSubqueryBroadcastExec but got ${sub.getClass.getSimpleName}") + } + + // The two subqueries should reference DIFFERENT broadcast stages + // (one for date_dim, one for cat_dim) + val stages = subqueries.collect { case csb: CometSubqueryBroadcastExec => + System.identityHashCode(csb.child) + }.distinct + if (subqueries.size >= 2) { + assert( + stages.size == subqueries.size, + s"Expected ${subqueries.size} distinct broadcast stages but got ${stages.size}. " + + "buildKeys disambiguation may not be working.") + } + + // Verify correct results: date=2024-01-02 AND category=A → row (3, 2024-01-02, A, 30) + val icebergScans = collectIcebergNativeScans(cometPlan) + assert(icebergScans.nonEmpty, "Expected CometIcebergNativeScanExec in plan") + assert( + icebergScans.head.numPartitions == 1, + s"Expected DPP to prune to 1 partition but got ${icebergScans.head.numPartitions}") + + spark.sql("DROP TABLE aqe_cat.db.two_join_fact") + } + } + } + + test("AQE DPP - graceful fallback when broadcast join is not Comet") { + assume(icebergAvailable, "Iceberg not available") + withTempIcebergDir { warehouseDir => + val dimDir = new File(warehouseDir, "dim_parquet") + withSQLConf( + "spark.sql.catalog.aqe_cat" -> "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.aqe_cat.type" -> "hadoop", + "spark.sql.catalog.aqe_cat.warehouse" -> warehouseDir.getAbsolutePath, + "spark.sql.autoBroadcastJoinThreshold" -> "1KB", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true", + // Disable Comet BHJ so the join stays as Spark's BroadcastHashJoinExec. + // The rule can't find CometBroadcastHashJoinExec and must handle gracefully. + CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.key -> "false", + CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.key -> "false") { + + spark.sql(""" + CREATE TABLE aqe_cat.db.fallback_fact ( + id BIGINT, data STRING, date DATE + ) USING iceberg PARTITIONED BY (date) + """) + spark.sql(""" + INSERT INTO aqe_cat.db.fallback_fact VALUES + (1, 'a', DATE '1970-01-01'), (2, 'b', DATE '1970-01-02'), + (3, 'c', DATE '1970-01-02'), (4, 'd', DATE '1970-01-03') + """) + + spark + .createDataFrame(Seq((1L, java.sql.Date.valueOf("1970-01-02")))) + .toDF("id", "date") + .write + .parquet(dimDir.getAbsolutePath) + spark.read.parquet(dimDir.getAbsolutePath).createOrReplaceTempView("fallback_dim") + + // Query should still produce correct results even without Comet BHJ + val query = + """SELECT /*+ BROADCAST(d) */ f.* FROM aqe_cat.db.fallback_fact f + |JOIN fallback_dim d ON f.date = d.date AND d.id = 1 + |ORDER BY f.id""".stripMargin + checkSparkAnswer(query) + + spark.sql("DROP TABLE aqe_cat.db.fallback_fact") + } + } + } + + test("AQE DPP - empty broadcast result prunes all partitions") { + assume(icebergAvailable, "Iceberg not available") + withTempIcebergDir { warehouseDir => + val dimDir = new File(warehouseDir, "dim_parquet") + withSQLConf( + "spark.sql.catalog.aqe_cat" -> "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.aqe_cat.type" -> "hadoop", + "spark.sql.catalog.aqe_cat.warehouse" -> warehouseDir.getAbsolutePath, + "spark.sql.autoBroadcastJoinThreshold" -> "1KB", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") { + + spark.sql(""" + CREATE TABLE aqe_cat.db.empty_dpp_fact ( + id BIGINT, data STRING, date DATE + ) USING iceberg PARTITIONED BY (date) + """) + spark.sql(""" + INSERT INTO aqe_cat.db.empty_dpp_fact VALUES + (1, 'a', DATE '1970-01-01'), (2, 'b', DATE '1970-01-02'), + (3, 'c', DATE '1970-01-03') + """) + + // Dim table with a value that matches NO fact partitions + spark + .createDataFrame(Seq((1L, java.sql.Date.valueOf("2099-12-31")))) + .toDF("id", "date") + .write + .parquet(dimDir.getAbsolutePath) + spark.read.parquet(dimDir.getAbsolutePath).createOrReplaceTempView("empty_dim") + + val query = + """SELECT /*+ BROADCAST(d) */ f.* FROM aqe_cat.db.empty_dpp_fact f + |JOIN empty_dim d ON f.date = d.date AND d.id = 1""".stripMargin + + // Should return empty result — DPP prunes all partitions + val result = spark.sql(query).collect() + assert(result.isEmpty, s"Expected empty result but got ${result.length} rows") + + val (_, cometPlan) = checkSparkAnswer(query) + + // Verify the rule still converted the SAB + val subqueries = collectIcebergDPPSubqueries(cometPlan) + subqueries.foreach { sub => + assert( + sub.isInstanceOf[CometSubqueryBroadcastExec], + s"Expected CometSubqueryBroadcastExec but got ${sub.getClass.getSimpleName}") + } + + spark.sql("DROP TABLE aqe_cat.db.empty_dpp_fact") + } + } + } + + test("AQE DPP - no broadcast join (SMJ) handles SAB gracefully") { + assume(icebergAvailable, "Iceberg not available") + withTempIcebergDir { warehouseDir => + val dimDir = new File(warehouseDir, "dim_parquet") + withSQLConf( + "spark.sql.catalog.aqe_cat" -> "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.aqe_cat.type" -> "hadoop", + "spark.sql.catalog.aqe_cat.warehouse" -> warehouseDir.getAbsolutePath, + // Disable broadcast to force sort-merge join — no broadcast join for DPP to reuse + "spark.sql.autoBroadcastJoinThreshold" -> "-1", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") { + + spark.sql(""" + CREATE TABLE aqe_cat.db.smj_fact ( + id BIGINT, data STRING, date DATE + ) USING iceberg PARTITIONED BY (date) + """) + spark.sql(""" + INSERT INTO aqe_cat.db.smj_fact VALUES + (1, 'a', DATE '1970-01-01'), (2, 'b', DATE '1970-01-02'), + (3, 'c', DATE '1970-01-02'), (4, 'd', DATE '1970-01-03') + """) + + spark + .createDataFrame(Seq((1L, java.sql.Date.valueOf("1970-01-02")))) + .toDF("id", "date") + .write + .parquet(dimDir.getAbsolutePath) + spark.read.parquet(dimDir.getAbsolutePath).createOrReplaceTempView("smj_dim") + + // No BROADCAST hint + threshold=-1 forces SMJ. DPP may still create SABs + // but there's no broadcast join for our rule to find. + val query = + """SELECT f.* FROM aqe_cat.db.smj_fact f + |JOIN smj_dim d ON f.date = d.date + |WHERE d.id = 1 + |ORDER BY f.id""".stripMargin + + // Should produce correct results regardless of DPP path + checkSparkAnswer(query) + + spark.sql("DROP TABLE aqe_cat.db.smj_fact") + } + } + } } From dd619105534ed4b28fe5da4649468cc196f2bb36 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 22 Apr 2026 13:14:58 -0400 Subject: [PATCH 2/5] Fix AQE DPP - empty broadcast result prunes all partitions on spark 3.4? --- .../comet/serde/operator/CometIcebergNativeScan.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala index 3f240b11f8..4fea4b463a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala @@ -805,7 +805,8 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val toJsonMethod = schemaParserClass.getMethod("toJson", schemaClass) toJsonMethod.setAccessible(true) - // Access inputRDD - safe now, DPP is resolved + // Access inputRDD - safe now, DPP is resolved. + // When DPP prunes all partitions, inputRDD may be an EmptyRDD (not DataSourceRDD). scanExec.inputRDD match { case rdd: DataSourceRDD => val partitions = rdd.partitions @@ -986,7 +987,8 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit perPartitionBuilders += partitionBuilder.build() } case _ => - throw new IllegalStateException("Expected DataSourceRDD from BatchScanExec") + // Empty inputRDD (e.g., DPP pruned all partitions) — return empty serialization data + logDebug("BatchScanExec inputRDD is not DataSourceRDD (likely empty after DPP pruning)") } // Log deduplication summary From dffcb42a91b0da300a0a41ed42fe7ce58c459e01 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 22 Apr 2026 14:51:36 -0400 Subject: [PATCH 3/5] strengthen tests --- .../comet/CometIcebergNativeSuite.scala | 91 +++++++++++-------- 1 file changed, 52 insertions(+), 39 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala index 19a7ca448e..054e86ced1 100644 --- a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala @@ -26,7 +26,8 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometIcebergNativeScanExec, CometSubqueryBroadcastExec} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometIcebergNativeScanExec, CometSubqueryBroadcastExec} import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, BroadcastQueryStageExec} import org.apache.spark.sql.internal.SQLConf @@ -2795,10 +2796,6 @@ class CometIcebergNativeSuite // ---- AQE DPP broadcast reuse tests ---- - /** - * Collects DPP subquery plans from CometIcebergNativeScanExec's originalPlan.runtimeFilters. - * These are hidden from the normal plan expression tree (unlike BatchScanExec). - */ private def collectIcebergDPPSubqueries(plan: SparkPlan): Seq[SparkPlan] = { collect(plan) { case scan: CometIcebergNativeScanExec => scan } .filter(_.originalPlan != null) @@ -2808,6 +2805,14 @@ class CometIcebergNativeSuite } } + /** Extracts the broadcast-side child from a CometBroadcastHashJoinExec. */ + private def broadcastChild(join: CometBroadcastHashJoinExec): SparkPlan = { + join.buildSide match { + case BuildLeft => join.left + case BuildRight => join.right + } + } + test("AQE DPP - CometSubqueryBroadcastExec replaces SubqueryAdaptiveBroadcastExec") { assume(icebergAvailable, "Iceberg not available") withTempIcebergDir { warehouseDir => @@ -2854,25 +2859,21 @@ class CometIcebergNativeSuite s"Expected CometSubqueryBroadcastExec but got ${sub.getClass.getSimpleName}") } - // Verify broadcast reuse: child should be BroadcastQueryStageExec (the join's - // already-materialized broadcast), not a standalone exchange + // Verify broadcast reuse: subquery child should be the same BroadcastQueryStageExec + // instance that the join uses (reference equality, not just same type) + val joinBroadcastStage = collect(cometPlan) { + case j: CometBroadcastHashJoinExec => broadcastChild(j) + }.collectFirst { case b: BroadcastQueryStageExec => b }.get + subqueries.foreach { case csb: CometSubqueryBroadcastExec => assert( - csb.child.isInstanceOf[BroadcastQueryStageExec], - "Expected BroadcastQueryStageExec child (broadcast reuse) but got " + - s"${csb.child.getClass.getSimpleName}") + csb.child eq joinBroadcastStage, + "DPP subquery child should be the same BroadcastQueryStageExec instance " + + "as the join's broadcast side (eq check failed)") case _ => } - // Verify only 1 CometBroadcastExchangeExec (shared between join and DPP) - val broadcasts = collectWithSubqueries(cometPlan) { case e: CometBroadcastExchangeExec => - e - } - assert( - broadcasts.size == 1, - s"Expected 1 CometBroadcastExchangeExec (reused) but got ${broadcasts.size}") - // Verify correct results and partition pruning val icebergScans = collectIcebergNativeScans(cometPlan) assert(icebergScans.nonEmpty, "Expected CometIcebergNativeScanExec in plan") @@ -2941,21 +2942,19 @@ class CometIcebergNativeSuite s"Expected CometSubqueryBroadcastExec but got ${sub.getClass.getSimpleName}") } - // Both should reuse the same BroadcastQueryStageExec - val stages = subqueries.collect { case csb: CometSubqueryBroadcastExec => - csb.child - } - assert( - stages.forall(_.isInstanceOf[BroadcastQueryStageExec]), - "All DPP subqueries should reuse the join's BroadcastQueryStageExec") + // Both should reuse the exact same BroadcastQueryStageExec instance from the join + val joinBroadcastStage = collect(cometPlan) { + case j: CometBroadcastHashJoinExec => broadcastChild(j) + }.collectFirst { case b: BroadcastQueryStageExec => b }.get - // Only 1 broadcast exchange in the plan - val broadcasts = collectWithSubqueries(cometPlan) { case e: CometBroadcastExchangeExec => - e + subqueries.foreach { + case csb: CometSubqueryBroadcastExec => + assert( + csb.child eq joinBroadcastStage, + "Both DPP subqueries should reuse the same BroadcastQueryStageExec " + + "instance as the join (eq check failed)") + case _ => } - assert( - broadcasts.size == 1, - s"Expected 1 CometBroadcastExchangeExec but got ${broadcasts.size}") spark.sql("DROP TABLE aqe_cat.db.multi_dpp_reuse") } @@ -3030,16 +3029,30 @@ class CometIcebergNativeSuite s"Expected CometSubqueryBroadcastExec but got ${sub.getClass.getSimpleName}") } - // The two subqueries should reference DIFFERENT broadcast stages + // Each subquery should reuse the broadcast stage from the CORRECT join + // (not mixed up). Collect join broadcast stages keyed by their broadcast's exprId set. + val joinStages = collect(cometPlan) { + case j: CometBroadcastHashJoinExec => j + }.collect { case j if broadcastChild(j).isInstanceOf[BroadcastQueryStageExec] => + broadcastChild(j).asInstanceOf[BroadcastQueryStageExec] + } + + val subqueryCsbs = subqueries.collect { case csb: CometSubqueryBroadcastExec => csb } + subqueryCsbs.foreach { csb => + assert( + joinStages.exists(_ eq csb.child), + s"DPP subquery child should be eq to one of the join's BroadcastQueryStageExec " + + s"instances, but was not found") + } + + // The subqueries should reference DIFFERENT broadcast stages // (one for date_dim, one for cat_dim) - val stages = subqueries.collect { case csb: CometSubqueryBroadcastExec => - System.identityHashCode(csb.child) - }.distinct - if (subqueries.size >= 2) { + if (subqueryCsbs.size >= 2) { + val distinctChildren = subqueryCsbs.map(_.child).distinct assert( - stages.size == subqueries.size, - s"Expected ${subqueries.size} distinct broadcast stages but got ${stages.size}. " + - "buildKeys disambiguation may not be working.") + distinctChildren.size == subqueryCsbs.size, + s"Expected ${subqueryCsbs.size} distinct broadcast stages but got " + + s"${distinctChildren.size}. buildKeys disambiguation may not be working.") } // Verify correct results: date=2024-01-02 AND category=A → row (3, 2024-01-02, A, 30) From 52157da17ad76b16322d7b429d495658e3311352 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 22 Apr 2026 15:19:18 -0400 Subject: [PATCH 4/5] format --- .../apache/comet/CometIcebergNativeSuite.scala | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala index 054e86ced1..5e28c2db94 100644 --- a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala @@ -2861,8 +2861,8 @@ class CometIcebergNativeSuite // Verify broadcast reuse: subquery child should be the same BroadcastQueryStageExec // instance that the join uses (reference equality, not just same type) - val joinBroadcastStage = collect(cometPlan) { - case j: CometBroadcastHashJoinExec => broadcastChild(j) + val joinBroadcastStage = collect(cometPlan) { case j: CometBroadcastHashJoinExec => + broadcastChild(j) }.collectFirst { case b: BroadcastQueryStageExec => b }.get subqueries.foreach { @@ -2943,8 +2943,8 @@ class CometIcebergNativeSuite } // Both should reuse the exact same BroadcastQueryStageExec instance from the join - val joinBroadcastStage = collect(cometPlan) { - case j: CometBroadcastHashJoinExec => broadcastChild(j) + val joinBroadcastStage = collect(cometPlan) { case j: CometBroadcastHashJoinExec => + broadcastChild(j) }.collectFirst { case b: BroadcastQueryStageExec => b }.get subqueries.foreach { @@ -3031,10 +3031,11 @@ class CometIcebergNativeSuite // Each subquery should reuse the broadcast stage from the CORRECT join // (not mixed up). Collect join broadcast stages keyed by their broadcast's exprId set. - val joinStages = collect(cometPlan) { - case j: CometBroadcastHashJoinExec => j - }.collect { case j if broadcastChild(j).isInstanceOf[BroadcastQueryStageExec] => - broadcastChild(j).asInstanceOf[BroadcastQueryStageExec] + val joinStages = collect(cometPlan) { case j: CometBroadcastHashJoinExec => + j + }.collect { + case j if broadcastChild(j).isInstanceOf[BroadcastQueryStageExec] => + broadcastChild(j).asInstanceOf[BroadcastQueryStageExec] } val subqueryCsbs = subqueries.collect { case csb: CometSubqueryBroadcastExec => csb } From 9c71d9a43d375ac142f23e5e23da3c43219c8c24 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 22 Apr 2026 16:06:32 -0400 Subject: [PATCH 5/5] format --- .../scala/org/apache/comet/CometIcebergNativeSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala index 5e28c2db94..77b6aad71b 100644 --- a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} -import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometIcebergNativeScanExec, CometSubqueryBroadcastExec} +import org.apache.spark.sql.comet.{CometBroadcastHashJoinExec, CometIcebergNativeScanExec, CometSubqueryBroadcastExec} import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, BroadcastQueryStageExec} import org.apache.spark.sql.internal.SQLConf @@ -3042,8 +3042,8 @@ class CometIcebergNativeSuite subqueryCsbs.foreach { csb => assert( joinStages.exists(_ eq csb.child), - s"DPP subquery child should be eq to one of the join's BroadcastQueryStageExec " + - s"instances, but was not found") + "DPP subquery child should be eq to one of the join's BroadcastQueryStageExec " + + "instances, but was not found") } // The subqueries should reference DIFFERENT broadcast stages