Skip to content

Commit a951da9

Browse files
authored
feat: Support casting string float types (#2835)
1 parent f232887 commit a951da9

4 files changed

Lines changed: 125 additions & 46 deletions

File tree

docs/source/user-guide/latest/compatibility.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ The following cast operations are generally compatible with Spark except for the
158158
| string | short | |
159159
| string | integer | |
160160
| string | long | |
161+
| string | float | |
162+
| string | double | |
161163
| string | binary | |
162164
| string | date | Only supports years between 262143 BC and 262142 AD |
163165
| binary | string | |
@@ -180,8 +182,6 @@ The following cast operations are not compatible with Spark for all inputs and a
180182
|-|-|-|
181183
| float | decimal | There can be rounding differences |
182184
| double | decimal | There can be rounding differences |
183-
| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
184-
| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
185185
| string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10)
186186
or strings containing null bytes (e.g \\u0000) |
187187
| string | timestamp | Not all valid formats are supported |

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 85 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ use arrow::{
4545
record_batch::RecordBatch,
4646
util::display::FormatOptions,
4747
};
48+
use base64::prelude::*;
4849
use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
4950
use datafusion::common::{
5051
cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult,
@@ -66,8 +67,6 @@ use std::{
6667
sync::Arc,
6768
};
6869

69-
use base64::prelude::*;
70-
7170
static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
7271

7372
const MICROS_PER_SECOND: i64 = 1000000;
@@ -217,12 +216,7 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool
217216
use DataType::*;
218217
match to_type {
219218
Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
220-
Float32 | Float64 => {
221-
// https://github.com/apache/datafusion-comet/issues/326
222-
// Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
223-
// Does not support ANSI mode.
224-
options.allow_incompat
225-
}
219+
Float32 | Float64 => true,
226220
Decimal128(_, _) => {
227221
// https://github.com/apache/datafusion-comet/issues/325
228222
// Does not support fullwidth digits and null byte handling.
@@ -975,6 +969,7 @@ fn cast_array(
975969
cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
976970
}
977971
(Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
972+
(Utf8, Float32 | Float64) => cast_string_to_float(&array, to_type, eval_mode),
978973
(Utf8 | LargeUtf8, Decimal128(precision, scale)) => {
979974
cast_string_to_decimal(&array, to_type, precision, scale, eval_mode)
980975
}
@@ -1046,7 +1041,7 @@ fn cast_array(
10461041
}
10471042
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
10481043
_ if cast_options.is_adapting_schema
1049-
|| is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) =>
1044+
|| is_datafusion_spark_compatible(from_type, to_type) =>
10501045
{
10511046
// use DataFusion cast only when we know that it is compatible with Spark
10521047
Ok(cast_with_options(&array, to_type, &native_cast_options)?)
@@ -1063,6 +1058,86 @@ fn cast_array(
10631058
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
10641059
}
10651060

1061+
fn cast_string_to_float(
1062+
array: &ArrayRef,
1063+
to_type: &DataType,
1064+
eval_mode: EvalMode,
1065+
) -> SparkResult<ArrayRef> {
1066+
match to_type {
1067+
DataType::Float32 => cast_string_to_float_impl::<Float32Type>(array, eval_mode, "FLOAT"),
1068+
DataType::Float64 => cast_string_to_float_impl::<Float64Type>(array, eval_mode, "DOUBLE"),
1069+
_ => Err(SparkError::Internal(format!(
1070+
"Unsupported cast to float type: {:?}",
1071+
to_type
1072+
))),
1073+
}
1074+
}
1075+
1076+
fn cast_string_to_float_impl<T: ArrowPrimitiveType>(
1077+
array: &ArrayRef,
1078+
eval_mode: EvalMode,
1079+
type_name: &str,
1080+
) -> SparkResult<ArrayRef>
1081+
where
1082+
T::Native: FromStr + num::Float,
1083+
{
1084+
let arr = array
1085+
.as_any()
1086+
.downcast_ref::<StringArray>()
1087+
.ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?;
1088+
1089+
let mut builder = PrimitiveBuilder::<T>::with_capacity(arr.len());
1090+
1091+
for i in 0..arr.len() {
1092+
if arr.is_null(i) {
1093+
builder.append_null();
1094+
} else {
1095+
let str_value = arr.value(i).trim();
1096+
match parse_string_to_float(str_value) {
1097+
Some(v) => builder.append_value(v),
1098+
None => {
1099+
if eval_mode == EvalMode::Ansi {
1100+
return Err(invalid_value(arr.value(i), "STRING", type_name));
1101+
}
1102+
builder.append_null();
1103+
}
1104+
}
1105+
}
1106+
}
1107+
1108+
Ok(Arc::new(builder.finish()))
1109+
}
1110+
1111+
/// helper to parse floats from string inputs
1112+
fn parse_string_to_float<F>(s: &str) -> Option<F>
1113+
where
1114+
F: FromStr + num::Float,
1115+
{
1116+
// Handle +inf / -inf
1117+
if s.eq_ignore_ascii_case("inf")
1118+
|| s.eq_ignore_ascii_case("+inf")
1119+
|| s.eq_ignore_ascii_case("infinity")
1120+
|| s.eq_ignore_ascii_case("+infinity")
1121+
{
1122+
return Some(F::infinity());
1123+
}
1124+
if s.eq_ignore_ascii_case("-inf") || s.eq_ignore_ascii_case("-infinity") {
1125+
return Some(F::neg_infinity());
1126+
}
1127+
if s.eq_ignore_ascii_case("nan") {
1128+
return Some(F::nan());
1129+
}
1130+
// Remove D/F suffix if present
1131+
let pruned_float_str =
1132+
if s.ends_with("d") || s.ends_with("D") || s.ends_with('f') || s.ends_with('F') {
1133+
&s[..s.len() - 1]
1134+
} else {
1135+
s
1136+
};
1137+
// Rust's parse logic already handles scientific notations so we just rely on it
1138+
pruned_float_str.parse::<F>().ok()
1139+
}
1140+
10661141
fn cast_binary_to_string<O: OffsetSizeTrait>(
10671142
array: &dyn Array,
10681143
spark_cast_options: &SparkCastOptions,
@@ -1133,11 +1208,7 @@ fn cast_binary_formatter(value: &[u8]) -> String {
11331208

11341209
/// Determines if DataFusion supports the given cast in a way that is
11351210
/// compatible with Spark
1136-
fn is_datafusion_spark_compatible(
1137-
from_type: &DataType,
1138-
to_type: &DataType,
1139-
allow_incompat: bool,
1140-
) -> bool {
1211+
fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool {
11411212
if from_type == to_type {
11421213
return true;
11431214
}
@@ -1190,10 +1261,6 @@ fn is_datafusion_spark_compatible(
11901261
| DataType::Decimal256(_, _)
11911262
| DataType::Utf8 // note that there can be formatting differences
11921263
),
1193-
DataType::Utf8 if allow_incompat => matches!(
1194-
to_type,
1195-
DataType::Binary | DataType::Float32 | DataType::Float64
1196-
),
11971264
DataType::Utf8 => matches!(to_type, DataType::Binary),
11981265
DataType::Date32 => matches!(to_type, DataType::Utf8),
11991266
DataType::Timestamp(_, _) => {

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
185185
case DataTypes.BinaryType =>
186186
Compatible()
187187
case DataTypes.FloatType | DataTypes.DoubleType =>
188-
// https://github.com/apache/datafusion-comet/issues/326
189-
Incompatible(
190-
Some(
191-
"Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " +
192-
"Does not support ANSI mode."))
188+
Compatible()
193189
case _: DecimalType =>
194190
// https://github.com/apache/datafusion-comet/issues/325
195191
Incompatible(Some("""Does not support fullwidth unicode digits (e.g \\uFF10)

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -642,34 +642,50 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
642642
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.LongType)
643643
}
644644

645-
ignore("cast StringType to FloatType") {
645+
test("cast StringType to DoubleType") {
646646
// https://github.com/apache/datafusion-comet/issues/326
647+
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType)
648+
}
649+
650+
test("cast StringType to FloatType") {
647651
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.FloatType)
648652
}
649653

650-
test("cast StringType to FloatType (partial support)") {
651-
withSQLConf(
652-
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
653-
SQLConf.ANSI_ENABLED.key -> "false") {
654-
castTest(
655-
gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"),
656-
DataTypes.FloatType,
657-
testAnsi = false)
654+
val specialValues: Seq[String] = Seq(
655+
"1.5f",
656+
"1.5F",
657+
"2.0d",
658+
"2.0D",
659+
"3.14159265358979d",
660+
"inf",
661+
"Inf",
662+
"INF",
663+
"+inf",
664+
"+Infinity",
665+
"-inf",
666+
"-Infinity",
667+
"NaN",
668+
"nan",
669+
"NAN",
670+
"1.23e4",
671+
"1.23E4",
672+
"-1.23e-4",
673+
" 123.456789 ",
674+
"0.0",
675+
"-0.0",
676+
"",
677+
"xyz",
678+
null)
679+
680+
test("cast StringType to FloatType special values") {
681+
Seq(true, false).foreach { ansiMode =>
682+
castTest(specialValues.toDF("a"), DataTypes.FloatType, testAnsi = ansiMode)
658683
}
659684
}
660685

661-
ignore("cast StringType to DoubleType") {
662-
// https://github.com/apache/datafusion-comet/issues/326
663-
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType)
664-
}
665-
test("cast StringType to DoubleType (partial support)") {
666-
withSQLConf(
667-
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
668-
SQLConf.ANSI_ENABLED.key -> "false") {
669-
castTest(
670-
gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"),
671-
DataTypes.DoubleType,
672-
testAnsi = false)
686+
test("cast StringType to DoubleType special values") {
687+
Seq(true, false).foreach { ansiMode =>
688+
castTest(specialValues.toDF("a"), DataTypes.DoubleType, testAnsi = ansiMode)
673689
}
674690
}
675691

0 commit comments

Comments
 (0)