Skip to content

Commit c73ac2e

Browse files
authored
feat: implement cast from whole numbers to binary format and bool to decimal (#3083)
1 parent b0477a7 commit c73ac2e

3 files changed

Lines changed: 163 additions & 82 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
// under the License.
1717

1818
use crate::utils::array_with_timezone;
19+
use crate::EvalMode::Legacy;
1920
use crate::{timezone, BinaryOutputStyle};
2021
use crate::{EvalMode, SparkError, SparkResult};
2122
use arrow::array::builder::StringBuilder;
2223
use arrow::array::{
23-
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
24+
BinaryBuilder, BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
2425
PrimitiveBuilder, StringArray, StructArray, TimestampMicrosecondBuilder,
2526
};
2627
use arrow::compute::can_cast_types;
@@ -304,29 +305,32 @@ fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> b
304305

305306
fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
306307
use DataType::*;
307-
matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
308+
matches!(
309+
to_type,
310+
Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
311+
)
308312
}
309313

310314
fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
311315
use DataType::*;
312316
matches!(
313317
to_type,
314-
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
318+
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) | Binary
315319
)
316320
}
317321

318322
fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
319323
use DataType::*;
320324
matches!(
321325
to_type,
322-
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _)
326+
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) | Binary
323327
)
324328
}
325329

326330
fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
327331
use DataType::*;
328332
match to_type {
329-
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true,
333+
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 | Binary => true,
330334
Decimal128(_, _) => {
331335
// incompatible: no overflow check
332336
options.allow_incompat
@@ -338,7 +342,7 @@ fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
338342
fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
339343
use DataType::*;
340344
match to_type {
341-
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
345+
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Binary => true,
342346
Decimal128(_, _) => {
343347
// incompatible: no overflow check
344348
options.allow_incompat
@@ -501,6 +505,29 @@ macro_rules! cast_float_to_string {
501505
}};
502506
}
503507

508+
// eval mode is not needed since all ints can be implemented in binary format
509+
macro_rules! cast_whole_num_to_binary {
510+
($array:expr, $primitive_type:ty, $byte_size:expr) => {{
511+
let input_arr = $array
512+
.as_any()
513+
.downcast_ref::<$primitive_type>()
514+
.ok_or_else(|| SparkError::Internal("Expected numeric array".to_string()))?;
515+
516+
let len = input_arr.len();
517+
let mut builder = BinaryBuilder::with_capacity(len, len * $byte_size);
518+
519+
for i in 0..input_arr.len() {
520+
if input_arr.is_null(i) {
521+
builder.append_null();
522+
} else {
523+
builder.append_value(input_arr.value(i).to_be_bytes());
524+
}
525+
}
526+
527+
Ok(Arc::new(builder.finish()) as ArrayRef)
528+
}};
529+
}
530+
504531
macro_rules! cast_int_to_int_macro {
505532
(
506533
$array: expr,
@@ -1101,6 +1128,19 @@ fn cast_array(
11011128
}
11021129
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
11031130
(Date32, Timestamp(_, tz)) => Ok(cast_date_to_timestamp(&array, cast_options, tz)?),
1131+
(Int8, Binary) if (eval_mode == Legacy) => cast_whole_num_to_binary!(&array, Int8Array, 1),
1132+
(Int16, Binary) if (eval_mode == Legacy) => {
1133+
cast_whole_num_to_binary!(&array, Int16Array, 2)
1134+
}
1135+
(Int32, Binary) if (eval_mode == Legacy) => {
1136+
cast_whole_num_to_binary!(&array, Int32Array, 4)
1137+
}
1138+
(Int64, Binary) if (eval_mode == Legacy) => {
1139+
cast_whole_num_to_binary!(&array, Int64Array, 8)
1140+
}
1141+
(Boolean, Decimal128(precision, scale)) => {
1142+
cast_boolean_to_decimal(&array, *precision, *scale)
1143+
}
11041144
_ if cast_options.is_adapting_schema
11051145
|| is_datafusion_spark_compatible(from_type, to_type) =>
11061146
{
@@ -1163,6 +1203,16 @@ fn cast_date_to_timestamp(
11631203
))
11641204
}
11651205

1206+
fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> SparkResult<ArrayRef> {
1207+
let bool_array = array.as_boolean();
1208+
let scaled_val = 10_i128.pow(scale as u32);
1209+
let result: Decimal128Array = bool_array
1210+
.iter()
1211+
.map(|v| v.map(|b| if b { scaled_val } else { 0 }))
1212+
.collect();
1213+
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
1214+
}
1215+
11661216
fn cast_string_to_float(
11671217
array: &ArrayRef,
11681218
to_type: &DataType,

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,13 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
147147
case (DataTypes.BooleanType, _) =>
148148
canCastFromBoolean(toType)
149149
case (DataTypes.ByteType, _) =>
150-
canCastFromByte(toType)
150+
canCastFromByte(toType, evalMode)
151151
case (DataTypes.ShortType, _) =>
152-
canCastFromShort(toType)
152+
canCastFromShort(toType, evalMode)
153153
case (DataTypes.IntegerType, _) =>
154-
canCastFromInt(toType)
154+
canCastFromInt(toType, evalMode)
155155
case (DataTypes.LongType, _) =>
156-
canCastFromLong(toType)
156+
canCastFromLong(toType, evalMode)
157157
case (DataTypes.FloatType, _) =>
158158
canCastFromFloat(toType)
159159
case (DataTypes.DoubleType, _) =>
@@ -264,58 +264,68 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
264264

265265
private def canCastFromBoolean(toType: DataType): SupportLevel = toType match {
266266
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType |
267-
DataTypes.FloatType | DataTypes.DoubleType =>
267+
DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
268268
Compatible()
269269
case _ => unsupported(DataTypes.BooleanType, toType)
270270
}
271271

272-
private def canCastFromByte(toType: DataType): SupportLevel = toType match {
273-
case DataTypes.BooleanType =>
274-
Compatible()
275-
case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType =>
276-
Compatible()
277-
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
278-
Compatible()
279-
case _ =>
280-
unsupported(DataTypes.ByteType, toType)
281-
}
272+
private def canCastFromByte(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel =
273+
toType match {
274+
case DataTypes.BooleanType =>
275+
Compatible()
276+
case DataTypes.ShortType | DataTypes.IntegerType | DataTypes.LongType =>
277+
Compatible()
278+
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
279+
Compatible()
280+
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
281+
Compatible()
282+
case _ =>
283+
unsupported(DataTypes.ByteType, toType)
284+
}
282285

283-
private def canCastFromShort(toType: DataType): SupportLevel = toType match {
284-
case DataTypes.BooleanType =>
285-
Compatible()
286-
case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType =>
287-
Compatible()
288-
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
289-
Compatible()
290-
case _ =>
291-
unsupported(DataTypes.ShortType, toType)
292-
}
286+
private def canCastFromShort(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel =
287+
toType match {
288+
case DataTypes.BooleanType =>
289+
Compatible()
290+
case DataTypes.ByteType | DataTypes.IntegerType | DataTypes.LongType =>
291+
Compatible()
292+
case DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
293+
Compatible()
294+
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
295+
Compatible()
296+
case _ =>
297+
unsupported(DataTypes.ShortType, toType)
298+
}
293299

294-
private def canCastFromInt(toType: DataType): SupportLevel = toType match {
295-
case DataTypes.BooleanType =>
296-
Compatible()
297-
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType =>
298-
Compatible()
299-
case DataTypes.FloatType | DataTypes.DoubleType =>
300-
Compatible()
301-
case _: DecimalType =>
302-
Compatible()
303-
case _ =>
304-
unsupported(DataTypes.IntegerType, toType)
305-
}
300+
private def canCastFromInt(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel =
301+
toType match {
302+
case DataTypes.BooleanType =>
303+
Compatible()
304+
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.LongType =>
305+
Compatible()
306+
case DataTypes.FloatType | DataTypes.DoubleType =>
307+
Compatible()
308+
case _: DecimalType =>
309+
Compatible()
310+
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
311+
case _ =>
312+
unsupported(DataTypes.IntegerType, toType)
313+
}
306314

307-
private def canCastFromLong(toType: DataType): SupportLevel = toType match {
308-
case DataTypes.BooleanType =>
309-
Compatible()
310-
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType =>
311-
Compatible()
312-
case DataTypes.FloatType | DataTypes.DoubleType =>
313-
Compatible()
314-
case _: DecimalType =>
315-
Compatible()
316-
case _ =>
317-
unsupported(DataTypes.LongType, toType)
318-
}
315+
private def canCastFromLong(toType: DataType, evalMode: CometEvalMode.Value): SupportLevel =
316+
toType match {
317+
case DataTypes.BooleanType =>
318+
Compatible()
319+
case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType =>
320+
Compatible()
321+
case DataTypes.FloatType | DataTypes.DoubleType =>
322+
Compatible()
323+
case _: DecimalType =>
324+
Compatible()
325+
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
326+
case _ =>
327+
unsupported(DataTypes.LongType, toType)
328+
}
319329

320330
private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
321331
case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType |

0 commit comments

Comments
 (0)