@@ -2234,7 +2234,7 @@ index a3cfdc5a240..3793b6191bf 100644
22342234 })
22352235 checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0)))
22362236diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
2237- index 272be70f9fe..06957694002 100644
2237+ index 272be70f9fe..d38a6d41a47 100644
22382238--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
22392239+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
22402240@@ -28,12 +28,14 @@ import org.apache.spark.SparkException
@@ -2261,7 +2261,7 @@ index 272be70f9fe..06957694002 100644
22612261 }
22622262 }
22632263
2264- @@ -131,30 +134,39 @@ class AdaptiveQueryExecSuite
2264+ @@ -131,36 +134,46 @@ class AdaptiveQueryExecSuite
22652265 private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
22662266 collect(plan) {
22672267 case j: SortMergeJoinExec => j
@@ -2301,15 +2301,22 @@ index 272be70f9fe..06957694002 100644
23012301 }
23022302 }
23032303
2304- @@ -204,6 +216,7 @@ class AdaptiveQueryExecSuite
2304+ private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = {
2305+ collect(plan) {
2306+ case l: CollectLimitExec => l
2307+ + case l: CometCollectLimitExec => l.originalPlan.asInstanceOf[CollectLimitExec]
2308+ }
2309+ }
2310+
2311+ @@ -204,6 +217,7 @@ class AdaptiveQueryExecSuite
23052312 val parts = rdd.partitions
23062313 assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
23072314 }
23082315+
23092316 assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead))
23102317 }
23112318
2312- @@ -212,7 +225 ,7 @@ class AdaptiveQueryExecSuite
2319+ @@ -212,7 +226 ,7 @@ class AdaptiveQueryExecSuite
23132320 val plan = df.queryExecution.executedPlan
23142321 assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
23152322 val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect {
@@ -2318,7 +2325,7 @@ index 272be70f9fe..06957694002 100644
23182325 }
23192326 assert(shuffle.size == 1)
23202327 assert(shuffle(0).outputPartitioning.numPartitions == numPartition)
2321- @@ -228,7 +241 ,8 @@ class AdaptiveQueryExecSuite
2328+ @@ -228,7 +242 ,8 @@ class AdaptiveQueryExecSuite
23222329 assert(smj.size == 1)
23232330 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
23242331 assert(bhj.size == 1)
@@ -2328,7 +2335,7 @@ index 272be70f9fe..06957694002 100644
23282335 }
23292336 }
23302337
2331- @@ -255,7 +269 ,8 @@ class AdaptiveQueryExecSuite
2338+ @@ -255,7 +270 ,8 @@ class AdaptiveQueryExecSuite
23322339 }
23332340 }
23342341
@@ -2338,7 +2345,7 @@ index 272be70f9fe..06957694002 100644
23382345 withSQLConf(
23392346 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
23402347 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
2341- @@ -287,7 +302 ,8 @@ class AdaptiveQueryExecSuite
2348+ @@ -287,7 +303 ,8 @@ class AdaptiveQueryExecSuite
23422349 }
23432350 }
23442351
@@ -2348,7 +2355,7 @@ index 272be70f9fe..06957694002 100644
23482355 withSQLConf(
23492356 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
23502357 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
2351- @@ -301,7 +317 ,8 @@ class AdaptiveQueryExecSuite
2358+ @@ -301,7 +318 ,8 @@ class AdaptiveQueryExecSuite
23522359 val localReads = collect(adaptivePlan) {
23532360 case read: AQEShuffleReadExec if read.isLocalRead => read
23542361 }
@@ -2358,7 +2365,7 @@ index 272be70f9fe..06957694002 100644
23582365 val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD]
23592366 val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD]
23602367 // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2
2361- @@ -326,7 +343 ,9 @@ class AdaptiveQueryExecSuite
2368+ @@ -326,7 +344 ,9 @@ class AdaptiveQueryExecSuite
23622369 .groupBy($"a").count()
23632370 checkAnswer(testDf, Seq())
23642371 val plan = testDf.queryExecution.executedPlan
@@ -2369,7 +2376,7 @@ index 272be70f9fe..06957694002 100644
23692376 val coalescedReads = collect(plan) {
23702377 case r: AQEShuffleReadExec => r
23712378 }
2372- @@ -340,7 +359 ,9 @@ class AdaptiveQueryExecSuite
2379+ @@ -340,7 +360 ,9 @@ class AdaptiveQueryExecSuite
23732380 .groupBy($"a").count()
23742381 checkAnswer(testDf, Seq())
23752382 val plan = testDf.queryExecution.executedPlan
@@ -2380,7 +2387,7 @@ index 272be70f9fe..06957694002 100644
23802387 val coalescedReads = collect(plan) {
23812388 case r: AQEShuffleReadExec => r
23822389 }
2383- @@ -350,7 +371 ,7 @@ class AdaptiveQueryExecSuite
2390+ @@ -350,7 +372 ,7 @@ class AdaptiveQueryExecSuite
23842391 }
23852392 }
23862393
@@ -2389,7 +2396,7 @@ index 272be70f9fe..06957694002 100644
23892396 withSQLConf(
23902397 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
23912398 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
2392- @@ -365,7 +386 ,7 @@ class AdaptiveQueryExecSuite
2399+ @@ -365,7 +387 ,7 @@ class AdaptiveQueryExecSuite
23932400 }
23942401 }
23952402
@@ -2398,7 +2405,7 @@ index 272be70f9fe..06957694002 100644
23982405 withSQLConf(
23992406 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
24002407 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
2401- @@ -381,7 +402 ,7 @@ class AdaptiveQueryExecSuite
2408+ @@ -381,7 +403 ,7 @@ class AdaptiveQueryExecSuite
24022409 }
24032410 }
24042411
@@ -2407,7 +2414,7 @@ index 272be70f9fe..06957694002 100644
24072414 withSQLConf(
24082415 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
24092416 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
2410- @@ -426,7 +447 ,7 @@ class AdaptiveQueryExecSuite
2417+ @@ -426,7 +448 ,7 @@ class AdaptiveQueryExecSuite
24112418 }
24122419 }
24132420
@@ -2416,7 +2423,7 @@ index 272be70f9fe..06957694002 100644
24162423 withSQLConf(
24172424 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
24182425 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
2419- @@ -471,7 +492 ,7 @@ class AdaptiveQueryExecSuite
2426+ @@ -471,7 +493 ,7 @@ class AdaptiveQueryExecSuite
24202427 }
24212428 }
24222429
@@ -2425,7 +2432,7 @@ index 272be70f9fe..06957694002 100644
24252432 withSQLConf(
24262433 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
24272434 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
2428- @@ -517,7 +538 ,7 @@ class AdaptiveQueryExecSuite
2435+ @@ -517,7 +539 ,7 @@ class AdaptiveQueryExecSuite
24292436 }
24302437 }
24312438
@@ -2434,7 +2441,7 @@ index 272be70f9fe..06957694002 100644
24342441 withSQLConf(
24352442 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
24362443 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
2437- @@ -536,7 +557 ,7 @@ class AdaptiveQueryExecSuite
2444+ @@ -536,7 +558 ,7 @@ class AdaptiveQueryExecSuite
24382445 }
24392446 }
24402447
@@ -2443,7 +2450,7 @@ index 272be70f9fe..06957694002 100644
24432450 withSQLConf(
24442451 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
24452452 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
2446- @@ -567,7 +588 ,9 @@ class AdaptiveQueryExecSuite
2453+ @@ -567,7 +589 ,9 @@ class AdaptiveQueryExecSuite
24472454 assert(smj.size == 1)
24482455 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
24492456 assert(bhj.size == 1)
@@ -2454,7 +2461,7 @@ index 272be70f9fe..06957694002 100644
24542461 // Even with local shuffle read, the query stage reuse can also work.
24552462 val ex = findReusedExchange(adaptivePlan)
24562463 assert(ex.nonEmpty)
2457- @@ -588,7 +611 ,9 @@ class AdaptiveQueryExecSuite
2464+ @@ -588,7 +612 ,9 @@ class AdaptiveQueryExecSuite
24582465 assert(smj.size == 1)
24592466 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
24602467 assert(bhj.size == 1)
@@ -2465,7 +2472,7 @@ index 272be70f9fe..06957694002 100644
24652472 // Even with local shuffle read, the query stage reuse can also work.
24662473 val ex = findReusedExchange(adaptivePlan)
24672474 assert(ex.isEmpty)
2468- @@ -597,7 +622 ,8 @@ class AdaptiveQueryExecSuite
2475+ @@ -597,7 +623 ,8 @@ class AdaptiveQueryExecSuite
24692476 }
24702477 }
24712478
@@ -2475,7 +2482,7 @@ index 272be70f9fe..06957694002 100644
24752482 withSQLConf(
24762483 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
24772484 SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
2478- @@ -692,7 +718 ,8 @@ class AdaptiveQueryExecSuite
2485+ @@ -692,7 +719 ,8 @@ class AdaptiveQueryExecSuite
24792486 val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
24802487 assert(bhj.size == 1)
24812488 // There is still a SMJ, and its two shuffles can't apply local read.
@@ -2485,7 +2492,7 @@ index 272be70f9fe..06957694002 100644
24852492 }
24862493 }
24872494
2488- @@ -814,7 +841 ,8 @@ class AdaptiveQueryExecSuite
2495+ @@ -814,7 +842 ,8 @@ class AdaptiveQueryExecSuite
24892496 }
24902497 }
24912498
@@ -2495,7 +2502,7 @@ index 272be70f9fe..06957694002 100644
24952502 Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
24962503 def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") {
24972504 findTopLevelSortMergeJoin(plan)
2498- @@ -1087,7 +1115 ,8 @@ class AdaptiveQueryExecSuite
2505+ @@ -1087,7 +1116 ,8 @@ class AdaptiveQueryExecSuite
24992506 }
25002507 }
25012508
@@ -2505,7 +2512,7 @@ index 272be70f9fe..06957694002 100644
25052512 withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
25062513 val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
25072514 "SELECT key FROM testData GROUP BY key")
2508- @@ -1721,7 +1750 ,7 @@ class AdaptiveQueryExecSuite
2515+ @@ -1721,7 +1751 ,7 @@ class AdaptiveQueryExecSuite
25092516 val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
25102517 "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id")
25112518 assert(collect(adaptivePlan) {
@@ -2514,7 +2521,7 @@ index 272be70f9fe..06957694002 100644
25142521 }.length == 1)
25152522 }
25162523 }
2517- @@ -1801,7 +1830 ,8 @@ class AdaptiveQueryExecSuite
2524+ @@ -1801,7 +1831 ,8 @@ class AdaptiveQueryExecSuite
25182525 }
25192526 }
25202527
@@ -2524,7 +2531,7 @@ index 272be70f9fe..06957694002 100644
25242531 def hasRepartitionShuffle(plan: SparkPlan): Boolean = {
25252532 find(plan) {
25262533 case s: ShuffleExchangeLike =>
2527- @@ -1986,6 +2016 ,9 @@ class AdaptiveQueryExecSuite
2534+ @@ -1986,6 +2017 ,9 @@ class AdaptiveQueryExecSuite
25282535 def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = {
25292536 assert(collect(ds.queryExecution.executedPlan) {
25302537 case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
@@ -2534,7 +2541,7 @@ index 272be70f9fe..06957694002 100644
25342541 }.size == 1)
25352542 ds.collect()
25362543 val plan = ds.queryExecution.executedPlan
2537- @@ -1994,6 +2027 ,9 @@ class AdaptiveQueryExecSuite
2544+ @@ -1994,6 +2028 ,9 @@ class AdaptiveQueryExecSuite
25382545 }.isEmpty)
25392546 assert(collect(plan) {
25402547 case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
@@ -2544,7 +2551,7 @@ index 272be70f9fe..06957694002 100644
25442551 }.size == 1)
25452552 checkAnswer(ds, testData)
25462553 }
2547- @@ -2150,7 +2186 ,8 @@ class AdaptiveQueryExecSuite
2554+ @@ -2150,7 +2187 ,8 @@ class AdaptiveQueryExecSuite
25482555 }
25492556 }
25502557
@@ -2554,7 +2561,7 @@ index 272be70f9fe..06957694002 100644
25542561 withTempView("t1", "t2") {
25552562 def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
25562563 Seq("100", "100000").foreach { size =>
2557- @@ -2236,7 +2273 ,8 @@ class AdaptiveQueryExecSuite
2564+ @@ -2236,7 +2274 ,8 @@ class AdaptiveQueryExecSuite
25582565 }
25592566 }
25602567
@@ -2564,7 +2571,7 @@ index 272be70f9fe..06957694002 100644
25642571 withTempView("v") {
25652572 withSQLConf(
25662573 SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
2567- @@ -2335,7 +2373 ,7 @@ class AdaptiveQueryExecSuite
2574+ @@ -2335,7 +2374 ,7 @@ class AdaptiveQueryExecSuite
25682575 runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
25692576 s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
25702577 val shuffles1 = collect(adaptive1) {
@@ -2573,7 +2580,7 @@ index 272be70f9fe..06957694002 100644
25732580 }
25742581 assert(shuffles1.size == 3)
25752582 // shuffles1.head is the top-level shuffle under the Aggregate operator
2576- @@ -2348,7 +2386 ,7 @@ class AdaptiveQueryExecSuite
2583+ @@ -2348,7 +2387 ,7 @@ class AdaptiveQueryExecSuite
25772584 runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
25782585 s"JOIN skewData2 ON key1 = key2")
25792586 val shuffles2 = collect(adaptive2) {
@@ -2582,7 +2589,7 @@ index 272be70f9fe..06957694002 100644
25822589 }
25832590 if (hasRequiredDistribution) {
25842591 assert(shuffles2.size == 3)
2585- @@ -2382,7 +2420 ,8 @@ class AdaptiveQueryExecSuite
2592+ @@ -2382,7 +2421 ,8 @@ class AdaptiveQueryExecSuite
25862593 }
25872594 }
25882595
@@ -2592,16 +2599,6 @@ index 272be70f9fe..06957694002 100644
25922599 CostEvaluator.instantiate(
25932600 classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
25942601 intercept[IllegalArgumentException] {
2595- @@ -2513,7 +2552,8 @@ class AdaptiveQueryExecSuite
2596- }
2597-
2598- test("SPARK-48037: Fix SortShuffleWriter lacks shuffle write related metrics " +
2599- - "resulting in potentially inaccurate data") {
2600- + "resulting in potentially inaccurate data",
2601- + IgnoreComet("too many shuffle partitions causes Java heap OOM")) {
2602- withTable("t3") {
2603- withSQLConf(
2604- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
26052602@@ -2548,6 +2588,7 @@ class AdaptiveQueryExecSuite
26062603 val (_, adaptive) = runAdaptiveAndVerifyResult(query)
26072604 assert(adaptive.collect {
0 commit comments