Skip to content

Commit d6d5f09

Browse files
authored
feat: support collect_set (#3954)
1 parent a2a3dd3 commit d6d5f09

12 files changed

Lines changed: 426 additions & 28 deletions

File tree

docs/source/user-guide/latest/compatibility.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ Expressions that are not 100% Spark-compatible will fall back to Spark by defaul
5858
`spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is the Spark expression class name. See
5959
the [Comet Supported Expressions Guide](expressions.md) for more information on this configuration setting.
6060

61+
### Aggregate Expressions
62+
63+
- **CollectSet**: Comet deduplicates NaN values (treats `NaN == NaN`) while Spark treats each NaN as a distinct value.
64+
When `spark.comet.exec.strictFloatingPoint=true`, `collect_set` on floating-point types falls back to Spark unless
65+
`spark.comet.expression.CollectSet.allowIncompatible=true` is set.
66+
6167
### Array Expressions
6268

6369
- **ArraysOverlap**: Inconsistent behavior when arrays contain NULL values.

docs/source/user-guide/latest/expressions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ Expressions that are not Spark-compatible will fall back to Spark by default and
203203
| BitXorAgg | | Yes | |
204204
| BoolAnd | `bool_and` | Yes | |
205205
| BoolOr | `bool_or` | Yes | |
206+
| CollectSet | | No | NaN dedup differs from Spark. See compatibility guide. |
206207
| Corr | | Yes | |
207208
| Count | | Yes | |
208209
| CovPopulation | | Yes | |

docs/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
- [x] bool_and
3434
- [x] bool_or
3535
- [ ] collect_list
36-
- [ ] collect_set
36+
- [x] collect_set
3737
- [ ] corr
3838
- [x] count
3939
- [x] count_if

native/core/src/execution/planner.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ use datafusion_comet_spark_expr::{
7070
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle,
7171
BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SumInteger, ToCsv,
7272
};
73+
use datafusion_spark::function::aggregate::collect::SparkCollectSet;
7374
use iceberg::expr::Bind;
7475

7576
use crate::execution::operators::ExecutionError::GeneralError;
@@ -2266,6 +2267,11 @@ impl PhysicalPlanner {
22662267
));
22672268
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
22682269
}
2270+
AggExprStruct::CollectSet(expr) => {
2271+
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
2272+
let func = AggregateUDF::new_from_impl(SparkCollectSet::new());
2273+
Self::create_aggr_func_expr("collect_set", schema, vec![child], func)
2274+
}
22692275
}
22702276
}
22712277

native/proto/src/proto/expr.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ message AggExpr {
140140
Stddev stddev = 14;
141141
Correlation correlation = 15;
142142
BloomFilterAgg bloomFilterAgg = 16;
143+
CollectSet collectSet = 17;
143144
}
144145

145146
// Optional filter expression for SQL FILTER (WHERE ...) clause.
@@ -248,6 +249,11 @@ message BloomFilterAgg {
248249
DataType datatype = 4;
249250
}
250251

252+
message CollectSet {
253+
Expr child = 1;
254+
DataType datatype = 2;
255+
}
256+
251257
enum EvalMode {
252258
LEGACY = 0;
253259
TRY = 1;

spark/src/main/scala/org/apache/comet/serde/CometSortOrder.scala

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
package org.apache.comet.serde
2121

2222
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Descending, NullsFirst, NullsLast, SortOrder}
23-
import org.apache.spark.sql.types._
2423

2524
import org.apache.comet.CometConf
2625
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -30,19 +29,8 @@ object CometSortOrder extends CometExpressionSerde[SortOrder] {
3029

3130
override def getSupportLevel(expr: SortOrder): SupportLevel = {
3231

33-
def containsFloatingPoint(dt: DataType): Boolean = {
34-
dt match {
35-
case DataTypes.FloatType | DataTypes.DoubleType => true
36-
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
37-
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
38-
case MapType(keyType, valueType, _) =>
39-
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
40-
case _ => false
41-
}
42-
}
43-
4432
if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
45-
containsFloatingPoint(expr.child.dataType)) {
33+
SupportLevel.containsFloatingPoint(expr.child.dataType)) {
4634
// https://github.com/apache/datafusion-comet/issues/2626
4735
Incompatible(
4836
Some(

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
262262
classOf[BitOrAgg] -> CometBitOrAgg,
263263
classOf[BitXorAgg] -> CometBitXOrAgg,
264264
classOf[BloomFilterAggregate] -> CometBloomFilterAggregate,
265+
classOf[CollectSet] -> CometCollectSet,
265266
classOf[Corr] -> CometCorr,
266267
classOf[Count] -> CometCount,
267268
classOf[CovPopulation] -> CometCovPopulation,

spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
package org.apache.comet.serde
2121

22+
import org.apache.spark.sql.types._
23+
2224
sealed trait SupportLevel
2325

2426
/**
@@ -40,3 +42,18 @@ case class Incompatible(notes: Option[String] = None) extends SupportLevel
4042

4143
/** Comet does not support this feature */
4244
case class Unsupported(notes: Option[String] = None) extends SupportLevel
45+
46+
object SupportLevel {
47+
48+
/**
49+
* Returns true if the given data type contains FloatType or DoubleType at any nesting level.
50+
*/
51+
def containsFloatingPoint(dt: DataType): Boolean = dt match {
52+
case FloatType | DoubleType => true
53+
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
54+
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
55+
case MapType(keyType, valueType, _) =>
56+
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
57+
case _ => false
58+
}
59+
}

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.comet.serde
2222
import scala.jdk.CollectionConverters._
2323

2424
import org.apache.spark.sql.catalyst.expressions.Attribute
25-
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, CollectSet, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
2828

@@ -664,6 +664,52 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
664664
}
665665
}
666666

667+
object CometCollectSet extends CometAggregateExpressionSerde[CollectSet] {
668+
669+
override def getSupportLevel(expr: CollectSet): SupportLevel = {
670+
if (COMET_EXEC_STRICT_FLOATING_POINT.get() &&
671+
SupportLevel.containsFloatingPoint(expr.children.head.dataType)) {
672+
Incompatible(
673+
Some(
674+
"collect_set on floating-point types is not 100% compatible with Spark " +
675+
"(Comet deduplicates NaN values while Spark treats each NaN as distinct), " +
676+
s"and Comet is running with ${COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
677+
s"${CometConf.COMPAT_GUIDE}"))
678+
} else {
679+
Compatible()
680+
}
681+
}
682+
683+
override def convert(
684+
aggExpr: AggregateExpression,
685+
expr: CollectSet,
686+
inputs: Seq[Attribute],
687+
binding: Boolean,
688+
conf: SQLConf): Option[ExprOuterClass.AggExpr] = {
689+
val child = expr.children.head
690+
val childExpr = exprToProto(child, inputs, binding)
691+
val dataType = serializeDataType(expr.dataType)
692+
693+
if (childExpr.isDefined && dataType.isDefined) {
694+
val builder = ExprOuterClass.CollectSet.newBuilder()
695+
builder.setChild(childExpr.get)
696+
builder.setDatatype(dataType.get)
697+
698+
Some(
699+
ExprOuterClass.AggExpr
700+
.newBuilder()
701+
.setCollectSet(builder)
702+
.build())
703+
} else if (dataType.isEmpty) {
704+
withInfo(aggExpr, s"datatype ${expr.dataType} is not supported", child)
705+
None
706+
} else {
707+
withInfo(aggExpr, child)
708+
None
709+
}
710+
}
711+
}
712+
667713
object AggSerde {
668714
import org.apache.spark.sql.types._
669715

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,16 +121,6 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] {
121121
}
122122

123123
object CometSortArray extends CometExpressionSerde[SortArray] {
124-
private def containsFloatingPoint(dt: DataType): Boolean = {
125-
dt match {
126-
case FloatType | DoubleType => true
127-
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
128-
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
129-
case MapType(keyType, valueType, _) =>
130-
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
131-
case _ => false
132-
}
133-
}
134124

135125
private def supportedSortArrayElementType(
136126
dt: DataType,
@@ -156,7 +146,7 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
156146
if (!supportedSortArrayElementType(elementType)) {
157147
Unsupported(Some(s"Sort on array element type $elementType is not supported"))
158148
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
159-
containsFloatingPoint(elementType)) {
149+
SupportLevel.containsFloatingPoint(elementType)) {
160150
Incompatible(
161151
Some(
162152
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +

0 commit comments

Comments
 (0)