Skip to content

Commit 4ca6eac

Browse files
committed
Add unit tests for single spill file shuffle behavior
1 parent 9086a0a commit 4ca6eac

1 file changed

Lines changed: 262 additions & 15 deletions

File tree

native/shuffle/src/shuffle_writer.rs

Lines changed: 262 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -374,25 +374,20 @@ mod test {
374374

375375
repartitioner.insert_batch(batch.clone()).await.unwrap();
376376

377-
{
378-
let partition_writers = repartitioner.partition_writers();
379-
assert_eq!(partition_writers.len(), 2);
380-
381-
assert!(!partition_writers[0].has_spill_file());
382-
assert!(!partition_writers[1].has_spill_file());
383-
}
377+
// before spill, no spill files should exist
378+
assert_eq!(repartitioner.spill_count_files(), 0);
384379

385380
repartitioner.spill().unwrap();
386381

387-
// after spill, there should be spill files
388-
{
389-
let partition_writers = repartitioner.partition_writers();
390-
assert!(partition_writers[0].has_spill_file());
391-
assert!(partition_writers[1].has_spill_file());
392-
}
382+
// after spill, exactly one combined spill file should exist (not one per partition)
383+
assert_eq!(repartitioner.spill_count_files(), 1);
393384

394385
// insert another batch after spilling
395386
repartitioner.insert_batch(batch.clone()).await.unwrap();
387+
388+
// spill again -- should create a second combined spill file
389+
repartitioner.spill().unwrap();
390+
assert_eq!(repartitioner.spill_count_files(), 2);
396391
}
397392

398393
fn create_runtime(memory_limit: usize) -> Arc<RuntimeEnv> {
@@ -701,8 +696,6 @@ mod test {
701696
total_rows
702697
}
703698

704-
#[test]
705-
#[cfg_attr(miri, ignore)]
706699
fn test_empty_schema_shuffle_writer() {
707700
use std::fs;
708701
use std::io::Read;
@@ -857,3 +850,257 @@ mod test {
857850
}
858851
}
859852
}
853+
854+
/// Verify that spilling an empty repartitioner produces no spill files.
855+
#[tokio::test]
856+
async fn spill_empty_buffers_produces_no_file() {
857+
let batch = create_batch(100);
858+
let memory_limit = 512 * 1024;
859+
let num_partitions = 4;
860+
let runtime_env = create_runtime(memory_limit);
861+
let metrics_set = ExecutionPlanMetricsSet::new();
862+
let mut repartitioner = MultiPartitionShuffleRepartitioner::try_new(
863+
0,
864+
"/tmp/spill_empty_data.out".to_string(),
865+
"/tmp/spill_empty_index.out".to_string(),
866+
batch.schema(),
867+
CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions),
868+
ShufflePartitionerMetrics::new(&metrics_set, 0),
869+
runtime_env,
870+
1024,
871+
CompressionCodec::Lz4Frame,
872+
false,
873+
1024 * 1024,
874+
)
875+
.unwrap();
876+
877+
// Spill with no data inserted -- should be a no-op
878+
repartitioner.spill().unwrap();
879+
assert_eq!(repartitioner.spill_count_files(), 0);
880+
}
881+
882+
/// Verify that spilling with many partitions (some empty) still creates
883+
/// exactly one spill file per spill event, and that shuffle_write succeeds.
884+
#[test]
885+
#[cfg_attr(miri, ignore)]
886+
fn test_spill_with_sparse_partitions() {
887+
// 100 rows across 50 partitions -- many partitions will be empty
888+
shuffle_write_test(100, 5, 50, Some(10 * 1024));
889+
}
890+
891+
/// Verify that the output of a spill-based shuffle contains the same total
892+
/// row count and valid partition structure as a non-spill shuffle.
893+
#[test]
894+
#[cfg_attr(miri, ignore)]
895+
fn test_spill_output_matches_non_spill() {
896+
use std::fs;
897+
898+
let batch_size = 1000;
899+
let num_batches = 10;
900+
let num_partitions = 8;
901+
let total_rows = batch_size * num_batches;
902+
903+
let batch = create_batch(batch_size);
904+
let batches = (0..num_batches).map(|_| batch.clone()).collect::<Vec<_>>();
905+
906+
let parse_offsets = |index_data: &[u8]| -> Vec<i64> {
907+
index_data
908+
.chunks(8)
909+
.map(|chunk| i64::from_le_bytes(chunk.try_into().unwrap()))
910+
.collect()
911+
};
912+
913+
let count_rows_in_partition = |data: &[u8], start: i64, end: i64| -> usize {
914+
if start == end {
915+
return 0;
916+
}
917+
read_all_ipc_blocks(&data[start as usize..end as usize])
918+
};
919+
920+
// Run 1: no spilling (large memory limit)
921+
{
922+
let partitions = std::slice::from_ref(&batches);
923+
let exec = ShuffleWriterExec::try_new(
924+
Arc::new(DataSourceExec::new(Arc::new(
925+
MemorySourceConfig::try_new(partitions, batch.schema(), None).unwrap(),
926+
))),
927+
CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions),
928+
CompressionCodec::Zstd(1),
929+
"/tmp/no_spill_data.out".to_string(),
930+
"/tmp/no_spill_index.out".to_string(),
931+
false,
932+
1024 * 1024,
933+
)
934+
.unwrap();
935+
936+
let config = SessionConfig::new();
937+
let runtime_env = Arc::new(
938+
RuntimeEnvBuilder::new()
939+
.with_memory_limit(100 * 1024 * 1024, 1.0)
940+
.build()
941+
.unwrap(),
942+
);
943+
let ctx = SessionContext::new_with_config_rt(config, runtime_env);
944+
let task_ctx = ctx.task_ctx();
945+
let stream = exec.execute(0, task_ctx).unwrap();
946+
let rt = Runtime::new().unwrap();
947+
rt.block_on(collect(stream)).unwrap();
948+
}
949+
950+
// Run 2: with spilling (memory limit forces spilling during insert_batch)
951+
{
952+
let partitions = std::slice::from_ref(&batches);
953+
let exec = ShuffleWriterExec::try_new(
954+
Arc::new(DataSourceExec::new(Arc::new(
955+
MemorySourceConfig::try_new(partitions, batch.schema(), None).unwrap(),
956+
))),
957+
CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions),
958+
CompressionCodec::Zstd(1),
959+
"/tmp/spill_data.out".to_string(),
960+
"/tmp/spill_index.out".to_string(),
961+
false,
962+
1024 * 1024,
963+
)
964+
.unwrap();
965+
966+
let config = SessionConfig::new();
967+
let runtime_env = Arc::new(
968+
RuntimeEnvBuilder::new()
969+
.with_memory_limit(512 * 1024, 1.0)
970+
.build()
971+
.unwrap(),
972+
);
973+
let ctx = SessionContext::new_with_config_rt(config, runtime_env);
974+
let task_ctx = ctx.task_ctx();
975+
let stream = exec.execute(0, task_ctx).unwrap();
976+
let rt = Runtime::new().unwrap();
977+
rt.block_on(collect(stream)).unwrap();
978+
}
979+
980+
let no_spill_index = fs::read("/tmp/no_spill_index.out").unwrap();
981+
let spill_index = fs::read("/tmp/spill_index.out").unwrap();
982+
983+
assert_eq!(
984+
no_spill_index.len(),
985+
spill_index.len(),
986+
"Index files should have the same number of entries"
987+
);
988+
989+
let no_spill_offsets = parse_offsets(&no_spill_index);
990+
let spill_offsets = parse_offsets(&spill_index);
991+
992+
let no_spill_data = fs::read("/tmp/no_spill_data.out").unwrap();
993+
let spill_data = fs::read("/tmp/spill_data.out").unwrap();
994+
995+
// Verify row counts per partition match between spill and non-spill runs
996+
let mut no_spill_total_rows = 0;
997+
let mut spill_total_rows = 0;
998+
for i in 0..num_partitions {
999+
let ns_rows = count_rows_in_partition(
1000+
&no_spill_data,
1001+
no_spill_offsets[i],
1002+
no_spill_offsets[i + 1],
1003+
);
1004+
let s_rows =
1005+
count_rows_in_partition(&spill_data, spill_offsets[i], spill_offsets[i + 1]);
1006+
assert_eq!(
1007+
ns_rows, s_rows,
1008+
"Partition {i} row count mismatch: no_spill={ns_rows}, spill={s_rows}"
1009+
);
1010+
no_spill_total_rows += ns_rows;
1011+
spill_total_rows += s_rows;
1012+
}
1013+
1014+
assert_eq!(
1015+
no_spill_total_rows, total_rows,
1016+
"Non-spill total row count mismatch"
1017+
);
1018+
assert_eq!(
1019+
spill_total_rows, total_rows,
1020+
"Spill total row count mismatch"
1021+
);
1022+
1023+
// Cleanup
1024+
let _ = fs::remove_file("/tmp/no_spill_data.out");
1025+
let _ = fs::remove_file("/tmp/no_spill_index.out");
1026+
let _ = fs::remove_file("/tmp/spill_data.out");
1027+
let _ = fs::remove_file("/tmp/spill_index.out");
1028+
}
1029+
1030+
/// Verify multiple spill events with subsequent insert_batch calls
1031+
/// produce correct output.
1032+
#[tokio::test]
1033+
#[cfg_attr(miri, ignore)]
1034+
async fn test_multiple_spills_then_write() {
1035+
let batch = create_batch(500);
1036+
let memory_limit = 512 * 1024;
1037+
let num_partitions = 4;
1038+
let runtime_env = create_runtime(memory_limit);
1039+
let metrics_set = ExecutionPlanMetricsSet::new();
1040+
let mut repartitioner = MultiPartitionShuffleRepartitioner::try_new(
1041+
0,
1042+
"/tmp/multi_spill_data.out".to_string(),
1043+
"/tmp/multi_spill_index.out".to_string(),
1044+
batch.schema(),
1045+
CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions),
1046+
ShufflePartitionerMetrics::new(&metrics_set, 0),
1047+
runtime_env,
1048+
1024,
1049+
CompressionCodec::Lz4Frame,
1050+
false,
1051+
1024 * 1024,
1052+
)
1053+
.unwrap();
1054+
1055+
// Insert -> spill -> insert -> spill -> insert (3 inserts, 2 spills)
1056+
repartitioner.insert_batch(batch.clone()).await.unwrap();
1057+
repartitioner.spill().unwrap();
1058+
assert_eq!(repartitioner.spill_count_files(), 1);
1059+
1060+
repartitioner.insert_batch(batch.clone()).await.unwrap();
1061+
repartitioner.spill().unwrap();
1062+
assert_eq!(repartitioner.spill_count_files(), 2);
1063+
1064+
repartitioner.insert_batch(batch.clone()).await.unwrap();
1065+
// Final shuffle_write merges 2 spill files + in-memory data
1066+
repartitioner.shuffle_write().unwrap();
1067+
1068+
// Verify output files exist and are non-empty
1069+
let data = std::fs::read("/tmp/multi_spill_data.out").unwrap();
1070+
assert!(!data.is_empty(), "Output data file should be non-empty");
1071+
1072+
let index = std::fs::read("/tmp/multi_spill_index.out").unwrap();
1073+
// Index should have (num_partitions + 1) * 8 bytes
1074+
assert_eq!(
1075+
index.len(),
1076+
(num_partitions + 1) * 8,
1077+
"Index file should have correct number of offset entries"
1078+
);
1079+
1080+
// Parse offsets and verify they are monotonically non-decreasing
1081+
let offsets: Vec<i64> = index
1082+
.chunks(8)
1083+
.map(|chunk| i64::from_le_bytes(chunk.try_into().unwrap()))
1084+
.collect();
1085+
assert_eq!(offsets[0], 0, "First offset should be 0");
1086+
for i in 1..offsets.len() {
1087+
assert!(
1088+
offsets[i] >= offsets[i - 1],
1089+
"Offsets must be monotonically non-decreasing: offset[{}]={} < offset[{}]={}",
1090+
i,
1091+
offsets[i],
1092+
i - 1,
1093+
offsets[i - 1]
1094+
);
1095+
}
1096+
assert_eq!(
1097+
*offsets.last().unwrap() as usize,
1098+
data.len(),
1099+
"Last offset should equal data file size"
1100+
);
1101+
1102+
// Cleanup
1103+
let _ = std::fs::remove_file("/tmp/multi_spill_data.out");
1104+
let _ = std::fs::remove_file("/tmp/multi_spill_index.out");
1105+
}
1106+
}

0 commit comments

Comments
 (0)