diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index b0ded47580..b3b89d043a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -84,6 +84,26 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM } } + /** + * Reports native Parquet writer SQL metrics to Spark's task-level + * [[org.apache.spark.executor.OutputMetrics]] so the Spark UI Stages tab Output column shows + * bytes and records written. + * + * Must be registered on the task thread before [[org.apache.comet.CometExecIterator]] so + * Spark's completion listener stack invokes the iterator `close` (final SQL metric update) + * before this listener runs. + */ + def reportNativeWriteOutputMetrics(ctx: TaskContext): Unit = { + ctx.addTaskCompletionListener[Unit] { _ => + metrics.get("bytes_written").foreach { m => + ctx.taskMetrics().outputMetrics.setBytesWritten(m.value) + } + metrics.get("rows_written").foreach { m => + ctx.taskMetrics().outputMetrics.setRecordsWritten(m.value) + } + } + } + /** * Gets a child node. Called from native. */ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index 39e7ac6eef..d93bda5f9f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -24,6 +24,7 @@ import scala.jdk.CollectionConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext, TaskAttemptID, TaskID, TaskType} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.spark.TaskContext import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -197,6 +198,9 @@ case class CometNativeWriteExec( } val nativeMetrics = CometMetricNode.fromCometPlan(this) + // Register before CometExecIterator so completion listeners run after iterator close + // (Spark runs task completion callbacks in reverse registration order). + Option(TaskContext.get()).foreach(nativeMetrics.reportNativeWriteOutputMetrics) val size = modifiedNativeOp.getSerializedSize val planBytes = new Array[Byte](size) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index cc02551bf8..a8b7d5d9da 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -19,18 +19,23 @@ package org.apache.spark.sql.comet +import java.io.File + import scala.collection.mutable -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.SparkListener +import org.apache.spark.scheduler.SparkListenerJobStart import org.apache.spark.scheduler.SparkListenerTaskEnd import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf @@ -100,6 +105,91 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("native parquet write reports task-level output metrics") { + withParquetTable((0 until 5000).map(i => (i, (i + 1).toLong)), "tbl") { + withTempPath { dir => + val outPath = new File(dir, "written").getAbsolutePath + val expectedRows = 5000L + val outputBytes = mutable.ArrayBuffer.empty[Long] + val outputRecords = mutable.ArrayBuffer.empty[Long] + val targetStageIds = mutable.HashSet.empty[Int] + val jobGroupId = s"native-write-metrics-${java.util.UUID.randomUUID().toString}" + + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + val isTargetJob = Option(jobStart.properties) + .flatMap(props => Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID))) + .contains(jobGroupId) + if (isTargetJob) { + targetStageIds.synchronized { + targetStageIds ++= jobStart.stageInfos.map(_.stageId) + } + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val isTargetStage = targetStageIds.synchronized { + targetStageIds.contains(taskEnd.stageId) + } + if (isTargetStage) { + val om = taskEnd.taskMetrics.outputMetrics + if (om.bytesWritten > 0) { + outputBytes.synchronized { + outputBytes += om.bytesWritten + outputRecords += om.recordsWritten + } + } + } + } + } + spark.sparkContext.addSparkListener(listener) + + try { + spark.sparkContext.listenerBus.waitUntilEmpty() + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey( + classOf[DataWritingCommandExec]) -> "true", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax") { + spark.sparkContext.setJobGroup(jobGroupId, "native parquet write output metrics") + try { + sql("SELECT * FROM tbl").write.parquet(outPath) + } finally { + spark.sparkContext.clearJobGroup() + } + } + + spark.sparkContext.listenerBus.waitUntilEmpty() + + assert(outputBytes.nonEmpty, "No task reported outputMetrics.bytesWritten") + val totalOutputBytes = outputBytes.sum + val totalOutputRecords = outputRecords.sum + + assert( + totalOutputRecords == expectedRows, + s"recordsWritten mismatch: metrics=$totalOutputRecords, expected=$expectedRows") + + val outputDir = new File(outPath) + val fileBytes = Option(outputDir.listFiles()) + .getOrElse(Array.empty) + .filter(f => f.isFile && f.getName.startsWith("part-")) + .map(_.length()) + .sum + + assert(fileBytes > 0L, s"Expected written parquet bytes should be > 0, got $fileBytes") + val ratio = totalOutputBytes.toDouble / fileBytes.toDouble + assert( + ratio >= 0.7 && ratio <= 1.3, + s"bytesWritten ratio out of range: metrics=$totalOutputBytes, files=$fileBytes, ratio=$ratio") + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + } + test("native_datafusion scan reports task-level input metrics matching Spark") { val totalRows = 10000 withTempPath { dir =>