diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 4823153f7b..586a31813c 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -343,6 +343,7 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.serde.SupportConditionSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 6d6ac14ec9..2c8102d54f 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -219,6 +219,7 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.serde.SupportConditionSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala index 20c0343037..bd8f4fbbb9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometExpressionSerde.scala @@ -37,15 +37,39 @@ trait CometExpressionSerde[T <: Expression] { */ def getExprConfigName(expr: T): String = expr.getClass.getSimpleName + /** + * Declarative support conditions for this expression. First match wins. + * + * Prefer declaring `conditions` over overriding [[getSupportLevel]]: the list is enumerable for + * documentation and tests, and each condition carries a stable id and static description. + * + * Subclasses should declare `conditions` as a `val` so the list is built once, not on every + * `getSupportLevel` call. + */ + def conditions: Seq[SupportCondition[T]] = Seq.empty + /** * Determine the support level of the expression based on its attributes. * + * The default implementation is derived from [[conditions]]. Subclasses may still override this + * during migration, but new serdes should prefer declaring `conditions`. + * * @param expr * The Spark expression. * @return * Support level (Compatible, Incompatible, or Unsupported). */ - def getSupportLevel(expr: T): SupportLevel = Compatible(None) + def getSupportLevel(expr: T): SupportLevel = + conditions.find(_.fires(expr)) match { + case Some(c) => + val msg = Some(c.message(expr)).filter(_.nonEmpty) + c.level match { + case SupportLevelKind.Compatible => Compatible(msg) + case SupportLevelKind.Incompatible => Incompatible(msg) + case SupportLevelKind.Unsupported => Unsupported(msg) + } + case None => Compatible(None) + } /** * Convert a Spark expression into a protocol buffer representation that can be passed into diff --git a/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala b/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala index cb78c7d2d4..6786f9e65c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala +++ b/spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala @@ -19,8 +19,13 @@ package org.apache.comet.serde +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types._ +/** + * @see + * [[SupportLevelKind]] + */ sealed trait SupportLevel /** @@ -57,3 +62,101 @@ object SupportLevel { case _ => false } } + +/** + * The kind of support outcome produced by a [[SupportCondition]]. + * + * The member names mirror the [[SupportLevel]] case classes on purpose so that a condition's + * `level` reads the same as the resulting `SupportLevel`. To disambiguate references in this file + * or in wildcard imports of the `serde` package, qualify as `SupportLevelKind.Compatible` etc. + */ +sealed trait SupportLevelKind +object SupportLevelKind { + case object Compatible extends SupportLevelKind + case object Incompatible extends SupportLevelKind + case object Unsupported extends SupportLevelKind +} + +/** + * A single support condition: a predicate that, when matched against an expression, yields a + * [[SupportLevel]] with an optional message. + * + * Conditions are declared statically per serde so that they can be enumerated at build time for + * documentation and tests. They evaluate at runtime by calling `fires(expr)`. + * + * Ordering within a serde's `conditions` list is significant: the first condition whose `fires` + * predicate matches determines the outcome. If no condition matches, the expression is treated as + * `Compatible(None)`. + */ +trait SupportCondition[-T <: Expression] { + + /** Stable, machine-readable id, unique per serde. Used in docs and tests. */ + def id: String + + /** Static prose describing when this fires. For example, "Child is BinaryType". */ + def description: String + + /** The outcome if this condition matches. */ + def level: SupportLevelKind + + /** Runtime predicate. May consult `CometConf` or other dynamic state. */ + def fires(expr: T): Boolean + + /** Runtime message, usually constant. May interpolate from the expression. */ + def message(expr: T): String + + /** Optional issue links for doc output. */ + def issues: Seq[String] = Nil +} + +object SupportCondition { + + private final case class Impl[T <: Expression]( + id: String, + description: String, + level: SupportLevelKind, + firesFn: T => Boolean, + messageFn: T => String, + override val issues: Seq[String]) + extends SupportCondition[T] { + override def fires(expr: T): Boolean = firesFn(expr) + override def message(expr: T): String = messageFn(expr) + } + + /** Generic builder. Use this when `message` depends on the expression. */ + def apply[T <: Expression]( + id: String, + description: String, + level: SupportLevelKind, + fires: T => Boolean, + message: T => String, + issues: Seq[String] = Nil): SupportCondition[T] = + Impl(id, description, level, fires, message, issues) + + /** Convenience: unsupported with a static message. */ + def unsupported[T <: Expression]( + id: String, + description: String, + fires: T => Boolean, + message: String, + issues: Seq[String] = Nil): SupportCondition[T] = + Impl(id, description, SupportLevelKind.Unsupported, fires, (_: T) => message, issues) + + /** Convenience: incompatible with a static message. */ + def incompatible[T <: Expression]( + id: String, + description: String, + fires: T => Boolean, + message: String, + issues: Seq[String] = Nil): SupportCondition[T] = + Impl(id, description, SupportLevelKind.Incompatible, fires, (_: T) => message, issues) + + /** Convenience: compatible-with-note (a caveat on an otherwise supported path). */ + def compatibleWithNote[T <: Expression]( + id: String, + description: String, + fires: T => Boolean, + message: String, + issues: Seq[String] = Nil): SupportCondition[T] = + Impl(id, description, SupportLevelKind.Compatible, fires, (_: T) => message, issues) +} diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 14d4536fc1..7f476a4907 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -141,21 +141,27 @@ object CometSortArray extends CometExpressionSerde[SortArray] { } } - override def getSupportLevel(expr: SortArray): SupportLevel = { - val elementType = expr.base.dataType.asInstanceOf[ArrayType].elementType - - if (!supportedSortArrayElementType(elementType)) { - Unsupported(Some(s"Sort on array element type $elementType is not supported")) - } else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && - SupportLevel.containsFloatingPoint(elementType)) { - Incompatible( - Some( + override val conditions: Seq[SupportCondition[SortArray]] = { + def elementType(e: SortArray): DataType = + e.base.dataType.asInstanceOf[ArrayType].elementType + Seq( + SupportCondition[SortArray]( + id = "unsupported-element-type", + description = "Array element type is not supported for sorting", + level = SupportLevelKind.Unsupported, + fires = e => !supportedSortArrayElementType(elementType(e)), + message = e => s"Sort on array element type ${elementType(e)} is not supported"), + SupportCondition[SortArray]( + id = "strict-floating-point", + description = "Strict floating-point mode is on and element type contains float/double", + level = SupportLevelKind.Incompatible, + fires = e => + CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() && + SupportLevel.containsFloatingPoint(elementType(e)), + message = _ => "Sorting on floating-point is not 100% compatible with Spark, and Comet is running " + s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " + s"${CometConf.COMPAT_GUIDE}")) - } else { - Compatible() - } } override def convert( diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 50621fc389..7f54a3388d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -68,10 +68,12 @@ object CometUpper extends CometCaseConversionBase[Upper]("upper") object CometLower extends CometCaseConversionBase[Lower]("lower") object CometLength extends CometScalarFunction[Length]("length") { - override def getSupportLevel(expr: Length): SupportLevel = expr.child.dataType match { - case _: BinaryType => Unsupported(Some("Length on BinaryType is not supported")) - case _ => Compatible() - } + override val conditions: Seq[SupportCondition[Length]] = Seq( + SupportCondition.unsupported[Length]( + id = "binary-child", + description = "Child is BinaryType", + fires = _.child.dataType.isInstanceOf[BinaryType], + message = "Length on BinaryType is not supported")) } object CometInitCap extends CometScalarFunction[InitCap]("initcap") { diff --git a/spark/src/test/scala/org/apache/comet/serde/SupportConditionSuite.scala b/spark/src/test/scala/org/apache/comet/serde/SupportConditionSuite.scala new file mode 100644 index 0000000000..3fa598debf --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/serde/SupportConditionSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Length, Literal} +import org.apache.spark.sql.types.{BinaryType, IntegerType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class SupportConditionSuite extends AnyFunSuite { + + /** Serde with no conditions: default path yields Compatible(None). */ + private object EmptySerde extends CometExpressionSerde[Literal] { + override def convert( + expr: Literal, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = None + } + + /** Serde with three ordered conditions exercising each level. */ + private object OrderedSerde extends CometExpressionSerde[Literal] { + override val conditions: Seq[SupportCondition[Literal]] = Seq( + SupportCondition.unsupported[Literal]( + id = "null-literal", + description = "Literal value is null", + fires = _.value == null, + message = "null literal not supported"), + SupportCondition.incompatible[Literal]( + id = "long-literal", + description = "Literal is LongType", + fires = _.dataType == LongType, + message = "Long literals are incompatible"), + SupportCondition.compatibleWithNote[Literal]( + id = "string-with-note", + description = "Literal is StringType", + fires = _.dataType == StringType, + message = "string literal carries a note")) + + override def convert( + expr: Literal, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = None + } + + /** Serde with an expression-dependent message. */ + private object DynamicMessageSerde extends CometExpressionSerde[Literal] { + override val conditions: Seq[SupportCondition[Literal]] = Seq( + SupportCondition[Literal]( + id = "dynamic", + description = "Always fires, message includes data type", + level = SupportLevelKind.Unsupported, + fires = _ => true, + message = (e: Literal) => s"unsupported dtype ${e.dataType.simpleString}")) + + override def convert( + expr: Literal, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = None + } + + test("empty conditions returns Compatible(None)") { + val result = EmptySerde.getSupportLevel(Literal(1, IntegerType)) + assert(result == Compatible(None)) + } + + test("no matching condition returns Compatible(None)") { + val result = OrderedSerde.getSupportLevel(Literal(1, IntegerType)) + assert(result == Compatible(None)) + } + + test("first matching condition wins when multiple could fire") { + // null Long literal matches both "null-literal" (Unsupported) and + // "long-literal" (Incompatible). Unsupported must win because it is first. + val nullLong = Literal(null, LongType) + val result = OrderedSerde.getSupportLevel(nullLong) + assert(result == Unsupported(Some("null literal not supported"))) + } + + test("non-first matching condition produces its own message") { + val longLit = Literal(42L, LongType) + val result = OrderedSerde.getSupportLevel(longLit) + assert(result == Incompatible(Some("Long literals are incompatible"))) + } + + test("compatible-with-note produces Compatible(Some(msg))") { + val stringLit = Literal(UTF8String.fromString("hello"), StringType) + val result = OrderedSerde.getSupportLevel(stringLit) + assert(result == Compatible(Some("string literal carries a note"))) + } + + test("message can depend on the expression instance") { + val result = + DynamicMessageSerde.getSupportLevel(Literal(UTF8String.fromString("x"), StringType)) + assert(result == Unsupported(Some("unsupported dtype string"))) + } + + test("CometLength: BinaryType child returns Unsupported") { + val expr = Length(Literal(Array[Byte](1, 2, 3), BinaryType)) + val result = CometLength.getSupportLevel(expr) + assert(result == Unsupported(Some("Length on BinaryType is not supported"))) + } + + test("CometLength: StringType child returns Compatible") { + val expr = Length(Literal(UTF8String.fromString("hello"), StringType)) + val result = CometLength.getSupportLevel(expr) + assert(result == Compatible(None)) + } + + test("CometLength: declared condition pins id, level, and description") { + val conditions = CometLength.conditions + assert(conditions.size == 1) + val c = conditions.head + assert(c.id == "binary-child") + assert(c.level == SupportLevelKind.Unsupported) + assert(c.description == "Child is BinaryType") + } + + test("CometSortArray: unsupported element type returns Unsupported") { + import org.apache.spark.sql.catalyst.expressions.SortArray + import org.apache.spark.sql.types.{ArrayType, MapType} + val mapArray = Literal.create(null, ArrayType(MapType(StringType, StringType))) + val expr = SortArray(mapArray, Literal(true)) + val result = CometSortArray.getSupportLevel(expr) + assert(result.isInstanceOf[Unsupported]) + val msg = result.asInstanceOf[Unsupported].notes.getOrElse("") + assert(msg.startsWith("Sort on array element type")) + } + + test("CometSortArray: non-float element type returns Compatible") { + import org.apache.spark.sql.catalyst.expressions.SortArray + import org.apache.spark.sql.types.{ArrayType, IntegerType} + val arr = Literal.create(null, ArrayType(IntegerType)) + val expr = SortArray(arr, Literal(true)) + val result = CometSortArray.getSupportLevel(expr) + assert(result == Compatible(None)) + } + + test("CometSortArray: declared conditions order is element-type then floating-point") { + val conditions = CometSortArray.conditions + assert(conditions.size == 2) + assert(conditions(0).id == "unsupported-element-type") + assert(conditions(0).level == SupportLevelKind.Unsupported) + assert(conditions(1).id == "strict-floating-point") + assert(conditions(1).level == SupportLevelKind.Incompatible) + } +}