@@ -22,7 +22,7 @@ package org.apache.comet.serde
2222import scala .jdk .CollectionConverters ._
2323
2424import org .apache .spark .sql .catalyst .expressions .Attribute
25- import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , BitAndAgg , BitOrAgg , BitXorAgg , BloomFilterAggregate , CentralMomentAgg , Corr , Count , Covariance , CovPopulation , CovSample , First , Last , Max , Min , StddevPop , StddevSamp , Sum , VariancePop , VarianceSamp }
25+ import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , BitAndAgg , BitOrAgg , BitXorAgg , BloomFilterAggregate , CentralMomentAgg , CollectSet , Corr , Count , Covariance , CovPopulation , CovSample , First , Last , Max , Min , StddevPop , StddevSamp , Sum , VariancePop , VarianceSamp }
2626import org .apache .spark .sql .internal .SQLConf
2727import org .apache .spark .sql .types .{ByteType , DataTypes , DecimalType , IntegerType , LongType , ShortType , StringType }
2828
@@ -664,6 +664,52 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt
664664 }
665665}
666666
667+ object CometCollectSet extends CometAggregateExpressionSerde [CollectSet ] {
668+
669+ override def getSupportLevel (expr : CollectSet ): SupportLevel = {
670+ if (COMET_EXEC_STRICT_FLOATING_POINT .get() &&
671+ SupportLevel .containsFloatingPoint(expr.children.head.dataType)) {
672+ Incompatible (
673+ Some (
674+ " collect_set on floating-point types is not 100% compatible with Spark " +
675+ " (Comet deduplicates NaN values while Spark treats each NaN as distinct), " +
676+ s " and Comet is running with ${COMET_EXEC_STRICT_FLOATING_POINT .key}=true. " +
677+ s " ${CometConf .COMPAT_GUIDE }" ))
678+ } else {
679+ Compatible ()
680+ }
681+ }
682+
683+ override def convert (
684+ aggExpr : AggregateExpression ,
685+ expr : CollectSet ,
686+ inputs : Seq [Attribute ],
687+ binding : Boolean ,
688+ conf : SQLConf ): Option [ExprOuterClass .AggExpr ] = {
689+ val child = expr.children.head
690+ val childExpr = exprToProto(child, inputs, binding)
691+ val dataType = serializeDataType(expr.dataType)
692+
693+ if (childExpr.isDefined && dataType.isDefined) {
694+ val builder = ExprOuterClass .CollectSet .newBuilder()
695+ builder.setChild(childExpr.get)
696+ builder.setDatatype(dataType.get)
697+
698+ Some (
699+ ExprOuterClass .AggExpr
700+ .newBuilder()
701+ .setCollectSet(builder)
702+ .build())
703+ } else if (dataType.isEmpty) {
704+ withInfo(aggExpr, s " datatype ${expr.dataType} is not supported " , child)
705+ None
706+ } else {
707+ withInfo(aggExpr, child)
708+ None
709+ }
710+ }
711+ }
712+
667713object AggSerde {
668714 import org .apache .spark .sql .types ._
669715
0 commit comments