@@ -37,6 +37,8 @@ import org.apache.parquet.hadoop.example.{ExampleParquetWriter, GroupWriteSuppor
3737import org .apache .parquet .schema .{MessageType , MessageTypeParser }
3838import org .apache .spark ._
3939import org .apache .spark .internal .config .{MEMORY_OFFHEAP_ENABLED , MEMORY_OFFHEAP_SIZE , SHUFFLE_MANAGER }
40+ import org .apache .spark .sql .catalyst .plans .logical
41+ import org .apache .spark .sql .catalyst .util .sideBySide
4042import org .apache .spark .sql .comet .CometPlanChecker
4143import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec }
4244import org .apache .spark .sql .execution ._
@@ -128,7 +130,7 @@ abstract class CometTestBase
128130 if (withTol.isDefined) {
129131 checkAnswerWithTolerance(dfComet, expected, withTol.get)
130132 } else {
131- checkAnswer (dfComet, expected)
133+ checkCometAnswer (dfComet, expected)
132134 }
133135
134136 if (assertCometNative) {
@@ -358,6 +360,48 @@ abstract class CometTestBase
358360 }
359361 }
360362
363+ /**
364+ * Compares the Comet DataFrame result against the expected Spark answer, using labels that
365+ * correctly identify which side is Comet and which is Spark. This avoids the misleading "Spark
366+ * Answer" label that Spark's built-in `checkAnswer` would apply to the Comet result.
367+ */
368+ protected def checkCometAnswer (cometDf : DataFrame , sparkAnswer : Seq [Row ]): Unit = {
369+ val isSorted = cometDf.logicalPlan.collect { case s : logical.Sort => s }.nonEmpty
370+ val cometAnswer =
371+ try cometDf.collect().toSeq
372+ catch {
373+ case e : Exception =>
374+ fail(s """ Exception thrown while executing query in Comet:
375+ | ${cometDf.queryExecution}
376+ |== Exception ==
377+ | $e
378+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
379+ """ .stripMargin)
380+ }
381+ if (! QueryTest .compare(
382+ QueryTest .prepareAnswer(sparkAnswer, isSorted),
383+ QueryTest .prepareAnswer(cometAnswer, isSorted))) {
384+ val getRowType : Option [Row ] => String = row =>
385+ row
386+ .map(r => if (r.schema == null ) " struct<>" else r.schema.catalogString)
387+ .getOrElse(" struct<>" )
388+ fail(s """ Results do not match for query:
389+ |Timezone: ${java.util.TimeZone .getDefault}
390+ |Timezone Env: ${sys.env.getOrElse(" TZ" , " " )}
391+ |
392+ | ${cometDf.queryExecution}
393+ |== Results ==
394+ | ${sideBySide(
395+ s " == Spark Answer - ${sparkAnswer.size} == " +:
396+ getRowType(sparkAnswer.headOption) +:
397+ QueryTest .prepareAnswer(sparkAnswer, isSorted).map(_.toString()),
398+ s " == Comet Answer - ${cometAnswer.size} == " +:
399+ getRowType(cometAnswer.headOption) +:
400+ QueryTest .prepareAnswer(cometAnswer, isSorted).map(_.toString())).mkString(" \n " )}
401+ """ .stripMargin)
402+ }
403+ }
404+
361405 /**
362406 * A helper function for comparing Comet DataFrame with Spark result using absolute tolerance.
363407 */
0 commit comments