@@ -22,15 +22,15 @@ package org.apache.comet.serde
2222import scala .jdk .CollectionConverters ._
2323
2424import org .apache .spark .sql .catalyst .expressions .{Attribute , EvalMode }
25- import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , BitAndAgg , BitOrAgg , BitXorAgg , BloomFilterAggregate , CentralMomentAgg , Corr , Count , Covariance , CovPopulation , CovSample , First , Last , Max , Min , StddevPop , StddevSamp , Sum , VariancePop , VarianceSamp }
25+ import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , BitAndAgg , BitOrAgg , BitXorAgg , BloomFilterAggregate , CentralMomentAgg , Corr , Count , CovPopulation , CovSample , Covariance , First , Last , Max , Min , StddevPop , StddevSamp , Sum , VariancePop , VarianceSamp }
2626import org .apache .spark .sql .internal .SQLConf
2727import org .apache .spark .sql .types .{ByteType , DataTypes , DecimalType , IntegerType , LongType , ShortType , StringType }
2828
2929import org .apache .comet .CometConf
3030import org .apache .comet .CometConf .COMET_EXEC_STRICT_FLOATING_POINT
3131import org .apache .comet .CometSparkSessionExtensions .withInfo
32- import org .apache .comet .serde .QueryPlanSerde .{evalModeToProto , exprToProto , serializeDataType }
33- import org .apache .comet .shims .CometEvalModeUtil
32+ import org .apache .comet .serde .QueryPlanSerde .{exprToProto , serializeDataType }
33+ import org .apache .comet .shims .CometExprShim
3434
3535object CometMin extends CometAggregateExpressionSerde [Min ] {
3636
@@ -211,10 +211,10 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
211211 }
212212}
213213
214- object CometSum extends CometAggregateExpressionSerde [Sum ] {
214+ object CometSum extends CometAggregateExpressionSerde [Sum ] with CometExprShim {
215215
216216 override def getSupportLevel (sum : Sum ): SupportLevel = {
217- sum.evalMode match {
217+ sparkEvalMode( sum) match {
218218 case EvalMode .ANSI if ! sum.dataType.isInstanceOf [DecimalType ] =>
219219 Incompatible (Some (" ANSI mode for non decimal inputs is not supported" ))
220220 case EvalMode .TRY if ! sum.dataType.isInstanceOf [DecimalType ] =>
@@ -243,7 +243,7 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
243243 val builder = ExprOuterClass .Sum .newBuilder()
244244 builder.setChild(childExpr.get)
245245 builder.setDatatype(dataType.get)
246- builder.setEvalMode(evalModeToProto( CometEvalModeUtil .fromSparkEvalMode( sum.evalMode)) )
246+ builder.setFailOnError(sparkEvalMode( sum) == EvalMode . ANSI )
247247
248248 Some (
249249 ExprOuterClass .AggExpr
0 commit comments