Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ use datafusion::{
use datafusion_comet_spark_expr::{
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle,
BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc,
SumInteger, ToCsv,
SparkBloomFilterVersion, SumInteger, ToCsv,
};
use datafusion_spark::function::aggregate::collect::SparkCollectSet;
use iceberg::expr::Bind;
Expand Down Expand Up @@ -2287,10 +2287,17 @@ impl PhysicalPlanner {
let num_bits =
self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let version = match expr.version() {
spark_expression::BloomFilterVersion::V2 => SparkBloomFilterVersion::V2,
// Default (Unspecified or V1) preserves the pre-Spark-4.1 format that
// Comet has always emitted, keeping older Spark versions byte-equivalent.
_ => SparkBloomFilterVersion::V1,
};
let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
Arc::clone(&num_items),
Arc::clone(&num_bits),
datatype,
version,
));
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
}
Expand Down
11 changes: 11 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,17 @@ message BloomFilterAgg {
Expr numItems = 2;
Expr numBits = 3;
DataType datatype = 4;
// Output serialization version. Spark 4.0 and earlier always wrote V1; Spark
// 4.1+ defaults to V2 (different bit-scattering algorithm and a `seed` field
// in the binary format). The JVM serde sets this to the matching version so
// Comet's aggregate output is byte-equivalent with Spark's.
BloomFilterVersion version = 5;
}

enum BloomFilterVersion {
BLOOM_FILTER_VERSION_UNSPECIFIED = 0;
BLOOM_FILTER_VERSION_V1 = 1;
BLOOM_FILTER_VERSION_V2 = 2;
}

message CollectSet {
Expand Down
3 changes: 2 additions & 1 deletion native/spark-expr/benches/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::physical_expr::expressions::{Column, Literal};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::ExecutionPlan;
use datafusion_comet_spark_expr::BloomFilterAgg;
use datafusion_comet_spark_expr::{BloomFilterAgg, SparkBloomFilterVersion};
use futures::StreamExt;
use std::hint::black_box;
use std::sync::Arc;
Expand Down Expand Up @@ -66,6 +66,7 @@ fn criterion_benchmark(c: &mut Criterion) {
Arc::clone(&num_items),
Arc::clone(&num_bits),
DataType::Binary,
SparkBloomFilterVersion::V1,
)));
b.to_async(&rt).iter(|| {
black_box(agg_test(
Expand Down
15 changes: 12 additions & 3 deletions native/spark-expr/src/bloom_filter/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use std::{any::Any, sync::Arc};

use crate::bloom_filter::spark_bloom_filter;
use crate::bloom_filter::spark_bloom_filter::SparkBloomFilter;
use crate::bloom_filter::spark_bloom_filter::{SparkBloomFilter, SparkBloomFilterVersion};

use arrow::array::ArrayRef;
use arrow::array::BinaryArray;
Expand All @@ -37,6 +37,10 @@ pub struct BloomFilterAgg {
signature: Signature,
num_items: i32,
num_bits: i32,
/// Output serialization version. Spark <= 4.0 only knows V1; Spark 4.1+'s
/// `BloomFilter.create` defaults to V2, so the JVM serde sets this to V2 on
/// 4.1+ to keep `bloom_filter_agg` byte-equivalent with Spark's aggregator.
version: SparkBloomFilterVersion,
}

#[inline]
Expand All @@ -54,6 +58,7 @@ impl BloomFilterAgg {
num_items: Arc<dyn PhysicalExpr>,
num_bits: Arc<dyn PhysicalExpr>,
data_type: DataType,
version: SparkBloomFilterVersion,
) -> Self {
assert!(matches!(data_type, DataType::Binary));
Self {
Expand All @@ -70,6 +75,7 @@ impl BloomFilterAgg {
),
num_items: extract_i32_from_literal(num_items),
num_bits: extract_i32_from_literal(num_bits),
version,
}
}
}
Expand All @@ -92,10 +98,13 @@ impl AggregateUDFImpl for BloomFilterAgg {
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SparkBloomFilter::from((
Ok(Box::new(SparkBloomFilter::new(
self.version,
spark_bloom_filter::optimal_num_hash_functions(self.num_items, self.num_bits),
self.num_bits,
))))
// Spark's BloomFilterAggregate always uses BloomFilterImplV2.DEFAULT_SEED (= 0).
0,
)))
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Expand Down
1 change: 1 addition & 0 deletions native/spark-expr/src/bloom_filter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod bit;

mod spark_bit_array;
mod spark_bloom_filter;
pub use spark_bloom_filter::SparkBloomFilterVersion;

pub mod bloom_filter_agg;
pub use bloom_filter_might_contain::BloomFilterMightContain;
Expand Down
Loading
Loading