diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeRDD.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeRDD.scala index 11c62185f..f2629241a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeRDD.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeRDD.scala @@ -67,10 +67,12 @@ class NativeRDD( override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val computingNativePlan = nativePlanWrapper.plan(split, context) - // SPARK-44605: Spark 4+ refines ShuffleWriteProcessor API (early execution of NativeRDD.ShuffleWrite iterator) - // Adaptation for Spark 4.x: Defer NativeRDD.ShuffleWrite execution to ShuffleWriteProcessor.write() to align with Spark 3.x logic - if (SparkVersionUtil.isSparkV40OrGreater && - computingNativePlan.getPhysicalPlanTypeCase == PhysicalPlanNode.PhysicalPlanTypeCase.SHUFFLE_WRITER) { + // SPARK-44605: Spark 4+ refines ShuffleWriteProcessor API (early execution of native + // shuffle-writer iterators). Adaptation for Spark 4.x: defer both native shuffle-writer + // plan types (SHUFFLE_WRITER and RSS_SHUFFLE_WRITER) to ShuffleWriteProcessor.write() to + // align with Spark 3.x logic. + if (SparkVersionUtil.isSparkV40OrGreater && (computingNativePlan.getPhysicalPlanTypeCase == PhysicalPlanNode.PhysicalPlanTypeCase.SHUFFLE_WRITER + || computingNativePlan.getPhysicalPlanTypeCase == PhysicalPlanNode.PhysicalPlanTypeCase.RSS_SHUFFLE_WRITER)) { Iterator.empty } else { NativeHelper.executeNativePlan(computingNativePlan, metrics, split, Some(context)) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala index ef8267467..3d124bfd9 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/AuronRssShuffleWriterBase.scala @@ -78,7 +78,7 @@ abstract class AuronRssShuffleWriterBase[K, V](metrics: ShuffleWriteMetricsRepor def rssStop(success: Boolean): Option[MapStatus] - @sparkver("3.2 / 3.3 / 3.4 / 3.5") + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") override def getPartitionLengths(): Array[Long] = rpw.getPartitionLengthMap override def write(records: Iterator[Product2[K, V]]): Unit = { diff --git a/thirdparty/auron-celeborn-0.6/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/celeborn/AuronCelebornShuffleWriter.scala b/thirdparty/auron-celeborn-0.6/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/celeborn/AuronCelebornShuffleWriter.scala index ba5841b16..3ae5b79e8 100644 --- a/thirdparty/auron-celeborn-0.6/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/celeborn/AuronCelebornShuffleWriter.scala +++ b/thirdparty/auron-celeborn-0.6/src/main/scala/org/apache/spark/sql/execution/auron/shuffle/celeborn/AuronCelebornShuffleWriter.scala @@ -63,7 +63,7 @@ class AuronCelebornShuffleWriter[K, V]( celebornPartitionWriter } - @sparkver("3.2 / 3.3 / 3.4 / 3.5") + @sparkver("3.2 / 3.3 / 3.4 / 3.5 / 4.0") override def getPartitionLengths(): Array[Long] = { celebornPartitionWriter.getPartitionLengthMap }