From f12ce3e33bc1c4ab2f9f85a63fcaf1326883e268 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 20 Apr 2026 10:19:37 -0700 Subject: [PATCH 1/8] feat: support `PartialMerge` --- native/core/src/execution/planner.rs | 15 ++++-- native/proto/src/proto/operator.proto | 1 + .../apache/comet/serde/QueryPlanSerde.scala | 16 +++--- .../apache/spark/sql/comet/operators.scala | 38 ++++++++------ .../comet/exec/CometAggregateSuite.scala | 52 ++++++++++++++++++- .../apache/comet/exec/CometExecSuite.scala | 7 ++- 6 files changed, 98 insertions(+), 31 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 176104a3a5..b5dbc4a26a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -967,10 +967,17 @@ impl PhysicalPlanner { let group_by = PhysicalGroupBy::new_single(group_exprs?); let schema = child.schema(); - let mode = if agg.mode == 0 { - DFAggregateMode::Partial - } else { - DFAggregateMode::Final + let mode = match agg.mode { + 0 => DFAggregateMode::Partial, + // Both Final and PartialMerge use merge semantics in DataFusion. + // The output difference (final values vs intermediate buffers) is + // handled by the presence/absence of result_exprs. + 1 | 2 => DFAggregateMode::Final, + other => { + return Err(ExecutionError::GeneralError(format!( + "Unsupported aggregate mode: {other}" + ))) + } }; let agg_exprs: PhyAggResult = agg diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index fb438b26a4..1f94bd72ef 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -319,6 +319,7 @@ message ParquetWriter { enum AggregateMode { Partial = 0; Final = 1; + PartialMerge = 2; } message Expand { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b74785bd1f..4cc1f2c662 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -461,15 +461,15 @@ object QueryPlanSerde extends Logging with CometExprShim { binding: Boolean, conf: SQLConf): Option[AggExpr] = { - // Support Count(distinct single_value) - // COUNT(DISTINCT x) - supported - // COUNT(DISTINCT x, x) - supported through transition to COUNT(DISTINCT x) - // COUNT(DISTINCT x, y) - not supported + // Distinct aggregates with a single column are supported (e.g., COUNT(DISTINCT x), + // SUM(DISTINCT x), AVG(DISTINCT x)). The multi-stage plan generated by Spark + // guarantees distinct semantics through grouping — the native side does not need + // to handle deduplication. + // Multi-column distinct is only supported for COUNT (e.g., COUNT(DISTINCT x, y)). if (aggExpr.isDistinct - && - !(aggExpr.aggregateFunction.prettyName == "count" && - aggExpr.aggregateFunction.children.length == 1)) { - withInfo(aggExpr, s"Distinct aggregate not supported for: $aggExpr") + && aggExpr.aggregateFunction.children.length > 1 + && aggExpr.aggregateFunction.prettyName != "count") { + withInfo(aggExpr, s"Multi-column distinct aggregate not supported for: $aggExpr") return None } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a6f6b03330..341f74bce8 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1441,16 +1441,18 @@ trait CometBaseAggregate { val mode = modes.head match { case Partial => CometAggregateMode.Partial case Final => CometAggregateMode.Final + case PartialMerge => CometAggregateMode.PartialMerge case _ => withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}") return None } - // In final mode, the aggregate expressions are bound to the output of the - // child and partial aggregate expressions buffer attributes produced by partial - // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, - // we don't have to do this because we don't use the merging expression. - val binding = mode != CometAggregateMode.Final + // In final/partial-merge mode, the aggregate expressions are bound to the + // output of the child and partial aggregate expressions buffer attributes + // produced by partial aggregation. This is done in Spark `HashAggregateExec` + // internally. In Comet, we don't have to do this because we don't use the + // merging expression. + val binding = mode == CometAggregateMode.Partial // `output` is only used when `binding` is true (i.e., non-Final) val output = child.output @@ -1496,14 +1498,21 @@ trait CometBaseAggregate { /** * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with - * partial mode, it will return None. + * partial or partial-merge mode, it will return None. */ private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = { + def isPartialOrMerge(mode: AggregateMode): Boolean = + mode == Partial || mode == PartialMerge + plan.collectFirst { - case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => + case agg: CometHashAggregateExec + if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) => Some(agg) - case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None - case agg: ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => + case agg: HashAggregateExec + if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) => + None + case agg: ObjectHashAggregateExec + if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) => None case a: AQEShuffleReadExec => findCometPartialAgg(a.child) case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan) @@ -1642,11 +1651,8 @@ case class CometHashAggregateExec( // The aggExprs could be empty. For example, if the aggregate functions only have // distinct aggregate functions or only have group by, the aggExprs is empty and - // modes is empty too. If aggExprs is not empty, we need to verify all the - // aggregates have the same mode. + // modes is empty too. val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct - assert(modes.length == 1 || modes.isEmpty) - val mode = modes.headOption override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(resultExpressions) @@ -1663,7 +1669,7 @@ case class CometHashAggregateExec( } override def stringArgs: Iterator[Any] = - Iterator(input, mode, groupingExpressions, aggregateExpressions, child) + Iterator(input, modes, groupingExpressions, aggregateExpressions, child) override def equals(obj: Any): Boolean = { obj match { @@ -1672,7 +1678,7 @@ case class CometHashAggregateExec( this.groupingExpressions == other.groupingExpressions && this.aggregateExpressions == other.aggregateExpressions && this.input == other.input && - this.mode == other.mode && + this.modes == other.modes && this.child == other.child && this.serializedPlanOpt == other.serializedPlanOpt case _ => @@ -1681,7 +1687,7 @@ case class CometHashAggregateExec( } override def hashCode(): Int = - Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, mode, child) + Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, modes, child) override protected def outputExpressions: Seq[NamedExpression] = resultExpressions } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 95f3774e01..2a1b98bcd9 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -641,7 +641,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndFallbackReason( "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1)," + " COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2", - "Unsupported aggregation mode PartialMerge") + "All aggregate expressions do not have the same mode") } } } @@ -650,6 +650,56 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("partialMerge - cnt distinct + sum") { + withTempDir(dir => { + withSQLConf("spark.comet.enabled" -> "false") { + sql(""" + CREATE OR REPLACE TEMP VIEW t (v, v1, i) AS + VALUES + ('c', 'a', 1), + ('c1', 'a1', 1), + ('c2', 'a2', 2), + ('c3', 'a3', 2), + ('c4', 'a4', 2), + ('c', 'a', 1), + ('c1', 'a1', 1), + ('c2', 'a2', 2), + ('c3', 'a3', 2), + ('c4', 'a4', 2), + ('c', 'a', 1), + ('c1', 'a1', 1), + ('c2', 'a2', 2), + ('c3', 'a3', 2), + ('c4', 'a4', 2) + """) + sql("select * from t") + .repartition(3) + .write + .mode("overwrite") + .parquet(dir.getAbsolutePath) + } + + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + "spark.comet.exec.shuffle.fallbackToColumnar" -> "false", + "spark.comet.cast.allowIncompatible" -> "true", + "spark.sql.adaptive.enabled" -> "false", + "spark.comet.explain.native.enabled" -> "true", + "spark.comet.enabled" -> "true", + "spark.comet.expression.Cast.allowIncompatible" -> "true", + "spark.comet.exec.shuffle.enableFastEncoding" -> "true", + "spark.comet.exec.shuffle.enabled" -> "true", + "spark.comet.explainFallback.enabled" -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_iceberg_compat", + "spark.shuffle.manager" -> + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager", + "spark.comet.logFallbackReasons.enabled" -> "true") { + spark.read.parquet(dir.getAbsolutePath).createOrReplaceTempView("t2") + checkSparkAnswerAndOperator("SELECT i, sum(v1), count(distinct v) FROM t2 group by i") + } + }) + } + test("multiple group-by columns + single aggregate column (first/last), with nulls") { val numValues = 10000 diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 22983119bb..2f82753066 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -484,9 +484,12 @@ class CometExecSuite extends CometTestBase { case s: CometHashAggregateExec => s }.get - assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode]) + assert( + agg.modes.nonEmpty && agg.modes.headOption.get.isInstanceOf[AggregateMode]) val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec] - assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode]) + assert( + newAgg.modes.nonEmpty && + newAgg.modes.headOption.get.isInstanceOf[AggregateMode]) } } From 10709c4cad689324efa06f1585d818e321aad9b4 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 20 Apr 2026 10:23:10 -0700 Subject: [PATCH 2/8] feat: support `PartialMerge` --- .../src/test/scala/org/apache/comet/exec/CometExecSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 2f82753066..6c8420ad2c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -484,8 +484,7 @@ class CometExecSuite extends CometTestBase { case s: CometHashAggregateExec => s }.get - assert( - agg.modes.nonEmpty && agg.modes.headOption.get.isInstanceOf[AggregateMode]) + assert(agg.modes.nonEmpty && agg.modes.headOption.get.isInstanceOf[AggregateMode]) val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec] assert( newAgg.modes.nonEmpty && From 0c0327acbd7c9df35c7c5e844d25b198290a6f91 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 20 Apr 2026 12:22:53 -0700 Subject: [PATCH 3/8] feat: support `PartialMerge` --- native/core/src/execution/mod.rs | 1 + native/core/src/execution/planner.rs | 73 ++++++++++++++++-- native/proto/src/proto/operator.proto | 7 ++ .../apache/spark/sql/comet/operators.scala | 75 +++++++++++++------ .../comet/exec/CometAggregateSuite.scala | 5 +- 5 files changed, 129 insertions(+), 32 deletions(-) diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index f556fce41c..ec247f72b7 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -19,6 +19,7 @@ pub mod columnar_to_row; pub mod expressions; pub mod jni_api; +pub(crate) mod merge_as_partial; pub(crate) mod metrics; pub mod operators; pub(crate) mod planner; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index b5dbc4a26a..70401b1579 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -969,10 +969,8 @@ impl PhysicalPlanner { let mode = match agg.mode { 0 => DFAggregateMode::Partial, - // Both Final and PartialMerge use merge semantics in DataFusion. - // The output difference (final values vs intermediate buffers) is - // handled by the presence/absence of result_exprs. - 1 | 2 => DFAggregateMode::Final, + 1 => DFAggregateMode::Final, + 2 => DFAggregateMode::Partial, // PartialMerge uses Partial + MergeAsPartial other => { return Err(ExecutionError::GeneralError(format!( "Unsupported aggregate mode: {other}" @@ -980,13 +978,78 @@ impl PhysicalPlanner { } }; + // Determine per-expression modes. PartialMerge (2) expressions use + // MergeAsPartial wrapper so they run merge semantics in Partial mode. + let per_expr_modes: Vec = if !agg.expr_modes.is_empty() { + agg.expr_modes.clone() + } else { + vec![agg.mode; agg.agg_exprs.len()] + }; + + let has_partial_merge = per_expr_modes.contains(&2); + let agg_exprs: PhyAggResult = agg .agg_exprs .iter() .map(|expr| self.create_agg_expr(expr, Arc::clone(&schema))) .collect(); - let aggr_expr = agg_exprs?.into_iter().map(Arc::new).collect(); + let aggr_expr: Vec> = if has_partial_merge { + // Wrap PartialMerge expressions with MergeAsPartial. + // State fields in the child's output start at initial_input_buffer_offset. + let mut state_offset = agg.initial_input_buffer_offset as usize; + let child_schema = child.schema(); + + agg_exprs? + .into_iter() + .enumerate() + .map(|(idx, expr)| { + let expr_mode = per_expr_modes[idx]; + if expr_mode == 2 { + // PartialMerge: wrap with MergeAsPartial + let state_fields = expr.state_fields().map_err(|e| { + ExecutionError::GeneralError(e.to_string()) + })?; + let num_state_fields = state_fields.len(); + + // Create Column refs pointing to state field positions + let state_cols: Vec> = (0..num_state_fields) + .map(|i| { + let col_idx = state_offset + i; + let field = child_schema.field(col_idx); + Arc::new(Column::new(field.name(), col_idx)) + as Arc + }) + .collect(); + state_offset += num_state_fields; + + let merge_udf = crate::execution::merge_as_partial::MergeAsPartialUDF::new(&expr) + .map_err(|e| ExecutionError::DataFusionError(e.to_string()))?; + let merge_udf_arc = Arc::new( + datafusion::logical_expr::AggregateUDF::new_from_impl(merge_udf), + ); + + let merge_expr = AggregateExprBuilder::new( + merge_udf_arc, + state_cols, + ) + .schema(Arc::clone(&child_schema)) + .alias(format!("col_{idx}")) + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .map_err(|e| ExecutionError::DataFusionError(e.to_string()))?; + + Ok(Arc::new(merge_expr)) + } else { + // Partial: use as-is + Ok(Arc::new(expr)) + } + }) + .collect::, ExecutionError>>()? + } else { + agg_exprs?.into_iter().map(Arc::new).collect() + }; // Build per-aggregate filter expressions from the FILTER (WHERE ...) clause. // Filters are only present in Partial mode; Final/PartialMerge always get None. diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 1f94bd72ef..a4b29c28f4 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -270,6 +270,13 @@ message HashAggregate { repeated spark.spark_expression.AggExpr agg_exprs = 2; repeated spark.spark_expression.Expr result_exprs = 3; AggregateMode mode = 5; + // Per-expression modes for mixed-mode aggregates (e.g., PartialMerge + Partial). + // When set, each entry corresponds to agg_exprs at the same index. + // When empty, all expressions use the `mode` field. + repeated AggregateMode expr_modes = 6; + // Offset in the child's output where aggregate buffer attributes start. + // Used by PartialMerge to locate state fields in the input. + int32 initial_input_buffer_offset = 7; } message Limit { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 341f74bce8..f9d514865a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1356,8 +1356,9 @@ trait CometBaseAggregate { childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { val modes = aggregate.aggregateExpressions.map(_.mode).distinct - // In distinct aggregates there can be a combination of modes - val multiMode = modes.size > 1 + // In distinct aggregates there can be a combination of modes. + // We support {Partial, PartialMerge} mix; other combinations are rejected. + val multiMode = modes.size > 1 && modes.toSet != Set(Partial, PartialMerge) // For a final mode HashAggregate, we only need to transform the HashAggregate // if there is Comet partial aggregation. val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty @@ -1430,34 +1431,42 @@ trait CometBaseAggregate { Some(builder.setHashAgg(hashAggBuilder).build()) } else { val modes = aggregateExpressions.map(_.mode).distinct - - if (modes.size != 1) { - // This shouldn't happen as all aggregation expressions should share the same mode. - // Fallback to Spark nevertheless here. - withInfo(aggregate, "All aggregate expressions do not have the same mode") + val modeSet = modes.toSet + + // Validate mode combinations. We support: + // - All Partial + // - All Final + // - All PartialMerge + // - Mixed {Partial, PartialMerge} (for distinct aggregate plans) + val isMixedPartialMerge = modeSet == Set(Partial, PartialMerge) + if (modes.size > 1 && !isMixedPartialMerge) { + withInfo(aggregate, s"Unsupported mixed aggregation modes: ${modes.mkString(", ")}") return None } - val mode = modes.head match { - case Partial => CometAggregateMode.Partial - case Final => CometAggregateMode.Final - case PartialMerge => CometAggregateMode.PartialMerge - case _ => - withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}") - return None + // Determine the proto mode. For uniform modes, use that mode directly. + // For mixed {Partial, PartialMerge}, use Partial as the base mode since + // PartialMerge expressions are wrapped with MergeAsPartial on the native side. + val mode = if (isMixedPartialMerge) { + CometAggregateMode.Partial + } else { + modes.head match { + case Partial => CometAggregateMode.Partial + case Final => CometAggregateMode.Final + case PartialMerge => CometAggregateMode.PartialMerge + case _ => + withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}") + return None + } } - // In final/partial-merge mode, the aggregate expressions are bound to the - // output of the child and partial aggregate expressions buffer attributes - // produced by partial aggregation. This is done in Spark `HashAggregateExec` - // internally. In Comet, we don't have to do this because we don't use the - // merging expression. - val binding = mode == CometAggregateMode.Partial - // `output` is only used when `binding` is true (i.e., non-Final) + // Per-expression binding: Partial expressions bind to child output, + // PartialMerge/Final expressions do not (native planner handles their input). val output = child.output - - val aggExprs = - aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf)) + val aggExprs = aggregateExpressions.map { a => + val exprBinding = a.mode != PartialMerge && a.mode != Final + aggExprToProto(a, output, exprBinding, aggregate.conf) + } if (aggExprs.exists(_.isEmpty)) { withInfo( @@ -1485,6 +1494,24 @@ trait CometBaseAggregate { hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) } hashAggBuilder.setModeValue(mode.getNumber) + + // Send per-expression modes and buffer offset for PartialMerge handling + val hasPartialMerge = aggregateExpressions.exists(_.mode == PartialMerge) + if (hasPartialMerge) { + val exprModes = aggregateExpressions.map { a => + a.mode match { + case Partial => CometAggregateMode.Partial + case PartialMerge => CometAggregateMode.PartialMerge + case Final => CometAggregateMode.Final + case other => + withInfo(aggregate, s"Unsupported aggregation mode $other") + return None + } + } + hashAggBuilder.addAllExprModes(exprModes.asJava) + hashAggBuilder.setInitialInputBufferOffset(aggregate.initialInputBufferOffset) + } + Some(builder.setHashAgg(hashAggBuilder).build()) } else { val allChildren: Seq[Expression] = diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 2a1b98bcd9..0251e26f3d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -638,10 +638,9 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { dictionaryEnabled) { withView("v") { sql("CREATE TEMP VIEW v AS SELECT _1, _2 FROM tbl ORDER BY _1") - checkSparkAnswerAndFallbackReason( + checkSparkAnswerAndOperator( "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1)," + - " COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2", - "All aggregate expressions do not have the same mode") + " COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2") } } } From 336aadd074b15d3abe03dbb39d09f36d4d58df19 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 20 Apr 2026 13:27:01 -0700 Subject: [PATCH 4/8] feat: support `PartialMerge` --- native/core/src/execution/planner.rs | 4 +-- .../apache/spark/sql/comet/operators.scala | 18 +++++++++++ .../comet/exec/CometAggregateSuite.scala | 30 ++++++++++++++++++- 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 70401b1579..d07de8af91 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1035,8 +1035,8 @@ impl PhysicalPlanner { ) .schema(Arc::clone(&child_schema)) .alias(format!("col_{idx}")) - .with_ignore_nulls(false) - .with_distinct(false) + .with_ignore_nulls(expr.ignore_nulls()) + .with_distinct(expr.is_distinct()) .build() .map_err(|e| ExecutionError::DataFusionError(e.to_string()))?; diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index f9d514865a..4ce9e5c5c0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1460,6 +1460,24 @@ trait CometBaseAggregate { } } + // FIRST/LAST are order-dependent aggregates whose merge result depends on + // hash table processing order. In PartialMerge mode, DataFusion's hash table + // may process rows in a different order than Spark's, producing different results. + val hasPartialMergeMode = modeSet.contains(PartialMerge) + if (hasPartialMergeMode) { + val unsupportedAggs = aggregateExpressions.filter { a => + a.mode == PartialMerge && (a.aggregateFunction.isInstanceOf[First] || + a.aggregateFunction.isInstanceOf[Last]) + } + if (unsupportedAggs.nonEmpty) { + withInfo( + aggregate, + s"PartialMerge not supported for order-dependent aggregates: " + + unsupportedAggs.map(_.aggregateFunction.prettyName).mkString(", ")) + return None + } + } + // Per-expression binding: Partial expressions bind to child output, // PartialMerge/Final expressions do not (native planner handles their input). val output = child.output diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 0251e26f3d..484be72b1c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -640,7 +640,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { sql("CREATE TEMP VIEW v AS SELECT _1, _2 FROM tbl ORDER BY _1") checkSparkAnswerAndOperator( "SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1)," + - " COUNT(DISTINCT _1), AVG(_1), FIRST(_1), LAST(_1) FROM v GROUP BY _2") + " COUNT(DISTINCT _1), AVG(_1)" + + " FROM v GROUP BY _2 ORDER BY _2") } } } @@ -649,6 +650,33 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + // FIRST/LAST are order-dependent aggregates whose merge result depends on hash table + // processing order. In PartialMerge mode, DataFusion's hash table may process rows + // in a different order than Spark's, so we fall back to Spark for correctness. + test("partialMerge - FIRST/LAST with distinct aggregates falls back") { + val numValues = 10000 + Seq(100).foreach { numGroups => + Seq(128).foreach { batchSize => + withSQLConf( + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + CometConf.COMET_BATCH_SIZE.key -> batchSize.toString) { + withParquetTable( + (0 until numValues).map(i => (i, Random.nextInt() % numGroups)), + "tbl", + false) { + withView("v") { + sql("CREATE TEMP VIEW v AS SELECT _1, _2 FROM tbl ORDER BY _1") + checkSparkAnswerAndFallbackReason( + "SELECT _2, FIRST(_1), LAST(_1), COUNT(DISTINCT _1)" + + " FROM v GROUP BY _2 ORDER BY _2", + "PartialMerge not supported for order-dependent aggregates") + } + } + } + } + } + } + test("partialMerge - cnt distinct + sum") { withTempDir(dir => { withSQLConf("spark.comet.enabled" -> "false") { From 18645c9e3444d71e573436f4efd864eee6a139ad Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 20 Apr 2026 13:53:54 -0700 Subject: [PATCH 5/8] feat: support `PartialMerge` --- native/core/src/execution/planner.rs | 70 ++++++++++--------- .../apache/spark/sql/comet/operators.scala | 13 ++-- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index d07de8af91..44186c6f3f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -116,8 +116,8 @@ use datafusion_comet_proto::{ }, spark_operator::{ self, lower_window_frame_bound::LowerFrameBoundStruct, operator::OpStruct, - upper_window_frame_bound::UpperFrameBoundStruct, BuildSide, - CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType, + upper_window_frame_bound::UpperFrameBoundStruct, AggregateMode as ProtoAggregateMode, + BuildSide, CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType, }, spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; @@ -967,10 +967,12 @@ impl PhysicalPlanner { let group_by = PhysicalGroupBy::new_single(group_exprs?); let schema = child.schema(); + let partial_merge = ProtoAggregateMode::PartialMerge as i32; + let mode = match agg.mode { 0 => DFAggregateMode::Partial, 1 => DFAggregateMode::Final, - 2 => DFAggregateMode::Partial, // PartialMerge uses Partial + MergeAsPartial + m if m == partial_merge => DFAggregateMode::Partial, other => { return Err(ExecutionError::GeneralError(format!( "Unsupported aggregate mode: {other}" @@ -978,15 +980,11 @@ impl PhysicalPlanner { } }; - // Determine per-expression modes. PartialMerge (2) expressions use - // MergeAsPartial wrapper so they run merge semantics in Partial mode. - let per_expr_modes: Vec = if !agg.expr_modes.is_empty() { - agg.expr_modes.clone() - } else { - vec![agg.mode; agg.agg_exprs.len()] - }; - - let has_partial_merge = per_expr_modes.contains(&2); + // Check if any expression uses PartialMerge mode. When present, + // those expressions are wrapped with MergeAsPartial to get merge + // semantics inside a Partial-mode AggregateExec. + let has_partial_merge = + agg.mode == partial_merge || agg.expr_modes.contains(&partial_merge); let agg_exprs: PhyAggResult = agg .agg_exprs @@ -998,51 +996,57 @@ impl PhysicalPlanner { // Wrap PartialMerge expressions with MergeAsPartial. // State fields in the child's output start at initial_input_buffer_offset. let mut state_offset = agg.initial_input_buffer_offset as usize; - let child_schema = child.schema(); + let per_expr_modes: Vec = if !agg.expr_modes.is_empty() { + agg.expr_modes.clone() + } else { + vec![agg.mode; agg.agg_exprs.len()] + }; agg_exprs? .into_iter() .enumerate() .map(|(idx, expr)| { - let expr_mode = per_expr_modes[idx]; - if expr_mode == 2 { + if per_expr_modes[idx] == partial_merge { // PartialMerge: wrap with MergeAsPartial - let state_fields = expr.state_fields().map_err(|e| { - ExecutionError::GeneralError(e.to_string()) - })?; + let state_fields = expr + .state_fields() + .map_err(|e| ExecutionError::GeneralError(e.to_string()))?; let num_state_fields = state_fields.len(); - // Create Column refs pointing to state field positions let state_cols: Vec> = (0..num_state_fields) .map(|i| { let col_idx = state_offset + i; - let field = child_schema.field(col_idx); + let field = schema.field(col_idx); Arc::new(Column::new(field.name(), col_idx)) as Arc }) .collect(); state_offset += num_state_fields; - let merge_udf = crate::execution::merge_as_partial::MergeAsPartialUDF::new(&expr) + let merge_udf = + crate::execution::merge_as_partial::MergeAsPartialUDF::new( + &expr, + ) .map_err(|e| ExecutionError::DataFusionError(e.to_string()))?; let merge_udf_arc = Arc::new( - datafusion::logical_expr::AggregateUDF::new_from_impl(merge_udf), + datafusion::logical_expr::AggregateUDF::new_from_impl( + merge_udf, + ), ); - let merge_expr = AggregateExprBuilder::new( - merge_udf_arc, - state_cols, - ) - .schema(Arc::clone(&child_schema)) - .alias(format!("col_{idx}")) - .with_ignore_nulls(expr.ignore_nulls()) - .with_distinct(expr.is_distinct()) - .build() - .map_err(|e| ExecutionError::DataFusionError(e.to_string()))?; + let merge_expr = + AggregateExprBuilder::new(merge_udf_arc, state_cols) + .schema(Arc::clone(&schema)) + .alias(format!("col_{idx}")) + .with_ignore_nulls(expr.ignore_nulls()) + .with_distinct(expr.is_distinct()) + .build() + .map_err(|e| { + ExecutionError::DataFusionError(e.to_string()) + })?; Ok(Arc::new(merge_expr)) } else { - // Partial: use as-is Ok(Arc::new(expr)) } }) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 4ce9e5c5c0..6f71e652e2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1356,9 +1356,11 @@ trait CometBaseAggregate { childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { val modes = aggregate.aggregateExpressions.map(_.mode).distinct + val modeSet = modes.toSet + val hasPartialMerge = modeSet.contains(PartialMerge) // In distinct aggregates there can be a combination of modes. // We support {Partial, PartialMerge} mix; other combinations are rejected. - val multiMode = modes.size > 1 && modes.toSet != Set(Partial, PartialMerge) + val multiMode = modes.size > 1 && modeSet != Set(Partial, PartialMerge) // For a final mode HashAggregate, we only need to transform the HashAggregate // if there is Comet partial aggregation. val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty @@ -1430,9 +1432,6 @@ trait CometBaseAggregate { hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) Some(builder.setHashAgg(hashAggBuilder).build()) } else { - val modes = aggregateExpressions.map(_.mode).distinct - val modeSet = modes.toSet - // Validate mode combinations. We support: // - All Partial // - All Final @@ -1463,8 +1462,7 @@ trait CometBaseAggregate { // FIRST/LAST are order-dependent aggregates whose merge result depends on // hash table processing order. In PartialMerge mode, DataFusion's hash table // may process rows in a different order than Spark's, producing different results. - val hasPartialMergeMode = modeSet.contains(PartialMerge) - if (hasPartialMergeMode) { + if (hasPartialMerge) { val unsupportedAggs = aggregateExpressions.filter { a => a.mode == PartialMerge && (a.aggregateFunction.isInstanceOf[First] || a.aggregateFunction.isInstanceOf[Last]) @@ -1514,7 +1512,6 @@ trait CometBaseAggregate { hashAggBuilder.setModeValue(mode.getNumber) // Send per-expression modes and buffer offset for PartialMerge handling - val hasPartialMerge = aggregateExpressions.exists(_.mode == PartialMerge) if (hasPartialMerge) { val exprModes = aggregateExpressions.map { a => a.mode match { @@ -1654,7 +1651,7 @@ object CometObjectHashAggregateExec * case branch here mapping it to the native state type. */ private def adjustOutputForNativeState(op: ObjectHashAggregateExec): Seq[Attribute] = { - // CometBaseAggregate.doConvert guarantees all expressions share the same mode. + // This adjustment only applies to pure-Partial aggregates (checked below). val modes = op.aggregateExpressions.map(_.mode).distinct if (modes != Seq(Partial)) { return op.output From 2eeeb2a5e547fdc48d11487ceff082398098ca3c Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 20 Apr 2026 14:01:26 -0700 Subject: [PATCH 6/8] feat: support `PartialMerge` --- native/core/src/execution/merge_as_partial.rs | 242 ++++++++++++++++++ native/core/src/execution/planner.rs | 14 +- 2 files changed, 247 insertions(+), 9 deletions(-) create mode 100644 native/core/src/execution/merge_as_partial.rs diff --git a/native/core/src/execution/merge_as_partial.rs b/native/core/src/execution/merge_as_partial.rs new file mode 100644 index 0000000000..dc671a22d8 --- /dev/null +++ b/native/core/src/execution/merge_as_partial.rs @@ -0,0 +1,242 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! MergeAsPartial wrapper for implementing Spark's PartialMerge aggregate mode. +//! +//! Spark's PartialMerge mode merges intermediate state buffers and outputs intermediate +//! state (not final values). DataFusion has no equivalent mode — `Partial` calls +//! `update_batch` and outputs state, while `Final` calls `merge_batch` and outputs +//! evaluated results. +//! +//! This wrapper bridges the gap: it operates under DataFusion's `Partial` mode (which +//! outputs state) but redirects `update_batch` calls to `merge_batch`, giving merge +//! semantics with state output. + +use std::any::Any; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion::common::Result; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::function::StateFieldsArgs; +use datafusion::logical_expr::{ + Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, +}; +use datafusion::physical_expr::aggregate::AggregateFunctionExpr; +use datafusion::scalar::ScalarValue; + +/// An AggregateUDF wrapper that gives merge semantics in Partial mode. +/// +/// When DataFusion runs an AggregateExec in Partial mode, it calls `update_batch` +/// on each accumulator and outputs `state()`. This wrapper intercepts `update_batch` +/// and redirects it to `merge_batch` on the inner accumulator, effectively +/// implementing PartialMerge: merge inputs, output state. +/// +/// We store the inner AggregateUDF (not the AggregateFunctionExpr) to avoid keeping +/// references to UnboundColumn expressions that would panic if evaluated. +#[derive(Debug)] +pub struct MergeAsPartialUDF { + /// The inner aggregate UDF, cloned from the original expression. + inner_udf: AggregateUDF, + /// Pre-computed return type from the original expression. + return_type: DataType, + /// Pre-computed state fields from the original expression. + cached_state_fields: Vec, + /// Cached signature that accepts state field types. + signature: Signature, + /// Name for this wrapper. + name: String, +} + +impl PartialEq for MergeAsPartialUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + } +} + +impl Eq for MergeAsPartialUDF {} + +impl Hash for MergeAsPartialUDF { + fn hash(&self, state: &mut H) { + self.name.hash(state); + } +} + +impl MergeAsPartialUDF { + pub fn new(inner_expr: &AggregateFunctionExpr) -> Result { + let name = format!("merge_as_partial_{}", inner_expr.name()); + let return_type = inner_expr.field().data_type().clone(); + let cached_state_fields = inner_expr.state_fields()?; + + // Use a permissive signature since we accept state field types which + // vary per aggregate function. + let signature = Signature::variadic_any(Volatility::Immutable); + + Ok(Self { + inner_udf: inner_expr.fun().clone(), + return_type, + cached_state_fields, + signature, + name, + }) + } +} + +impl AggregateUDFImpl for MergeAsPartialUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // In Partial mode, return_type isn't used for output schema (state_fields is). + // Return the inner function's return type for consistency. + Ok(self.return_type.clone()) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + // State fields must match the inner aggregate's state fields so that + // the output of this PartialMerge stage is compatible with subsequent + // Final or PartialMerge stages. + Ok(self.cached_state_fields.clone()) + } + + fn accumulator(&self, args: AccumulatorArgs) -> Result> { + // Create the inner accumulator using the provided args (which have the + // correct Column refs, not UnboundColumns). + let inner_acc = self.inner_udf.accumulator(args)?; + Ok(Box::new(MergeAsPartialAccumulator { inner: inner_acc })) + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + self.inner_udf.groups_accumulator_supported(args) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let inner_acc = self.inner_udf.create_groups_accumulator(args)?; + Ok(Box::new(MergeAsPartialGroupsAccumulator { + inner: inner_acc, + })) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::NotSupported + } + + fn default_value(&self, data_type: &DataType) -> Result { + ScalarValue::try_from(data_type) + } + + fn is_descending(&self) -> Option { + None + } +} + +/// Accumulator wrapper that redirects update_batch to merge_batch. +struct MergeAsPartialAccumulator { + inner: Box, +} + +impl Debug for MergeAsPartialAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MergeAsPartialAccumulator").finish() + } +} + +impl Accumulator for MergeAsPartialAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Redirect update to merge — this is the key trick. + self.inner.merge_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.inner.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + self.inner.evaluate() + } + + fn state(&mut self) -> Result> { + self.inner.state() + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +/// GroupsAccumulator wrapper that redirects update_batch to merge_batch. +struct MergeAsPartialGroupsAccumulator { + inner: Box, +} + +impl Debug for MergeAsPartialGroupsAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MergeAsPartialGroupsAccumulator").finish() + } +} + +impl GroupsAccumulator for MergeAsPartialGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // Redirect update to merge — this is the key trick. + self.inner + .merge_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.inner + .merge_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + self.inner.evaluate(emit_to) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + self.inner.state(emit_to) + } + + fn size(&self) -> usize { + self.inner.size() + } +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 44186c6f3f..7c9c03c88d 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -116,8 +116,7 @@ use datafusion_comet_proto::{ }, spark_operator::{ self, lower_window_frame_bound::LowerFrameBoundStruct, operator::OpStruct, - upper_window_frame_bound::UpperFrameBoundStruct, AggregateMode as ProtoAggregateMode, - BuildSide, CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType, + upper_window_frame_bound::UpperFrameBoundStruct, BuildSide, CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType, }, spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; @@ -967,12 +966,10 @@ impl PhysicalPlanner { let group_by = PhysicalGroupBy::new_single(group_exprs?); let schema = child.schema(); - let partial_merge = ProtoAggregateMode::PartialMerge as i32; - let mode = match agg.mode { 0 => DFAggregateMode::Partial, 1 => DFAggregateMode::Final, - m if m == partial_merge => DFAggregateMode::Partial, + 2 => DFAggregateMode::Partial, // PartialMerge: Partial + MergeAsPartial other => { return Err(ExecutionError::GeneralError(format!( "Unsupported aggregate mode: {other}" @@ -980,11 +977,10 @@ impl PhysicalPlanner { } }; - // Check if any expression uses PartialMerge mode. When present, + // Check if any expression uses PartialMerge mode (2). When present, // those expressions are wrapped with MergeAsPartial to get merge // semantics inside a Partial-mode AggregateExec. - let has_partial_merge = - agg.mode == partial_merge || agg.expr_modes.contains(&partial_merge); + let has_partial_merge = agg.mode == 2 || agg.expr_modes.contains(&2); let agg_exprs: PhyAggResult = agg .agg_exprs @@ -1006,7 +1002,7 @@ impl PhysicalPlanner { .into_iter() .enumerate() .map(|(idx, expr)| { - if per_expr_modes[idx] == partial_merge { + if per_expr_modes[idx] == 2 { // PartialMerge: wrap with MergeAsPartial let state_fields = expr .state_fields() From 6d04b604f530e8a5c41f4593851a34d2057320d0 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 20 Apr 2026 14:02:45 -0700 Subject: [PATCH 7/8] feat: support `PartialMerge` --- native/core/src/execution/planner.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 7c9c03c88d..55a09501d7 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -116,7 +116,8 @@ use datafusion_comet_proto::{ }, spark_operator::{ self, lower_window_frame_bound::LowerFrameBoundStruct, operator::OpStruct, - upper_window_frame_bound::UpperFrameBoundStruct, BuildSide, CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType, + upper_window_frame_bound::UpperFrameBoundStruct, BuildSide, + CompressionCodec as SparkCompressionCodec, JoinType, Operator, WindowFrameType, }, spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; From 84ac3877fb1ad99ca9f2186636a2403c43cd9fbb Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 20 Apr 2026 14:17:11 -0700 Subject: [PATCH 8/8] feat: support `PartialMerge` --- .../src/main/scala/org/apache/spark/sql/comet/operators.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 6f71e652e2..775b83fa58 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -30,7 +30,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectSet, Final, Partial, PartialMerge} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, CollectSet, Final, First, Last, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -1470,7 +1470,7 @@ trait CometBaseAggregate { if (unsupportedAggs.nonEmpty) { withInfo( aggregate, - s"PartialMerge not supported for order-dependent aggregates: " + + "PartialMerge not supported for order-dependent aggregates: " + unsupportedAggs.map(_.aggregateFunction.prettyName).mkString(", ")) return None }