@@ -22,15 +22,15 @@ package org.apache.comet.serde
2222import scala .jdk .CollectionConverters ._
2323
2424import org .apache .spark .sql .catalyst .expressions .{Attribute , EvalMode }
25- import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , BitAndAgg , BitOrAgg , BitXorAgg , BloomFilterAggregate , CentralMomentAgg , Corr , Count , CovPopulation , CovSample , Covariance , First , Last , Max , Min , StddevPop , StddevSamp , Sum , VariancePop , VarianceSamp }
25+ import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , BitAndAgg , BitOrAgg , BitXorAgg , BloomFilterAggregate , CentralMomentAgg , Corr , Count , Covariance , CovPopulation , CovSample , First , Last , Max , Min , StddevPop , StddevSamp , Sum , VariancePop , VarianceSamp }
2626import org .apache .spark .sql .internal .SQLConf
2727import org .apache .spark .sql .types .{ByteType , DataTypes , DecimalType , IntegerType , LongType , ShortType , StringType }
2828
2929import org .apache .comet .CometConf
3030import org .apache .comet .CometConf .COMET_EXEC_STRICT_FLOATING_POINT
3131import org .apache .comet .CometSparkSessionExtensions .withInfo
32- import org .apache .comet .serde .QueryPlanSerde .{exprToProto , serializeDataType }
33- import org .apache .comet .shims .CometExprShim
32+ import org .apache .comet .serde .QueryPlanSerde .{evalModeToProto , exprToProto , serializeDataType }
33+ import org .apache .comet .shims .{ CometEvalModeUtil , CometSumShim }
3434
3535object CometMin extends CometAggregateExpressionSerde [Min ] {
3636
@@ -211,7 +211,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
211211 }
212212}
213213
214- object CometSum extends CometAggregateExpressionSerde [Sum ] with CometExprShim {
214+ object CometSum extends CometAggregateExpressionSerde [Sum ] with CometSumShim {
215215
216216 override def getSupportLevel (sum : Sum ): SupportLevel = {
217217 sparkEvalMode(sum) match {
@@ -243,7 +243,8 @@ object CometSum extends CometAggregateExpressionSerde[Sum] with CometExprShim {
243243 val builder = ExprOuterClass .Sum .newBuilder()
244244 builder.setChild(childExpr.get)
245245 builder.setDatatype(dataType.get)
246- builder.setFailOnError(sparkEvalMode(sum) == EvalMode .ANSI )
246+ builder.setEvalMode(
247+ evalModeToProto(CometEvalModeUtil .fromSparkEvalMode(sparkEvalMode(sum))))
247248
248249 Some (
249250 ExprOuterClass .AggExpr
0 commit comments