Skip to content

Commit 2ef0217

Browse files
authored
fix: SortMergeJoin full outer join incorrectly matches rows when filter evaluates to NULL (#21660)
## Which issue does this PR close? - Closes #. ## Rationale for this change While enabling `SortMergeJoinExec` with filters in [DataFusion Comet](https://github.com/apache/datafusion-comet), I hit a correctness bug in DataFusion's `SortMergeJoinExec` for full outer joins with nullable filter columns. The bug was originally surfaced via the [SPARK-43113](https://issues.apache.org/jira/browse/SPARK-43113) reproducer. When a join filter expression evaluates to `NULL` (e.g., `l.b < (r.b + 1)` where `r.b` is `NULL`), the full outer join treats the row pair as **matched** instead of **unmatched**. Per SQL semantics, `NULL` in a boolean filter context is `false` (not satisfied), so the rows should be emitted as separate unmatched rows. The bug has been present since filtered full outer join support was added to SMJ in #12764 / #13369. It was never caught because: 1. The join fuzz tests generate filter column data with `Int32Array::from_iter_values()`, which never produces `NULL` values. 2. No existing unit test or sqllogictest exercised a full outer join filter that evaluates to `NULL`. ## What changes are included in this PR? **Root cause:** The full outer join code path had a special case that preserved raw `NULL` values from the filter expression result (`pre_mask`) instead of converting them to `false` via `prep_null_mask_filter` like left/right outer joins do. This caused two problems: 1. **Streamed (left) side:** In `get_corrected_filter_mask()`, `NULL` entries in `filter_mask` are treated as "pass through" (for pre-null-joined rows from `append_nulls()`). But when the filter expression itself returns `NULL`, those entries also appear as `NULL` in the mask — and get incorrectly treated as matched. This produced wrong join output (matched rows instead of unmatched). 2. **Buffered (right) side:** `BooleanArray::value()` was called on `NULL` positions in `pre_mask` to update `FilterState`. At NULL positions, the values buffer contains a deterministic but semantically meaningless result (computed from the default zero-storage of NULL inputs). For some rows this value happens to be `true`, which incorrectly marks unmatched buffered rows as `SomePassed` and silently drops them from the output. **Fix:** Remove the full outer join special case in `materializing_stream.rs`. All outer join types now uniformly use the null-corrected `mask` (where `NULL` → `false` via `prep_null_mask_filter`) for both deferred filtering metadata and `FilterState` tracking. Semi/anti/mark joins are unaffected — they use `BitwiseSortMergeJoinStream` which already converts NULLs to `false`. **Tests:** - New unit test `join_full_null_filter_result` reproducing the SPARK-43113 scenario with a nullable right-side column. - Modified `make_staggered_batches_i32` in `join_fuzz.rs` to inject ~10% `NULL` values into the filter column (`x`), so the fuzz tests exercise `NULL` filter handling across all join types. ## Are these changes tested? Yes. - New unit test (`join_full_null_filter_result`) directly reproduces the bug. - Existing 57 SMJ unit tests all pass. - All 41 join fuzz tests pass with the new nullable filter column data, including `test_full_join_1k_filtered` which compares `HashJoinExec` vs `SortMergeJoinExec` and would have caught this bug if the fuzz data had included `NULL`s. - Will run 100 iterations of the fuzz tests overnight to shake out any remaining nondeterministic issues. - Testing in Comet CI (all Spark SQL tests) apache/datafusion-comet#3916 ## Are there any user-facing changes? Full outer sort-merge joins with filters involving nullable columns now produce correct results. Previously, rows where the filter evaluated to `NULL` were incorrectly included as matched; they are now correctly emitted as unmatched (null-joined) rows.
1 parent 9873357 commit 2ef0217

File tree

3 files changed

+122
-19
lines changed

3 files changed

+122
-19
lines changed

datafusion/core/tests/fuzz_cases/join_fuzz.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,14 @@ fn make_staggered_batches_i32(len: usize, with_extra_column: bool) -> Vec<Record
12761276
input12.sort_unstable();
12771277
let input1 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.0));
12781278
let input2 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.1));
1279-
let input3 = Int32Array::from_iter_values(input3);
1279+
let input3 = Int32Array::from_iter(input3.into_iter().map(|v| {
1280+
// ~10% NULLs in filter column to exercise NULL filter handling
1281+
if rng.random_range(0..10) == 0 {
1282+
None
1283+
} else {
1284+
Some(v)
1285+
}
1286+
}));
12801287
let input4 = Int32Array::from_iter_values(input4);
12811288

12821289
let mut columns = vec![

datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,27 +1423,22 @@ impl MaterializingSortMergeJoinStream {
14231423
.evaluate(&filter_batch)?
14241424
.into_array(filter_batch.num_rows())?;
14251425

1426-
let pre_mask = datafusion_common::cast::as_boolean_array(&filter_result)?;
1426+
let filter_result_mask =
1427+
datafusion_common::cast::as_boolean_array(&filter_result)?;
14271428

1428-
let mask = if pre_mask.null_count() > 0 {
1429-
compute::prep_null_mask_filter(pre_mask)
1429+
// Convert NULL filter results to false — NULL means "not satisfied"
1430+
// per SQL semantics, same as Left/Right outer joins.
1431+
let mask = if filter_result_mask.null_count() > 0 {
1432+
compute::prep_null_mask_filter(filter_result_mask)
14301433
} else {
1431-
pre_mask.clone()
1434+
filter_result_mask.clone()
14321435
};
14331436

14341437
if needs_deferred_filtering(&self.filter, self.join_type) {
1435-
// Full join uses pre_mask (preserving nulls) for
1436-
// get_corrected_filter_mask; other outer joins use mask.
1437-
let mask_to_use = if self.join_type != JoinType::Full {
1438-
&mask
1439-
} else {
1440-
pre_mask
1441-
};
1442-
14431438
self.joined_record_batches.push_batch_with_filter_metadata(
14441439
output_batch,
14451440
&combined_left_indices,
1446-
mask_to_use,
1441+
&mask,
14471442
self.streamed_batch_counter.load(Relaxed),
14481443
self.join_type,
14491444
);
@@ -1468,7 +1463,7 @@ impl MaterializingSortMergeJoinStream {
14681463
let idx = right.value(i) as usize;
14691464
match buffered_batch.join_filter_status[idx] {
14701465
FilterState::SomePassed => {}
1471-
_ if pre_mask.value(offset + i) => {
1466+
_ if mask.value(offset + i) => {
14721467
buffered_batch.join_filter_status[idx] =
14731468
FilterState::SomePassed;
14741469
}

datafusion/physical-plan/src/joins/sort_merge_join/tests.rs

Lines changed: 105 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ use datafusion_common::{
5555
test_util::{batches_to_sort_string, batches_to_string},
5656
};
5757
use datafusion_common::{
58-
JoinType, NullEquality, Result, assert_batches_eq, assert_contains,
58+
JoinType, NullEquality, Result, ScalarValue, assert_batches_eq, assert_contains,
5959
};
6060
use datafusion_common_runtime::JoinSet;
6161
use datafusion_execution::config::SessionConfig;
@@ -65,6 +65,7 @@ use datafusion_execution::runtime_env::RuntimeEnvBuilder;
6565
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
6666
use datafusion_expr::Operator;
6767
use datafusion_physical_expr::expressions::BinaryExpr;
68+
use datafusion_physical_expr::expressions::Literal;
6869
use datafusion_physical_expr_common::physical_expr::PhysicalExprRef;
6970
use futures::{Stream, StreamExt};
7071
use insta::assert_snapshot;
@@ -2049,6 +2050,108 @@ async fn join_full_multiple_batches() -> Result<()> {
20492050
Ok(())
20502051
}
20512052

2053+
/// Full outer join where the filter evaluates to NULL due to a nullable column.
2054+
/// NULL filter results must be treated as unmatched, not matched.
2055+
/// Reproducer for SPARK-43113.
2056+
#[tokio::test]
2057+
async fn join_full_null_filter_result() -> Result<()> {
2058+
// Left: (a, b) all non-null, sorted on a
2059+
let left = build_table_two_cols(
2060+
("a1", &vec![1, 1, 2, 2, 3, 3]),
2061+
("b1", &vec![1, 2, 1, 2, 1, 2]),
2062+
);
2063+
2064+
// Right: (a, b) with b nullable, sorted on a
2065+
let right_schema = Arc::new(Schema::new(vec![
2066+
Field::new("a2", DataType::Int32, false),
2067+
Field::new("b2", DataType::Int32, true),
2068+
]));
2069+
let right_batch = RecordBatch::try_new(
2070+
Arc::clone(&right_schema),
2071+
vec![
2072+
Arc::new(Int32Array::from(vec![1, 2])),
2073+
Arc::new(Int32Array::from(vec![None, Some(2)])),
2074+
],
2075+
)?;
2076+
let right =
2077+
TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None).unwrap();
2078+
2079+
let on = vec![(
2080+
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
2081+
Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
2082+
)];
2083+
2084+
// Filter: b1 < (b2 + 1) AND b1 < (a2 + 1)
2085+
// When b2 is NULL, (b2 + 1) is NULL, so b1 < NULL is NULL → unmatched.
2086+
let lit_1: PhysicalExprRef = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
2087+
let b1_lt_b2_plus_1: PhysicalExprRef = Arc::new(BinaryExpr::new(
2088+
Arc::new(Column::new("b1", 0)),
2089+
Operator::Lt,
2090+
Arc::new(BinaryExpr::new(
2091+
Arc::new(Column::new("b2", 1)),
2092+
Operator::Plus,
2093+
Arc::clone(&lit_1),
2094+
)),
2095+
));
2096+
let b1_lt_a2_plus_1: PhysicalExprRef = Arc::new(BinaryExpr::new(
2097+
Arc::new(Column::new("b1", 0)),
2098+
Operator::Lt,
2099+
Arc::new(BinaryExpr::new(
2100+
Arc::new(Column::new("a2", 2)),
2101+
Operator::Plus,
2102+
Arc::clone(&lit_1),
2103+
)),
2104+
));
2105+
let filter_expr: PhysicalExprRef = Arc::new(BinaryExpr::new(
2106+
b1_lt_b2_plus_1,
2107+
Operator::And,
2108+
b1_lt_a2_plus_1,
2109+
));
2110+
2111+
let filter = JoinFilter::new(
2112+
filter_expr,
2113+
vec![
2114+
ColumnIndex {
2115+
index: 1,
2116+
side: JoinSide::Left,
2117+
},
2118+
ColumnIndex {
2119+
index: 1,
2120+
side: JoinSide::Right,
2121+
},
2122+
ColumnIndex {
2123+
index: 0,
2124+
side: JoinSide::Right,
2125+
},
2126+
],
2127+
Arc::new(Schema::new(vec![
2128+
Field::new("b1", DataType::Int32, true),
2129+
Field::new("b2", DataType::Int32, true),
2130+
Field::new("a2", DataType::Int32, true),
2131+
])),
2132+
);
2133+
2134+
let (_, batches) = join_collect_with_filter(left, right, on, filter, Full).await?;
2135+
2136+
// r=(1,NULL): b2 is NULL → b1 < (NULL+1) is NULL → all a=1 rows unmatched
2137+
// r=(2,2): b1 < 3 AND b1 < 3 → both l=(2,1) and l=(2,2) match
2138+
// l=(3,*): no right row with a=3 → unmatched
2139+
assert_snapshot!(batches_to_sort_string(&batches), @r"
2140+
+----+----+----+----+
2141+
| a1 | b1 | a2 | b2 |
2142+
+----+----+----+----+
2143+
| | | 1 | |
2144+
| 1 | 1 | | |
2145+
| 1 | 2 | | |
2146+
| 2 | 1 | 2 | 2 |
2147+
| 2 | 2 | 2 | 2 |
2148+
| 3 | 1 | | |
2149+
| 3 | 2 | | |
2150+
+----+----+----+----+
2151+
");
2152+
Ok(())
2153+
}
2154+
20522155
#[tokio::test]
20532156
async fn overallocation_single_batch_no_spill() -> Result<()> {
20542157
let left = build_table(
@@ -3589,9 +3692,7 @@ async fn join_filtered_with_multiple_buffered_batches() -> Result<()> {
35893692
Arc::new(Column::new("val_r", 1)),
35903693
)),
35913694
Operator::Lt,
3592-
Arc::new(datafusion_physical_expr::expressions::Literal::new(
3593-
datafusion_common::ScalarValue::Int32(Some(350)),
3594-
)),
3695+
Arc::new(Literal::new(ScalarValue::Int32(Some(350)))),
35953696
)),
35963697
vec![
35973698
ColumnIndex {

0 commit comments

Comments
 (0)