diff --git a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorterAsync.java b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorterAsync.java index 33bcbc9cac..f5e3d9b686 100644 --- a/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorterAsync.java +++ b/spark/src/main/java/org/apache/spark/shuffle/sort/CometShuffleExternalSorterAsync.java @@ -96,7 +96,7 @@ public final class CometShuffleExternalSorterAsync private final LinkedList spills = new LinkedList<>(); /** Peak memory used by this sorter so far, in bytes. */ - private long peakMemoryUsedBytes; + private volatile long peakMemoryUsedBytes; // Checksum calculator for each partition. Empty when shuffle checksum disabled. private final long[] partitionChecksums; @@ -152,8 +152,16 @@ public CometShuffleExternalSorterAsync( this.tracingEnabled = (boolean) CometConf$.MODULE$.COMET_TRACING_ENABLED().get(); this.threadNum = (int) CometConf$.MODULE$.COMET_COLUMNAR_SHUFFLE_ASYNC_THREAD_NUM().get(); - assert (this.threadNum > 0); + if (this.threadNum <= 0) { + throw new IllegalArgumentException( + "spark.comet.columnar.shuffle.async.thread.num must be positive, got: " + this.threadNum); + } this.threadPool = ShuffleThreadPool.getThreadPool(); + if (this.threadPool == null) { + throw new IllegalStateException( + "Async shuffle thread pool is not initialized. " + + "Ensure spark.comet.columnar.shuffle.async.enabled is true."); + } this.preferDictionaryRatio = (double) CometConf$.MODULE$.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO().get(); @@ -215,10 +223,21 @@ public void spill() throws IOException { SpillSorter spillingSorter = activeSpillSorter; Callable task = () -> { - spillingSorter.writeSortedFileNative(false, tracingEnabled); - final long spillSize = spillingSorter.freeMemory(); - spillingSorter.freeArray(); - spillingSorters.remove(spillingSorter); + long spillSize = 0; + try { + spillingSorter.writeSortedFileNative(false, tracingEnabled); + spillSize = spillingSorter.freeMemory(); + } finally { + // Ensure cleanup happens even if writeSortedFileNative() throws. + // freeMemory() may have already been called above, but it's safe to call again + // (returns 0 if already freed). freeArray() must be called to release the pointer + // array. + if (spillSize == 0) { + spillSize = spillingSorter.freeMemory(); + } + spillingSorter.freeArray(); + spillingSorters.remove(spillingSorter); + } // Reset the in-memory sorter's pointer array only after freeing up the memory pages // holding the records. Otherwise, if the task is over allocated memory, then without @@ -233,11 +252,20 @@ public void spill() throws IOException { spillingSorters.add(spillingSorter); asyncSpillTasks.add(threadPool.submit(task)); - while (asyncSpillTasks.size() == threadNum) { - for (Future spillingTask : asyncSpillTasks) { - if (spillingTask.isDone()) { - asyncSpillTasks.remove(spillingTask); - break; + // If we've reached the max concurrent spill tasks, block until one completes. + // This provides backpressure to avoid unbounded memory growth. + while (asyncSpillTasks.size() >= threadNum) { + Future oldestTask = asyncSpillTasks.peek(); + if (oldestTask != null) { + try { + oldestTask.get(); // Block until the oldest task completes + asyncSpillTasks.poll(); // Remove the completed task + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while waiting for spill task", e); + } catch (ExecutionException e) { + asyncSpillTasks.poll(); // Remove the failed task + throw new IOException("Async spill task failed", e.getCause()); } } } @@ -288,6 +316,23 @@ private long freeMemory() { /** Force all memory and spill files to be deleted; called by shuffle error-handling code. */ @Override public void cleanupResources() { + // Cancel any pending async spill tasks to stop background work. + // The tasks have try-finally blocks that will clean up their SpillSorter resources. + for (Future task : asyncSpillTasks) { + task.cancel(true); + } + + // Wait briefly for cancelled tasks to complete their cleanup. + // This ensures SpillSorters are removed from spillingSorters before we iterate it. + for (Future task : asyncSpillTasks) { + try { + task.get(100, TimeUnit.MILLISECONDS); + } catch (Exception e) { + // Ignore - task was cancelled or failed, we're cleaning up anyway + } + } + asyncSpillTasks.clear(); + freeMemory(); for (SpillInfo spill : spills) { @@ -383,23 +428,38 @@ public SpillInfo[] closeAndGetSpills() throws IOException { final TempShuffleBlockId blockId = spilledFileInfo._1(); final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); - // Waits for all async tasks to finish. + // Waits for all async tasks to finish, collecting any exceptions. + // We wait for all tasks even if some fail to ensure proper cleanup. + IOException firstException = null; for (Future task : asyncSpillTasks) { try { task.get(); - } catch (Exception e) { - throw new IOException(e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + if (firstException == null) { + firstException = new IOException("Interrupted while waiting for spill tasks", e); + } + } catch (ExecutionException e) { + if (firstException == null) { + firstException = new IOException("Async spill task failed", e.getCause()); + } else { + firstException.addSuppressed(e.getCause()); + } } } asyncSpillTasks.clear(); + if (firstException != null) { + throw firstException; + } + activeSpillSorter.setSpillInfo(spillInfo); activeSpillSorter.writeSortedFileNative(true, tracingEnabled); freeMemory(); } - return spills.toArray(new SpillInfo[spills.size()]); + return spills.toArray(new SpillInfo[0]); } }