@@ -23,45 +23,99 @@ use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
2323use std:: sync:: Arc ;
2424
2525fn criterion_benchmark ( c : & mut Criterion ) {
26- let batch = create_utf8_batch ( ) ;
26+ let small_int_batch = create_small_int_string_batch ( ) ;
27+ let int_batch = create_int_string_batch ( ) ;
28+ let decimal_batch = create_decimal_string_batch ( ) ;
2729 let expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
30+
31+ for ( mode, mode_name) in [
32+ ( EvalMode :: Legacy , "legacy" ) ,
33+ ( EvalMode :: Ansi , "ansi" ) ,
34+ ( EvalMode :: Try , "try" ) ,
35+ ] {
36+ let spark_cast_options = SparkCastOptions :: new ( mode, "" , false ) ;
37+ let cast_to_i8 = Cast :: new ( expr. clone ( ) , DataType :: Int8 , spark_cast_options. clone ( ) ) ;
38+ let cast_to_i16 = Cast :: new ( expr. clone ( ) , DataType :: Int16 , spark_cast_options. clone ( ) ) ;
39+ let cast_to_i32 = Cast :: new ( expr. clone ( ) , DataType :: Int32 , spark_cast_options. clone ( ) ) ;
40+ let cast_to_i64 = Cast :: new ( expr. clone ( ) , DataType :: Int64 , spark_cast_options) ;
41+
42+ let mut group = c. benchmark_group ( format ! ( "cast_string_to_int/{}" , mode_name) ) ;
43+ group. bench_function ( "i8" , |b| {
44+ b. iter ( || cast_to_i8. evaluate ( & small_int_batch) . unwrap ( ) ) ;
45+ } ) ;
46+ group. bench_function ( "i16" , |b| {
47+ b. iter ( || cast_to_i16. evaluate ( & small_int_batch) . unwrap ( ) ) ;
48+ } ) ;
49+ group. bench_function ( "i32" , |b| {
50+ b. iter ( || cast_to_i32. evaluate ( & int_batch) . unwrap ( ) ) ;
51+ } ) ;
52+ group. bench_function ( "i64" , |b| {
53+ b. iter ( || cast_to_i64. evaluate ( & int_batch) . unwrap ( ) ) ;
54+ } ) ;
55+ group. finish ( ) ;
56+ }
57+
58+ // Benchmark decimal truncation (Legacy mode only)
2859 let spark_cast_options = SparkCastOptions :: new ( EvalMode :: Legacy , "" , false ) ;
29- let cast_string_to_i8 = Cast :: new ( expr. clone ( ) , DataType :: Int8 , spark_cast_options. clone ( ) ) ;
30- let cast_string_to_i16 = Cast :: new ( expr. clone ( ) , DataType :: Int16 , spark_cast_options. clone ( ) ) ;
31- let cast_string_to_i32 = Cast :: new ( expr. clone ( ) , DataType :: Int32 , spark_cast_options. clone ( ) ) ;
32- let cast_string_to_i64 = Cast :: new ( expr, DataType :: Int64 , spark_cast_options) ;
60+ let cast_to_i32 = Cast :: new ( expr. clone ( ) , DataType :: Int32 , spark_cast_options. clone ( ) ) ;
61+ let cast_to_i64 = Cast :: new ( expr. clone ( ) , DataType :: Int64 , spark_cast_options) ;
3362
34- let mut group = c. benchmark_group ( "cast_string_to_int" ) ;
35- group. bench_function ( "cast_string_to_i8 " , |b| {
36- b. iter ( || cast_string_to_i8 . evaluate ( & batch ) . unwrap ( ) ) ;
63+ let mut group = c. benchmark_group ( "cast_string_to_int/legacy_decimals " ) ;
64+ group. bench_function ( "i32 " , |b| {
65+ b. iter ( || cast_to_i32 . evaluate ( & decimal_batch ) . unwrap ( ) ) ;
3766 } ) ;
38- group. bench_function ( "cast_string_to_i16" , |b| {
39- b. iter ( || cast_string_to_i16. evaluate ( & batch) . unwrap ( ) ) ;
40- } ) ;
41- group. bench_function ( "cast_string_to_i32" , |b| {
42- b. iter ( || cast_string_to_i32. evaluate ( & batch) . unwrap ( ) ) ;
43- } ) ;
44- group. bench_function ( "cast_string_to_i64" , |b| {
45- b. iter ( || cast_string_to_i64. evaluate ( & batch) . unwrap ( ) ) ;
67+ group. bench_function ( "i64" , |b| {
68+ b. iter ( || cast_to_i64. evaluate ( & decimal_batch) . unwrap ( ) ) ;
4669 } ) ;
70+ group. finish ( ) ;
4771}
4872
49- // Create UTF8 batch with strings representing ints, floats, nulls
50- fn create_utf8_batch ( ) -> RecordBatch {
73+ /// Create batch with small integer strings that fit in i8 range (for i8/i16 benchmarks)
74+ fn create_small_int_string_batch ( ) -> RecordBatch {
5175 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Utf8 , true ) ] ) ) ;
5276 let mut b = StringBuilder :: new ( ) ;
5377 for i in 0 ..1000 {
5478 if i % 10 == 0 {
5579 b. append_null ( ) ;
56- } else if i % 2 == 0 {
57- b. append_value ( format ! ( "{}" , rand:: random:: <f64 >( ) ) ) ;
5880 } else {
59- b. append_value ( format ! ( "{}" , rand:: random:: <i64 >( ) ) ) ;
81+ b. append_value ( format ! ( "{}" , rand:: random:: <i8 >( ) ) ) ;
6082 }
6183 }
6284 let array = b. finish ( ) ;
85+ RecordBatch :: try_new ( schema, vec ! [ Arc :: new( array) ] ) . unwrap ( )
86+ }
6387
64- RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( array) ] ) . unwrap ( )
88+ /// Create batch with valid integer strings (works for all eval modes)
89+ fn create_int_string_batch ( ) -> RecordBatch {
90+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Utf8 , true ) ] ) ) ;
91+ let mut b = StringBuilder :: new ( ) ;
92+ for i in 0 ..1000 {
93+ if i % 10 == 0 {
94+ b. append_null ( ) ;
95+ } else {
96+ b. append_value ( format ! ( "{}" , rand:: random:: <i32 >( ) ) ) ;
97+ }
98+ }
99+ let array = b. finish ( ) ;
100+ RecordBatch :: try_new ( schema, vec ! [ Arc :: new( array) ] ) . unwrap ( )
101+ }
102+
103+ /// Create batch with decimal strings (for Legacy mode decimal truncation)
104+ fn create_decimal_string_batch ( ) -> RecordBatch {
105+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Utf8 , true ) ] ) ) ;
106+ let mut b = StringBuilder :: new ( ) ;
107+ for i in 0 ..1000 {
108+ if i % 10 == 0 {
109+ b. append_null ( ) ;
110+ } else {
111+ // Generate integers with decimal portions to test truncation
112+ let int_part: i32 = rand:: random ( ) ;
113+ let dec_part: u32 = rand:: random :: < u32 > ( ) % 1000 ;
114+ b. append_value ( format ! ( "{}.{}" , int_part, dec_part) ) ;
115+ }
116+ }
117+ let array = b. finish ( ) ;
118+ RecordBatch :: try_new ( schema, vec ! [ Arc :: new( array) ] ) . unwrap ( )
65119}
66120
67121fn config ( ) -> Criterion {
0 commit comments