@@ -21,11 +21,12 @@ package org.apache.comet.serde
2121
2222import scala .annotation .tailrec
2323
24- import org .apache .spark .sql .catalyst .expressions .{ArrayAppend , ArrayContains , ArrayDistinct , ArrayExcept , ArrayFilter , ArrayInsert , ArrayIntersect , ArrayJoin , ArrayMax , ArrayMin , ArrayRemove , ArrayRepeat , ArraysOverlap , ArrayUnion , Attribute , CreateArray , ElementAt , Expression , Flatten , GetArrayItem , IsNotNull , Literal , Reverse , Size }
24+ import org .apache .spark .sql .catalyst .expressions .{ArrayAppend , ArrayContains , ArrayDistinct , ArrayExcept , ArrayFilter , ArrayInsert , ArrayIntersect , ArrayJoin , ArrayMax , ArrayMin , ArrayRemove , ArrayRepeat , ArraysOverlap , ArrayUnion , Attribute , CreateArray , ElementAt , Expression , Flatten , GetArrayItem , IsNotNull , Literal , Reverse , Size , SortArray }
2525import org .apache .spark .sql .catalyst .util .GenericArrayData
2626import org .apache .spark .sql .internal .SQLConf
2727import org .apache .spark .sql .types ._
2828
29+ import org .apache .comet .CometConf
2930import org .apache .comet .CometSparkSessionExtensions .withInfo
3031import org .apache .comet .serde .QueryPlanSerde ._
3132import org .apache .comet .shims .CometExprShim
@@ -136,6 +137,80 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] {
136137 }
137138}
138139
140+ object CometSortArray extends CometExpressionSerde [SortArray ] {
141+ private def containsFloatingPoint (dt : DataType ): Boolean = {
142+ dt match {
143+ case FloatType | DoubleType => true
144+ case ArrayType (elementType, _) => containsFloatingPoint(elementType)
145+ case StructType (fields) => fields.exists(f => containsFloatingPoint(f.dataType))
146+ case MapType (keyType, valueType, _) =>
147+ containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
148+ case _ => false
149+ }
150+ }
151+
152+ private def supportedSortArrayElementType (
153+ dt : DataType ,
154+ nestedInArray : Boolean = false ): Boolean = {
155+ dt match {
156+ // DataFusion's array_sort compares nested arrays through Arrow's rank kernel.
157+ // That kernel does not support Struct or Null child values,
158+ // so array<array<struct<...>>> and array<array<null>> would fail at runtime.
159+ case _ : NullType if ! nestedInArray =>
160+ true
161+ case ArrayType (elementType, _) =>
162+ supportedSortArrayElementType(elementType, nestedInArray = true )
163+ case StructType (fields) if ! nestedInArray =>
164+ fields.forall(f => supportedSortArrayElementType(f.dataType))
165+ case _ =>
166+ supportedScalarSortElementType(dt)
167+ }
168+ }
169+
170+ override def getSupportLevel (expr : SortArray ): SupportLevel = {
171+ val elementType = expr.base.dataType.asInstanceOf [ArrayType ].elementType
172+
173+ if (! supportedSortArrayElementType(elementType)) {
174+ Unsupported (Some (s " Sort on array element type $elementType is not supported " ))
175+ } else if (CometConf .COMET_EXEC_STRICT_FLOATING_POINT .get() &&
176+ containsFloatingPoint(elementType)) {
177+ Incompatible (
178+ Some (
179+ " Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
180+ s " with ${CometConf .COMET_EXEC_STRICT_FLOATING_POINT .key}=true. " +
181+ s " ${CometConf .COMPAT_GUIDE }" ))
182+ } else {
183+ Compatible ()
184+ }
185+ }
186+
187+ override def convert (
188+ expr : SortArray ,
189+ inputs : Seq [Attribute ],
190+ binding : Boolean ): Option [ExprOuterClass .Expr ] = {
191+ val arrayExprProto = exprToProtoInternal(expr.base, inputs, binding)
192+ val (sortDirectionExprProto, nullOrderingExprProto) = expr.ascendingOrder match {
193+ case Literal (value : Boolean , BooleanType ) =>
194+ val direction = if (value) " ASC" else " DESC"
195+ val nullOrdering = if (value) " NULLS FIRST" else " NULLS LAST"
196+ (
197+ exprToProtoInternal(Literal (direction), inputs, binding),
198+ exprToProtoInternal(Literal (nullOrdering), inputs, binding))
199+ case other =>
200+ withInfo(expr, s " ascendingOrder must be a boolean literal: $other" )
201+ (None , None )
202+ }
203+
204+ val sortArrayScalarExpr =
205+ scalarFunctionExprToProto(
206+ " array_sort" ,
207+ arrayExprProto,
208+ sortDirectionExprProto,
209+ nullOrderingExprProto)
210+ optExprWithInfo(sortArrayScalarExpr, expr, expr.children: _* )
211+ }
212+ }
213+
139214object CometArrayIntersect extends CometExpressionSerde [ArrayIntersect ] {
140215
141216 override def getSupportLevel (expr : ArrayIntersect ): SupportLevel = Incompatible (None )
0 commit comments