Skip to content

Commit 6d7dae9

Browse files
committed
feat: support collect_set
1 parent 88c1ffc commit 6d7dae9

File tree

7 files changed

+259
-5
lines changed

7 files changed

+259
-5
lines changed

docs/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
- [x] bool_and
3434
- [x] bool_or
3535
- [ ] collect_list
36-
- [ ] collect_set
36+
- [x] collect_set
3737
- [ ] corr
3838
- [x] count
3939
- [x] count_if

native/Cargo.lock

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/core/src/execution/planner.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ use datafusion_comet_spark_expr::{
7070
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle,
7171
BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SumInteger, ToCsv,
7272
};
73+
use datafusion_spark::function::aggregate::collect::SparkCollectSet;
7374
use iceberg::expr::Bind;
7475

7576
use crate::execution::operators::ExecutionError::GeneralError;
@@ -2266,6 +2267,11 @@ impl PhysicalPlanner {
22662267
));
22672268
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
22682269
}
2270+
AggExprStruct::CollectSet(expr) => {
2271+
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
2272+
let func = AggregateUDF::new_from_impl(SparkCollectSet::new());
2273+
Self::create_aggr_func_expr("collect_set", schema, vec![child], func)
2274+
}
22692275
}
22702276
}
22712277

native/proto/src/proto/expr.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ message AggExpr {
139139
Stddev stddev = 14;
140140
Correlation correlation = 15;
141141
BloomFilterAgg bloomFilterAgg = 16;
142+
CollectSet collectSet = 17;
142143
}
143144

144145
// Optional filter expression for SQL FILTER (WHERE ...) clause.
@@ -247,6 +248,11 @@ message BloomFilterAgg {
247248
DataType datatype = 4;
248249
}
249250

251+
message CollectSet {
252+
Expr child = 1;
253+
DataType datatype = 2;
254+
}
255+
250256
enum EvalMode {
251257
LEGACY = 0;
252258
TRY = 1;

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
258258
classOf[BitOrAgg] -> CometBitOrAgg,
259259
classOf[BitXorAgg] -> CometBitXOrAgg,
260260
classOf[BloomFilterAggregate] -> CometBloomFilterAggregate,
261+
classOf[CollectSet] -> CometCollectSet,
261262
classOf[Corr] -> CometCorr,
262263
classOf[Count] -> CometCount,
263264
classOf[CovPopulation] -> CometCovPopulation,

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.comet.serde
2222
import scala.jdk.CollectionConverters._
2323

2424
import org.apache.spark.sql.catalyst.expressions.Attribute
25-
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
2828

@@ -671,6 +671,37 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
671671
}
672672
}
673673

674+
object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] {
675+
override def convert(
676+
aggExpr: AggregateExpression,
677+
expr: CollectSet,
678+
inputs: Seq[Attribute],
679+
binding: Boolean,
680+
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
681+
val child = expr.children.head
682+
val childExpr = exprToProto(child, inputs, binding)
683+
val dataType = serializeDataType(expr.dataType)
684+
685+
if (childExpr.isDefined && dataType.isDefined) {
686+
val builder = ExprOuterClass.CollectSet.newBuilder()
687+
builder.setChild(childExpr.get)
688+
builder.setDatatype(dataType.get)
689+
690+
Some(
691+
ExprOuterClass.AggExpr
692+
.newBuilder()
693+
.setCollectSet(builder)
694+
.build())
695+
} else if (dataType.isEmpty) {
696+
withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
697+
None
698+
} else {
699+
withInfo(aggExpr, child)
700+
None
701+
}
702+
}
703+
}
704+
674705
object AggSerde {
675706
import org.apache.spark.sql.types._
676707

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
-- Licensed to the Apache Software Foundation (ASF) under one
2+
-- or more contributor license agreements. See the NOTICE file
3+
-- distributed with this work for additional information
4+
-- regarding copyright ownership. The ASF licenses this file
5+
-- to you under the Apache License, Version 2.0 (the
6+
-- "License"); you may not use this file except in compliance
7+
-- with the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing,
12+
-- software distributed under the License is distributed on an
13+
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
-- KIND, either express or implied. See the License for the
15+
-- specific language governing permissions and limitations
16+
-- under the License.
17+
18+
-- Config: spark.comet.expression.CollectSet.allowIncompatible=true
19+
-- ConfigMatrix: parquet.enable.dictionary=false,true
20+
21+
-- ============================================================
22+
-- Setup: tables
23+
-- ============================================================
24+
25+
statement
26+
CREATE TABLE test_collect_set_int(i int, grp string) USING parquet
27+
28+
statement
29+
INSERT INTO test_collect_set_int VALUES
30+
(1, 'a'), (2, 'a'), (1, 'a'), (3, 'a'),
31+
(4, 'b'), (4, 'b'), (NULL, 'b'), (5, 'b')
32+
33+
statement
34+
CREATE TABLE test_collect_set_types(
35+
b boolean, bi bigint, d double, s string, dc decimal(10,2), dt date, grp string
36+
) USING parquet
37+
38+
statement
39+
INSERT INTO test_collect_set_types VALUES
40+
(true, 10, 1.1, 'x', 1.50, DATE '2024-01-01', 'a'),
41+
(false, 20, 2.2, 'y', 2.50, DATE '2024-01-02', 'a'),
42+
(true, 10, 1.1, 'x', 1.50, DATE '2024-01-01', 'a'),
43+
(NULL, 30, 3.3, 'z', 3.50, DATE '2024-01-03', 'b'),
44+
(true, 30, 3.3, 'z', 3.50, DATE '2024-01-03', 'b')
45+
46+
statement
47+
CREATE TABLE test_collect_set_nulls(val int, grp string) USING parquet
48+
49+
statement
50+
INSERT INTO test_collect_set_nulls VALUES
51+
(NULL, 'a'), (NULL, 'a'), (NULL, 'b'), (1, 'b')
52+
53+
statement
54+
CREATE TABLE test_collect_set_empty(val int) USING parquet
55+
56+
statement
57+
CREATE TABLE test_collect_set_single(val int) USING parquet
58+
59+
statement
60+
INSERT INTO test_collect_set_single VALUES (42)
61+
62+
-- ============================================================
63+
-- Note: collect_set result ordering is non-deterministic.
64+
-- We materialize aggregate results via CTAS and then sort
65+
-- the arrays in a separate query to avoid sort_array in the
66+
-- aggregate result expressions (which would cause the Final
67+
-- aggregate to fall back to Spark).
68+
-- ============================================================
69+
70+
-- ============================================================
71+
-- Operator coverage: verify collect_set runs natively
72+
-- (use size() which is supported, avoids array ordering issues)
73+
-- ============================================================
74+
75+
query
76+
SELECT grp, size(collect_set(i)) FROM test_collect_set_int GROUP BY grp ORDER BY grp
77+
78+
-- ============================================================
79+
-- Basic: integer dedup
80+
-- ============================================================
81+
82+
statement
83+
CREATE TABLE cs_basic USING parquet AS
84+
SELECT collect_set(i) as cs FROM test_collect_set_int
85+
86+
query spark_answer_only
87+
SELECT sort_array(cs) FROM cs_basic
88+
89+
-- ============================================================
90+
-- GROUP BY: integer dedup per group
91+
-- ============================================================
92+
93+
statement
94+
CREATE TABLE cs_grp_int USING parquet AS
95+
SELECT grp, collect_set(i) as cs FROM test_collect_set_int GROUP BY grp
96+
97+
query spark_answer_only
98+
SELECT grp, sort_array(cs) FROM cs_grp_int ORDER BY grp
99+
100+
-- ============================================================
101+
-- NULLs: all NULLs in a group returns empty array
102+
-- ============================================================
103+
104+
statement
105+
CREATE TABLE cs_nulls USING parquet AS
106+
SELECT grp, collect_set(val) as cs FROM test_collect_set_nulls GROUP BY grp
107+
108+
query spark_answer_only
109+
SELECT grp, sort_array(cs) FROM cs_nulls ORDER BY grp
110+
111+
-- ============================================================
112+
-- Empty table: returns empty array
113+
-- ============================================================
114+
115+
statement
116+
CREATE TABLE cs_empty USING parquet AS
117+
SELECT collect_set(val) as cs FROM test_collect_set_empty
118+
119+
query spark_answer_only
120+
SELECT sort_array(cs) FROM cs_empty
121+
122+
-- ============================================================
123+
-- Single value
124+
-- ============================================================
125+
126+
statement
127+
CREATE TABLE cs_single USING parquet AS
128+
SELECT collect_set(val) as cs FROM test_collect_set_single
129+
130+
query spark_answer_only
131+
SELECT sort_array(cs) FROM cs_single
132+
133+
-- ============================================================
134+
-- Multiple data types
135+
-- ============================================================
136+
137+
-- boolean
138+
statement
139+
CREATE TABLE cs_bool USING parquet AS
140+
SELECT grp, collect_set(b) as cs FROM test_collect_set_types GROUP BY grp
141+
142+
query spark_answer_only
143+
SELECT grp, sort_array(cs) FROM cs_bool ORDER BY grp
144+
145+
-- bigint
146+
statement
147+
CREATE TABLE cs_bigint USING parquet AS
148+
SELECT grp, collect_set(bi) as cs FROM test_collect_set_types GROUP BY grp
149+
150+
query spark_answer_only
151+
SELECT grp, sort_array(cs) FROM cs_bigint ORDER BY grp
152+
153+
-- double
154+
statement
155+
CREATE TABLE cs_double USING parquet AS
156+
SELECT grp, collect_set(d) as cs FROM test_collect_set_types GROUP BY grp
157+
158+
query spark_answer_only
159+
SELECT grp, sort_array(cs) FROM cs_double ORDER BY grp
160+
161+
-- string
162+
statement
163+
CREATE TABLE cs_string USING parquet AS
164+
SELECT grp, collect_set(s) as cs FROM test_collect_set_types GROUP BY grp
165+
166+
query spark_answer_only
167+
SELECT grp, sort_array(cs) FROM cs_string ORDER BY grp
168+
169+
-- decimal
170+
statement
171+
CREATE TABLE cs_decimal USING parquet AS
172+
SELECT grp, collect_set(dc) as cs FROM test_collect_set_types GROUP BY grp
173+
174+
query spark_answer_only
175+
SELECT grp, sort_array(cs) FROM cs_decimal ORDER BY grp
176+
177+
-- date
178+
statement
179+
CREATE TABLE cs_date USING parquet AS
180+
SELECT grp, collect_set(dt) as cs FROM test_collect_set_types GROUP BY grp
181+
182+
query spark_answer_only
183+
SELECT grp, sort_array(cs) FROM cs_date ORDER BY grp
184+
185+
-- ============================================================
186+
-- Mixed with other aggregates
187+
-- ============================================================
188+
189+
statement
190+
CREATE TABLE cs_mixed USING parquet AS
191+
SELECT grp, collect_set(i) as cs, count(*) as cnt, sum(i) as total
192+
FROM test_collect_set_int GROUP BY grp
193+
194+
query spark_answer_only
195+
SELECT grp, sort_array(cs), cnt, total FROM cs_mixed ORDER BY grp
196+
197+
-- ============================================================
198+
-- All duplicates in a group
199+
-- ============================================================
200+
201+
statement
202+
CREATE TABLE test_collect_set_dupes(val int, grp string) USING parquet
203+
204+
statement
205+
INSERT INTO test_collect_set_dupes VALUES (7, 'a'), (7, 'a'), (7, 'a'), (8, 'b'), (9, 'b')
206+
207+
statement
208+
CREATE TABLE cs_dupes USING parquet AS
209+
SELECT grp, collect_set(val) as cs FROM test_collect_set_dupes GROUP BY grp
210+
211+
query spark_answer_only
212+
SELECT grp, sort_array(cs) FROM cs_dupes ORDER BY grp

0 commit comments

Comments
 (0)