Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 242 additions & 0 deletions native/core/src/execution/merge_as_partial.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! MergeAsPartial wrapper for implementing Spark's PartialMerge aggregate mode.
//!
//! Spark's PartialMerge mode merges intermediate state buffers and outputs intermediate
//! state (not final values). DataFusion has no equivalent mode — `Partial` calls
//! `update_batch` and outputs state, while `Final` calls `merge_batch` and outputs
//! evaluated results.
//!
//! This wrapper bridges the gap: it operates under DataFusion's `Partial` mode (which
//! outputs state) but redirects `update_batch` calls to `merge_batch`, giving merge
//! semantics with state output.

use std::any::Any;
use std::fmt::Debug;
use std::hash::{Hash, Hasher};

use arrow::array::{ArrayRef, BooleanArray};
use arrow::datatypes::{DataType, FieldRef};
use datafusion::common::Result;
use datafusion::logical_expr::function::AccumulatorArgs;
use datafusion::logical_expr::function::StateFieldsArgs;
use datafusion::logical_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF,
Signature, Volatility,
};
use datafusion::physical_expr::aggregate::AggregateFunctionExpr;
use datafusion::scalar::ScalarValue;

/// An AggregateUDF wrapper that gives merge semantics in Partial mode.
///
/// When DataFusion runs an AggregateExec in Partial mode, it calls `update_batch`
/// on each accumulator and outputs `state()`. This wrapper intercepts `update_batch`
/// and redirects it to `merge_batch` on the inner accumulator, effectively
/// implementing PartialMerge: merge inputs, output state.
///
/// We store the inner AggregateUDF (not the AggregateFunctionExpr) to avoid keeping
/// references to UnboundColumn expressions that would panic if evaluated.
#[derive(Debug)]
pub struct MergeAsPartialUDF {
/// The inner aggregate UDF, cloned from the original expression.
inner_udf: AggregateUDF,
/// Pre-computed return type from the original expression.
return_type: DataType,
/// Pre-computed state fields from the original expression.
cached_state_fields: Vec<FieldRef>,
/// Cached signature that accepts state field types.
signature: Signature,
/// Name for this wrapper.
name: String,
}

impl PartialEq for MergeAsPartialUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}

impl Eq for MergeAsPartialUDF {}

impl Hash for MergeAsPartialUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}

impl MergeAsPartialUDF {
pub fn new(inner_expr: &AggregateFunctionExpr) -> Result<Self> {
let name = format!("merge_as_partial_{}", inner_expr.name());
let return_type = inner_expr.field().data_type().clone();
let cached_state_fields = inner_expr.state_fields()?;

// Use a permissive signature since we accept state field types which
// vary per aggregate function.
let signature = Signature::variadic_any(Volatility::Immutable);

Ok(Self {
inner_udf: inner_expr.fun().clone(),
return_type,
cached_state_fields,
signature,
name,
})
}
}

impl AggregateUDFImpl for MergeAsPartialUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
// In Partial mode, return_type isn't used for output schema (state_fields is).
// Return the inner function's return type for consistency.
Ok(self.return_type.clone())
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
// State fields must match the inner aggregate's state fields so that
// the output of this PartialMerge stage is compatible with subsequent
// Final or PartialMerge stages.
Ok(self.cached_state_fields.clone())
}

fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// Create the inner accumulator using the provided args (which have the
// correct Column refs, not UnboundColumns).
let inner_acc = self.inner_udf.accumulator(args)?;
Ok(Box::new(MergeAsPartialAccumulator { inner: inner_acc }))
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
self.inner_udf.groups_accumulator_supported(args)
}

fn create_groups_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
let inner_acc = self.inner_udf.create_groups_accumulator(args)?;
Ok(Box::new(MergeAsPartialGroupsAccumulator {
inner: inner_acc,
}))
}

fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::NotSupported
}

fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
ScalarValue::try_from(data_type)
}

fn is_descending(&self) -> Option<bool> {
None
}
}

/// Accumulator wrapper that redirects update_batch to merge_batch.
struct MergeAsPartialAccumulator {
inner: Box<dyn Accumulator>,
}

impl Debug for MergeAsPartialAccumulator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MergeAsPartialAccumulator").finish()
}
}

impl Accumulator for MergeAsPartialAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
// Redirect update to merge — this is the key trick.
self.inner.merge_batch(values)
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.inner.merge_batch(states)
}

fn evaluate(&mut self) -> Result<ScalarValue> {
self.inner.evaluate()
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.inner.state()
}

fn size(&self) -> usize {
self.inner.size()
}
}

/// GroupsAccumulator wrapper that redirects update_batch to merge_batch.
struct MergeAsPartialGroupsAccumulator {
inner: Box<dyn GroupsAccumulator>,
}

impl Debug for MergeAsPartialGroupsAccumulator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MergeAsPartialGroupsAccumulator").finish()
}
}

impl GroupsAccumulator for MergeAsPartialGroupsAccumulator {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
// Redirect update to merge — this is the key trick.
self.inner
.merge_batch(values, group_indices, opt_filter, total_num_groups)
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
self.inner
.merge_batch(values, group_indices, opt_filter, total_num_groups)
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
self.inner.evaluate(emit_to)
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
self.inner.state(emit_to)
}

fn size(&self) -> usize {
self.inner.size()
}
}
1 change: 1 addition & 0 deletions native/core/src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
pub mod columnar_to_row;
pub mod expressions;
pub mod jni_api;
pub(crate) mod merge_as_partial;
pub(crate) mod metrics;
pub mod operators;
pub(crate) mod planner;
Expand Down
81 changes: 76 additions & 5 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -967,19 +967,90 @@ impl PhysicalPlanner {
let group_by = PhysicalGroupBy::new_single(group_exprs?);
let schema = child.schema();

let mode = if agg.mode == 0 {
DFAggregateMode::Partial
} else {
DFAggregateMode::Final
let mode = match agg.mode {
0 => DFAggregateMode::Partial,
1 => DFAggregateMode::Final,
2 => DFAggregateMode::Partial, // PartialMerge: Partial + MergeAsPartial
other => {
return Err(ExecutionError::GeneralError(format!(
"Unsupported aggregate mode: {other}"
)))
}
};

// Check if any expression uses PartialMerge mode (2). When present,
// those expressions are wrapped with MergeAsPartial to get merge
// semantics inside a Partial-mode AggregateExec.
let has_partial_merge = agg.mode == 2 || agg.expr_modes.contains(&2);

let agg_exprs: PhyAggResult = agg
.agg_exprs
.iter()
.map(|expr| self.create_agg_expr(expr, Arc::clone(&schema)))
.collect();

let aggr_expr = agg_exprs?.into_iter().map(Arc::new).collect();
let aggr_expr: Vec<Arc<AggregateFunctionExpr>> = if has_partial_merge {
// Wrap PartialMerge expressions with MergeAsPartial.
// State fields in the child's output start at initial_input_buffer_offset.
let mut state_offset = agg.initial_input_buffer_offset as usize;
let per_expr_modes: Vec<i32> = if !agg.expr_modes.is_empty() {
agg.expr_modes.clone()
} else {
vec![agg.mode; agg.agg_exprs.len()]
};

agg_exprs?
.into_iter()
.enumerate()
.map(|(idx, expr)| {
if per_expr_modes[idx] == 2 {
// PartialMerge: wrap with MergeAsPartial
let state_fields = expr
.state_fields()
.map_err(|e| ExecutionError::GeneralError(e.to_string()))?;
let num_state_fields = state_fields.len();

let state_cols: Vec<Arc<dyn PhysicalExpr>> = (0..num_state_fields)
.map(|i| {
let col_idx = state_offset + i;
let field = schema.field(col_idx);
Arc::new(Column::new(field.name(), col_idx))
as Arc<dyn PhysicalExpr>
})
.collect();
state_offset += num_state_fields;

let merge_udf =
crate::execution::merge_as_partial::MergeAsPartialUDF::new(
&expr,
)
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))?;
let merge_udf_arc = Arc::new(
datafusion::logical_expr::AggregateUDF::new_from_impl(
merge_udf,
),
);

let merge_expr =
AggregateExprBuilder::new(merge_udf_arc, state_cols)
.schema(Arc::clone(&schema))
.alias(format!("col_{idx}"))
.with_ignore_nulls(expr.ignore_nulls())
.with_distinct(expr.is_distinct())
.build()
.map_err(|e| {
ExecutionError::DataFusionError(e.to_string())
})?;

Ok(Arc::new(merge_expr))
} else {
Ok(Arc::new(expr))
}
})
.collect::<Result<Vec<_>, ExecutionError>>()?
} else {
agg_exprs?.into_iter().map(Arc::new).collect()
};

// Build per-aggregate filter expressions from the FILTER (WHERE ...) clause.
// Filters are only present in Partial mode; Final/PartialMerge always get None.
Expand Down
8 changes: 8 additions & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,13 @@ message HashAggregate {
repeated spark.spark_expression.AggExpr agg_exprs = 2;
repeated spark.spark_expression.Expr result_exprs = 3;
AggregateMode mode = 5;
// Per-expression modes for mixed-mode aggregates (e.g., PartialMerge + Partial).
// When set, each entry corresponds to agg_exprs at the same index.
// When empty, all expressions use the `mode` field.
repeated AggregateMode expr_modes = 6;
// Offset in the child's output where aggregate buffer attributes start.
// Used by PartialMerge to locate state fields in the input.
int32 initial_input_buffer_offset = 7;
}

message Limit {
Expand Down Expand Up @@ -319,6 +326,7 @@ message ParquetWriter {
enum AggregateMode {
Partial = 0;
Final = 1;
PartialMerge = 2;
}

message Expand {
Expand Down
16 changes: 8 additions & 8 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,15 @@ object QueryPlanSerde extends Logging with CometExprShim {
binding: Boolean,
conf: SQLConf): Option[AggExpr] = {

// Support Count(distinct single_value)
// COUNT(DISTINCT x) - supported
// COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x)
// COUNT(DISTINCT x, y) - not supported
// Distinct aggregates with a single column are supported (e.g., COUNT(DISTINCT x),
// SUM(DISTINCT x), AVG(DISTINCT x)). The multi-stage plan generated by Spark
// guarantees distinct semantics through grouping — the native side does not need
// to handle deduplication.
// Multi-column distinct is only supported for COUNT (e.g., COUNT(DISTINCT x, y)).
if (aggExpr.isDistinct
&&
!(aggExpr.aggregateFunction.prettyName == "count" &&
aggExpr.aggregateFunction.children.length == 1)) {
withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr")
&& aggExpr.aggregateFunction.children.length > 1
&& aggExpr.aggregateFunction.prettyName != "count") {
withInfo(aggExpr, s"Multi-column distinct aggregate not supported for: $aggExpr")
return None
}

Expand Down
Loading
Loading