Skip to content

Commit 3e9f850

Browse files
authored
chore: Cast module refactor boolean module (apache#3491)
1 parent 219859b commit 3e9f850

6 files changed

Lines changed: 430 additions & 127 deletions

File tree

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,7 @@ harness = false
9999
[[test]]
100100
name = "test_udf_registration"
101101
path = "tests/spark_expr_reg.rs"
102+
103+
[[bench]]
104+
name = "cast_from_boolean"
105+
harness = false
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{BooleanBuilder, RecordBatch};
19+
use arrow::datatypes::{DataType, Field, Schema};
20+
use criterion::{criterion_group, criterion_main, Criterion};
21+
use datafusion::physical_expr::expressions::Column;
22+
use datafusion::physical_expr::PhysicalExpr;
23+
use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
24+
use std::sync::Arc;
25+
26+
fn criterion_benchmark(c: &mut Criterion) {
27+
let expr = Arc::new(Column::new("a", 0));
28+
let boolean_batch = create_boolean_batch();
29+
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
30+
let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
31+
let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
32+
let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
33+
let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options.clone());
34+
let cast_to_f32 = Cast::new(expr.clone(), DataType::Float32, spark_cast_options.clone());
35+
let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64, spark_cast_options.clone());
36+
let cast_to_str = Cast::new(expr.clone(), DataType::Utf8, spark_cast_options.clone());
37+
let cast_to_decimal = Cast::new(expr, DataType::Decimal128(10, 4), spark_cast_options);
38+
39+
let mut group = c.benchmark_group("cast_bool".to_string());
40+
group.bench_function("i8", |b| {
41+
b.iter(|| cast_to_i8.evaluate(&boolean_batch).unwrap());
42+
});
43+
group.bench_function("i16", |b| {
44+
b.iter(|| cast_to_i16.evaluate(&boolean_batch).unwrap());
45+
});
46+
group.bench_function("i32", |b| {
47+
b.iter(|| cast_to_i32.evaluate(&boolean_batch).unwrap());
48+
});
49+
group.bench_function("i64", |b| {
50+
b.iter(|| cast_to_i64.evaluate(&boolean_batch).unwrap());
51+
});
52+
group.bench_function("f32", |b| {
53+
b.iter(|| cast_to_f32.evaluate(&boolean_batch).unwrap());
54+
});
55+
group.bench_function("f64", |b| {
56+
b.iter(|| cast_to_f64.evaluate(&boolean_batch).unwrap());
57+
});
58+
group.bench_function("str", |b| {
59+
b.iter(|| cast_to_str.evaluate(&boolean_batch).unwrap());
60+
});
61+
group.bench_function("decimal", |b| {
62+
b.iter(|| cast_to_decimal.evaluate(&boolean_batch).unwrap());
63+
});
64+
}
65+
66+
fn create_boolean_batch() -> RecordBatch {
67+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)]));
68+
let mut b = BooleanBuilder::with_capacity(1000);
69+
for i in 0..1000 {
70+
if i % 10 == 0 {
71+
b.append_null();
72+
} else {
73+
b.append_value(rand::random::<bool>());
74+
}
75+
}
76+
let array = b.finish();
77+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
78+
}
79+
80+
fn config() -> Criterion {
81+
Criterion::default()
82+
}
83+
84+
criterion_group! {
85+
name = benches;
86+
config = config();
87+
targets = criterion_benchmark
88+
}
89+
criterion_main!(benches);
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::SparkResult;
19+
use arrow::array::{ArrayRef, AsArray, Decimal128Array};
20+
use arrow::datatypes::DataType;
21+
use std::sync::Arc;
22+
23+
pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool {
24+
use DataType::*;
25+
matches!(
26+
to_type,
27+
Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8
28+
)
29+
}
30+
31+
// only DF incompatible boolean cast
32+
pub fn cast_boolean_to_decimal(
33+
array: &ArrayRef,
34+
precision: u8,
35+
scale: i8,
36+
) -> SparkResult<ArrayRef> {
37+
let bool_array = array.as_boolean();
38+
let scaled_val = 10_i128.pow(scale as u32);
39+
let result: Decimal128Array = bool_array
40+
.iter()
41+
.map(|v| v.map(|b| if b { scaled_val } else { 0 }))
42+
.collect();
43+
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
44+
}
45+
46+
#[cfg(test)]
47+
mod tests {
48+
use super::*;
49+
use crate::cast::cast_array;
50+
use crate::{EvalMode, SparkCastOptions};
51+
use arrow::array::{
52+
Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
53+
Int64Array, Int8Array, StringArray,
54+
};
55+
use arrow::datatypes::DataType::Decimal128;
56+
use std::sync::Arc;
57+
58+
fn test_input_bool_array() -> ArrayRef {
59+
Arc::new(BooleanArray::from(vec![Some(true), Some(false), None]))
60+
}
61+
62+
fn test_input_spark_opts() -> SparkCastOptions {
63+
SparkCastOptions::new(EvalMode::Legacy, "Asia/Kolkata", false)
64+
}
65+
66+
#[test]
67+
fn test_is_df_cast_from_bool_spark_compatible() {
68+
assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Boolean));
69+
assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int8));
70+
assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int16));
71+
assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int32));
72+
assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int64));
73+
assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float32));
74+
assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float64));
75+
assert!(is_df_cast_from_bool_spark_compatible(&DataType::Utf8));
76+
assert!(!is_df_cast_from_bool_spark_compatible(
77+
&DataType::Decimal128(10, 4)
78+
));
79+
assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Null));
80+
}
81+
82+
#[test]
83+
fn test_bool_to_int8_cast() {
84+
let result = cast_array(
85+
test_input_bool_array(),
86+
&DataType::Int8,
87+
&test_input_spark_opts(),
88+
)
89+
.unwrap();
90+
let arr = result.as_any().downcast_ref::<Int8Array>().unwrap();
91+
assert_eq!(arr.value(0), 1);
92+
assert_eq!(arr.value(1), 0);
93+
assert!(arr.is_null(2));
94+
}
95+
96+
#[test]
97+
fn test_bool_to_int16_cast() {
98+
let result = cast_array(
99+
test_input_bool_array(),
100+
&DataType::Int16,
101+
&test_input_spark_opts(),
102+
)
103+
.unwrap();
104+
let arr = result.as_any().downcast_ref::<Int16Array>().unwrap();
105+
assert_eq!(arr.value(0), 1);
106+
assert_eq!(arr.value(1), 0);
107+
assert!(arr.is_null(2));
108+
}
109+
110+
#[test]
111+
fn test_bool_to_int32_cast() {
112+
let result = cast_array(
113+
test_input_bool_array(),
114+
&DataType::Int32,
115+
&test_input_spark_opts(),
116+
)
117+
.unwrap();
118+
let arr = result.as_any().downcast_ref::<Int32Array>().unwrap();
119+
assert_eq!(arr.value(0), 1);
120+
assert_eq!(arr.value(1), 0);
121+
assert!(arr.is_null(2));
122+
}
123+
124+
#[test]
125+
fn test_bool_to_int64_cast() {
126+
let result = cast_array(
127+
test_input_bool_array(),
128+
&DataType::Int64,
129+
&test_input_spark_opts(),
130+
)
131+
.unwrap();
132+
let arr = result.as_any().downcast_ref::<Int64Array>().unwrap();
133+
assert_eq!(arr.value(0), 1);
134+
assert_eq!(arr.value(1), 0);
135+
assert!(arr.is_null(2));
136+
}
137+
138+
#[test]
139+
fn test_bool_to_float32_cast() {
140+
let result = cast_array(
141+
test_input_bool_array(),
142+
&DataType::Float32,
143+
&test_input_spark_opts(),
144+
)
145+
.unwrap();
146+
let arr = result.as_any().downcast_ref::<Float32Array>().unwrap();
147+
assert_eq!(arr.value(0), 1.0);
148+
assert_eq!(arr.value(1), 0.0);
149+
assert!(arr.is_null(2));
150+
}
151+
152+
#[test]
153+
fn test_bool_to_float64_cast() {
154+
let result = cast_array(
155+
test_input_bool_array(),
156+
&DataType::Float64,
157+
&test_input_spark_opts(),
158+
)
159+
.unwrap();
160+
let arr = result.as_any().downcast_ref::<Float64Array>().unwrap();
161+
assert_eq!(arr.value(0), 1.0);
162+
assert_eq!(arr.value(1), 0.0);
163+
assert!(arr.is_null(2));
164+
}
165+
166+
#[test]
167+
fn test_bool_to_string_cast() {
168+
let result = cast_array(
169+
test_input_bool_array(),
170+
&DataType::Utf8,
171+
&test_input_spark_opts(),
172+
)
173+
.unwrap();
174+
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
175+
assert_eq!(arr.value(0), "true");
176+
assert_eq!(arr.value(1), "false");
177+
assert!(arr.is_null(2));
178+
}
179+
180+
#[test]
181+
fn test_bool_to_decimal_cast() {
182+
let result = cast_array(
183+
test_input_bool_array(),
184+
&Decimal128(10, 4),
185+
&test_input_spark_opts(),
186+
)
187+
.unwrap();
188+
let expected_arr = Decimal128Array::from(vec![10000_i128, 0_i128])
189+
.with_precision_and_scale(10, 4)
190+
.unwrap();
191+
let arr = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
192+
assert_eq!(arr.value(0), expected_arr.value(0));
193+
assert_eq!(arr.value(1), expected_arr.value(1));
194+
assert!(arr.is_null(2));
195+
}
196+
}

0 commit comments

Comments
 (0)