Skip to content

Commit 915d5bd

Browse files
authored
feat: support sort_array expression (#3706)
1 parent c427bc1 commit 915d5bd

5 files changed

Lines changed: 493 additions & 22 deletions

File tree

docs/source/user-guide/latest/compatibility.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on
6666
- **ArrayUnion**: Sorts input arrays before performing the union, while Spark preserves the order of the first array
6767
and appends unique elements from the second.
6868
[#3644](https://github.com/apache/datafusion-comet/issues/3644)
69+
- **SortArray**: Nested arrays with `Struct` or `Null` child values are not supported natively and will fall back to Spark.
6970

7071
### Date/Time Expressions
7172

docs/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
- [ ] sequence
106106
- [ ] shuffle
107107
- [ ] slice
108-
- [ ] sort_array
108+
- [x] sort_array
109109

110110
### bitwise_funcs
111111

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
6060
classOf[ArrayMin] -> CometArrayMin,
6161
classOf[ArrayRemove] -> CometArrayRemove,
6262
classOf[ArrayRepeat] -> CometArrayRepeat,
63+
classOf[SortArray] -> CometSortArray,
6364
classOf[ArraysOverlap] -> CometArraysOverlap,
6465
classOf[ArrayUnion] -> CometArrayUnion,
6566
classOf[CreateArray] -> CometCreateArray,
@@ -796,30 +797,23 @@ object QueryPlanSerde extends Logging with CometExprShim {
796797
* TODO: Include SparkSQL's [[YearMonthIntervalType]] and [[DayTimeIntervalType]]
797798
*/
798799
// scalastyle:on
799-
def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = {
800-
def canRank(dt: DataType): Boolean = {
801-
dt match {
802-
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
803-
_: DoubleType | _: DecimalType =>
804-
true
805-
case _: DateType | _: TimestampType | _: TimestampNTZType =>
806-
true
807-
case _: BooleanType | _: BinaryType | _: StringType => true
808-
case _ => false
809-
}
800+
def supportedScalarSortElementType(dt: DataType): Boolean = {
801+
dt match {
802+
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
803+
_: DoubleType | _: DecimalType | _: DateType | _: TimestampType | _: TimestampNTZType |
804+
_: BooleanType | _: BinaryType | _: StringType =>
805+
true
806+
case _ =>
807+
false
810808
}
809+
}
811810

811+
def supportedSortType(op: SparkPlan, sortOrder: Seq[SortOrder]): Boolean = {
812812
if (sortOrder.length == 1) {
813813
val canSort = sortOrder.head.dataType match {
814-
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
815-
_: DoubleType | _: DecimalType =>
816-
true
817-
case _: DateType | _: TimestampType | _: TimestampNTZType =>
818-
true
819-
case _: BooleanType | _: BinaryType | _: StringType => true
820-
case ArrayType(elementType, _) => canRank(elementType)
821-
case MapType(_, valueType, _) => canRank(valueType)
822-
case _ => false
814+
case ArrayType(elementType, _) => supportedScalarSortElementType(elementType)
815+
case MapType(_, valueType, _) => supportedScalarSortElementType(valueType)
816+
case _ => supportedScalarSortElementType(sortOrder.head.dataType)
823817
}
824818
if (!canSort) {
825819
withInfo(op, s"Sort on single column of type ${sortOrder.head.dataType} is not supported")

spark/src/main/scala/org/apache/comet/serde/arrays.scala

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ package org.apache.comet.serde
2121

2222
import scala.annotation.tailrec
2323

24-
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size}
24+
import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray}
2525
import org.apache.spark.sql.catalyst.util.GenericArrayData
2626
import org.apache.spark.sql.internal.SQLConf
2727
import org.apache.spark.sql.types._
2828

29+
import org.apache.comet.CometConf
2930
import org.apache.comet.CometSparkSessionExtensions.withInfo
3031
import org.apache.comet.serde.QueryPlanSerde._
3132
import org.apache.comet.shims.CometExprShim
@@ -136,6 +137,80 @@ object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] {
136137
}
137138
}
138139

140+
object CometSortArray extends CometExpressionSerde[SortArray] {
141+
private def containsFloatingPoint(dt: DataType): Boolean = {
142+
dt match {
143+
case FloatType | DoubleType => true
144+
case ArrayType(elementType, _) => containsFloatingPoint(elementType)
145+
case StructType(fields) => fields.exists(f => containsFloatingPoint(f.dataType))
146+
case MapType(keyType, valueType, _) =>
147+
containsFloatingPoint(keyType) || containsFloatingPoint(valueType)
148+
case _ => false
149+
}
150+
}
151+
152+
private def supportedSortArrayElementType(
153+
dt: DataType,
154+
nestedInArray: Boolean = false): Boolean = {
155+
dt match {
156+
// DataFusion's array_sort compares nested arrays through Arrow's rank kernel.
157+
// That kernel does not support Struct or Null child values,
158+
// so array<array<struct<...>>> and array<array<null>> would fail at runtime.
159+
case _: NullType if !nestedInArray =>
160+
true
161+
case ArrayType(elementType, _) =>
162+
supportedSortArrayElementType(elementType, nestedInArray = true)
163+
case StructType(fields) if !nestedInArray =>
164+
fields.forall(f => supportedSortArrayElementType(f.dataType))
165+
case _ =>
166+
supportedScalarSortElementType(dt)
167+
}
168+
}
169+
170+
override def getSupportLevel(expr: SortArray): SupportLevel = {
171+
val elementType = expr.base.dataType.asInstanceOf[ArrayType].elementType
172+
173+
if (!supportedSortArrayElementType(elementType)) {
174+
Unsupported(Some(s"Sort on array element type $elementType is not supported"))
175+
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
176+
containsFloatingPoint(elementType)) {
177+
Incompatible(
178+
Some(
179+
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
180+
s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
181+
s"${CometConf.COMPAT_GUIDE}"))
182+
} else {
183+
Compatible()
184+
}
185+
}
186+
187+
override def convert(
188+
expr: SortArray,
189+
inputs: Seq[Attribute],
190+
binding: Boolean): Option[ExprOuterClass.Expr] = {
191+
val arrayExprProto = exprToProtoInternal(expr.base, inputs, binding)
192+
val (sortDirectionExprProto, nullOrderingExprProto) = expr.ascendingOrder match {
193+
case Literal(value: Boolean, BooleanType) =>
194+
val direction = if (value) "ASC" else "DESC"
195+
val nullOrdering = if (value) "NULLS FIRST" else "NULLS LAST"
196+
(
197+
exprToProtoInternal(Literal(direction), inputs, binding),
198+
exprToProtoInternal(Literal(nullOrdering), inputs, binding))
199+
case other =>
200+
withInfo(expr, s"ascendingOrder must be a boolean literal: $other")
201+
(None, None)
202+
}
203+
204+
val sortArrayScalarExpr =
205+
scalarFunctionExprToProto(
206+
"array_sort",
207+
arrayExprProto,
208+
sortDirectionExprProto,
209+
nullOrderingExprProto)
210+
optExprWithInfo(sortArrayScalarExpr, expr, expr.children: _*)
211+
}
212+
}
213+
139214
object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] {
140215

141216
override def getSupportLevel(expr: ArrayIntersect): SupportLevel = Incompatible(None)

0 commit comments

Comments
 (0)