Skip to content

Commit 88c1ffc

Browse files
fix: make tan and atan2 compatible (#3849)
## Which issue does this PR close? Closes: #1897 ## Rationale for this change #1897 claims `tan` is incompatible; however, what is really incompatible is `atan2` that is failing in the same test in #1896 The correct results are ``` atan2(-0.0, -0.0) = -π <= Comet answer atan2(-0.0, +0.0) = -0.0 atan2(+0.0, -0.0) = +π atan2(+0.0, +0.0) = +0.0 <= Spark answer ``` Looks like Spark adds `+0.0` to inputs in order to convert `atan2(-0.0, -0.0)` to ` atan2(+0.0, +0.0)` ## What changes are included in this PR? Fixed `atan2` and enabled `tan` ## How are these changes tested?
1 parent 204d7e4 commit 88c1ffc

7 files changed

Lines changed: 102 additions & 113 deletions

File tree

docs/source/user-guide/latest/compatibility.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on
7676
timezone is UTC.
7777
[#2649](https://github.com/apache/datafusion-comet/issues/2649)
7878

79-
### Math Expressions
80-
81-
- **Tan**: `tan(-0.0)` produces `0.0` instead of `-0.0`.
82-
[#1897](https://github.com/apache/datafusion-comet/issues/1897)
83-
8479
### Aggregate Expressions
8580

8681
- **Corr**: Returns null instead of NaN in some edge cases.

docs/source/user-guide/latest/expressions.md

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -127,48 +127,48 @@ Expressions that are not Spark-compatible will fall back to Spark by default and
127127

128128
## Math Expressions
129129

130-
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
131-
| -------------- | --------- | ----------------- | ----------------------------------------------------------------------------------------------------- |
132-
| Abs | `abs` | Yes | |
133-
| Acos | `acos` | Yes | |
134-
| Add | `+` | Yes | |
135-
| Asin | `asin` | Yes | |
136-
| Atan | `atan` | Yes | |
137-
| Atan2 | `atan2` | Yes | |
138-
| BRound | `bround` | Yes | |
139-
| Ceil | `ceil` | Yes | |
140-
| Cos | `cos` | Yes | |
141-
| Cosh | `cosh` | Yes | |
142-
| Cot | `cot` | Yes | |
143-
| Divide | `/` | Yes | |
144-
| Exp | `exp` | Yes | |
145-
| Expm1 | `expm1` | Yes | |
146-
| Floor | `floor` | Yes | |
147-
| Hex | `hex` | Yes | |
148-
| IntegralDivide | `div` | Yes | |
149-
| IsNaN | `isnan` | Yes | |
150-
| Log | `log` | Yes | |
151-
| Log2 | `log2` | Yes | |
152-
| Log10 | `log10` | Yes | |
153-
| Multiply | `*` | Yes | |
154-
| Pow | `power` | Yes | |
155-
| Rand | `rand` | Yes | |
156-
| Randn | `randn` | Yes | |
157-
| Remainder | `%` | Yes | |
158-
| Round | `round` | Yes | |
159-
| Signum | `signum` | Yes | |
160-
| Sin | `sin` | Yes | |
161-
| Sinh | `sinh` | Yes | |
162-
| Sqrt | `sqrt` | Yes | |
163-
| Subtract | `-` | Yes | |
164-
| Tan | `tan` | No | tan(-0.0) produces incorrect result ([#1897](https://github.com/apache/datafusion-comet/issues/1897)) |
165-
| Tanh | `tanh` | Yes | |
166-
| TryAdd | `try_add` | Yes | Only integer inputs are supported |
167-
| TryDivide | `try_div` | Yes | Only integer inputs are supported |
168-
| TryMultiply | `try_mul` | Yes | Only integer inputs are supported |
169-
| TrySubtract | `try_sub` | Yes | Only integer inputs are supported |
170-
| UnaryMinus | `-` | Yes | |
171-
| Unhex | `unhex` | Yes | |
130+
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
131+
| -------------- | --------- | ----------------- | --------------------------------- |
132+
| Abs | `abs` | Yes | |
133+
| Acos | `acos` | Yes | |
134+
| Add | `+` | Yes | |
135+
| Asin | `asin` | Yes | |
136+
| Atan | `atan` | Yes | |
137+
| Atan2 | `atan2` | Yes | |
138+
| BRound | `bround` | Yes | |
139+
| Ceil | `ceil` | Yes | |
140+
| Cos | `cos` | Yes | |
141+
| Cosh | `cosh` | Yes | |
142+
| Cot | `cot` | Yes | |
143+
| Divide | `/` | Yes | |
144+
| Exp | `exp` | Yes | |
145+
| Expm1 | `expm1` | Yes | |
146+
| Floor | `floor` | Yes | |
147+
| Hex | `hex` | Yes | |
148+
| IntegralDivide | `div` | Yes | |
149+
| IsNaN | `isnan` | Yes | |
150+
| Log | `log` | Yes | |
151+
| Log2 | `log2` | Yes | |
152+
| Log10 | `log10` | Yes | |
153+
| Multiply | `*` | Yes | |
154+
| Pow | `power` | Yes | |
155+
| Rand | `rand` | Yes | |
156+
| Randn | `randn` | Yes | |
157+
| Remainder | `%` | Yes | |
158+
| Round | `round` | Yes | |
159+
| Signum | `signum` | Yes | |
160+
| Sin | `sin` | Yes | |
161+
| Sinh | `sinh` | Yes | |
162+
| Sqrt | `sqrt` | Yes | |
163+
| Subtract | `-` | Yes | |
164+
| Tan | `tan` | Yes | |
165+
| Tanh | `tanh` | Yes | |
166+
| TryAdd | `try_add` | Yes | Only integer inputs are supported |
167+
| TryDivide | `try_div` | Yes | Only integer inputs are supported |
168+
| TryMultiply | `try_mul` | Yes | Only integer inputs are supported |
169+
| TrySubtract | `try_sub` | Yes | Only integer inputs are supported |
170+
| UnaryMinus | `-` | Yes | |
171+
| Unhex | `unhex` | Yes | |
172172

173173
## Hashing Functions
174174

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
117117
classOf[Sinh] -> CometScalarFunction("sinh"),
118118
classOf[Sqrt] -> CometScalarFunction("sqrt"),
119119
classOf[Subtract] -> CometSubtract,
120-
classOf[Tan] -> CometTan,
120+
classOf[Tan] -> CometScalarFunction("tan"),
121121
classOf[Tanh] -> CometScalarFunction("tanh"),
122122
classOf[Cot] -> CometScalarFunction("cot"),
123123
classOf[UnaryMinus] -> CometUnaryMinus,

spark/src/main/scala/org/apache/comet/serde/math.scala

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.comet.serde
2121

22-
import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Logarithm, Tan, Unhex}
22+
import org.apache.spark.sql.catalyst.expressions.{Abs, Add, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Logarithm, Unhex}
2323
import org.apache.spark.sql.types.{DecimalType, DoubleType, NumericType}
2424

2525
import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -30,8 +30,11 @@ object CometAtan2 extends CometExpressionSerde[Atan2] {
3030
expr: Atan2,
3131
inputs: Seq[Attribute],
3232
binding: Boolean): Option[ExprOuterClass.Expr] = {
33-
val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
34-
val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
33+
// Spark adds +0.0 to inputs in order to convert -0.0 to +0.0
34+
val left = Add(expr.left, Literal.default(expr.left.dataType))
35+
val right = Add(expr.right, Literal.default(expr.right.dataType))
36+
val leftExpr = exprToProtoInternal(left, inputs, binding)
37+
val rightExpr = exprToProtoInternal(right, inputs, binding)
3538
val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr)
3639
optExprWithInfo(optExpr, expr, expr.left, expr.right)
3740
}
@@ -189,24 +192,6 @@ object CometAbs extends CometExpressionSerde[Abs] with MathExprBase {
189192
}
190193
}
191194

192-
object CometTan extends CometExpressionSerde[Tan] {
193-
194-
override def getSupportLevel(expr: Tan): SupportLevel =
195-
Incompatible(
196-
Some(
197-
"tan(-0.0) produces incorrect result" +
198-
" (https://github.com/apache/datafusion-comet/issues/1897)"))
199-
200-
override def convert(
201-
expr: Tan,
202-
inputs: Seq[Attribute],
203-
binding: Boolean): Option[ExprOuterClass.Expr] = {
204-
val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding))
205-
val optExpr = scalarFunctionExprToProto("tan", childExpr: _*)
206-
optExprWithInfo(optExpr, expr, expr.children: _*)
207-
}
208-
}
209-
210195
sealed trait MathExprBase {
211196
protected def nullIfNegative(expression: Expression): Expression = {
212197
val zero = Literal.default(expression.dataType)

spark/src/test/resources/sql-tests/expressions/math/atan2.sql

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@ statement
2121
CREATE TABLE test_atan2(y double, x double) USING parquet
2222

2323
statement
24-
INSERT INTO test_atan2 VALUES (0.0, 1.0), (1.0, 0.0), (1.0, 1.0), (-1.0, -1.0), (0.0, 0.0), (NULL, 1.0), (1.0, NULL), (cast('NaN' as double), 1.0), (cast('Infinity' as double), 1.0)
24+
INSERT INTO test_atan2 VALUES
25+
(0.0, 0.0), (0.0, -0.0), (0.0, 1.0), (0.0, -1.0), (0.0, NULL), (0.0, cast('NaN' as double)), (0.0, cast('Infinity' as double)), (0.0, cast('-Infinity' as double)),
26+
(-0.0, 0.0), (-0.0, -0.0), (-0.0, 1.0), (-0.0, -1.0), (-0.0, NULL), (-0.0, cast('NaN' as double)), (-0.0, cast('Infinity' as double)), (-0.0, cast('-Infinity' as double)),
27+
(1.0, 0.0), (1.0, -0.0), (1.0, 1.0), (1.0, -1.0), (1.0, NULL), (1.0, cast('NaN' as double)), (1.0, cast('Infinity' as double)), (1.0, cast('-Infinity' as double)),
28+
(-1.0, 0.0), (-1.0, -0.0), (-1.0, 1.0), (-1.0, -1.0), (-1.0, NULL), (-1.0, cast('NaN' as double)), (-1.0, cast('Infinity' as double)), (-1.0, cast('-Infinity' as double)),
29+
(NULL, 0.0), (NULL, -0.0), (NULL, 1.0), (NULL, -1.0), (NULL, NULL), (NULL, cast('NaN' as double)), (NULL, cast('Infinity' as double)), (NULL, cast('-Infinity' as double)),
30+
(cast('NaN' as double), 0.0), (cast('NaN' as double), -0.0), (cast('NaN' as double), 1.0), (cast('NaN' as double), -1.0), (cast('NaN' as double), NULL), (cast('NaN' as double), cast('NaN' as double)), (cast('NaN' as double), cast('Infinity' as double)), (cast('NaN' as double), cast('-Infinity' as double)),
31+
(cast('Infinity' as double), 0.0), (cast('Infinity' as double), -0.0), (cast('Infinity' as double), 1.0), (cast('Infinity' as double), -1.0), (cast('Infinity' as double), NULL), (cast('Infinity' as double), cast('NaN' as double)), (cast('Infinity' as double), cast('Infinity' as double)), (cast('Infinity' as double), cast('-Infinity' as double)),
32+
(cast('-Infinity' as double), 0.0), (cast('-Infinity' as double), -0.0), (cast('-Infinity' as double), 1.0), (cast('-Infinity' as double), -1.0), (cast('-Infinity' as double), NULL), (cast('-Infinity' as double), cast('NaN' as double)), (cast('-Infinity' as double), cast('Infinity' as double)), (cast('-Infinity' as double), cast('-Infinity' as double))
2533

2634
query tolerance=1e-6
2735
SELECT atan2(y, x) FROM test_atan2
@@ -34,6 +42,9 @@ SELECT atan2(y, 1.0) FROM test_atan2
3442
query tolerance=1e-6
3543
SELECT atan2(1.0, x) FROM test_atan2
3644

37-
-- literal + literal
45+
-- literal permutations
3846
query tolerance=1e-6
39-
SELECT atan2(1.0, 1.0), atan2(0.0, 0.0), atan2(-1.0, -1.0), atan2(NULL, 1.0)
47+
SELECT atan2(0.0, 0.0), atan2(0.0, -0.0), atan2(0.0, 1.0), atan2(0.0, -1.0),
48+
atan2(-0.0, 0.0), atan2(-0.0, -0.0), atan2(-0.0, 1.0), atan2(-0.0, -1.0),
49+
atan2(1.0, 0.0), atan2(1.0, -0.0), atan2(1.0, 1.0), atan2(1.0, -1.0),
50+
atan2(-1.0, 0.0), atan2(-1.0, -0.0), atan2(-1.0, 1.0), atan2(-1.0, -1.0)

spark/src/test/resources/sql-tests/expressions/math/tan.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ statement
2222
CREATE TABLE test_tan(d double) USING parquet
2323

2424
statement
25-
INSERT INTO test_tan VALUES (0.0), (0.7853981633974483), (-0.7853981633974483), (1.0), (NULL), (cast('NaN' as double)), (cast('Infinity' as double))
25+
INSERT INTO test_tan VALUES (0.0), (-0.0), (0.7853981633974483), (-0.7853981633974483), (1.0), (NULL), (cast('NaN' as double)), (cast('Infinity' as double))
2626

2727
query tolerance=1e-6
2828
SELECT tan(d) FROM test_tan
2929

3030
-- literal arguments
3131
query tolerance=1e-6
32-
SELECT tan(0.0), tan(0.7853981633974483), tan(NULL)
32+
SELECT tan(0.0), tan(-0.0), tan(0.7853981633974483), tan(NULL)

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.scalatest.Tag
2828

2929
import org.apache.hadoop.fs.Path
3030
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
31-
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, StructsToJson, Tan, TruncDate, TruncTimestamp}
31+
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, StructsToJson, TruncDate, TruncTimestamp}
3232
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
3333
import org.apache.spark.sql.comet.CometProjectExec
3434
import org.apache.spark.sql.execution.{ProjectExec, SparkPlan}
@@ -1362,8 +1362,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13621362

13631363
private val doubleValues: Seq[Double] = Seq(
13641364
-1.0,
1365-
// TODO we should eventually enable negative zero but there are known issues still
1366-
// -0.0,
1365+
-0.0,
13671366
0.0,
13681367
+1.0,
13691368
Double.MinValue,
@@ -1374,42 +1373,41 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13741373
Double.NegativeInfinity)
13751374

13761375
test("various math scalar functions") {
1377-
val data = doubleValues.map(n => (n, n))
1378-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Tan]) -> "true") {
1379-
withParquetTable(data, "tbl") {
1380-
// expressions with single arg
1381-
for (expr <- Seq(
1382-
"acos",
1383-
"asin",
1384-
"atan",
1385-
"cos",
1386-
"cosh",
1387-
"exp",
1388-
"ln",
1389-
"log10",
1390-
"log2",
1391-
"sin",
1392-
"sinh",
1393-
"sqrt",
1394-
"tan",
1395-
"tanh",
1396-
"cot")) {
1397-
val (_, cometPlan) =
1398-
checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1), $expr(_2) FROM tbl"))
1399-
val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec =>
1400-
op
1401-
}
1402-
assert(cometProjectExecs.length == 1, expr)
1403-
}
1404-
// expressions with two args
1405-
for (expr <- Seq("atan2", "pow")) {
1406-
val (_, cometPlan) =
1407-
checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1, _2) FROM tbl"))
1408-
val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec =>
1409-
op
1410-
}
1411-
assert(cometProjectExecs.length == 1, expr)
1412-
}
1376+
withParquetTable(doubleValues.map(n => (n, n)), "tbl") {
1377+
// expressions with single arg
1378+
for (expr <- Seq(
1379+
"acos",
1380+
"asin",
1381+
"atan",
1382+
"cos",
1383+
"cosh",
1384+
"exp",
1385+
"ln",
1386+
"log10",
1387+
"log2",
1388+
"sin",
1389+
"sinh",
1390+
"sqrt",
1391+
"tan",
1392+
"tanh",
1393+
"cot")) {
1394+
val (_, cometPlan) =
1395+
checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1), $expr(_2) FROM tbl"))
1396+
val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec =>
1397+
op
1398+
}
1399+
assert(cometProjectExecs.length == 1, expr)
1400+
}
1401+
}
1402+
withParquetTable(doubleValues.flatMap(m => doubleValues.map(n => (m, n))), "tbl") {
1403+
// expressions with two args
1404+
for (expr <- Seq("atan2", "pow")) {
1405+
val (_, cometPlan) =
1406+
checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1, _2) FROM tbl"))
1407+
val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec =>
1408+
op
1409+
}
1410+
assert(cometProjectExecs.length == 1, expr)
14131411
}
14141412
}
14151413
}

0 commit comments

Comments
 (0)