diff --git a/native/core/src/execution/merge_as_partial.rs b/native/core/src/execution/merge_as_partial.rs new file mode 100644 index 0000000000..dc671a22d8 --- /dev/null +++ b/native/core/src/execution/merge_as_partial.rs @@ -0,0 +1,242 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! MergeAsPartial wrapper for implementing Spark's PartialMerge aggregate mode. +//! +//! Spark's PartialMerge mode merges intermediate state buffers and outputs intermediate +//! state (not final values). DataFusion has no equivalent mode — `Partial` calls +//! `update_batch` and outputs state, while `Final` calls `merge_batch` and outputs +//! evaluated results. +//! +//! This wrapper bridges the gap: it operates under DataFusion's `Partial` mode (which +//! outputs state) but redirects `update_batch` calls to `merge_batch`, giving merge +//! semantics with state output. + +use std::any::Any; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion::common::Result; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::function::StateFieldsArgs; +use datafusion::logical_expr::{ + Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, +}; +use datafusion::physical_expr::aggregate::AggregateFunctionExpr; +use datafusion::scalar::ScalarValue; + +/// An AggregateUDF wrapper that gives merge semantics in Partial mode. +/// +/// When DataFusion runs an AggregateExec in Partial mode, it calls `update_batch` +/// on each accumulator and outputs `state()`. This wrapper intercepts `update_batch` +/// and redirects it to `merge_batch` on the inner accumulator, effectively +/// implementing PartialMerge: merge inputs, output state. +/// +/// We store the inner AggregateUDF (not the AggregateFunctionExpr) to avoid keeping +/// references to UnboundColumn expressions that would panic if evaluated. +#[derive(Debug)] +pub struct MergeAsPartialUDF { + /// The inner aggregate UDF, cloned from the original expression. + inner_udf: AggregateUDF, + /// Pre-computed return type from the original expression. + return_type: DataType, + /// Pre-computed state fields from the original expression. + cached_state_fields: Vec, + /// Cached signature that accepts state field types. + signature: Signature, + /// Name for this wrapper. + name: String, +} + +impl PartialEq for MergeAsPartialUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } +} + +impl Eq for MergeAsPartialUDF {} + +impl Hash for MergeAsPartialUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + } +} + +impl MergeAsPartialUDF { + pub fn new(inner_expr: &AggregateFunctionExpr) -> Result { + let name = format!("merge_as_partial_{}", inner_expr.name()); + let return_type = inner_expr.field().data_type().clone(); + let cached_state_fields = inner_expr.state_fields()?; + + // Use a permissive signature since we accept state field types which + // vary per aggregate function. + let signature = Signature::variadic_any(Volatility::Immutable); + + Ok(Self { + inner_udf: inner_expr.fun().clone(), + return_type, + cached_state_fields, + signature, + name, + }) + } +} + +impl AggregateUDFImpl for MergeAsPartialUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // In Partial mode, return_type isn't used for output schema (state_fields is). + // Return the inner function's return type for consistency. + Ok(self.return_type.clone()) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + // State fields must match the inner aggregate's state fields so that + // the output of this PartialMerge stage is compatible with subsequent + // Final or PartialMerge stages. + Ok(self.cached_state_fields.clone()) + } + + fn accumulator(&self, args: AccumulatorArgs) -> Result> { + // Create the inner accumulator using the provided args (which have the + // correct Column refs, not UnboundColumns). + let inner_acc = self.inner_udf.accumulator(args)?; + Ok(Box::new(MergeAsPartialAccumulator { inner: inner_acc })) + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + self.inner_udf.groups_accumulator_supported(args) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let inner_acc = self.inner_udf.create_groups_accumulator(args)?; + Ok(Box::new(MergeAsPartialGroupsAccumulator { + inner: inner_acc, + })) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::NotSupported + } + + fn default_value(&self, data_type: &DataType) -> Result { + ScalarValue::try_from(data_type) + } + + fn is_descending(&self) -> Option { + None + } +} + +/// Accumulator wrapper that redirects update_batch to merge_batch. +struct MergeAsPartialAccumulator { + inner: Box, +} + +impl Debug for MergeAsPartialAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MergeAsPartialAccumulator").finish() + } +} + +impl Accumulator for MergeAsPartialAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Redirect update to merge — this is the key trick. + self.inner.merge_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.inner.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + self.inner.evaluate() + } + + fn state(&mut self) -> Result> { + self.inner.state() + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +/// GroupsAccumulator wrapper that redirects update_batch to merge_batch. +struct MergeAsPartialGroupsAccumulator { + inner: Box, +} + +impl Debug for MergeAsPartialGroupsAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MergeAsPartialGroupsAccumulator").finish() + } +} + +impl GroupsAccumulator for MergeAsPartialGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // Redirect update to merge — this is the key trick. + self.inner + .merge_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.inner + .merge_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + self.inner.evaluate(emit_to) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + self.inner.state(emit_to) + } + + fn size(&self) -> usize { + self.inner.size() + } +} diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index f556fce41c..ec247f72b7 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -19,6 +19,7 @@ pub mod columnar_to_row; pub mod expressions; pub mod jni_api; +pub(crate) mod merge_as_partial; pub(crate) mod metrics; pub mod operators; pub(crate) mod planner; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 176104a3a5..55a09501d7 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -967,19 +967,90 @@ impl PhysicalPlanner { let group_by = PhysicalGroupBy::new_single(group_exprs?); let schema = child.schema(); - let mode = if agg.mode == 0 { - DFAggregateMode::Partial - } else { - DFAggregateMode::Final + let mode = match agg.mode { + 0 => DFAggregateMode::Partial, + 1 => DFAggregateMode::Final, + 2 => DFAggregateMode::Partial, // PartialMerge: Partial + MergeAsPartial + other => { + return Err(ExecutionError::GeneralError(format!( + "Unsupported aggregate mode: {other}" + ))) + } }; + // Check if any expression uses PartialMerge mode (2). When present, + // those expressions are wrapped with MergeAsPartial to get merge + // semantics inside a Partial-mode AggregateExec. + let has_partial_merge = agg.mode == 2 || agg.expr_modes.contains(&2); + let agg_exprs: PhyAggResult = agg .agg_exprs .iter() .map(|expr| self.create_agg_expr(expr, Arc::clone(&schema))) .collect(); - let aggr_expr = agg_exprs?.into_iter().map(Arc::new).collect(); + let aggr_expr: Vec> = if has_partial_merge { + // Wrap PartialMerge expressions with MergeAsPartial. + // State fields in the child's output start at initial_input_buffer_offset. + let mut state_offset = agg.initial_input_buffer_offset as usize; + let per_expr_modes: Vec = if !agg.expr_modes.is_empty() { + agg.expr_modes.clone() + } else { + vec![agg.mode; agg.agg_exprs.len()] + }; + + agg_exprs? + .into_iter() + .enumerate() + .map(|(idx, expr)| { + if per_expr_modes[idx] == 2 { + // PartialMerge: wrap with MergeAsPartial + let state_fields = expr + .state_fields() + .map_err(|e| ExecutionError::GeneralError(e.to_string()))?; + let num_state_fields = state_fields.len(); + + let state_cols: Vec> = (0..num_state_fields) + .map(|i| { + let col_idx = state_offset + i; + let field = schema.field(col_idx); + Arc::new(Column::new(field.name(), col_idx)) + as Arc + }) + .collect(); + state_offset += num_state_fields; + + let merge_udf = + crate::execution::merge_as_partial::MergeAsPartialUDF::new( + &expr, + ) + .map_err(|e| ExecutionError::DataFusionError(e.to_string()))?; + let merge_udf_arc = Arc::new( + datafusion::logical_expr::AggregateUDF::new_from_impl( + merge_udf, + ), + ); + + let merge_expr = + AggregateExprBuilder::new(merge_udf_arc, state_cols) + .schema(Arc::clone(&schema)) + .alias(format!("col_{idx}")) + .with_ignore_nulls(expr.ignore_nulls()) + .with_distinct(expr.is_distinct()) + .build() + .map_err(|e| { + ExecutionError::DataFusionError(e.to_string()) + })?; + + Ok(Arc::new(merge_expr)) + } else { + Ok(Arc::new(expr)) + } + }) + .collect::, ExecutionError>>()? + } else { + agg_exprs?.into_iter().map(Arc::new).collect() + }; // Build per-aggregate filter expressions from the FILTER (WHERE ...) clause. // Filters are only present in Partial mode; Final/PartialMerge always get None. diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index fb438b26a4..a4b29c28f4 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -270,6 +270,13 @@ message HashAggregate { repeated spark.spark_expression.AggExpr agg_exprs = 2; repeated spark.spark_expression.Expr result_exprs = 3; AggregateMode mode = 5; + // Per-expression modes for mixed-mode aggregates (e.g., PartialMerge + Partial). + // When set, each entry corresponds to agg_exprs at the same index. + // When empty, all expressions use the `mode` field. + repeated AggregateMode expr_modes = 6; + // Offset in the child's output where aggregate buffer attributes start. + // Used by PartialMerge to locate state fields in the input. + int32 initial_input_buffer_offset = 7; } message Limit { @@ -319,6 +326,7 @@ message ParquetWriter { enum AggregateMode { Partial = 0; Final = 1; + PartialMerge = 2; } message Expand { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b74785bd1f..4cc1f2c662 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -461,15 +461,15 @@ object QueryPlanSerde extends Logging with CometExprShim { binding: Boolean, conf: SQLConf): Option[AggExpr] = { - // Support Count(distinct single_value) - // COUNT(DISTINCT x) - supported - // COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x) - // COUNT(DISTINCT x, y) - not supported + // Distinct aggregates with a single column are supported (e.g., COUNT(DISTINCT x), + // SUM(DISTINCT x), AVG(DISTINCT x)). The multi-stage plan generated by Spark + // guarantees distinct semantics through grouping — the native side does not need + // to handle deduplication. + // Multi-column distinct is only supported for COUNT (e.g., COUNT(DISTINCT x, y)). if (aggExpr.isDistinct - && - !(aggExpr.aggregateFunction.prettyName == "count" && - aggExpr.aggregateFunction.children.length == 1)) { - withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr") + && aggExpr.aggregateFunction.children.length > 1 + && aggExpr.aggregateFunction.prettyName != "count") { + withInfo(aggExpr, s"Multi-column distinct aggregate not supported for: $aggExpr") return None } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a6f6b03330..775b83fa58 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -30,7 +30,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectSet, Final, Partial, PartialMerge} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectSet, Final, First, Last, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -1356,8 +1356,11 @@ trait CometBaseAggregate { childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { val modes = aggregate.aggregateExpressions.map(_.mode).distinct - // In distinct aggregates there can be a combination of modes - val multiMode = modes.size > 1 + val modeSet = modes.toSet + val hasPartialMerge = modeSet.contains(PartialMerge) + // In distinct aggregates there can be a combination of modes. + // We support {Partial, PartialMerge} mix; other combinations are rejected. + val multiMode = modes.size > 1 && modeSet != Set(Partial, PartialMerge) // For a final mode HashAggregate, we only need to transform the HashAggregate // if there is Comet partial aggregation. val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty @@ -1429,33 +1432,57 @@ trait CometBaseAggregate { hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) Some(builder.setHashAgg(hashAggBuilder).build()) } else { - val modes = aggregateExpressions.map(_.mode).distinct - - if (modes.size != 1) { - // This shouldn't happen as all aggregation expressions should share the same mode. - // Fallback to Spark nevertheless here. - withInfo(aggregate, "All aggregate expressions do not have the same mode") + // Validate mode combinations. We support: + // - All Partial + // - All Final + // - All PartialMerge + // - Mixed {Partial, PartialMerge} (for distinct aggregate plans) + val isMixedPartialMerge = modeSet == Set(Partial, PartialMerge) + if (modes.size > 1 && !isMixedPartialMerge) { + withInfo(aggregate, s"Unsupported mixed aggregation modes: ${modes.mkString(", ")}") return None } - val mode = modes.head match { - case Partial => CometAggregateMode.Partial - case Final => CometAggregateMode.Final - case _ => - withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}") + // Determine the proto mode. For uniform modes, use that mode directly. + // For mixed {Partial, PartialMerge}, use Partial as the base mode since + // PartialMerge expressions are wrapped with MergeAsPartial on the native side. + val mode = if (isMixedPartialMerge) { + CometAggregateMode.Partial + } else { + modes.head match { + case Partial => CometAggregateMode.Partial + case Final => CometAggregateMode.Final + case PartialMerge => CometAggregateMode.PartialMerge + case _ => + withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}") + return None + } + } + + // FIRST/LAST are order-dependent aggregates whose merge result depends on + // hash table processing order. In PartialMerge mode, DataFusion's hash table + // may process rows in a different order than Spark's, producing different results. + if (hasPartialMerge) { + val unsupportedAggs = aggregateExpressions.filter { a => + a.mode == PartialMerge && (a.aggregateFunction.isInstanceOf[First] || + a.aggregateFunction.isInstanceOf[Last]) + } + if (unsupportedAggs.nonEmpty) { + withInfo( + aggregate, + "PartialMerge not supported for order-dependent aggregates: " + + unsupportedAggs.map(_.aggregateFunction.prettyName).mkString(", ")) return None + } } - // In final mode, the aggregate expressions are bound to the output of the - // child and partial aggregate expressions buffer attributes produced by partial - // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, - // we don't have to do this because we don't use the merging expression. - val binding = mode != CometAggregateMode.Final - // `output` is only used when `binding` is true (i.e., non-Final) + // Per-expression binding: Partial expressions bind to child output, + // PartialMerge/Final expressions do not (native planner handles their input). val output = child.output - - val aggExprs = - aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf)) + val aggExprs = aggregateExpressions.map { a => + val exprBinding = a.mode != PartialMerge && a.mode != Final + aggExprToProto(a, output, exprBinding, aggregate.conf) + } if (aggExprs.exists(_.isEmpty)) { withInfo( @@ -1483,6 +1510,23 @@ trait CometBaseAggregate { hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) } hashAggBuilder.setModeValue(mode.getNumber) + + // Send per-expression modes and buffer offset for PartialMerge handling + if (hasPartialMerge) { + val exprModes = aggregateExpressions.map { a => + a.mode match { + case Partial => CometAggregateMode.Partial + case PartialMerge => CometAggregateMode.PartialMerge + case Final => CometAggregateMode.Final + case other => + withInfo(aggregate, s"Unsupported aggregation mode $other") + return None + } + } + hashAggBuilder.addAllExprModes(exprModes.asJava) + hashAggBuilder.setInitialInputBufferOffset(aggregate.initialInputBufferOffset) + } + Some(builder.setHashAgg(hashAggBuilder).build()) } else { val allChildren: Seq[Expression] = @@ -1496,14 +1540,21 @@ trait CometBaseAggregate { /** * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with - * partial mode, it will return None. + * partial or partial-merge mode, it will return None. */ private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = { + def isPartialOrMerge(mode: AggregateMode): Boolean = + mode == Partial || mode == PartialMerge + plan.collectFirst { - case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => + case agg: CometHashAggregateExec + if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) => Some(agg) - case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None - case agg: ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => + case agg: HashAggregateExec + if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) => + None + case agg: ObjectHashAggregateExec + if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) => None case a: AQEShuffleReadExec => findCometPartialAgg(a.child) case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan) @@ -1600,7 +1651,7 @@ object CometObjectHashAggregateExec * case branch here mapping it to the native state type. */ private def adjustOutputForNativeState(op: ObjectHashAggregateExec): Seq[Attribute] = { - // CometBaseAggregate.doConvert guarantees all expressions share the same mode. + // This adjustment only applies to pure-Partial aggregates (checked below). val modes = op.aggregateExpressions.map(_.mode).distinct if (modes != Seq(Partial)) { return op.output @@ -1642,11 +1693,8 @@ case class CometHashAggregateExec( // The aggExprs could be empty. For example, if the aggregate functions only have // distinct aggregate functions or only have group by, the aggExprs is empty and - // modes is empty too. If aggExprs is not empty, we need to verify all the - // aggregates have the same mode. + // modes is empty too. val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct - assert(modes.length == 1 || modes.isEmpty) - val mode = modes.headOption override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(resultExpressions) @@ -1663,7 +1711,7 @@ case class CometHashAggregateExec( } override def stringArgs: Iterator[Any] = - Iterator(input, mode, groupingExpressions, aggregateExpressions, child) + Iterator(input, modes, groupingExpressions, aggregateExpressions, child) override def equals(obj: Any): Boolean = { obj match { @@ -1672,7 +1720,7 @@ case class CometHashAggregateExec( this.groupingExpressions == other.groupingExpressions && this.aggregateExpressions == other.aggregateExpressions && this.input == other.input && - this.mode == other.mode && + this.modes == other.modes && this.child == other.child && this.serializedPlanOpt == other.serializedPlanOpt case _ => @@ -1681,7 +1729,7 @@ case class CometHashAggregateExec( } override def hashCode(): Int = - Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, mode, child) + Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, modes, child) override protected def outputExpressions: Seq[NamedExpression] = resultExpressions } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 95f3774e01..484be72b1c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -638,10 +638,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { dictionaryEnabled) { withView("v") { sql("CREATE TEMP VIEW v AS SELECT _1, _2 FROM tbl ORDER BY _1") - checkSparkAnswerAndFallbackReason( + checkSparkAnswerAndOperator( "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1)," + - " COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2", - "Unsupported aggregation mode PartialMerge") + " COUNT(DISTINCT _1), AVG(_1)" + + " FROM v GROUP BY _2 ORDER BY _2") } } } @@ -650,6 +650,83 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + // FIRST/LAST are order-dependent aggregates whose merge result depends on hash table + // processing order. In PartialMerge mode, DataFusion's hash table may process rows + // in a different order than Spark's, so we fall back to Spark for correctness. + test("partialMerge - FIRST/LAST with distinct aggregates falls back") { + val numValues = 10000 + Seq(100).foreach { numGroups => + Seq(128).foreach { batchSize => + withSQLConf( + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + CometConf.COMET_BATCH_SIZE.key -> batchSize.toString) { + withParquetTable( + (0 until numValues).map(i => (i, Random.nextInt() % numGroups)), + "tbl", + false) { + withView("v") { + sql("CREATE TEMP VIEW v AS SELECT _1, _2 FROM tbl ORDER BY _1") + checkSparkAnswerAndFallbackReason( + "SELECT _2, FIRST(_1), LAST(_1), COUNT(DISTINCT _1)" + + " FROM v GROUP BY _2 ORDER BY _2", + "PartialMerge not supported for order-dependent aggregates") + } + } + } + } + } + } + + test("partialMerge - cnt distinct + sum") { + withTempDir(dir => { + withSQLConf("spark.comet.enabled" -> "false") { + sql(""" + CREATE OR REPLACE TEMP VIEW t (v, v1, i) AS + VALUES + ('c', 'a', 1), + ('c1', 'a1', 1), + ('c2', 'a2', 2), + ('c3', 'a3', 2), + ('c4', 'a4', 2), + ('c', 'a', 1), + ('c1', 'a1', 1), + ('c2', 'a2', 2), + ('c3', 'a3', 2), + ('c4', 'a4', 2), + ('c', 'a', 1), + ('c1', 'a1', 1), + ('c2', 'a2', 2), + ('c3', 'a3', 2), + ('c4', 'a4', 2) + """) + sql("select * from t") + .repartition(3) + .write + .mode("overwrite") + .parquet(dir.getAbsolutePath) + } + + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + "spark.comet.exec.shuffle.fallbackToColumnar" -> "false", + "spark.comet.cast.allowIncompatible" -> "true", + "spark.sql.adaptive.enabled" -> "false", + "spark.comet.explain.native.enabled" -> "true", + "spark.comet.enabled" -> "true", + "spark.comet.expression.Cast.allowIncompatible" -> "true", + "spark.comet.exec.shuffle.enableFastEncoding" -> "true", + "spark.comet.exec.shuffle.enabled" -> "true", + "spark.comet.explainFallback.enabled" -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_iceberg_compat", + "spark.shuffle.manager" -> + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager", + "spark.comet.logFallbackReasons.enabled" -> "true") { + spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("t2") + checkSparkAnswerAndOperator("SELECT i, sum(v1), count(distinct v) FROM t2 group by i") + } + }) + } + test("multiple group-by columns + single aggregate column (first/last), with nulls") { val numValues = 10000 diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 22983119bb..6c8420ad2c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -484,9 +484,11 @@ class CometExecSuite extends CometTestBase { case s: CometHashAggregateExec => s }.get - assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode]) + assert(agg.modes.nonEmpty && agg.modes.headOption.get.isInstanceOf[AggregateMode]) val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec] - assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode]) + assert( + newAgg.modes.nonEmpty && + newAgg.modes.headOption.get.isInstanceOf[AggregateMode]) } }