Skip to content

Commit 9fa440b

Browse files
committed
Add unit tests for single spill file shuffle behavior
1 parent dc5d494 commit 9fa440b

File tree

1 file changed

+261
-13
lines changed

1 file changed

+261
-13
lines changed

native/shuffle/src/shuffle_writer.rs

Lines changed: 261 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -367,25 +367,20 @@ mod test {
367367

368368
repartitioner.insert_batch(batch.clone()).await.unwrap();
369369

370-
{
371-
let partition_writers = repartitioner.partition_writers();
372-
assert_eq!(partition_writers.len(), 2);
373-
374-
assert!(!partition_writers[0].has_spill_file());
375-
assert!(!partition_writers[1].has_spill_file());
376-
}
370+
// before spill, no spill files should exist
371+
assert_eq!(repartitioner.spill_count_files(), 0);
377372

378373
repartitioner.spill().unwrap();
379374

380-
// after spill, there should be spill files
381-
{
382-
let partition_writers = repartitioner.partition_writers();
383-
assert!(partition_writers[0].has_spill_file());
384-
assert!(partition_writers[1].has_spill_file());
385-
}
375+
// after spill, exactly one combined spill file should exist (not one per partition)
376+
assert_eq!(repartitioner.spill_count_files(), 1);
386377

387378
// insert another batch after spilling
388379
repartitioner.insert_batch(batch.clone()).await.unwrap();
380+
381+
// spill again -- should create a second combined spill file
382+
repartitioner.spill().unwrap();
383+
assert_eq!(repartitioner.spill_count_files(), 2);
389384
}
390385

391386
fn create_runtime(memory_limit: usize) -> Arc<RuntimeEnv> {
@@ -693,4 +688,257 @@ mod test {
693688
}
694689
total_rows
695690
}
691+
692+
/// Verify that spilling an empty repartitioner produces no spill files.
693+
#[tokio::test]
694+
async fn spill_empty_buffers_produces_no_file() {
695+
let batch = create_batch(100);
696+
let memory_limit = 512 * 1024;
697+
let num_partitions = 4;
698+
let runtime_env = create_runtime(memory_limit);
699+
let metrics_set = ExecutionPlanMetricsSet::new();
700+
let mut repartitioner = MultiPartitionShuffleRepartitioner::try_new(
701+
0,
702+
"/tmp/spill_empty_data.out".to_string(),
703+
"/tmp/spill_empty_index.out".to_string(),
704+
batch.schema(),
705+
CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions),
706+
ShufflePartitionerMetrics::new(&metrics_set, 0),
707+
runtime_env,
708+
1024,
709+
CompressionCodec::Lz4Frame,
710+
false,
711+
1024 * 1024,
712+
)
713+
.unwrap();
714+
715+
// Spill with no data inserted -- should be a no-op
716+
repartitioner.spill().unwrap();
717+
assert_eq!(repartitioner.spill_count_files(), 0);
718+
}
719+
720+
/// Verify that spilling with many partitions (some empty) still creates
721+
/// exactly one spill file per spill event, and that shuffle_write succeeds.
722+
#[test]
723+
#[cfg_attr(miri, ignore)]
724+
fn test_spill_with_sparse_partitions() {
725+
// 100 rows across 50 partitions -- many partitions will be empty
726+
shuffle_write_test(100, 5, 50, Some(10 * 1024));
727+
}
728+
729+
/// Verify that the output of a spill-based shuffle contains the same total
730+
/// row count and valid partition structure as a non-spill shuffle.
731+
#[test]
732+
#[cfg_attr(miri, ignore)]
733+
fn test_spill_output_matches_non_spill() {
734+
use std::fs;
735+
736+
let batch_size = 1000;
737+
let num_batches = 10;
738+
let num_partitions = 8;
739+
let total_rows = batch_size * num_batches;
740+
741+
let batch = create_batch(batch_size);
742+
let batches = (0..num_batches).map(|_| batch.clone()).collect::<Vec<_>>();
743+
744+
let parse_offsets = |index_data: &[u8]| -> Vec<i64> {
745+
index_data
746+
.chunks(8)
747+
.map(|chunk| i64::from_le_bytes(chunk.try_into().unwrap()))
748+
.collect()
749+
};
750+
751+
let count_rows_in_partition = |data: &[u8], start: i64, end: i64| -> usize {
752+
if start == end {
753+
return 0;
754+
}
755+
read_all_ipc_blocks(&data[start as usize..end as usize])
756+
};
757+
758+
// Run 1: no spilling (large memory limit)
759+
{
760+
let partitions = std::slice::from_ref(&batches);
761+
let exec = ShuffleWriterExec::try_new(
762+
Arc::new(DataSourceExec::new(Arc::new(
763+
MemorySourceConfig::try_new(partitions, batch.schema(), None).unwrap(),
764+
))),
765+
CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions),
766+
CompressionCodec::Zstd(1),
767+
"/tmp/no_spill_data.out".to_string(),
768+
"/tmp/no_spill_index.out".to_string(),
769+
false,
770+
1024 * 1024,
771+
)
772+
.unwrap();
773+
774+
let config = SessionConfig::new();
775+
let runtime_env = Arc::new(
776+
RuntimeEnvBuilder::new()
777+
.with_memory_limit(100 * 1024 * 1024, 1.0)
778+
.build()
779+
.unwrap(),
780+
);
781+
let ctx = SessionContext::new_with_config_rt(config, runtime_env);
782+
let task_ctx = ctx.task_ctx();
783+
let stream = exec.execute(0, task_ctx).unwrap();
784+
let rt = Runtime::new().unwrap();
785+
rt.block_on(collect(stream)).unwrap();
786+
}
787+
788+
// Run 2: with spilling (memory limit forces spilling during insert_batch)
789+
{
790+
let partitions = std::slice::from_ref(&batches);
791+
let exec = ShuffleWriterExec::try_new(
792+
Arc::new(DataSourceExec::new(Arc::new(
793+
MemorySourceConfig::try_new(partitions, batch.schema(), None).unwrap(),
794+
))),
795+
CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions),
796+
CompressionCodec::Zstd(1),
797+
"/tmp/spill_data.out".to_string(),
798+
"/tmp/spill_index.out".to_string(),
799+
false,
800+
1024 * 1024,
801+
)
802+
.unwrap();
803+
804+
let config = SessionConfig::new();
805+
let runtime_env = Arc::new(
806+
RuntimeEnvBuilder::new()
807+
.with_memory_limit(512 * 1024, 1.0)
808+
.build()
809+
.unwrap(),
810+
);
811+
let ctx = SessionContext::new_with_config_rt(config, runtime_env);
812+
let task_ctx = ctx.task_ctx();
813+
let stream = exec.execute(0, task_ctx).unwrap();
814+
let rt = Runtime::new().unwrap();
815+
rt.block_on(collect(stream)).unwrap();
816+
}
817+
818+
let no_spill_index = fs::read("/tmp/no_spill_index.out").unwrap();
819+
let spill_index = fs::read("/tmp/spill_index.out").unwrap();
820+
821+
assert_eq!(
822+
no_spill_index.len(),
823+
spill_index.len(),
824+
"Index files should have the same number of entries"
825+
);
826+
827+
let no_spill_offsets = parse_offsets(&no_spill_index);
828+
let spill_offsets = parse_offsets(&spill_index);
829+
830+
let no_spill_data = fs::read("/tmp/no_spill_data.out").unwrap();
831+
let spill_data = fs::read("/tmp/spill_data.out").unwrap();
832+
833+
// Verify row counts per partition match between spill and non-spill runs
834+
let mut no_spill_total_rows = 0;
835+
let mut spill_total_rows = 0;
836+
for i in 0..num_partitions {
837+
let ns_rows = count_rows_in_partition(
838+
&no_spill_data,
839+
no_spill_offsets[i],
840+
no_spill_offsets[i + 1],
841+
);
842+
let s_rows =
843+
count_rows_in_partition(&spill_data, spill_offsets[i], spill_offsets[i + 1]);
844+
assert_eq!(
845+
ns_rows, s_rows,
846+
"Partition {i} row count mismatch: no_spill={ns_rows}, spill={s_rows}"
847+
);
848+
no_spill_total_rows += ns_rows;
849+
spill_total_rows += s_rows;
850+
}
851+
852+
assert_eq!(
853+
no_spill_total_rows, total_rows,
854+
"Non-spill total row count mismatch"
855+
);
856+
assert_eq!(
857+
spill_total_rows, total_rows,
858+
"Spill total row count mismatch"
859+
);
860+
861+
// Cleanup
862+
let _ = fs::remove_file("/tmp/no_spill_data.out");
863+
let _ = fs::remove_file("/tmp/no_spill_index.out");
864+
let _ = fs::remove_file("/tmp/spill_data.out");
865+
let _ = fs::remove_file("/tmp/spill_index.out");
866+
}
867+
868+
/// Verify multiple spill events with subsequent insert_batch calls
869+
/// produce correct output.
870+
#[tokio::test]
871+
#[cfg_attr(miri, ignore)]
872+
async fn test_multiple_spills_then_write() {
873+
let batch = create_batch(500);
874+
let memory_limit = 512 * 1024;
875+
let num_partitions = 4;
876+
let runtime_env = create_runtime(memory_limit);
877+
let metrics_set = ExecutionPlanMetricsSet::new();
878+
let mut repartitioner = MultiPartitionShuffleRepartitioner::try_new(
879+
0,
880+
"/tmp/multi_spill_data.out".to_string(),
881+
"/tmp/multi_spill_index.out".to_string(),
882+
batch.schema(),
883+
CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions),
884+
ShufflePartitionerMetrics::new(&metrics_set, 0),
885+
runtime_env,
886+
1024,
887+
CompressionCodec::Lz4Frame,
888+
false,
889+
1024 * 1024,
890+
)
891+
.unwrap();
892+
893+
// Insert -> spill -> insert -> spill -> insert (3 inserts, 2 spills)
894+
repartitioner.insert_batch(batch.clone()).await.unwrap();
895+
repartitioner.spill().unwrap();
896+
assert_eq!(repartitioner.spill_count_files(), 1);
897+
898+
repartitioner.insert_batch(batch.clone()).await.unwrap();
899+
repartitioner.spill().unwrap();
900+
assert_eq!(repartitioner.spill_count_files(), 2);
901+
902+
repartitioner.insert_batch(batch.clone()).await.unwrap();
903+
// Final shuffle_write merges 2 spill files + in-memory data
904+
repartitioner.shuffle_write().unwrap();
905+
906+
// Verify output files exist and are non-empty
907+
let data = std::fs::read("/tmp/multi_spill_data.out").unwrap();
908+
assert!(!data.is_empty(), "Output data file should be non-empty");
909+
910+
let index = std::fs::read("/tmp/multi_spill_index.out").unwrap();
911+
// Index should have (num_partitions + 1) * 8 bytes
912+
assert_eq!(
913+
index.len(),
914+
(num_partitions + 1) * 8,
915+
"Index file should have correct number of offset entries"
916+
);
917+
918+
// Parse offsets and verify they are monotonically non-decreasing
919+
let offsets: Vec<i64> = index
920+
.chunks(8)
921+
.map(|chunk| i64::from_le_bytes(chunk.try_into().unwrap()))
922+
.collect();
923+
assert_eq!(offsets[0], 0, "First offset should be 0");
924+
for i in 1..offsets.len() {
925+
assert!(
926+
offsets[i] >= offsets[i - 1],
927+
"Offsets must be monotonically non-decreasing: offset[{}]={} < offset[{}]={}",
928+
i,
929+
offsets[i],
930+
i - 1,
931+
offsets[i - 1]
932+
);
933+
}
934+
assert_eq!(
935+
*offsets.last().unwrap() as usize,
936+
data.len(),
937+
"Last offset should equal data file size"
938+
);
939+
940+
// Cleanup
941+
let _ = std::fs::remove_file("/tmp/multi_spill_data.out");
942+
let _ = std::fs::remove_file("/tmp/multi_spill_index.out");
943+
}
696944
}

0 commit comments

Comments
 (0)