Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_build_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ jobs:
org.apache.comet.expressions.conditional.CometIfSuite
org.apache.comet.expressions.conditional.CometCoalesceSuite
org.apache.comet.expressions.conditional.CometCaseWhenSuite
org.apache.comet.serde.SupportConditionSuite
- name: "sql"
value: |
org.apache.spark.sql.CometToPrettyStringSuite
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr_build_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ jobs:
org.apache.comet.expressions.conditional.CometIfSuite
org.apache.comet.expressions.conditional.CometCoalesceSuite
org.apache.comet.expressions.conditional.CometCaseWhenSuite
org.apache.comet.serde.SupportConditionSuite
- name: "sql"
value: |
org.apache.spark.sql.CometToPrettyStringSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,39 @@ trait CometExpressionSerde[T <: Expression] {
*/
def getExprConfigName(expr: T): String = expr.getClass.getSimpleName

/**
* Declarative support conditions for this expression. First match wins.
*
* Prefer declaring `conditions` over overriding [[getSupportLevel]]: the list is enumerable for
* documentation and tests, and each condition carries a stable id and static description.
*
* Subclasses should declare `conditions` as a `val` so the list is built once, not on every
* `getSupportLevel` call.
*/
def conditions: Seq[SupportCondition[T]] = Seq.empty

/**
* Determine the support level of the expression based on its attributes.
*
* The default implementation is derived from [[conditions]]. Subclasses may still override this
* during migration, but new serdes should prefer declaring `conditions`.
*
* @param expr
* The Spark expression.
* @return
* Support level (Compatible, Incompatible, or Unsupported).
*/
def getSupportLevel(expr: T): SupportLevel = Compatible(None)
def getSupportLevel(expr: T): SupportLevel =
conditions.find(_.fires(expr)) match {
case Some(c) =>
val msg = Some(c.message(expr)).filter(_.nonEmpty)
c.level match {
case SupportLevelKind.Compatible => Compatible(msg)
case SupportLevelKind.Incompatible => Incompatible(msg)
case SupportLevelKind.Unsupported => Unsupported(msg)
}
case None => Compatible(None)
}

/**
* Convert a Spark expression into a protocol buffer representation that can be passed into
Expand Down
103 changes: 103 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/SupportLevel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@

package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.types._

/**
* @see
* [[SupportLevelKind]]
*/
sealed trait SupportLevel

/**
Expand Down Expand Up @@ -57,3 +62,101 @@ object SupportLevel {
case _ => false
}
}

/**
* The kind of support outcome produced by a [[SupportCondition]].
*
* The member names mirror the [[SupportLevel]] case classes on purpose so that a condition's
* `level` reads the same as the resulting `SupportLevel`. To disambiguate references in this file
* or in wildcard imports of the `serde` package, qualify as `SupportLevelKind.Compatible` etc.
*/
sealed trait SupportLevelKind
object SupportLevelKind {
case object Compatible extends SupportLevelKind
case object Incompatible extends SupportLevelKind
case object Unsupported extends SupportLevelKind
}

/**
* A single support condition: a predicate that, when matched against an expression, yields a
* [[SupportLevel]] with an optional message.
*
* Conditions are declared statically per serde so that they can be enumerated at build time for
* documentation and tests. They evaluate at runtime by calling `fires(expr)`.
*
* Ordering within a serde's `conditions` list is significant: the first condition whose `fires`
* predicate matches determines the outcome. If no condition matches, the expression is treated as
* `Compatible(None)`.
*/
trait SupportCondition[-T <: Expression] {

/** Stable, machine-readable id, unique per serde. Used in docs and tests. */
def id: String

/** Static prose describing when this fires. For example, "Child is BinaryType". */
def description: String

/** The outcome if this condition matches. */
def level: SupportLevelKind

/** Runtime predicate. May consult `CometConf` or other dynamic state. */
def fires(expr: T): Boolean

/** Runtime message, usually constant. May interpolate from the expression. */
def message(expr: T): String

/** Optional issue links for doc output. */
def issues: Seq[String] = Nil
}

object SupportCondition {

private final case class Impl[T <: Expression](
id: String,
description: String,
level: SupportLevelKind,
firesFn: T => Boolean,
messageFn: T => String,
override val issues: Seq[String])
extends SupportCondition[T] {
override def fires(expr: T): Boolean = firesFn(expr)
override def message(expr: T): String = messageFn(expr)
}

/** Generic builder. Use this when `message` depends on the expression. */
def apply[T <: Expression](
id: String,
description: String,
level: SupportLevelKind,
fires: T => Boolean,
message: T => String,
issues: Seq[String] = Nil): SupportCondition[T] =
Impl(id, description, level, fires, message, issues)

/** Convenience: unsupported with a static message. */
def unsupported[T <: Expression](
id: String,
description: String,
fires: T => Boolean,
message: String,
issues: Seq[String] = Nil): SupportCondition[T] =
Impl(id, description, SupportLevelKind.Unsupported, fires, (_: T) => message, issues)

/** Convenience: incompatible with a static message. */
def incompatible[T <: Expression](
id: String,
description: String,
fires: T => Boolean,
message: String,
issues: Seq[String] = Nil): SupportCondition[T] =
Impl(id, description, SupportLevelKind.Incompatible, fires, (_: T) => message, issues)

/** Convenience: compatible-with-note (a caveat on an otherwise supported path). */
def compatibleWithNote[T <: Expression](
id: String,
description: String,
fires: T => Boolean,
message: String,
issues: Seq[String] = Nil): SupportCondition[T] =
Impl(id, description, SupportLevelKind.Compatible, fires, (_: T) => message, issues)
}
30 changes: 18 additions & 12 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,27 @@ object CometSortArray extends CometExpressionSerde[SortArray] {
}
}

override def getSupportLevel(expr: SortArray): SupportLevel = {
val elementType = expr.base.dataType.asInstanceOf[ArrayType].elementType

if (!supportedSortArrayElementType(elementType)) {
Unsupported(Some(s"Sort on array element type $elementType is not supported"))
} else if (CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
SupportLevel.containsFloatingPoint(elementType)) {
Incompatible(
Some(
override val conditions: Seq[SupportCondition[SortArray]] = {
def elementType(e: SortArray): DataType =
e.base.dataType.asInstanceOf[ArrayType].elementType
Seq(
SupportCondition[SortArray](
id = "unsupported-element-type",
description = "Array element type is not supported for sorting",
level = SupportLevelKind.Unsupported,
fires = e => !supportedSortArrayElementType(elementType(e)),
message = e => s"Sort on array element type ${elementType(e)} is not supported"),
SupportCondition[SortArray](
id = "strict-floating-point",
description = "Strict floating-point mode is on and element type contains float/double",
level = SupportLevelKind.Incompatible,
fires = e =>
CometConf.COMET_EXEC_STRICT_FLOATING_POINT.get() &&
SupportLevel.containsFloatingPoint(elementType(e)),
message = _ =>
"Sorting on floating-point is not 100% compatible with Spark, and Comet is running " +
s"with ${CometConf.COMET_EXEC_STRICT_FLOATING_POINT.key}=true. " +
s"${CometConf.COMPAT_GUIDE}"))
} else {
Compatible()
}
}

override def convert(
Expand Down
10 changes: 6 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/strings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ object CometUpper extends CometCaseConversionBase[Upper]("upper")
object CometLower extends CometCaseConversionBase[Lower]("lower")

object CometLength extends CometScalarFunction[Length]("length") {
override def getSupportLevel(expr: Length): SupportLevel = expr.child.dataType match {
case _: BinaryType => Unsupported(Some("Length on BinaryType is not supported"))
case _ => Compatible()
}
override val conditions: Seq[SupportCondition[Length]] = Seq(
SupportCondition.unsupported[Length](
id = "binary-child",
description = "Child is BinaryType",
fires = _.child.dataType.isInstanceOf[BinaryType],
message = "Length on BinaryType is not supported"))
}

object CometInitCap extends CometScalarFunction[InitCap]("initcap") {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.serde

import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.sql.catalyst.expressions.{Attribute, Length, Literal}
import org.apache.spark.sql.types.{BinaryType, IntegerType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class SupportConditionSuite extends AnyFunSuite {

/** Serde with no conditions: default path yields Compatible(None). */
private object EmptySerde extends CometExpressionSerde[Literal] {
override def convert(
expr: Literal,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = None
}

/** Serde with three ordered conditions exercising each level. */
private object OrderedSerde extends CometExpressionSerde[Literal] {
override val conditions: Seq[SupportCondition[Literal]] = Seq(
SupportCondition.unsupported[Literal](
id = "null-literal",
description = "Literal value is null",
fires = _.value == null,
message = "null literal not supported"),
SupportCondition.incompatible[Literal](
id = "long-literal",
description = "Literal is LongType",
fires = _.dataType == LongType,
message = "Long literals are incompatible"),
SupportCondition.compatibleWithNote[Literal](
id = "string-with-note",
description = "Literal is StringType",
fires = _.dataType == StringType,
message = "string literal carries a note"))

override def convert(
expr: Literal,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = None
}

/** Serde with an expression-dependent message. */
private object DynamicMessageSerde extends CometExpressionSerde[Literal] {
override val conditions: Seq[SupportCondition[Literal]] = Seq(
SupportCondition[Literal](
id = "dynamic",
description = "Always fires, message includes data type",
level = SupportLevelKind.Unsupported,
fires = _ => true,
message = (e: Literal) => s"unsupported dtype ${e.dataType.simpleString}"))

override def convert(
expr: Literal,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = None
}

test("empty conditions returns Compatible(None)") {
val result = EmptySerde.getSupportLevel(Literal(1, IntegerType))
assert(result == Compatible(None))
}

test("no matching condition returns Compatible(None)") {
val result = OrderedSerde.getSupportLevel(Literal(1, IntegerType))
assert(result == Compatible(None))
}

test("first matching condition wins when multiple could fire") {
// null Long literal matches both "null-literal" (Unsupported) and
// "long-literal" (Incompatible). Unsupported must win because it is first.
val nullLong = Literal(null, LongType)
val result = OrderedSerde.getSupportLevel(nullLong)
assert(result == Unsupported(Some("null literal not supported")))
}

test("non-first matching condition produces its own message") {
val longLit = Literal(42L, LongType)
val result = OrderedSerde.getSupportLevel(longLit)
assert(result == Incompatible(Some("Long literals are incompatible")))
}

test("compatible-with-note produces Compatible(Some(msg))") {
val stringLit = Literal(UTF8String.fromString("hello"), StringType)
val result = OrderedSerde.getSupportLevel(stringLit)
assert(result == Compatible(Some("string literal carries a note")))
}

test("message can depend on the expression instance") {
val result =
DynamicMessageSerde.getSupportLevel(Literal(UTF8String.fromString("x"), StringType))
assert(result == Unsupported(Some("unsupported dtype string")))
}

test("CometLength: BinaryType child returns Unsupported") {
val expr = Length(Literal(Array[Byte](1, 2, 3), BinaryType))
val result = CometLength.getSupportLevel(expr)
assert(result == Unsupported(Some("Length on BinaryType is not supported")))
}

test("CometLength: StringType child returns Compatible") {
val expr = Length(Literal(UTF8String.fromString("hello"), StringType))
val result = CometLength.getSupportLevel(expr)
assert(result == Compatible(None))
}

test("CometLength: declared condition pins id, level, and description") {
val conditions = CometLength.conditions
assert(conditions.size == 1)
val c = conditions.head
assert(c.id == "binary-child")
assert(c.level == SupportLevelKind.Unsupported)
assert(c.description == "Child is BinaryType")
}

test("CometSortArray: unsupported element type returns Unsupported") {
import org.apache.spark.sql.catalyst.expressions.SortArray
import org.apache.spark.sql.types.{ArrayType, MapType}
val mapArray = Literal.create(null, ArrayType(MapType(StringType, StringType)))
val expr = SortArray(mapArray, Literal(true))
val result = CometSortArray.getSupportLevel(expr)
assert(result.isInstanceOf[Unsupported])
val msg = result.asInstanceOf[Unsupported].notes.getOrElse("")
assert(msg.startsWith("Sort on array element type"))
}

test("CometSortArray: non-float element type returns Compatible") {
import org.apache.spark.sql.catalyst.expressions.SortArray
import org.apache.spark.sql.types.{ArrayType, IntegerType}
val arr = Literal.create(null, ArrayType(IntegerType))
val expr = SortArray(arr, Literal(true))
val result = CometSortArray.getSupportLevel(expr)
assert(result == Compatible(None))
}

test("CometSortArray: declared conditions order is element-type then floating-point") {
val conditions = CometSortArray.conditions
assert(conditions.size == 2)
assert(conditions(0).id == "unsupported-element-type")
assert(conditions(0).level == SupportLevelKind.Unsupported)
assert(conditions(1).id == "strict-floating-point")
assert(conditions(1).level == SupportLevelKind.Incompatible)
}
}
Loading