Skip to content

Commit c917ac8

Browse files
authored
Merge branch 'main' into df54
2 parents ec18066 + 29557ca commit c917ac8

File tree

13 files changed

+372
-77
lines changed

13 files changed

+372
-77
lines changed

docs/source/user-guide/latest/compatibility.md

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on
7676
timezone is UTC.
7777
[#2649](https://github.com/apache/datafusion-comet/issues/2649)
7878

79-
### Aggregate Expressions
80-
81-
- **Corr**: Returns null instead of NaN in some edge cases.
82-
[#2646](https://github.com/apache/datafusion-comet/issues/2646)
83-
8479
### Struct Expressions
8580

8681
- **StructsToJson (to_json)**: Does not support `+Infinity` and `-Infinity` for numeric types (float, double).

docs/source/user-guide/latest/expressions.md

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -195,27 +195,27 @@ Expressions that are not Spark-compatible will fall back to Spark by default and
195195

196196
## Aggregate Expressions
197197

198-
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
199-
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------------------------------------------------------- |
200-
| Average | | Yes, except for ANSI mode | |
201-
| BitAndAgg | | Yes | |
202-
| BitOrAgg | | Yes | |
203-
| BitXorAgg | | Yes | |
204-
| BoolAnd | `bool_and` | Yes | |
205-
| BoolOr | `bool_or` | Yes | |
206-
| Corr | | No | Returns null instead of NaN in some edge cases ([#2646](https://github.com/apache/datafusion-comet/issues/2646)) |
207-
| Count | | Yes | |
208-
| CovPopulation | | Yes | |
209-
| CovSample | | Yes | |
210-
| First | | No | This function is not deterministic. Results may not match Spark. |
211-
| Last | | No | This function is not deterministic. Results may not match Spark. |
212-
| Max | | Yes | |
213-
| Min | | Yes | |
214-
| StddevPop | | Yes | |
215-
| StddevSamp | | Yes | |
216-
| Sum | | Yes, except for ANSI mode | |
217-
| VariancePop | | Yes | |
218-
| VarianceSamp | | Yes | |
198+
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
199+
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------- |
200+
| Average | | Yes, except for ANSI mode | |
201+
| BitAndAgg | | Yes | |
202+
| BitOrAgg | | Yes | |
203+
| BitXorAgg | | Yes | |
204+
| BoolAnd | `bool_and` | Yes | |
205+
| BoolOr | `bool_or` | Yes | |
206+
| Corr | | Yes | |
207+
| Count | | Yes | |
208+
| CovPopulation | | Yes | |
209+
| CovSample | | Yes | |
210+
| First | | No | This function is not deterministic. Results may not match Spark. |
211+
| Last | | No | This function is not deterministic. Results may not match Spark. |
212+
| Max | | Yes | |
213+
| Min | | Yes | |
214+
| StddevPop | | Yes | |
215+
| StddevSamp | | Yes | |
216+
| Sum | | Yes, except for ANSI mode | |
217+
| VariancePop | | Yes | |
218+
| VarianceSamp | | Yes | |
219219

220220
## Window Functions
221221

docs/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@
296296
- [x] atan
297297
- [x] atan2
298298
- [ ] atanh
299-
- [ ] bin
299+
- [x] bin
300300
- [ ] bround
301301
- [ ] cbrt
302302
- [x] ceil

native/core/src/execution/jni_api.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
564564
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSpace::default()));
565565
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitCount::default()));
566566
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkArrayContains::default()));
567+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBin::default()));
567568
}
568569

569570
/// Prepares arrow arrays for output.
@@ -1134,6 +1135,7 @@ pub extern "system" fn Java_org_apache_comet_Native_getRustThreadId(
11341135

11351136
use crate::execution::columnar_to_row::ColumnarToRowContext;
11361137
use arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema};
1138+
use datafusion_spark::function::math::bin::SparkBin;
11371139

11381140
/// Initialize a native columnar to row converter.
11391141
///

native/spark-expr/src/agg_funcs/correlation.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,19 +216,22 @@ impl Accumulator for CorrelationAccumulator {
216216
let stddev1 = self.stddev1.evaluate()?;
217217
let stddev2 = self.stddev2.evaluate()?;
218218

219+
if self.covar.get_count() == 0.0 {
220+
return Ok(ScalarValue::Float64(None));
221+
} else if self.covar.get_count() == 1.0 {
222+
if self.null_on_divide_by_zero {
223+
return Ok(ScalarValue::Float64(None));
224+
} else {
225+
return Ok(ScalarValue::Float64(Some(f64::NAN)));
226+
}
227+
}
219228
match (covar, stddev1, stddev2) {
220229
(
221230
ScalarValue::Float64(Some(c)),
222231
ScalarValue::Float64(Some(s1)),
223232
ScalarValue::Float64(Some(s2)),
224233
) if s1 != 0.0 && s2 != 0.0 => Ok(ScalarValue::Float64(Some(c / (s1 * s2)))),
225-
_ if self.null_on_divide_by_zero => Ok(ScalarValue::Float64(None)),
226-
_ => {
227-
if self.covar.get_count() == 1.0 {
228-
return Ok(ScalarValue::Float64(Some(f64::NAN)));
229-
}
230-
Ok(ScalarValue::Float64(None))
231-
}
234+
_ => Ok(ScalarValue::Float64(None)),
232235
}
233236
}
234237

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ use crate::{cast_whole_num_to_binary, BinaryOutputStyle};
4141
use crate::{EvalMode, SparkError};
4242
use arrow::array::builder::StringBuilder;
4343
use arrow::array::{
44-
BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray,
44+
new_null_array, BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray,
45+
StringArray, StructArray,
4546
};
46-
use arrow::compute::can_cast_types;
4747
use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
4848
use arrow::datatypes::{Field, Fields, GenericBinaryType};
4949
use arrow::error::ArrowError;
@@ -311,6 +311,9 @@ pub(crate) fn cast_array(
311311
};
312312

313313
let cast_result = match (&from_type, to_type) {
314+
// Null arrays carry no concrete values, so Arrow's native cast can change only the
315+
// logical type while preserving length and nullness.
316+
(Null, _) => Ok(cast_with_options(&array, to_type, &native_cast_options)?),
314317
(Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
315318
(LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
316319
(Utf8, Timestamp(_, _)) => cast_string_to_timestamp(
@@ -387,8 +390,25 @@ pub(crate) fn cast_array(
387390
cast_options,
388391
)?),
389392
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
390-
(List(_), List(_)) if can_cast_types(&from_type, to_type) => {
391-
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
393+
(List(_), List(to)) => {
394+
// Cast list elements recursively so nested array casts follow Spark semantics
395+
// instead of relying on Arrow's top-level cast support.
396+
let list_array = array.as_list::<i32>();
397+
let casted_values = match (list_array.values().data_type(), to.data_type()) {
398+
// Spark legacy array casts produce null elements for array<Date> -> array<Int>.
399+
(Date32, Int32) => new_null_array(to.data_type(), list_array.values().len()),
400+
_ => cast_array(
401+
Arc::clone(list_array.values()),
402+
to.data_type(),
403+
cast_options,
404+
)?,
405+
};
406+
Ok(Arc::new(ListArray::new(
407+
Arc::clone(to),
408+
list_array.offsets().clone(),
409+
casted_values,
410+
list_array.nulls().cloned(),
411+
)) as ArrayRef)
392412
}
393413
(Map(_, _), Map(_, _)) => Ok(cast_map_to_map(&array, &from_type, to_type, cast_options)?),
394414
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
@@ -820,7 +840,8 @@ fn cast_binary_formatter(value: &[u8]) -> String {
820840
#[cfg(test)]
821841
mod tests {
822842
use super::*;
823-
use arrow::array::StringArray;
843+
use arrow::array::{ListArray, NullArray, StringArray};
844+
use arrow::buffer::OffsetBuffer;
824845
use arrow::datatypes::TimestampMicrosecondType;
825846
use arrow::datatypes::{Field, Fields};
826847
#[test]
@@ -946,8 +967,6 @@ mod tests {
946967

947968
#[test]
948969
fn test_cast_string_array_to_string() {
949-
use arrow::array::ListArray;
950-
use arrow::buffer::OffsetBuffer;
951970
let values_array =
952971
StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("a"), None, None]);
953972
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
@@ -972,8 +991,6 @@ mod tests {
972991

973992
#[test]
974993
fn test_cast_i32_array_to_string() {
975-
use arrow::array::ListArray;
976-
use arrow::buffer::OffsetBuffer;
977994
let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]);
978995
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
979996
let item_field = Arc::new(Field::new("item", DataType::Int32, true));
@@ -994,4 +1011,33 @@ mod tests {
9941011
assert_eq!(r#"[null]"#, string_array.value(2));
9951012
assert_eq!(r#"[]"#, string_array.value(3));
9961013
}
1014+
1015+
#[test]
1016+
fn test_cast_array_of_nulls_to_array() {
1017+
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 2, 3, 3].into());
1018+
let from_item_field = Arc::new(Field::new("item", DataType::Null, true));
1019+
let from_array: ArrayRef = Arc::new(ListArray::new(
1020+
from_item_field,
1021+
offsets_buffer,
1022+
Arc::new(NullArray::new(3)),
1023+
None,
1024+
));
1025+
1026+
let to_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1027+
let to_array = cast_array(
1028+
from_array,
1029+
&to_type,
1030+
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
1031+
)
1032+
.unwrap();
1033+
1034+
let result = to_array.as_list::<i32>();
1035+
assert_eq!(3, result.len());
1036+
assert_eq!(result.value_offsets(), &[0, 2, 3, 3]);
1037+
1038+
let values = result.values().as_primitive::<Int32Type>();
1039+
assert_eq!(3, values.len());
1040+
assert_eq!(3, values.null_count());
1041+
assert!(values.iter().all(|value| value.is_none()));
1042+
}
9971043
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
142142

143143
(fromType, toType) match {
144144
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
145+
case (ArrayType(DataTypes.DateType, _), ArrayType(toElementType, _))
146+
if toElementType != DataTypes.IntegerType && toElementType != DataTypes.StringType =>
147+
unsupported(fromType, toType)
145148
case (dt: ArrayType, DataTypes.StringType) if dt.elementType == DataTypes.BinaryType =>
146149
Incompatible()
147150
case (dt: ArrayType, DataTypes.StringType) =>

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
122122
classOf[Cot] -> CometScalarFunction("cot"),
123123
classOf[UnaryMinus] -> CometUnaryMinus,
124124
classOf[Unhex] -> CometUnhex,
125-
classOf[Abs] -> CometAbs)
125+
classOf[Abs] -> CometAbs,
126+
classOf[Bin] -> CometScalarFunction("bin"))
126127

127128
private val mapExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
128129
classOf[GetMapValue] -> CometMapExtract,

spark/src/main/scala/org/apache/comet/serde/aggregates.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -584,13 +584,6 @@ object CometStddevPop extends CometAggregateExpressionSerde[StddevPop] with Come
584584
}
585585

586586
object CometCorr extends CometAggregateExpressionSerde[Corr] {
587-
588-
override def getSupportLevel(expr: Corr): SupportLevel =
589-
Incompatible(
590-
Some(
591-
"Returns null instead of NaN in some edge cases" +
592-
" (https://github.com/apache/datafusion-comet/issues/2646)"))
593-
594587
override def convert(
595588
aggExpr: AggregateExpression,
596589
corr: Corr,

spark/src/test/resources/sql-tests/expressions/aggregate/corr.sql

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
-- specific language governing permissions and limitations
1616
-- under the License.
1717

18-
-- Config: spark.comet.expression.Corr.allowIncompatible=true
1918

2019
statement
2120
CREATE TABLE test_corr(x double, y double, grp string) USING parquet
@@ -28,3 +27,13 @@ SELECT corr(x, y) FROM test_corr
2827

2928
query tolerance=1e-6
3029
SELECT grp, corr(x, y) FROM test_corr GROUP BY grp ORDER BY grp
30+
31+
-- Test permutations of NULL and NaN
32+
statement
33+
CREATE TABLE test_corr_nan(x double, y double, grp string) USING parquet
34+
35+
statement
36+
INSERT INTO test_corr_nan VALUES (cast('NaN' as double), cast('NaN' as double), 'both_nan'), (cast('NaN' as double), 1.0, 'nan_val'), (1.0, cast('NaN' as double), 'val_nan'), (NULL, cast('NaN' as double), 'null_nan'), (cast('NaN' as double), NULL, 'nan_null'), (NULL, NULL, 'both_null'), (NULL, 1.0, 'null_val'), (1.0, NULL, 'val_null'), (cast('NaN' as double), cast('NaN' as double), 'mixed'), (1.0, 2.0, 'mixed'), (3.0, 4.0, 'mixed'), (cast('NaN' as double), cast('NaN' as double), 'multi_nan'), (cast('NaN' as double), cast('NaN' as double), 'multi_nan'), (cast('NaN' as double), cast('NaN' as double), 'multi_nan')
37+
38+
query tolerance=1e-6
39+
SELECT grp, corr(x, y) FROM test_corr_nan GROUP BY grp ORDER BY grp

0 commit comments

Comments
 (0)