diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 046ccf0b1c..95901254ec 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -534,6 +534,20 @@ object CometConf extends ShimCometConf { .checkValue(v => v > 0, "Write buffer size must be positive") .createWithDefault(1) + val COMET_SHUFFLE_BATCH_SPILL_LIMIT: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.batchSpillLimit") + .category(CATEGORY_SHUFFLE) + .doc( + "Maximum number of input batches buffered before the native shuffle writer " + + "spills to disk, regardless of available memory. This prevents the shuffle writer " + + "from buffering too much data, which can degrade throughput due to poor cache " + + "locality during the final write phase. A value of 0 disables this threshold, " + + "meaning spills only occur when the memory pool is full. " + + "The default is 100.") + .intConf + .checkValue(v => v >= 0, "Batch spill limit must be non-negative") + .createWithDefault(100) + val COMET_SHUFFLE_PREFER_DICTIONARY_RATIO: ConfigEntry[Double] = conf( "spark.comet.shuffle.preferDictionary.ratio") .category(CATEGORY_SHUFFLE) diff --git a/docs/source/user-guide/latest/tuning.md b/docs/source/user-guide/latest/tuning.md index ff9acee1f4..b4658f91a5 100644 --- a/docs/source/user-guide/latest/tuning.md +++ b/docs/source/user-guide/latest/tuning.md @@ -154,6 +154,24 @@ partitioning keys. Columns that are not partitioning keys may contain complex ty Comet Columnar shuffle is JVM-based and supports `HashPartitioning`, `RoundRobinPartitioning`, `RangePartitioning`, and `SinglePartitioning`. This shuffle implementation supports complex data types as partitioning keys. +### Shuffle Spill Tuning + +The native shuffle writer buffers input batches in memory and periodically spills them to disk. Two mechanisms +control when spilling occurs: + +1. **Memory pressure**: When the memory pool rejects an allocation, the writer spills its buffered data to disk. + +2. **Batch spill limit**: The writer also spills after buffering a fixed number of input batches, regardless of + memory availability. This prevents the writer from accumulating too much data, which can degrade throughput + due to poor cache locality during the final write phase. + +The batch spill limit is configured via `spark.comet.exec.shuffle.batchSpillLimit` (default: 100). Setting it +to 0 disables this threshold, meaning spills only occur under memory pressure. + +In most cases, the default value of 100 provides good performance. If you observe that shuffle throughput +decreases when more memory is available to Comet, try lowering this value. If you observe excessive spilling +with small data, try increasing it or disabling it with 0. + ### Shuffle Compression By default, Spark compresses shuffle files using LZ4 compression. Comet overrides this behavior with ZSTD compression. diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index ac35925ace..aabe02b4b8 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1371,6 +1371,7 @@ impl PhysicalPlanner { }?; let write_buffer_size = writer.write_buffer_size as usize; + let batch_spill_limit = writer.batch_spill_limit as usize; let shuffle_writer = Arc::new(ShuffleWriterExec::try_new( Arc::clone(&child.native_plan), partitioning, @@ -1379,6 +1380,7 @@ impl PhysicalPlanner { writer.output_index_file.clone(), writer.tracing_enabled, write_buffer_size, + batch_spill_limit, )?); Ok(( diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index fb438b26a4..c95885cf56 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -294,6 +294,10 @@ message ShuffleWriter { // Size of the write buffer in bytes used when writing shuffle data to disk. // Larger values may improve write performance but use more memory. int32 write_buffer_size = 8; + // Maximum number of buffered batches before the shuffle writer spills to disk, + // regardless of available memory. A value of 0 disables this threshold + // (spills only when the memory pool is full). + int32 batch_spill_limit = 9; } message ParquetWriter { diff --git a/native/shuffle/benches/shuffle_writer.rs b/native/shuffle/benches/shuffle_writer.rs index 27abd919fa..04de8cc297 100644 --- a/native/shuffle/benches/shuffle_writer.rs +++ b/native/shuffle/benches/shuffle_writer.rs @@ -153,6 +153,7 @@ fn create_shuffle_writer_exec( "/tmp/index.out".to_string(), false, 1024 * 1024, + 0, ) .unwrap() } diff --git a/native/shuffle/src/bin/shuffle_bench.rs b/native/shuffle/src/bin/shuffle_bench.rs index bb8c2a0380..ac9ee82e63 100644 --- a/native/shuffle/src/bin/shuffle_bench.rs +++ b/native/shuffle/src/bin/shuffle_bench.rs @@ -114,6 +114,10 @@ struct Args { /// Each task reads the same input and writes to its own output files. #[arg(long, default_value_t = 1)] concurrent_tasks: usize, + + /// Maximum number of buffered batches before spilling (0 = disabled) + #[arg(long, default_value_t = 0)] + batch_spill_limit: usize, } fn main() { @@ -413,6 +417,7 @@ fn run_shuffle_write( args.limit, data_file.to_string(), index_file.to_string(), + args.batch_spill_limit, ) .await .unwrap(); @@ -436,6 +441,7 @@ async fn execute_shuffle_write( limit: usize, data_file: String, index_file: String, + batch_spill_limit: usize, ) -> datafusion::common::Result<(MetricsSet, MetricsSet)> { let config = SessionConfig::new().with_batch_size(batch_size); let mut runtime_builder = RuntimeEnvBuilder::new(); @@ -477,6 +483,7 @@ async fn execute_shuffle_write( index_file, false, write_buffer_size, + batch_spill_limit, ) .expect("Failed to create ShuffleWriterExec"); @@ -541,6 +548,7 @@ fn run_concurrent_shuffle_writes( let memory_limit = args.memory_limit; let write_buffer_size = args.write_buffer_size; let limit = args.limit; + let batch_spill_limit = args.batch_spill_limit; handles.push(tokio::spawn(async move { execute_shuffle_write( @@ -553,6 +561,7 @@ fn run_concurrent_shuffle_writes( limit, data_file, index_file, + batch_spill_limit, ) .await .unwrap() diff --git a/native/shuffle/src/partitioners/multi_partition.rs b/native/shuffle/src/partitioners/multi_partition.rs index 7de9314f54..53c724bd0b 100644 --- a/native/shuffle/src/partitioners/multi_partition.rs +++ b/native/shuffle/src/partitioners/multi_partition.rs @@ -125,6 +125,8 @@ pub(crate) struct MultiPartitionShuffleRepartitioner { tracing_enabled: bool, /// Size of the write buffer in bytes write_buffer_size: usize, + /// Maximum number of buffered batches before spilling, 0 = disabled + batch_spill_limit: usize, } impl MultiPartitionShuffleRepartitioner { @@ -141,6 +143,7 @@ impl MultiPartitionShuffleRepartitioner { codec: CompressionCodec, tracing_enabled: bool, write_buffer_size: usize, + batch_spill_limit: usize, ) -> datafusion::common::Result { let num_output_partitions = partitioning.partition_count(); assert_ne!( @@ -190,6 +193,7 @@ impl MultiPartitionShuffleRepartitioner { reservation, tracing_enabled, write_buffer_size, + batch_spill_limit, }) } @@ -427,7 +431,9 @@ impl MultiPartitionShuffleRepartitioner { mem_growth += after_size.saturating_sub(before_size); } - if self.reservation.try_grow(mem_growth).is_err() { + let over_batch_limit = + self.batch_spill_limit > 0 && self.buffered_batches.len() >= self.batch_spill_limit; + if over_batch_limit || self.reservation.try_grow(mem_growth).is_err() { self.spill()?; } diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs index 8502c79624..94318c648e 100644 --- a/native/shuffle/src/shuffle_writer.rs +++ b/native/shuffle/src/shuffle_writer.rs @@ -67,6 +67,8 @@ pub struct ShuffleWriterExec { tracing_enabled: bool, /// Size of the write buffer in bytes write_buffer_size: usize, + /// Maximum number of buffered batches before spilling, 0 = disabled + batch_spill_limit: usize, } impl ShuffleWriterExec { @@ -80,6 +82,7 @@ impl ShuffleWriterExec { output_index_file: String, tracing_enabled: bool, write_buffer_size: usize, + batch_spill_limit: usize, ) -> Result { let cache = Arc::new(PlanProperties::new( EquivalenceProperties::new(Arc::clone(&input.schema())), @@ -98,6 +101,7 @@ impl ShuffleWriterExec { codec, tracing_enabled, write_buffer_size, + batch_spill_limit, }) } } @@ -158,6 +162,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.output_index_file.clone(), self.tracing_enabled, self.write_buffer_size, + self.batch_spill_limit, )?)), _ => panic!("ShuffleWriterExec wrong number of children"), } @@ -185,6 +190,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.codec.clone(), self.tracing_enabled, self.write_buffer_size, + self.batch_spill_limit, ) .map_err(|e| ArrowError::ExternalError(Box::new(e))), ) @@ -205,6 +211,7 @@ async fn external_shuffle( codec: CompressionCodec, tracing_enabled: bool, write_buffer_size: usize, + batch_spill_limit: usize, ) -> Result { let schema = input.schema(); @@ -241,6 +248,7 @@ async fn external_shuffle( codec, tracing_enabled, write_buffer_size, + batch_spill_limit, )?), }; @@ -363,6 +371,7 @@ mod test { CompressionCodec::Lz4Frame, false, 1024 * 1024, // write_buffer_size: 1MB default + 0, // batch_spill_limit: disabled ) .unwrap(); @@ -467,6 +476,7 @@ mod test { "/tmp/index.out".to_string(), false, 1024 * 1024, // write_buffer_size: 1MB default + 0, // batch_spill_limit: disabled ) .unwrap(); @@ -526,6 +536,7 @@ mod test { index_file.clone(), false, 1024 * 1024, + 0, ) .unwrap(); @@ -730,6 +741,7 @@ mod test { index_file.to_str().unwrap().to_string(), false, 1024 * 1024, + 0, ) .unwrap(); @@ -818,6 +830,7 @@ mod test { index_file.to_str().unwrap().to_string(), false, 1024 * 1024, + 0, ) .unwrap(); diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index f27d021ac4..af117d070a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -192,6 +192,7 @@ class CometNativeShuffleWriter[K, V]( CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) shuffleWriterBuilder.setWriteBufferSize( CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().min(Int.MaxValue).toInt) + shuffleWriterBuilder.setBatchSpillLimit(CometConf.COMET_SHUFFLE_BATCH_SPILL_LIMIT.get()) outputPartitioning match { case p if isSinglePartitioning(p) =>