@@ -367,25 +367,20 @@ mod test {
367367
368368 repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
369369
370- {
371- let partition_writers = repartitioner. partition_writers ( ) ;
372- assert_eq ! ( partition_writers. len( ) , 2 ) ;
373-
374- assert ! ( !partition_writers[ 0 ] . has_spill_file( ) ) ;
375- assert ! ( !partition_writers[ 1 ] . has_spill_file( ) ) ;
376- }
370+ // before spill, no spill files should exist
371+ assert_eq ! ( repartitioner. spill_count_files( ) , 0 ) ;
377372
378373 repartitioner. spill ( ) . unwrap ( ) ;
379374
380- // after spill, there should be spill files
381- {
382- let partition_writers = repartitioner. partition_writers ( ) ;
383- assert ! ( partition_writers[ 0 ] . has_spill_file( ) ) ;
384- assert ! ( partition_writers[ 1 ] . has_spill_file( ) ) ;
385- }
375+ // after spill, exactly one combined spill file should exist (not one per partition)
376+ assert_eq ! ( repartitioner. spill_count_files( ) , 1 ) ;
386377
387378 // insert another batch after spilling
388379 repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
380+
381+ // spill again -- should create a second combined spill file
382+ repartitioner. spill ( ) . unwrap ( ) ;
383+ assert_eq ! ( repartitioner. spill_count_files( ) , 2 ) ;
389384 }
390385
391386 fn create_runtime ( memory_limit : usize ) -> Arc < RuntimeEnv > {
@@ -693,4 +688,257 @@ mod test {
693688 }
694689 total_rows
695690 }
691+
692+ /// Verify that spilling an empty repartitioner produces no spill files.
693+ #[ tokio:: test]
694+ async fn spill_empty_buffers_produces_no_file ( ) {
695+ let batch = create_batch ( 100 ) ;
696+ let memory_limit = 512 * 1024 ;
697+ let num_partitions = 4 ;
698+ let runtime_env = create_runtime ( memory_limit) ;
699+ let metrics_set = ExecutionPlanMetricsSet :: new ( ) ;
700+ let mut repartitioner = MultiPartitionShuffleRepartitioner :: try_new (
701+ 0 ,
702+ "/tmp/spill_empty_data.out" . to_string ( ) ,
703+ "/tmp/spill_empty_index.out" . to_string ( ) ,
704+ batch. schema ( ) ,
705+ CometPartitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , num_partitions) ,
706+ ShufflePartitionerMetrics :: new ( & metrics_set, 0 ) ,
707+ runtime_env,
708+ 1024 ,
709+ CompressionCodec :: Lz4Frame ,
710+ false ,
711+ 1024 * 1024 ,
712+ )
713+ . unwrap ( ) ;
714+
715+ // Spill with no data inserted -- should be a no-op
716+ repartitioner. spill ( ) . unwrap ( ) ;
717+ assert_eq ! ( repartitioner. spill_count_files( ) , 0 ) ;
718+ }
719+
720+ /// Verify that spilling with many partitions (some empty) still creates
721+ /// exactly one spill file per spill event, and that shuffle_write succeeds.
722+ #[ test]
723+ #[ cfg_attr( miri, ignore) ]
724+ fn test_spill_with_sparse_partitions ( ) {
725+ // 100 rows across 50 partitions -- many partitions will be empty
726+ shuffle_write_test ( 100 , 5 , 50 , Some ( 10 * 1024 ) ) ;
727+ }
728+
729+ /// Verify that the output of a spill-based shuffle contains the same total
730+ /// row count and valid partition structure as a non-spill shuffle.
731+ #[ test]
732+ #[ cfg_attr( miri, ignore) ]
733+ fn test_spill_output_matches_non_spill ( ) {
734+ use std:: fs;
735+
736+ let batch_size = 1000 ;
737+ let num_batches = 10 ;
738+ let num_partitions = 8 ;
739+ let total_rows = batch_size * num_batches;
740+
741+ let batch = create_batch ( batch_size) ;
742+ let batches = ( 0 ..num_batches) . map ( |_| batch. clone ( ) ) . collect :: < Vec < _ > > ( ) ;
743+
744+ let parse_offsets = |index_data : & [ u8 ] | -> Vec < i64 > {
745+ index_data
746+ . chunks ( 8 )
747+ . map ( |chunk| i64:: from_le_bytes ( chunk. try_into ( ) . unwrap ( ) ) )
748+ . collect ( )
749+ } ;
750+
751+ let count_rows_in_partition = |data : & [ u8 ] , start : i64 , end : i64 | -> usize {
752+ if start == end {
753+ return 0 ;
754+ }
755+ read_all_ipc_blocks ( & data[ start as usize ..end as usize ] )
756+ } ;
757+
758+ // Run 1: no spilling (large memory limit)
759+ {
760+ let partitions = std:: slice:: from_ref ( & batches) ;
761+ let exec = ShuffleWriterExec :: try_new (
762+ Arc :: new ( DataSourceExec :: new ( Arc :: new (
763+ MemorySourceConfig :: try_new ( partitions, batch. schema ( ) , None ) . unwrap ( ) ,
764+ ) ) ) ,
765+ CometPartitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , num_partitions) ,
766+ CompressionCodec :: Zstd ( 1 ) ,
767+ "/tmp/no_spill_data.out" . to_string ( ) ,
768+ "/tmp/no_spill_index.out" . to_string ( ) ,
769+ false ,
770+ 1024 * 1024 ,
771+ )
772+ . unwrap ( ) ;
773+
774+ let config = SessionConfig :: new ( ) ;
775+ let runtime_env = Arc :: new (
776+ RuntimeEnvBuilder :: new ( )
777+ . with_memory_limit ( 100 * 1024 * 1024 , 1.0 )
778+ . build ( )
779+ . unwrap ( ) ,
780+ ) ;
781+ let ctx = SessionContext :: new_with_config_rt ( config, runtime_env) ;
782+ let task_ctx = ctx. task_ctx ( ) ;
783+ let stream = exec. execute ( 0 , task_ctx) . unwrap ( ) ;
784+ let rt = Runtime :: new ( ) . unwrap ( ) ;
785+ rt. block_on ( collect ( stream) ) . unwrap ( ) ;
786+ }
787+
788+ // Run 2: with spilling (memory limit forces spilling during insert_batch)
789+ {
790+ let partitions = std:: slice:: from_ref ( & batches) ;
791+ let exec = ShuffleWriterExec :: try_new (
792+ Arc :: new ( DataSourceExec :: new ( Arc :: new (
793+ MemorySourceConfig :: try_new ( partitions, batch. schema ( ) , None ) . unwrap ( ) ,
794+ ) ) ) ,
795+ CometPartitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , num_partitions) ,
796+ CompressionCodec :: Zstd ( 1 ) ,
797+ "/tmp/spill_data.out" . to_string ( ) ,
798+ "/tmp/spill_index.out" . to_string ( ) ,
799+ false ,
800+ 1024 * 1024 ,
801+ )
802+ . unwrap ( ) ;
803+
804+ let config = SessionConfig :: new ( ) ;
805+ let runtime_env = Arc :: new (
806+ RuntimeEnvBuilder :: new ( )
807+ . with_memory_limit ( 512 * 1024 , 1.0 )
808+ . build ( )
809+ . unwrap ( ) ,
810+ ) ;
811+ let ctx = SessionContext :: new_with_config_rt ( config, runtime_env) ;
812+ let task_ctx = ctx. task_ctx ( ) ;
813+ let stream = exec. execute ( 0 , task_ctx) . unwrap ( ) ;
814+ let rt = Runtime :: new ( ) . unwrap ( ) ;
815+ rt. block_on ( collect ( stream) ) . unwrap ( ) ;
816+ }
817+
818+ let no_spill_index = fs:: read ( "/tmp/no_spill_index.out" ) . unwrap ( ) ;
819+ let spill_index = fs:: read ( "/tmp/spill_index.out" ) . unwrap ( ) ;
820+
821+ assert_eq ! (
822+ no_spill_index. len( ) ,
823+ spill_index. len( ) ,
824+ "Index files should have the same number of entries"
825+ ) ;
826+
827+ let no_spill_offsets = parse_offsets ( & no_spill_index) ;
828+ let spill_offsets = parse_offsets ( & spill_index) ;
829+
830+ let no_spill_data = fs:: read ( "/tmp/no_spill_data.out" ) . unwrap ( ) ;
831+ let spill_data = fs:: read ( "/tmp/spill_data.out" ) . unwrap ( ) ;
832+
833+ // Verify row counts per partition match between spill and non-spill runs
834+ let mut no_spill_total_rows = 0 ;
835+ let mut spill_total_rows = 0 ;
836+ for i in 0 ..num_partitions {
837+ let ns_rows = count_rows_in_partition (
838+ & no_spill_data,
839+ no_spill_offsets[ i] ,
840+ no_spill_offsets[ i + 1 ] ,
841+ ) ;
842+ let s_rows =
843+ count_rows_in_partition ( & spill_data, spill_offsets[ i] , spill_offsets[ i + 1 ] ) ;
844+ assert_eq ! (
845+ ns_rows, s_rows,
846+ "Partition {i} row count mismatch: no_spill={ns_rows}, spill={s_rows}"
847+ ) ;
848+ no_spill_total_rows += ns_rows;
849+ spill_total_rows += s_rows;
850+ }
851+
852+ assert_eq ! (
853+ no_spill_total_rows, total_rows,
854+ "Non-spill total row count mismatch"
855+ ) ;
856+ assert_eq ! (
857+ spill_total_rows, total_rows,
858+ "Spill total row count mismatch"
859+ ) ;
860+
861+ // Cleanup
862+ let _ = fs:: remove_file ( "/tmp/no_spill_data.out" ) ;
863+ let _ = fs:: remove_file ( "/tmp/no_spill_index.out" ) ;
864+ let _ = fs:: remove_file ( "/tmp/spill_data.out" ) ;
865+ let _ = fs:: remove_file ( "/tmp/spill_index.out" ) ;
866+ }
867+
868+ /// Verify multiple spill events with subsequent insert_batch calls
869+ /// produce correct output.
870+ #[ tokio:: test]
871+ #[ cfg_attr( miri, ignore) ]
872+ async fn test_multiple_spills_then_write ( ) {
873+ let batch = create_batch ( 500 ) ;
874+ let memory_limit = 512 * 1024 ;
875+ let num_partitions = 4 ;
876+ let runtime_env = create_runtime ( memory_limit) ;
877+ let metrics_set = ExecutionPlanMetricsSet :: new ( ) ;
878+ let mut repartitioner = MultiPartitionShuffleRepartitioner :: try_new (
879+ 0 ,
880+ "/tmp/multi_spill_data.out" . to_string ( ) ,
881+ "/tmp/multi_spill_index.out" . to_string ( ) ,
882+ batch. schema ( ) ,
883+ CometPartitioning :: Hash ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , num_partitions) ,
884+ ShufflePartitionerMetrics :: new ( & metrics_set, 0 ) ,
885+ runtime_env,
886+ 1024 ,
887+ CompressionCodec :: Lz4Frame ,
888+ false ,
889+ 1024 * 1024 ,
890+ )
891+ . unwrap ( ) ;
892+
893+ // Insert -> spill -> insert -> spill -> insert (3 inserts, 2 spills)
894+ repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
895+ repartitioner. spill ( ) . unwrap ( ) ;
896+ assert_eq ! ( repartitioner. spill_count_files( ) , 1 ) ;
897+
898+ repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
899+ repartitioner. spill ( ) . unwrap ( ) ;
900+ assert_eq ! ( repartitioner. spill_count_files( ) , 2 ) ;
901+
902+ repartitioner. insert_batch ( batch. clone ( ) ) . await . unwrap ( ) ;
903+ // Final shuffle_write merges 2 spill files + in-memory data
904+ repartitioner. shuffle_write ( ) . unwrap ( ) ;
905+
906+ // Verify output files exist and are non-empty
907+ let data = std:: fs:: read ( "/tmp/multi_spill_data.out" ) . unwrap ( ) ;
908+ assert ! ( !data. is_empty( ) , "Output data file should be non-empty" ) ;
909+
910+ let index = std:: fs:: read ( "/tmp/multi_spill_index.out" ) . unwrap ( ) ;
911+ // Index should have (num_partitions + 1) * 8 bytes
912+ assert_eq ! (
913+ index. len( ) ,
914+ ( num_partitions + 1 ) * 8 ,
915+ "Index file should have correct number of offset entries"
916+ ) ;
917+
918+ // Parse offsets and verify they are monotonically non-decreasing
919+ let offsets: Vec < i64 > = index
920+ . chunks ( 8 )
921+ . map ( |chunk| i64:: from_le_bytes ( chunk. try_into ( ) . unwrap ( ) ) )
922+ . collect ( ) ;
923+ assert_eq ! ( offsets[ 0 ] , 0 , "First offset should be 0" ) ;
924+ for i in 1 ..offsets. len ( ) {
925+ assert ! (
926+ offsets[ i] >= offsets[ i - 1 ] ,
927+ "Offsets must be monotonically non-decreasing: offset[{}]={} < offset[{}]={}" ,
928+ i,
929+ offsets[ i] ,
930+ i - 1 ,
931+ offsets[ i - 1 ]
932+ ) ;
933+ }
934+ assert_eq ! (
935+ * offsets. last( ) . unwrap( ) as usize ,
936+ data. len( ) ,
937+ "Last offset should equal data file size"
938+ ) ;
939+
940+ // Cleanup
941+ let _ = std:: fs:: remove_file ( "/tmp/multi_spill_data.out" ) ;
942+ let _ = std:: fs:: remove_file ( "/tmp/multi_spill_index.out" ) ;
943+ }
696944}
0 commit comments