@@ -45,6 +45,7 @@ use arrow::{
4545 record_batch:: RecordBatch ,
4646 util:: display:: FormatOptions ,
4747} ;
48+ use base64:: prelude:: * ;
4849use chrono:: { DateTime , NaiveDate , TimeZone , Timelike } ;
4950use datafusion:: common:: {
5051 cast:: as_generic_string_array, internal_err, DataFusionError , Result as DataFusionResult ,
@@ -66,8 +67,6 @@ use std::{
6667 sync:: Arc ,
6768} ;
6869
69- use base64:: prelude:: * ;
70-
7170static TIMESTAMP_FORMAT : Option < & str > = Some ( "%Y-%m-%d %H:%M:%S%.f" ) ;
7271
7372const MICROS_PER_SECOND : i64 = 1000000 ;
@@ -217,12 +216,7 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool
217216 use DataType :: * ;
218217 match to_type {
219218 Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true ,
220- Float32 | Float64 => {
221- // https://github.com/apache/datafusion-comet/issues/326
222- // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
223- // Does not support ANSI mode.
224- options. allow_incompat
225- }
219+ Float32 | Float64 => true ,
226220 Decimal128 ( _, _) => {
227221 // https://github.com/apache/datafusion-comet/issues/325
228222 // Does not support fullwidth digits and null byte handling.
@@ -975,6 +969,7 @@ fn cast_array(
975969 cast_string_to_timestamp ( & array, to_type, eval_mode, & cast_options. timezone )
976970 }
977971 ( Utf8 , Date32 ) => cast_string_to_date ( & array, to_type, eval_mode) ,
972+ ( Utf8 , Float32 | Float64 ) => cast_string_to_float ( & array, to_type, eval_mode) ,
978973 ( Utf8 | LargeUtf8 , Decimal128 ( precision, scale) ) => {
979974 cast_string_to_decimal ( & array, to_type, precision, scale, eval_mode)
980975 }
@@ -1046,7 +1041,7 @@ fn cast_array(
10461041 }
10471042 ( Binary , Utf8 ) => Ok ( cast_binary_to_string :: < i32 > ( & array, cast_options) ?) ,
10481043 _ if cast_options. is_adapting_schema
1049- || is_datafusion_spark_compatible ( from_type, to_type, cast_options . allow_incompat ) =>
1044+ || is_datafusion_spark_compatible ( from_type, to_type) =>
10501045 {
10511046 // use DataFusion cast only when we know that it is compatible with Spark
10521047 Ok ( cast_with_options ( & array, to_type, & native_cast_options) ?)
@@ -1063,6 +1058,86 @@ fn cast_array(
10631058 Ok ( spark_cast_postprocess ( cast_result?, from_type, to_type) )
10641059}
10651060
1061+ fn cast_string_to_float (
1062+ array : & ArrayRef ,
1063+ to_type : & DataType ,
1064+ eval_mode : EvalMode ,
1065+ ) -> SparkResult < ArrayRef > {
1066+ match to_type {
1067+ DataType :: Float32 => cast_string_to_float_impl :: < Float32Type > ( array, eval_mode, "FLOAT" ) ,
1068+ DataType :: Float64 => cast_string_to_float_impl :: < Float64Type > ( array, eval_mode, "DOUBLE" ) ,
1069+ _ => Err ( SparkError :: Internal ( format ! (
1070+ "Unsupported cast to float type: {:?}" ,
1071+ to_type
1072+ ) ) ) ,
1073+ }
1074+ }
1075+
1076+ fn cast_string_to_float_impl < T : ArrowPrimitiveType > (
1077+ array : & ArrayRef ,
1078+ eval_mode : EvalMode ,
1079+ type_name : & str ,
1080+ ) -> SparkResult < ArrayRef >
1081+ where
1082+ T :: Native : FromStr + num:: Float ,
1083+ {
1084+ let arr = array
1085+ . as_any ( )
1086+ . downcast_ref :: < StringArray > ( )
1087+ . ok_or_else ( || SparkError :: Internal ( "Expected string array" . to_string ( ) ) ) ?;
1088+
1089+ let mut builder = PrimitiveBuilder :: < T > :: with_capacity ( arr. len ( ) ) ;
1090+
1091+ for i in 0 ..arr. len ( ) {
1092+ if arr. is_null ( i) {
1093+ builder. append_null ( ) ;
1094+ } else {
1095+ let str_value = arr. value ( i) . trim ( ) ;
1096+ match parse_string_to_float ( str_value) {
1097+ Some ( v) => builder. append_value ( v) ,
1098+ None => {
1099+ if eval_mode == EvalMode :: Ansi {
1100+ return Err ( invalid_value ( arr. value ( i) , "STRING" , type_name) ) ;
1101+ }
1102+ builder. append_null ( ) ;
1103+ }
1104+ }
1105+ }
1106+ }
1107+
1108+ Ok ( Arc :: new ( builder. finish ( ) ) )
1109+ }
1110+
1111+ /// helper to parse floats from string inputs
1112+ fn parse_string_to_float < F > ( s : & str ) -> Option < F >
1113+ where
1114+ F : FromStr + num:: Float ,
1115+ {
1116+ // Handle +inf / -inf
1117+ if s. eq_ignore_ascii_case ( "inf" )
1118+ || s. eq_ignore_ascii_case ( "+inf" )
1119+ || s. eq_ignore_ascii_case ( "infinity" )
1120+ || s. eq_ignore_ascii_case ( "+infinity" )
1121+ {
1122+ return Some ( F :: infinity ( ) ) ;
1123+ }
1124+ if s. eq_ignore_ascii_case ( "-inf" ) || s. eq_ignore_ascii_case ( "-infinity" ) {
1125+ return Some ( F :: neg_infinity ( ) ) ;
1126+ }
1127+ if s. eq_ignore_ascii_case ( "nan" ) {
1128+ return Some ( F :: nan ( ) ) ;
1129+ }
1130+ // Remove D/F suffix if present
1131+ let pruned_float_str =
1132+ if s. ends_with ( "d" ) || s. ends_with ( "D" ) || s. ends_with ( 'f' ) || s. ends_with ( 'F' ) {
1133+ & s[ ..s. len ( ) - 1 ]
1134+ } else {
1135+ s
1136+ } ;
1137+ // Rust's parse logic already handles scientific notations so we just rely on it
1138+ pruned_float_str. parse :: < F > ( ) . ok ( )
1139+ }
1140+
10661141fn cast_binary_to_string < O : OffsetSizeTrait > (
10671142 array : & dyn Array ,
10681143 spark_cast_options : & SparkCastOptions ,
@@ -1133,11 +1208,7 @@ fn cast_binary_formatter(value: &[u8]) -> String {
11331208
11341209/// Determines if DataFusion supports the given cast in a way that is
11351210/// compatible with Spark
1136- fn is_datafusion_spark_compatible (
1137- from_type : & DataType ,
1138- to_type : & DataType ,
1139- allow_incompat : bool ,
1140- ) -> bool {
1211+ fn is_datafusion_spark_compatible ( from_type : & DataType , to_type : & DataType ) -> bool {
11411212 if from_type == to_type {
11421213 return true ;
11431214 }
@@ -1190,10 +1261,6 @@ fn is_datafusion_spark_compatible(
11901261 | DataType :: Decimal256 ( _, _)
11911262 | DataType :: Utf8 // note that there can be formatting differences
11921263 ) ,
1193- DataType :: Utf8 if allow_incompat => matches ! (
1194- to_type,
1195- DataType :: Binary | DataType :: Float32 | DataType :: Float64
1196- ) ,
11971264 DataType :: Utf8 => matches ! ( to_type, DataType :: Binary ) ,
11981265 DataType :: Date32 => matches ! ( to_type, DataType :: Utf8 ) ,
11991266 DataType :: Timestamp ( _, _) => {
0 commit comments