Skip to content

Commit f635cad

Browse files
chor: enable array_distinct (#3987)
1 parent 395e900 commit f635cad

5 files changed

Lines changed: 41 additions & 66 deletions

File tree

docs/source/user-guide/latest/expressions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ Comet supports using the following aggregate functions within window contexts wi
241241
| ArrayAppend | Yes | |
242242
| ArrayCompact | No | |
243243
| ArrayContains | Yes | |
244-
| ArrayDistinct | No | Behaves differently than spark. Comet first sorts then removes duplicates while Spark preserves the original order. |
244+
| ArrayDistinct | Yes | |
245245
| ArrayExcept | No | |
246246
| ArrayFilter | Yes | Only supports case where function is `IsNotNull` |
247247
| ArrayInsert | No | |

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
5050
classOf[ArrayAppend] -> CometArrayAppend,
5151
classOf[ArrayCompact] -> CometArrayCompact,
5252
classOf[ArrayContains] -> CometArrayContains,
53-
classOf[ArrayDistinct] -> CometArrayDistinct,
53+
classOf[ArrayDistinct] -> CometScalarFunction("array_distinct"),
5454
classOf[ArrayExcept] -> CometArrayExcept,
5555
classOf[ArrayFilter] -> CometArrayFilter,
5656
classOf[ArrayInsert] -> CometArrayInsert,

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import scala.annotation.tailrec
2323

24-
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray}
24+
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray}
2525
import org.apache.spark.sql.catalyst.util.GenericArrayData
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.types._
@@ -120,23 +120,6 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] {
120120
}
121121
}
122122

123-
object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] {
124-
125-
override def getSupportLevel(expr: ArrayDistinct): SupportLevel =
126-
Incompatible(Some("Output elements are sorted rather than preserving insertion order"))
127-
128-
override def convert(
129-
expr: ArrayDistinct,
130-
inputs: Seq[Attribute],
131-
binding: Boolean): Option[ExprOuterClass.Expr] = {
132-
val arrayExprProto = exprToProto(expr.children.head, inputs, binding)
133-
134-
val arrayDistinctScalarExpr =
135-
scalarFunctionExprToProto("array_distinct", arrayExprProto)
136-
optExprWithInfo(arrayDistinctScalarExpr, expr)
137-
}
138-
}
139-
140123
object CometSortArray extends CometExpressionSerde[SortArray] {
141124
private def containsFloatingPoint(dt: DataType): Boolean = {
142125
dt match {

spark/src/test/resources/sql-tests/expressions/array/array_distinct.sql

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
-- specific language governing permissions and limitations
1616
-- under the License.
1717

18-
-- ConfigMatrix: parquet.enable.dictionary=false,true
19-
2018
-- ===== INT arrays =====
2119

2220
statement
@@ -34,23 +32,23 @@ INSERT INTO test_array_distinct_int VALUES
3432
(array(0, -1, -1, 0, 1))
3533

3634
-- column argument
37-
query spark_answer_only
35+
query
3836
SELECT array_distinct(arr) FROM test_array_distinct_int
3937

4038
-- literal arguments
41-
query spark_answer_only
39+
query
4240
SELECT array_distinct(array(1, 2, 2, 3, 3))
4341

4442
-- all NULLs
45-
query spark_answer_only
43+
query
4644
SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS INT)))
4745

4846
-- NULL input
49-
query spark_answer_only
47+
query
5048
SELECT array_distinct(CAST(NULL AS array<int>))
5149

5250
-- boundary values
53-
query spark_answer_only
51+
query
5452
SELECT array_distinct(array(-2147483648, 2147483647, -2147483648, 2147483647, 0))
5553

5654
-- ===== LONG arrays =====
@@ -65,11 +63,11 @@ INSERT INTO test_array_distinct_long VALUES
6563
(array(NULL, 1, NULL, 2)),
6664
(array(-9223372036854775808, 9223372036854775807, -9223372036854775808))
6765

68-
query spark_answer_only
66+
query
6967
SELECT array_distinct(arr) FROM test_array_distinct_long
7068

7169
-- boundary values
72-
query spark_answer_only
70+
query
7371
SELECT array_distinct(array(CAST(-9223372036854775808 AS BIGINT), CAST(9223372036854775807 AS BIGINT), CAST(-9223372036854775808 AS BIGINT)))
7472

7573
-- ===== STRING arrays =====
@@ -86,11 +84,11 @@ INSERT INTO test_array_distinct_string VALUES
8684
(array('', '', NULL, '')),
8785
(array('hello', 'world', 'hello'))
8886

89-
query spark_answer_only
87+
query
9088
SELECT array_distinct(arr) FROM test_array_distinct_string
9189

9290
-- empty string and NULL distinction
93-
query spark_answer_only
91+
query
9492
SELECT array_distinct(array('', NULL, '', NULL, 'a'))
9593

9694
-- ===== BOOLEAN arrays =====
@@ -105,7 +103,7 @@ INSERT INTO test_array_distinct_bool VALUES
105103
(NULL),
106104
(array(NULL, true, NULL, false))
107105

108-
query spark_answer_only
106+
query
109107
SELECT array_distinct(arr) FROM test_array_distinct_bool
110108

111109
-- ===== DOUBLE arrays =====
@@ -119,23 +117,23 @@ INSERT INTO test_array_distinct_double VALUES
119117
(NULL),
120118
(array(NULL, 1.0, NULL, 2.0))
121119

122-
query spark_answer_only
120+
query
123121
SELECT array_distinct(arr) FROM test_array_distinct_double
124122

125123
-- NaN deduplication
126-
query spark_answer_only
124+
query
127125
SELECT array_distinct(array(CAST('NaN' AS DOUBLE), CAST('NaN' AS DOUBLE), 1.0, 1.0))
128126

129127
-- NaN with NULL
130-
query spark_answer_only
128+
query
131129
SELECT array_distinct(array(CAST('NaN' AS DOUBLE), NULL, CAST('NaN' AS DOUBLE), NULL, 1.0))
132130

133131
-- Infinity
134-
query spark_answer_only
132+
query
135133
SELECT array_distinct(array(CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST('Infinity' AS DOUBLE), 0.0))
136134

137135
-- negative zero
138-
query spark_answer_only
136+
query
139137
SELECT array_distinct(array(0.0, -0.0, 1.0))
140138

141139
-- ===== FLOAT arrays =====
@@ -149,11 +147,11 @@ INSERT INTO test_array_distinct_float VALUES
149147
(NULL),
150148
(array(CAST(NULL AS FLOAT), CAST(1.0 AS FLOAT), CAST(NULL AS FLOAT)))
151149

152-
query spark_answer_only
150+
query
153151
SELECT array_distinct(arr) FROM test_array_distinct_float
154152

155153
-- Float NaN deduplication
156-
query spark_answer_only
154+
query
157155
SELECT array_distinct(array(CAST('NaN' AS FLOAT), CAST('NaN' AS FLOAT), CAST(1.0 AS FLOAT)))
158156

159157
-- ===== DECIMAL arrays =====
@@ -167,13 +165,13 @@ INSERT INTO test_array_distinct_decimal VALUES
167165
(NULL),
168166
(array(NULL, 1.10, NULL, 1.10))
169167

170-
query spark_answer_only
168+
query
171169
SELECT array_distinct(arr) FROM test_array_distinct_decimal
172170

173171
-- ===== Nested array (array of arrays) =====
174172

175-
query spark_answer_only
173+
query
176174
SELECT array_distinct(array(array(1, 2), array(3, 4), array(1, 2), array(3, 4)))
177175

178-
query spark_answer_only
176+
query
179177
SELECT array_distinct(array(array(1, 2), CAST(NULL AS array<int>), array(1, 2), CAST(NULL AS array<int>)))

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.util.Random
2323

2424
import org.apache.hadoop.fs.Path
2525
import org.apache.spark.sql.CometTestBase
26-
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, ArraysOverlap, ArrayUnion}
26+
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, ArraysOverlap, ArrayUnion}
2727
import org.apache.spark.sql.catalyst.expressions.{ArrayContains, ArrayRemove}
2828
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2929
import org.apache.spark.sql.functions._
@@ -403,29 +403,23 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
403403
}
404404

405405
test("array_distinct") {
406-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[ArrayDistinct]) -> "true") {
407-
Seq(true, false).foreach { dictionaryEnabled =>
408-
withTempDir { dir =>
409-
withTempView("t1") {
410-
val path = new Path(dir.toURI.toString, "test.parquet")
411-
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
412-
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
413-
// The result needs to be in ascending order for checkSparkAnswerAndOperator to pass
414-
// because datafusion array_distinct sorts the elements and then removes the duplicates
415-
checkSparkAnswerAndOperator(
416-
spark.sql("SELECT array_distinct(array(_2, _2, _3, _4, _4)) FROM t1"))
417-
checkSparkAnswerAndOperator(
418-
spark.sql("SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1"))
419-
checkSparkAnswerAndOperator(spark.sql(
420-
"SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_2, _2, _4, _4, _5) END)) FROM t1"))
421-
// NULL needs to be the first element for checkSparkAnswerAndOperator to pass because
422-
// datafusion array_distinct sorts the elements and then removes the duplicates
423-
checkSparkAnswerAndOperator(
424-
spark.sql(
425-
"SELECT array_distinct(array(CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1"))
426-
checkSparkAnswerAndOperator(spark.sql(
427-
"SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS INT), _2, _2, _3, _4, _4)) FROM t1"))
428-
}
406+
Seq(true, false).foreach { dictionaryEnabled =>
407+
withTempDir { dir =>
408+
withTempView("t1") {
409+
val path = new Path(dir.toURI.toString, "test.parquet")
410+
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, n = 10000)
411+
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
412+
checkSparkAnswerAndOperator(
413+
spark.sql("SELECT array_distinct(array(_3, _2, _4, _2, _4)) FROM t1"))
414+
checkSparkAnswerAndOperator(
415+
spark.sql("SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_4) END)) FROM t1"))
416+
checkSparkAnswerAndOperator(spark.sql(
417+
"SELECT array_distinct((CASE WHEN _2 =_3 THEN array(_2, _2, _4, _4, _5) END)) FROM t1"))
418+
checkSparkAnswerAndOperator(
419+
spark.sql(
420+
"SELECT array_distinct(array(_2, _2, CAST(NULL AS INT), _3, _4, _4)) FROM t1"))
421+
checkSparkAnswerAndOperator(spark.sql(
422+
"SELECT array_distinct(array(_2, _2, CAST(NULL AS INT), CAST(NULL AS INT), _3, _4, _4)) FROM t1"))
429423
}
430424
}
431425
}

0 commit comments

Comments
 (0)