Skip to content

Commit e77c9bd

Browse files
hsiang-cclaudeandygroveparthchandra
authored
fix: add scalar support for bit_count expression (#3361)
* fix: add scalar support for bit_count expression The bit_count function now handles scalar inputs in addition to arrays. Scalar inputs return scalar outputs, maintaining proper type semantics. Enable bit_count tests in bitwise.sql Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> * Update spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql Co-authored-by: Andy Grove <agrove@apache.org> * Revert conf --------- Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com> Co-authored-by: Andy Grove <agrove@apache.org> Co-authored-by: Parth Chandra <parthc@apache.org>
1 parent c2a6b8a commit e77c9bd

2 files changed

Lines changed: 52 additions & 7 deletions

File tree

native/spark-expr/src/bitwise_funcs/bitwise_count.rs

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use arrow::{array::*, datatypes::DataType};
19-
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, Result};
19+
use datafusion::common::{exec_err, internal_datafusion_err, Result};
2020
use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
2121
use datafusion::{error::DataFusionError, logical_expr::ColumnarValue};
2222
use std::any::Any;
@@ -99,15 +99,38 @@ pub fn spark_bit_count(args: [ColumnarValue; 1]) -> Result<ColumnarValue> {
9999
DataType::Int16 => compute_op!(array, Int16Array),
100100
DataType::Int32 => compute_op!(array, Int32Array),
101101
DataType::Int64 => compute_op!(array, Int64Array),
102-
_ => exec_err!("bit_count can't be evaluated because the expression's type is {:?}, not signed int", array.data_type()),
102+
_ => exec_err!("bit_count can't be evaluated because the array's type is {:?}, not signed int/boolean", array.data_type()),
103103
};
104104
result.map(ColumnarValue::Array)
105105
}
106-
[ColumnarValue::Scalar(_)] => internal_err!("shouldn't go to bitwise count scalar path"),
106+
[ColumnarValue::Scalar(scalar)] => {
107+
use datafusion::common::ScalarValue;
108+
let result = match scalar {
109+
ScalarValue::Int8(Some(v)) => bit_count(v as i64),
110+
ScalarValue::Int16(Some(v)) => bit_count(v as i64),
111+
ScalarValue::Int32(Some(v)) => bit_count(v as i64),
112+
ScalarValue::Int64(Some(v)) => bit_count(v),
113+
ScalarValue::Boolean(Some(v)) => bit_count(if v { 1 } else { 0 }),
114+
ScalarValue::Int8(None)
115+
| ScalarValue::Int16(None)
116+
| ScalarValue::Int32(None)
117+
| ScalarValue::Int64(None)
118+
| ScalarValue::Boolean(None) => {
119+
return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None)))
120+
}
121+
_ => {
122+
return exec_err!(
123+
"bit_count can't be evaluated because the scalar's type is {:?}, not signed int/boolean",
124+
scalar.data_type()
125+
)
126+
}
127+
};
128+
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result))))
129+
}
107130
}
108131
}
109132

110-
// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType)
133+
// Here’s the equivalent Rust implementation of the bitCount function (similar to Java's bitCount for LongType)
111134
fn bit_count(i: i64) -> i32 {
112135
let mut u = i as u64;
113136
u = u - ((u >> 1) & 0x5555555555555555);
@@ -121,7 +144,7 @@ fn bit_count(i: i64) -> i32 {
121144

122145
#[cfg(test)]
123146
mod tests {
124-
use datafusion::common::{cast::as_int32_array, Result};
147+
use datafusion::common::{cast::as_int32_array, Result, ScalarValue};
125148

126149
use super::*;
127150

@@ -133,8 +156,18 @@ mod tests {
133156
Some(12345),
134157
Some(89),
135158
Some(-3456),
159+
Some(i32::MIN),
160+
Some(i32::MAX),
136161
])));
137-
let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]);
162+
let expected = &Int32Array::from(vec![
163+
Some(1),
164+
None,
165+
Some(6),
166+
Some(4),
167+
Some(54),
168+
Some(33),
169+
Some(31),
170+
]);
138171

139172
let ColumnarValue::Array(result) = spark_bit_count([args])? else {
140173
unreachable!()
@@ -145,4 +178,16 @@ mod tests {
145178

146179
Ok(())
147180
}
181+
182+
#[test]
183+
fn bitwise_count_scalar() {
184+
let args = ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MAX)));
185+
186+
match spark_bit_count([args]) {
187+
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(actual)))) => {
188+
assert_eq!(actual, 63)
189+
}
190+
_ => unreachable!(),
191+
}
192+
}
148193
}

spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ SELECT bit_get(i, pos) FROM test_bit_get
7373
query
7474
SELECT 1111 & 2, 1111 | 2, 1111 ^ 2
7575

76-
query ignore(https://github.com/apache/datafusion-comet/issues/3341)
76+
query
7777
SELECT bit_count(0), bit_count(7), bit_count(-1)
7878

7979
query spark_answer_only

0 commit comments

Comments
 (0)