@@ -374,25 +374,20 @@ mod test {
374374
375375 repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
376376
377- {
378- let partition_writers = repartitioner. partition_writers ( ) ;
379- assert_eq ! ( partition_writers. len( ) , 2 ) ;
380-
381- assert ! ( !partition_writers[ 0 ] . has_spill_file( ) ) ;
382- assert ! ( !partition_writers[ 1 ] . has_spill_file( ) ) ;
383- }
377+ // before spill, no spill files should exist
378+ assert_eq ! ( repartitioner. spill_count_files( ) , 0 ) ;
384379
385380 repartitioner. spill ( ) . unwrap ( ) ;
386381
387- // after spill, there should be spill files
388- {
389- let partition_writers = repartitioner. partition_writers ( ) ;
390- assert ! ( partition_writers[ 0 ] . has_spill_file( ) ) ;
391- assert ! ( partition_writers[ 1 ] . has_spill_file( ) ) ;
392- }
382+ // after spill, exactly one combined spill file should exist (not one per partition)
383+ assert_eq ! ( repartitioner. spill_count_files( ) , 1 ) ;
393384
394385 // insert another batch after spilling
395386 repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
387+
388+ // spill again -- should create a second combined spill file
389+ repartitioner. spill ( ) . unwrap ( ) ;
390+ assert_eq ! ( repartitioner. spill_count_files( ) , 2 ) ;
396391 }
397392
398393 fn create_runtime ( memory_limit : usize ) -> Arc < RuntimeEnv > {
@@ -701,8 +696,6 @@ mod test {
701696 total_rows
702697 }
703698
704- #[ test]
705- #[ cfg_attr( miri, ignore) ]
706699 fn test_empty_schema_shuffle_writer ( ) {
707700 use std:: fs;
708701 use std:: io:: Read ;
@@ -857,3 +850,257 @@ mod test {
857850 }
858851 }
859852}
853+
854+ /// Verify that spilling an empty repartitioner produces no spill files.
855+ #[ tokio:: test]
856+ async fn spill_empty_buffers_produces_no_file ( ) {
857+ let batch = create_batch ( 100 ) ;
858+ let memory_limit = 512 * 1024 ;
859+ let num_partitions = 4 ;
860+ let runtime_env = create_runtime ( memory_limit) ;
861+ let metrics_set = ExecutionPlanMetricsSet :: new ( ) ;
862+ let mut repartitioner = MultiPartitionShuffleRepartitioner :: try_new (
863+ 0 ,
864+ "/tmp/spill_empty_data.out" . to_string ( ) ,
865+ "/tmp/spill_empty_index.out" . to_string ( ) ,
866+ batch. schema ( ) ,
867+ CometPartitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , num_partitions) ,
868+ ShufflePartitionerMetrics :: new ( & metrics_set, 0 ) ,
869+ runtime_env,
870+ 1024 ,
871+ CompressionCodec :: Lz4Frame ,
872+ false ,
873+ 1024 * 1024 ,
874+ )
875+ . unwrap ( ) ;
876+
877+ // Spill with no data inserted -- should be a no-op
878+ repartitioner. spill ( ) . unwrap ( ) ;
879+ assert_eq ! ( repartitioner. spill_count_files( ) , 0 ) ;
880+ }
881+
882+ /// Verify that spilling with many partitions (some empty) still creates
883+ /// exactly one spill file per spill event, and that shuffle_write succeeds.
884+ #[ test]
885+ #[ cfg_attr( miri, ignore) ]
886+ fn test_spill_with_sparse_partitions ( ) {
887+ // 100 rows across 50 partitions -- many partitions will be empty
888+ shuffle_write_test ( 100 , 5 , 50 , Some ( 10 * 1024 ) ) ;
889+ }
890+
891+ /// Verify that the output of a spill-based shuffle contains the same total
892+ /// row count and valid partition structure as a non-spill shuffle.
893+ #[ test]
894+ #[ cfg_attr( miri, ignore) ]
895+ fn test_spill_output_matches_non_spill ( ) {
896+ use std:: fs;
897+
898+ let batch_size = 1000 ;
899+ let num_batches = 10 ;
900+ let num_partitions = 8 ;
901+ let total_rows = batch_size * num_batches;
902+
903+ let batch = create_batch ( batch_size) ;
904+ let batches = ( 0 ..num_batches) . map ( |_| batch. clone ( ) ) . collect :: < Vec < _ > > ( ) ;
905+
906+ let parse_offsets = |index_data : & [ u8 ] | -> Vec < i64 > {
907+ index_data
908+ . chunks ( 8 )
909+ . map ( |chunk| i64:: from_le_bytes ( chunk. try_into ( ) . unwrap ( ) ) )
910+ . collect ( )
911+ } ;
912+
913+ let count_rows_in_partition = |data : & [ u8 ] , start : i64 , end : i64 | -> usize {
914+ if start == end {
915+ return 0 ;
916+ }
917+ read_all_ipc_blocks ( & data[ start as usize ..end as usize ] )
918+ } ;
919+
920+ // Run 1: no spilling (large memory limit)
921+ {
922+ let partitions = std:: slice:: from_ref ( & batches) ;
923+ let exec = ShuffleWriterExec :: try_new (
924+ Arc :: new ( DataSourceExec :: new ( Arc :: new (
925+ MemorySourceConfig :: try_new ( partitions, batch. schema ( ) , None ) . unwrap ( ) ,
926+ ) ) ) ,
927+ CometPartitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , num_partitions) ,
928+ CompressionCodec :: Zstd ( 1 ) ,
929+ "/tmp/no_spill_data.out" . to_string ( ) ,
930+ "/tmp/no_spill_index.out" . to_string ( ) ,
931+ false ,
932+ 1024 * 1024 ,
933+ )
934+ . unwrap ( ) ;
935+
936+ let config = SessionConfig :: new ( ) ;
937+ let runtime_env = Arc :: new (
938+ RuntimeEnvBuilder :: new ( )
939+ . with_memory_limit ( 100 * 1024 * 1024 , 1.0 )
940+ . build ( )
941+ . unwrap ( ) ,
942+ ) ;
943+ let ctx = SessionContext :: new_with_config_rt ( config, runtime_env) ;
944+ let task_ctx = ctx. task_ctx ( ) ;
945+ let stream = exec. execute ( 0 , task_ctx) . unwrap ( ) ;
946+ let rt = Runtime :: new ( ) . unwrap ( ) ;
947+ rt. block_on ( collect ( stream) ) . unwrap ( ) ;
948+ }
949+
950+ // Run 2: with spilling (memory limit forces spilling during insert_batch)
951+ {
952+ let partitions = std:: slice:: from_ref ( & batches) ;
953+ let exec = ShuffleWriterExec :: try_new (
954+ Arc :: new ( DataSourceExec :: new ( Arc :: new (
955+ MemorySourceConfig :: try_new ( partitions, batch. schema ( ) , None ) . unwrap ( ) ,
956+ ) ) ) ,
957+ CometPartitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , num_partitions) ,
958+ CompressionCodec :: Zstd ( 1 ) ,
959+ "/tmp/spill_data.out" . to_string ( ) ,
960+ "/tmp/spill_index.out" . to_string ( ) ,
961+ false ,
962+ 1024 * 1024 ,
963+ )
964+ . unwrap ( ) ;
965+
966+ let config = SessionConfig :: new ( ) ;
967+ let runtime_env = Arc :: new (
968+ RuntimeEnvBuilder :: new ( )
969+ . with_memory_limit ( 512 * 1024 , 1.0 )
970+ . build ( )
971+ . unwrap ( ) ,
972+ ) ;
973+ let ctx = SessionContext :: new_with_config_rt ( config, runtime_env) ;
974+ let task_ctx = ctx. task_ctx ( ) ;
975+ let stream = exec. execute ( 0 , task_ctx) . unwrap ( ) ;
976+ let rt = Runtime :: new ( ) . unwrap ( ) ;
977+ rt. block_on ( collect ( stream) ) . unwrap ( ) ;
978+ }
979+
980+ let no_spill_index = fs:: read ( "/tmp/no_spill_index.out" ) . unwrap ( ) ;
981+ let spill_index = fs:: read ( "/tmp/spill_index.out" ) . unwrap ( ) ;
982+
983+ assert_eq ! (
984+ no_spill_index. len( ) ,
985+ spill_index. len( ) ,
986+ "Index files should have the same number of entries"
987+ ) ;
988+
989+ let no_spill_offsets = parse_offsets ( & no_spill_index) ;
990+ let spill_offsets = parse_offsets ( & spill_index) ;
991+
992+ let no_spill_data = fs:: read ( "/tmp/no_spill_data.out" ) . unwrap ( ) ;
993+ let spill_data = fs:: read ( "/tmp/spill_data.out" ) . unwrap ( ) ;
994+
995+ // Verify row counts per partition match between spill and non-spill runs
996+ let mut no_spill_total_rows = 0 ;
997+ let mut spill_total_rows = 0 ;
998+ for i in 0 ..num_partitions {
999+ let ns_rows = count_rows_in_partition (
1000+ & no_spill_data,
1001+ no_spill_offsets[ i] ,
1002+ no_spill_offsets[ i + 1 ] ,
1003+ ) ;
1004+ let s_rows =
1005+ count_rows_in_partition ( & spill_data, spill_offsets[ i] , spill_offsets[ i + 1 ] ) ;
1006+ assert_eq ! (
1007+ ns_rows, s_rows,
1008+ "Partition {i} row count mismatch: no_spill={ns_rows}, spill={s_rows}"
1009+ ) ;
1010+ no_spill_total_rows += ns_rows;
1011+ spill_total_rows += s_rows;
1012+ }
1013+
1014+ assert_eq ! (
1015+ no_spill_total_rows, total_rows,
1016+ "Non-spill total row count mismatch"
1017+ ) ;
1018+ assert_eq ! (
1019+ spill_total_rows, total_rows,
1020+ "Spill total row count mismatch"
1021+ ) ;
1022+
1023+ // Cleanup
1024+ let _ = fs:: remove_file ( "/tmp/no_spill_data.out" ) ;
1025+ let _ = fs:: remove_file ( "/tmp/no_spill_index.out" ) ;
1026+ let _ = fs:: remove_file ( "/tmp/spill_data.out" ) ;
1027+ let _ = fs:: remove_file ( "/tmp/spill_index.out" ) ;
1028+ }
1029+
1030+ /// Verify multiple spill events with subsequent insert_batch calls
1031+ /// produce correct output.
1032+ #[ tokio:: test]
1033+ #[ cfg_attr( miri, ignore) ]
1034+ async fn test_multiple_spills_then_write ( ) {
1035+ let batch = create_batch ( 500 ) ;
1036+ let memory_limit = 512 * 1024 ;
1037+ let num_partitions = 4 ;
1038+ let runtime_env = create_runtime ( memory_limit) ;
1039+ let metrics_set = ExecutionPlanMetricsSet :: new ( ) ;
1040+ let mut repartitioner = MultiPartitionShuffleRepartitioner :: try_new (
1041+ 0 ,
1042+ "/tmp/multi_spill_data.out" . to_string ( ) ,
1043+ "/tmp/multi_spill_index.out" . to_string ( ) ,
1044+ batch. schema ( ) ,
1045+ CometPartitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , num_partitions) ,
1046+ ShufflePartitionerMetrics :: new ( & metrics_set, 0 ) ,
1047+ runtime_env,
1048+ 1024 ,
1049+ CompressionCodec :: Lz4Frame ,
1050+ false ,
1051+ 1024 * 1024 ,
1052+ )
1053+ . unwrap ( ) ;
1054+
1055+ // Insert -> spill -> insert -> spill -> insert (3 inserts, 2 spills)
1056+ repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
1057+ repartitioner. spill ( ) . unwrap ( ) ;
1058+ assert_eq ! ( repartitioner. spill_count_files( ) , 1 ) ;
1059+
1060+ repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
1061+ repartitioner. spill ( ) . unwrap ( ) ;
1062+ assert_eq ! ( repartitioner. spill_count_files( ) , 2 ) ;
1063+
1064+ repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
1065+ // Final shuffle_write merges 2 spill files + in-memory data
1066+ repartitioner. shuffle_write ( ) . unwrap ( ) ;
1067+
1068+ // Verify output files exist and are non-empty
1069+ let data = std:: fs:: read ( "/tmp/multi_spill_data.out" ) . unwrap ( ) ;
1070+ assert ! ( !data. is_empty( ) , "Output data file should be non-empty" ) ;
1071+
1072+ let index = std:: fs:: read ( "/tmp/multi_spill_index.out" ) . unwrap ( ) ;
1073+ // Index should have (num_partitions + 1) * 8 bytes
1074+ assert_eq ! (
1075+ index. len( ) ,
1076+ ( num_partitions + 1 ) * 8 ,
1077+ "Index file should have correct number of offset entries"
1078+ ) ;
1079+
1080+ // Parse offsets and verify they are monotonically non-decreasing
1081+ let offsets: Vec < i64 > = index
1082+ . chunks ( 8 )
1083+ . map ( |chunk| i64:: from_le_bytes ( chunk. try_into ( ) . unwrap ( ) ) )
1084+ . collect ( ) ;
1085+ assert_eq ! ( offsets[ 0 ] , 0 , "First offset should be 0" ) ;
1086+ for i in 1 ..offsets. len ( ) {
1087+ assert ! (
1088+ offsets[ i] >= offsets[ i - 1 ] ,
1089+ "Offsets must be monotonically non-decreasing: offset[{}]={} < offset[{}]={}" ,
1090+ i,
1091+ offsets[ i] ,
1092+ i - 1 ,
1093+ offsets[ i - 1 ]
1094+ ) ;
1095+ }
1096+ assert_eq ! (
1097+ * offsets. last( ) . unwrap( ) as usize ,
1098+ data. len( ) ,
1099+ "Last offset should equal data file size"
1100+ ) ;
1101+
1102+ // Cleanup
1103+ let _ = std:: fs:: remove_file ( "/tmp/multi_spill_data.out" ) ;
1104+ let _ = std:: fs:: remove_file ( "/tmp/multi_spill_index.out" ) ;
1105+ }
1106+ }
0 commit comments