Skip to content

Commit 023d912

Browse files
authored
feat: fix array_compact for Spark 4.0 and correct return type metadata (#3796)
1 parent 964e578 commit 023d912

4 files changed

Lines changed: 63 additions & 20 deletions

File tree

native/spark-expr/src/array_funcs/array_compact.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ fn compact_list<OffsetSize: OffsetSizeTrait>(
132132
);
133133
let mut valid = NullBufferBuilder::new(list_array.len());
134134

135+
// Use logical_nulls() instead of is_null() to correctly handle NullArray.
136+
// NullArray::nulls() returns None (which makes is_null() return false),
137+
// but logical_nulls() correctly reports all elements as null.
138+
let value_nulls = values.logical_nulls();
139+
135140
for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
136141
if list_array.is_null(row_index) {
137142
offsets.push(offsets[row_index]);
@@ -144,7 +149,8 @@ fn compact_list<OffsetSize: OffsetSizeTrait>(
144149
let mut copied = 0usize;
145150

146151
for i in start..end {
147-
if !values.is_null(i) {
152+
let is_null = value_nulls.as_ref().map(|n| n.is_null(i)).unwrap_or(false);
153+
if !is_null {
148154
mutable.extend(0, i, i + 1);
149155
copied += 1;
150156
}

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ object CometArrayRepeat extends CometExpressionSerde[ArrayRepeat] {
295295
}
296296

297297
object CometArrayCompact extends CometExpressionSerde[Expression] {
298+
298299
override def convert(
299300
expr: Expression,
300301
inputs: Seq[Attribute],

spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator
2424
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
2525
import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.internal.types.StringTypeWithCollation
27-
import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, StringType}
27+
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, StringType}
2828

2929
import org.apache.comet.CometSparkSessionExtensions.withInfo
3030
import org.apache.comet.expressions.{CometCast, CometEvalMode}
3131
import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible}
3232
import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr}
33-
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}
33+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType}
3434

3535
/**
3636
* `CometExprShim` acts as a shim for parsing expressions from different Spark versions.
@@ -56,6 +56,28 @@ trait CometExprShim extends CommonStringExprs {
5656
inputs: Seq[Attribute],
5757
binding: Boolean): Option[Expr] = {
5858
expr match {
59+
case knc: KnownNotContainsNull =>
60+
// On Spark 4.0, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)).
61+
// Strip the wrapper and serialize the inner ArrayFilter as spark_array_compact.
62+
knc.child match {
63+
case filter: ArrayFilter =>
64+
filter.function.children.headOption match {
65+
case Some(_: IsNotNull) =>
66+
val arrayChild = filter.left
67+
val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType
68+
val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding)
69+
val returnType = ArrayType(elementType)
70+
val scalarExpr = scalarFunctionExprToProtoWithReturnType(
71+
"spark_array_compact",
72+
returnType,
73+
false,
74+
arrayExprProto)
75+
optExprWithInfo(scalarExpr, knc, arrayChild)
76+
case _ => exprToProtoInternal(knc.child, inputs, binding)
77+
}
78+
case _ => exprToProtoInternal(knc.child, inputs, binding)
79+
}
80+
5981
case s: StaticInvoke
6082
if s.staticObject == classOf[StringDecode] &&
6183
s.dataType.isInstanceOf[StringType] &&
@@ -109,12 +131,6 @@ trait CometExprShim extends CommonStringExprs {
109131
val optExpr = scalarFunctionExprToProto("width_bucket", childExprs: _*)
110132
optExprWithInfo(optExpr, wb, wb.children: _*)
111133

112-
// KnownNotContainsNull is a TaggingExpression added in Spark 4.0 that only
113-
// changes schema metadata (containsNull = false). It has no runtime effect,
114-
// so we pass through to the child expression.
115-
case k: KnownNotContainsNull =>
116-
exprToProtoInternal(k.child, inputs, binding)
117-
118134
// In Spark 4.0, StructsToJson is a RuntimeReplaceable whose replacement is
119135
// Invoke(Literal(StructsToJsonEvaluator), "evaluate", ...). Reconstruct the
120136
// original StructsToJson and recurse so support-level checks apply.

spark/src/test/resources/sql-tests/expressions/array/array_compact.sql

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,48 @@
1717

1818

1919
statement
20-
CREATE TABLE test_array_compact(arr array<int>) USING parquet
20+
CREATE TABLE test_array_compact(
21+
ints array<int>,
22+
strs array<string>,
23+
dbls array<double>,
24+
nested array<array<int>>
25+
) USING parquet
2126

2227
statement
23-
INSERT INTO test_array_compact VALUES (array(1, NULL, 2, NULL, 3)), (array()), (NULL), (array(NULL, NULL)), (array(1, 2, 3))
28+
INSERT INTO test_array_compact VALUES
29+
(array(1, NULL, 2, NULL, 3), array('a', NULL, 'b', NULL, 'c'), array(1.0, NULL, 2.0), array(array(1, NULL, 3), NULL, array(4, NULL, 6))),
30+
(array(), array(), array(), array()),
31+
(NULL, NULL, NULL, NULL),
32+
(array(NULL, NULL), array(NULL, NULL), array(NULL, NULL), array(NULL, NULL)),
33+
(array(1, 2, 3), array('x', 'y', 'z'), array(1.5, 2.5), array(array(1, 2), array(3, 4)))
2434

25-
-- column argument
35+
-- integer column
2636
query
27-
SELECT array_compact(arr) FROM test_array_compact
37+
SELECT array_compact(ints) FROM test_array_compact
38+
39+
-- string column
40+
query
41+
SELECT array_compact(strs) FROM test_array_compact
42+
43+
-- double column
44+
query
45+
SELECT array_compact(dbls) FROM test_array_compact
46+
47+
-- nested array column: outer nulls removed, inner nulls preserved
48+
query
49+
SELECT array_compact(nested) FROM test_array_compact
2850

2951
-- literal arguments
3052
query
3153
SELECT array_compact(array(1, NULL, 2, NULL, 3))
3254

33-
-- string element type
34-
statement
35-
CREATE TABLE test_array_compact_str(arr array<string>) USING parquet
36-
37-
statement
38-
INSERT INTO test_array_compact_str VALUES (array('a', NULL, 'b', NULL, 'c')), (array()), (NULL), (array(NULL, NULL)), (array('', NULL, '', NULL))
55+
-- literal string array
56+
query
57+
SELECT array_compact(array('a', NULL, 'b'))
3958

59+
-- all-null literal array
4060
query
41-
SELECT array_compact(arr) FROM test_array_compact_str
61+
SELECT array_compact(array(NULL, NULL, NULL))
4262

4363
-- double element type
4464
query

0 commit comments

Comments
 (0)