Skip to content

Commit 4caad8b

Browse files
authored
fix: support full-width and null characters, and negative scale in string to decimal (#3922)
* fix: support full-width and null characters, and negative scale in string to decimal
1 parent bb9cc4b commit 4caad8b

7 files changed

Lines changed: 211 additions & 80 deletions

File tree

docs/source/user-guide/latest/compatibility.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,19 @@ Cast operations in Comet fall into three levels of support:
136136
Spark.
137137
- **N/A**: Spark does not support this cast.
138138

139+
### String to Decimal
140+
141+
Comet's native `CAST(string AS DECIMAL)` implementation matches Apache Spark's behavior,
142+
including:
143+
144+
- Leading and trailing ASCII whitespace is trimmed before parsing.
145+
- Null bytes (`\u0000`) at the start or end of a string are trimmed, matching Spark's
146+
`UTF8String` behavior. Null bytes embedded in the middle of a string produce `NULL`.
147+
- Fullwidth Unicode digits (U+FF10–U+FF19, e.g. `123.45`) are treated as their ASCII
148+
equivalents, so `CAST('123.45' AS DECIMAL(10,2))` returns `123.45`.
149+
- Scientific notation (e.g. `1.23E+5`) is supported.
150+
- Special values (`inf`, `infinity`, `nan`) produce `NULL`.
151+
139152
### String to Timestamp
140153

141154
Comet's native `CAST(string AS TIMESTAMP)` implementation supports all timestamp formats accepted

native/spark-expr/src/conversion_funcs/string.rs

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,40 @@ fn cast_string_to_decimal256_impl(
438438
))
439439
}
440440

441+
/// Normalize fullwidth Unicode digits (U+FF10–U+FF19) to their ASCII equivalents.
442+
///
443+
/// Spark's UTF8String parser treats fullwidth digits as numerically equivalent to
444+
/// ASCII digits, e.g. "123.45" parses as 123.45. Each fullwidth digit encodes
445+
/// to exactly three UTF-8 bytes: [0xEF, 0xBC, 0x90+n] for digit n. The ASCII
446+
/// equivalent is 0x30+n, so the conversion is: third_byte - 0x60.
447+
///
448+
/// All other bytes (ASCII or other multi-byte sequences) are passed through
449+
/// unchanged, so the output is valid UTF-8 whenever the input is.
450+
fn normalize_fullwidth_digits(s: &str) -> String {
451+
let bytes = s.as_bytes();
452+
let mut out = Vec::with_capacity(s.len());
453+
let mut i = 0;
454+
while i < bytes.len() {
455+
if i + 2 < bytes.len()
456+
&& bytes[i] == 0xEF
457+
&& bytes[i + 1] == 0xBC
458+
&& bytes[i + 2] >= 0x90
459+
&& bytes[i + 2] <= 0x99
460+
{
461+
// e.g. 0x91 - 0x60 = 0x31 = b'1'
462+
out.push(bytes[i + 2] - 0x60);
463+
i += 3;
464+
} else {
465+
out.push(bytes[i]);
466+
i += 1;
467+
}
468+
}
469+
// SAFETY: we only replace valid 3-byte UTF-8 sequences [EF BC 9X] with a
470+
// single ASCII byte; all other bytes are copied unchanged, preserving the
471+
// UTF-8 invariant of the input.
472+
unsafe { String::from_utf8_unchecked(out) }
473+
}
474+
441475
/// Parse a decimal string into mantissa and scale
442476
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) , 0e50 -> (0,50) etc
443477
/// Parse a string to decimal following Spark's behavior
@@ -446,16 +480,30 @@ fn parse_string_to_decimal(input_str: &str, precision: u8, scale: i8) -> SparkRe
446480
let mut start = 0;
447481
let mut end = string_bytes.len();
448482

449-
// trim whitespaces
450-
while start < end && string_bytes[start].is_ascii_whitespace() {
483+
// Trim ASCII whitespace and null bytes from both ends. Spark's UTF8String
484+
// trims null bytes the same way it trims whitespace: "123\u0000" and
485+
// "\u0000123" both parse as 123. Null bytes in the middle are not trimmed
486+
// and will fail the digit validation in parse_decimal_str, producing NULL.
487+
while start < end && (string_bytes[start].is_ascii_whitespace() || string_bytes[start] == 0) {
451488
start += 1;
452489
}
453-
while end > start && string_bytes[end - 1].is_ascii_whitespace() {
490+
while end > start && (string_bytes[end - 1].is_ascii_whitespace() || string_bytes[end - 1] == 0)
491+
{
454492
end -= 1;
455493
}
456494

457495
let trimmed = &input_str[start..end];
458496

497+
// Normalize fullwidth digits to ASCII. Fast path skips the allocation for
498+
// pure-ASCII strings, which is the common case.
499+
let normalized;
500+
let trimmed = if trimmed.bytes().any(|b| b > 0x7F) {
501+
normalized = normalize_fullwidth_digits(trimmed);
502+
normalized.as_str()
503+
} else {
504+
trimmed
505+
};
506+
459507
if trimmed.is_empty() {
460508
return Ok(None);
461509
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
211211
case DataTypes.FloatType | DataTypes.DoubleType =>
212212
Compatible()
213213
case _: DecimalType =>
214-
// https://github.com/apache/datafusion-comet/issues/325
215-
Incompatible(Some("""Does not support fullwidth unicode digits (e.g \\uFF10)
216-
|or strings containing null bytes (e.g \\u0000)""".stripMargin))
214+
Compatible()
217215
case DataTypes.DateType =>
218216
// https://github.com/apache/datafusion-comet/issues/327
219217
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))

spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,21 @@ trait ShimSparkErrorConverter {
327327
try {
328328
DataType.fromDDL(typeName)
329329
} catch {
330-
case _: Exception => StringType
330+
case _: Exception =>
331+
// fromDDL rejects types that are syntactically invalid in SQL DDL, such as
332+
// DECIMAL(p,s) with a negative scale (valid when allowNegativeScaleOfDecimal=true).
333+
// Parse those manually rather than silently falling back to StringType.
334+
if (typeName.toUpperCase.startsWith("DECIMAL(") && typeName.endsWith(")")) {
335+
val inner = typeName.substring("DECIMAL(".length, typeName.length - 1)
336+
val parts = inner.split(",")
337+
if (parts.length == 2) {
338+
try {
339+
DataTypes.createDecimalType(parts(0).trim.toInt, parts(1).trim.toInt)
340+
} catch {
341+
case _: Exception => StringType
342+
}
343+
} else StringType
344+
} else StringType
331345
}
332346
}
333347
}

spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,21 @@ trait ShimSparkErrorConverter {
323323
try {
324324
DataType.fromDDL(typeName)
325325
} catch {
326-
case _: Exception => StringType
326+
case _: Exception =>
327+
// fromDDL rejects types that are syntactically invalid in SQL DDL, such as
328+
// DECIMAL(p,s) with a negative scale (valid when allowNegativeScaleOfDecimal=true).
329+
// Parse those manually rather than silently falling back to StringType.
330+
if (typeName.toUpperCase.startsWith("DECIMAL(") && typeName.endsWith(")")) {
331+
val inner = typeName.substring("DECIMAL(".length, typeName.length - 1)
332+
val parts = inner.split(",")
333+
if (parts.length == 2) {
334+
try {
335+
DataTypes.createDecimalType(parts(0).trim.toInt, parts(1).trim.toInt)
336+
} catch {
337+
case _: Exception => StringType
338+
}
339+
} else StringType
340+
} else StringType
327341
}
328342
}
329343
}

spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,21 @@ trait ShimSparkErrorConverter {
328328
try {
329329
DataType.fromDDL(typeName)
330330
} catch {
331-
case _: Exception => StringType
331+
case _: Exception =>
332+
// fromDDL rejects types that are syntactically invalid in SQL DDL, such as
333+
// DECIMAL(p,s) with a negative scale (valid when allowNegativeScaleOfDecimal=true).
334+
// Parse those manually rather than silently falling back to StringType.
335+
if (typeName.toUpperCase.startsWith("DECIMAL(") && typeName.endsWith(")")) {
336+
val inner = typeName.substring("DECIMAL(".length, typeName.length - 1)
337+
val parts = inner.split(",")
338+
if (parts.length == 2) {
339+
try {
340+
DataTypes.createDecimalType(parts(0).trim.toInt, parts(1).trim.toInt)
341+
} catch {
342+
case _: Exception => StringType
343+
}
344+
} else StringType
345+
} else StringType
332346
}
333347
}
334348
}

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 101 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -767,102 +767,119 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
767767
}
768768

769769
// This is to pass the first `all cast combinations are covered`
770-
ignore("cast StringType to DecimalType(10,2)") {
770+
test("cast StringType to DecimalType(10,2)") {
771771
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
772772
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
773773
}
774774

775-
test("cast StringType to DecimalType(10,2) (does not support fullwidth unicode digits)") {
776-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
777-
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
778-
Seq(true, false).foreach(ansiEnabled =>
779-
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled))
780-
}
775+
test("cast StringType to DecimalType(10,2) fuzz") {
776+
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
777+
Seq(true, false).foreach(ansiEnabled =>
778+
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled))
781779
}
782780

783781
test("cast StringType to DecimalType(2,2)") {
784-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
785-
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
786-
Seq(true, false).foreach(ansiEnabled =>
787-
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
788-
}
782+
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
783+
Seq(true, false).foreach(ansiEnabled =>
784+
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
789785
}
790786

791787
test("cast StringType to DecimalType check if right exception message is thrown") {
792-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
793-
val values = Seq("d11307\n").toDF("a")
794-
Seq(true, false).foreach(ansiEnabled =>
795-
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
796-
}
788+
val values = Seq("d11307\n").toDF("a")
789+
Seq(true, false).foreach(ansiEnabled =>
790+
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
797791
}
798792

799793
test("cast StringType to DecimalType(2,2) check if right exception is being thrown") {
800-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
801-
val values = gen.generateInts(10000).map(" " + _).toDF("a")
802-
Seq(true, false).foreach(ansiEnabled =>
803-
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
804-
}
794+
val values = gen.generateInts(10000).map(" " + _).toDF("a")
795+
Seq(true, false).foreach(ansiEnabled =>
796+
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled))
805797
}
806798

807799
test("cast StringType to DecimalType(38,10) high precision - check 0 mantissa") {
808-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
809-
val values = Seq("0e31", "000e3375", "0e40", "0E+695", "0e5887677").toDF("a")
810-
Seq(true, false).foreach(ansiEnabled =>
811-
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
812-
}
800+
val values = Seq("0e31", "000e3375", "0e40", "0E+695", "0e5887677").toDF("a")
801+
Seq(true, false).foreach(ansiEnabled =>
802+
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
813803
}
814804

815805
test("cast StringType to DecimalType(38,10) high precision") {
816-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
817-
val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
818-
Seq(true, false).foreach(ansiEnabled =>
819-
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
820-
}
806+
val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
807+
Seq(true, false).foreach(ansiEnabled =>
808+
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled))
809+
}
810+
811+
test("cast StringType to DecimalType - null bytes and fullwidth digits") {
812+
// Spark trims null bytes (\u0000) from both ends of a string before parsing,
813+
// matching its whitespace-trim behavior. Null bytes in the middle produce NULL.
814+
// Fullwidth digits (U+FF10-U+FF19) are treated as numeric equivalents to ASCII digits.
815+
val values = Seq(
816+
// null byte positions
817+
"123\u0000",
818+
"\u0000123",
819+
"12\u00003",
820+
"1\u00002\u00003",
821+
"\u0000",
822+
// null byte with decimal point
823+
"12\u0000.45",
824+
"12.\u000045",
825+
// fullwidth digits (U+FF10-U+FF19)
826+
"123.45", // "123.45" in fullwidth
827+
"123",
828+
"-123.45",
829+
"+123.45",
830+
"123.45E2",
831+
// mixed fullwidth and ASCII
832+
"123.45",
833+
null).toDF("a")
834+
castTest(values, DataTypes.createDecimalType(10, 2))
821835
}
822836

823837
test("cast StringType to DecimalType(10,2) basic values") {
824-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
825-
val values = Seq(
826-
"123.45",
827-
"-67.89",
828-
"-67.89",
829-
"-67.895",
830-
"67.895",
831-
"0.001",
832-
"999.99",
833-
"123.456",
834-
"123.45D",
835-
".5",
836-
"5.",
837-
"+123.45",
838-
" 123.45 ",
839-
"inf",
840-
"",
841-
"abc",
842-
null).toDF("a")
843-
Seq(true, false).foreach(ansiEnabled =>
844-
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled))
845-
}
838+
val values = Seq(
839+
"123.45",
840+
"-67.89",
841+
"-67.89",
842+
"-67.895",
843+
"67.895",
844+
"0.001",
845+
"999.99",
846+
"123.456",
847+
"123.45D",
848+
".5",
849+
"5.",
850+
"+123.45",
851+
" 123.45 ",
852+
"inf",
853+
"",
854+
"abc",
855+
// values from https://github.com/apache/datafusion-comet/issues/325
856+
"0",
857+
"1",
858+
"+1.0",
859+
".34",
860+
"-10.0",
861+
"4e7",
862+
null).toDF("a")
863+
Seq(true, false).foreach(ansiEnabled =>
864+
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled))
846865
}
847866

848867
test("cast StringType to Decimal type scientific notation") {
849-
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
850-
val values = Seq(
851-
"1.23E-5",
852-
"1.23e10",
853-
"1.23E+10",
854-
"-1.23e-5",
855-
"1e5",
856-
"1E-2",
857-
"-1.5e3",
858-
"1.23E0",
859-
"0e0",
860-
"1.23e",
861-
"e5",
862-
null).toDF("a")
863-
Seq(true, false).foreach(ansiEnabled =>
864-
castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = ansiEnabled))
865-
}
868+
val values = Seq(
869+
"1.23E-5",
870+
"1.23e10",
871+
"1.23E+10",
872+
"-1.23e-5",
873+
"1e5",
874+
"1E-2",
875+
"-1.5e3",
876+
"1.23E0",
877+
"0e0",
878+
"1.23e",
879+
"e5",
880+
null).toDF("a")
881+
Seq(true, false).foreach(ansiEnabled =>
882+
castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = ansiEnabled))
866883
}
867884

868885
test("cast StringType to BinaryType") {
@@ -1310,6 +1327,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
13101327
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4))
13111328
}
13121329

1330+
test("cast StringType to DecimalType with negative scale (allowNegativeScaleOfDecimal)") {
1331+
// With allowNegativeScaleOfDecimal=true, Spark allows DECIMAL(p, s) where s < 0.
1332+
// The value is rounded to the nearest 10^|s| — e.g. DECIMAL(10,-4) rounds to
1333+
// the nearest 10000. This requires the legacy SQL parser config to be enabled.
1334+
withSQLConf("spark.sql.legacy.allowNegativeScaleOfDecimal" -> "true") {
1335+
val values =
1336+
Seq("12500", "15000", "99990000", "-12500", "0", "0.001", "abc", null).toDF("a")
1337+
// testTry=false: try_cast uses SQL string interpolation (toType.sql → "DECIMAL(10,-4)")
1338+
// which the SQL parser rejects regardless of allowNegativeScaleOfDecimal.
1339+
castTest(values, DataTypes.createDecimalType(10, -4), testTry = false)
1340+
}
1341+
}
1342+
13131343
test("cast between decimals with negative precision") {
13141344
// cast to negative scale
13151345
checkSparkAnswerMaybeThrows(

0 commit comments

Comments
 (0)