Skip to content

Commit af3bd81

Browse files
authored
perf: Improve criterion benchmarks for cast string to int (#3049)
1 parent 77899ee commit af3bd81

1 file changed

Lines changed: 76 additions & 22 deletions

File tree

native/spark-expr/benches/cast_from_string.rs

Lines changed: 76 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,45 +23,99 @@ use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
2323
use std::sync::Arc;
2424

2525
fn criterion_benchmark(c: &mut Criterion) {
26-
let batch = create_utf8_batch();
26+
let small_int_batch = create_small_int_string_batch();
27+
let int_batch = create_int_string_batch();
28+
let decimal_batch = create_decimal_string_batch();
2729
let expr = Arc::new(Column::new("a", 0));
30+
31+
for (mode, mode_name) in [
32+
(EvalMode::Legacy, "legacy"),
33+
(EvalMode::Ansi, "ansi"),
34+
(EvalMode::Try, "try"),
35+
] {
36+
let spark_cast_options = SparkCastOptions::new(mode, "", false);
37+
let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
38+
let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
39+
let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
40+
let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options);
41+
42+
let mut group = c.benchmark_group(format!("cast_string_to_int/{}", mode_name));
43+
group.bench_function("i8", |b| {
44+
b.iter(|| cast_to_i8.evaluate(&small_int_batch).unwrap());
45+
});
46+
group.bench_function("i16", |b| {
47+
b.iter(|| cast_to_i16.evaluate(&small_int_batch).unwrap());
48+
});
49+
group.bench_function("i32", |b| {
50+
b.iter(|| cast_to_i32.evaluate(&int_batch).unwrap());
51+
});
52+
group.bench_function("i64", |b| {
53+
b.iter(|| cast_to_i64.evaluate(&int_batch).unwrap());
54+
});
55+
group.finish();
56+
}
57+
58+
// Benchmark decimal truncation (Legacy mode only)
2859
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false);
29-
let cast_string_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
30-
let cast_string_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
31-
let cast_string_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
32-
let cast_string_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options);
60+
let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
61+
let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options);
3362

34-
let mut group = c.benchmark_group("cast_string_to_int");
35-
group.bench_function("cast_string_to_i8", |b| {
36-
b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap());
63+
let mut group = c.benchmark_group("cast_string_to_int/legacy_decimals");
64+
group.bench_function("i32", |b| {
65+
b.iter(|| cast_to_i32.evaluate(&decimal_batch).unwrap());
3766
});
38-
group.bench_function("cast_string_to_i16", |b| {
39-
b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap());
40-
});
41-
group.bench_function("cast_string_to_i32", |b| {
42-
b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap());
43-
});
44-
group.bench_function("cast_string_to_i64", |b| {
45-
b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap());
67+
group.bench_function("i64", |b| {
68+
b.iter(|| cast_to_i64.evaluate(&decimal_batch).unwrap());
4669
});
70+
group.finish();
4771
}
4872

49-
// Create UTF8 batch with strings representing ints, floats, nulls
50-
fn create_utf8_batch() -> RecordBatch {
73+
/// Create batch with small integer strings that fit in i8 range (for i8/i16 benchmarks)
74+
fn create_small_int_string_batch() -> RecordBatch {
5175
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
5276
let mut b = StringBuilder::new();
5377
for i in 0..1000 {
5478
if i % 10 == 0 {
5579
b.append_null();
56-
} else if i % 2 == 0 {
57-
b.append_value(format!("{}", rand::random::<f64>()));
5880
} else {
59-
b.append_value(format!("{}", rand::random::<i64>()));
81+
b.append_value(format!("{}", rand::random::<i8>()));
6082
}
6183
}
6284
let array = b.finish();
85+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
86+
}
6387

64-
RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap()
88+
/// Create batch with valid integer strings (works for all eval modes)
89+
fn create_int_string_batch() -> RecordBatch {
90+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
91+
let mut b = StringBuilder::new();
92+
for i in 0..1000 {
93+
if i % 10 == 0 {
94+
b.append_null();
95+
} else {
96+
b.append_value(format!("{}", rand::random::<i32>()));
97+
}
98+
}
99+
let array = b.finish();
100+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
101+
}
102+
103+
/// Create batch with decimal strings (for Legacy mode decimal truncation)
104+
fn create_decimal_string_batch() -> RecordBatch {
105+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
106+
let mut b = StringBuilder::new();
107+
for i in 0..1000 {
108+
if i % 10 == 0 {
109+
b.append_null();
110+
} else {
111+
// Generate integers with decimal portions to test truncation
112+
let int_part: i32 = rand::random();
113+
let dec_part: u32 = rand::random::<u32>() % 1000;
114+
b.append_value(format!("{}.{}", int_part, dec_part));
115+
}
116+
}
117+
let array = b.finish();
118+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
65119
}
66120

67121
fn config() -> Criterion {

0 commit comments

Comments
 (0)