Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions native/spark-expr/src/string_funcs/substring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
#![allow(deprecated)]

use crate::kernels::strings::substring;
use arrow::datatypes::{DataType, Schema};
use arrow::array::{as_dictionary_array, as_largestring_array, as_string_array, Array, ArrayRef};
use arrow::datatypes::{DataType, Int32Type, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::common::DataFusionError;
use datafusion::logical_expr::ColumnarValue;
Expand Down Expand Up @@ -88,8 +89,14 @@ impl PhysicalExpr for SubstringExpr {
let arg = self.child.evaluate(batch)?;
match arg {
ColumnarValue::Array(array) => {
let result = substring(&array, self.start, self.len)?;

let result = if self.start < 0 {
// Spark and Arrow differ for negative start: Arrow clamps
// start to 0 then takes `len` chars, but Spark computes
// end = unclamped_start + len, then clamps both independently.
spark_substring_negative_start(&array, self.start, self.len)?
} else {
substring(&array, self.start, self.len)?
};
Ok(ColumnarValue::Array(result))
}
_ => Err(DataFusionError::Execution(
Expand All @@ -113,3 +120,64 @@ impl PhysicalExpr for SubstringExpr {
)))
}
}

/// Implement Spark's substring semantics for negative start positions.
/// Spark: start = numChars + pos, end = start + len, clamp both, empty if start >= end.
/// Arrow: start = max(0, numChars + pos), take len chars — differs when start is clamped.
fn spark_substring_negative_start(
array: &ArrayRef,
start: i64,
len: u64,
) -> datafusion::common::Result<ArrayRef> {
use arrow::array::{DictionaryArray, GenericStringBuilder};

match array.data_type() {
DataType::Utf8 => {
let str_array = as_string_array(array);
let mut builder = GenericStringBuilder::<i32>::new();
for i in 0..str_array.len() {
if str_array.is_null(i) {
builder.append_null();
} else {
builder.append_value(spark_substr_negative(str_array.value(i), start, len));
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
DataType::LargeUtf8 => {
let str_array = as_largestring_array(array);
let mut builder = GenericStringBuilder::<i64>::new();
for i in 0..str_array.len() {
if str_array.is_null(i) {
builder.append_null();
} else {
builder.append_value(spark_substr_negative(str_array.value(i), start, len));
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
DataType::Dictionary(_, _) => {
let dict = as_dictionary_array::<Int32Type>(array);
let values = spark_substring_negative_start(dict.values(), start, len)?;
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
Ok(Arc::new(result) as ArrayRef)
}
_ => Ok(Arc::clone(array)),
}
}

fn spark_substr_negative(s: &str, pos: i64, len: u64) -> String {
let num_chars = s.chars().count() as i64;
let start = num_chars + pos;
let end = start.saturating_add(len as i64).min(num_chars);
let start = start.max(0);

if start >= end {
return String::new();
}

s.chars()
.skip(start as usize)
.take((end - start) as usize)
.collect()
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ SELECT substring(s, 1, -1) FROM test_substring
query
SELECT substring(s, 100) FROM test_substring

query
SELECT substring(s, -2, 3) FROM test_substring

query
SELECT substring(s, -10, 3) FROM test_substring

query
SELECT substring(s, -300, 3) FROM test_substring

-- literal + literal + literal
query ignore(https://github.com/apache/datafusion-comet/issues/3337)
SELECT substring('hello world', 1, 5), substring('hello world', -3), substring('', 1, 5), substring(NULL, 1, 5)
148 changes: 148 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,152 @@ class CometStringExpressionSuite extends CometTestBase {
}
}

test("substring") {
val data = Seq(("hello world", ""), ("", ""), (null, ""), ("abc", ""))
withParquetTable(data, "tbl") {
// positive start
checkSparkAnswerAndOperator("SELECT substring(_1, 1, 5) FROM tbl")
// negative start, no length
checkSparkAnswerAndOperator("SELECT substring(_1, -3) FROM tbl")
// zero start
checkSparkAnswerAndOperator("SELECT substring(_1, 0, 3) FROM tbl")
// zero length
checkSparkAnswerAndOperator("SELECT substring(_1, 1, 0) FROM tbl")
// negative length
checkSparkAnswerAndOperator("SELECT substring(_1, 1, -1) FROM tbl")
// start beyond string length
checkSparkAnswerAndOperator("SELECT substring(_1, 100) FROM tbl")
// negative start with length
checkSparkAnswerAndOperator("SELECT substring(_1, -2, 3) FROM tbl")
// negative start beyond string length with length
checkSparkAnswerAndOperator("SELECT substring(_1, -10, 3) FROM tbl")
// large negative start with length
checkSparkAnswerAndOperator("SELECT substring(_1, -300, 3) FROM tbl")
}
}

test("substring - negative start boundary cases") {
// "abc" has length 3, so -3 means start at first char, -4 exceeds length
val data = Seq(("abc", ""), ("a", ""), ("ab", ""), ("", ""), (null, ""))
withParquetTable(data, "tbl") {
// abs(start) == string length exactly (boundary: should return from first char)
checkSparkAnswerAndOperator("SELECT substring(_1, -3, 2) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -3) FROM tbl")
// abs(start) == length + 1 (one past boundary: should return empty)
checkSparkAnswerAndOperator("SELECT substring(_1, -4, 2) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -4) FROM tbl")
// abs(start) == length - 1 (one before boundary)
checkSparkAnswerAndOperator("SELECT substring(_1, -2, 5) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -2) FROM tbl")
// -1: last character
checkSparkAnswerAndOperator("SELECT substring(_1, -1, 1) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -1) FROM tbl")
// -1 with length exceeding remaining chars
checkSparkAnswerAndOperator("SELECT substring(_1, -1, 100) FROM tbl")
}
}

test("substring - negative start with zero and negative length") {
val data = Seq(("hello", ""), ("ab", ""), ("", ""), (null, ""))
withParquetTable(data, "tbl") {
// negative start + zero length
checkSparkAnswerAndOperator("SELECT substring(_1, -3, 0) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -100, 0) FROM tbl")
// negative start + negative length
checkSparkAnswerAndOperator("SELECT substring(_1, -3, -1) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -1, -5) FROM tbl")
// negative start exceeding length + zero length
checkSparkAnswerAndOperator("SELECT substring(_1, -10, 0) FROM tbl")
// negative start exceeding length + negative length
checkSparkAnswerAndOperator("SELECT substring(_1, -10, -1) FROM tbl")
}
}

test("substring - single character and empty strings") {
val data = Seq(("x", ""), ("", ""), (null, ""))
withParquetTable(data, "tbl") {
for (start <- Seq(-2, -1, 0, 1, 2)) {
for (len <- Seq(0, 1, 5)) {
checkSparkAnswerAndOperator(s"SELECT substring(_1, $start, $len) FROM tbl")
}
// without explicit length
checkSparkAnswerAndOperator(s"SELECT substring(_1, $start) FROM tbl")
}
}
}

test("substring - unicode multi-byte characters") {
// scalastyle:off
val data = Seq(
("苹果手机", ""), // 4 Chinese characters (3 bytes each in UTF-8)
("café", ""), // combining accent
("😀🎉🔥", ""), // emoji (4 bytes each in UTF-8)
("aé苹😀", ""), // mixed: ASCII + 2-byte + 3-byte + 4-byte
("", ""),
(null, ""))
// scalastyle:on
withParquetTable(data, "tbl") {
// positive start into multi-byte
checkSparkAnswerAndOperator("SELECT substring(_1, 2, 2) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, 1, 1) FROM tbl")
// negative start with multi-byte
checkSparkAnswerAndOperator("SELECT substring(_1, -2) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -2, 1) FROM tbl")
// negative start exceeding multi-byte string length
checkSparkAnswerAndOperator("SELECT substring(_1, -10, 2) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -10) FROM tbl")
// abs(start) == char length boundary for 4-char string
checkSparkAnswerAndOperator("SELECT substring(_1, -4, 2) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -5, 2) FROM tbl")
// extract entire string
checkSparkAnswerAndOperator("SELECT substring(_1, 1, 100) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, 1) FROM tbl")
}
}

test("substring - large start and length values") {
val data = Seq(("hello world", ""), ("abc", ""), ("", ""), (null, ""))
withParquetTable(data, "tbl") {
checkSparkAnswerAndOperator(s"SELECT substring(_1, ${Int.MaxValue}, 5) FROM tbl")
checkSparkAnswerAndOperator(s"SELECT substring(_1, 1, ${Int.MaxValue}) FROM tbl")
checkSparkAnswerAndOperator(s"SELECT substring(_1, ${Int.MinValue + 1}, 5) FROM tbl")
checkSparkAnswerAndOperator(s"SELECT substring(_1, ${Int.MinValue + 1}) FROM tbl")
checkSparkAnswerAndOperator(
s"SELECT substring(_1, ${Int.MaxValue}, ${Int.MaxValue}) FROM tbl")
}
}

test("substring - dictionary encoded strings") {
// repeated values to trigger dictionary encoding
val data = (0 until 1000).map { i =>
val s = i % 5 match {
case 0 => "hello"
case 1 => "ab"
case 2 => ""
case 3 => null
case 4 => "world!"
}
Tuple1(s)
}
withSQLConf("parquet.enable.dictionary" -> "true") {
withParquetTable(data, "tbl") {
// positive start
checkSparkAnswerAndOperator("SELECT substring(_1, 2, 3) FROM tbl")
// negative start within bounds
checkSparkAnswerAndOperator("SELECT substring(_1, -3, 2) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -3) FROM tbl")
// negative start exceeding length for some values
checkSparkAnswerAndOperator("SELECT substring(_1, -4, 2) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -4) FROM tbl")
// negative start exceeding all string lengths
checkSparkAnswerAndOperator("SELECT substring(_1, -100, 3) FROM tbl")
checkSparkAnswerAndOperator("SELECT substring(_1, -100) FROM tbl")
// zero start
checkSparkAnswerAndOperator("SELECT substring(_1, 0, 3) FROM tbl")
// -1 last char
checkSparkAnswerAndOperator("SELECT substring(_1, -1, 1) FROM tbl")
}
}
}

}
Loading