Skip to content

Commit dae02c0

Browse files
authored
fix: report task output metrics in Spark UI (#3999)
1 parent ee140c7 commit dae02c0

3 files changed

Lines changed: 115 additions & 1 deletion

File tree

spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,26 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM
8484
}
8585
}
8686

87+
/**
88+
* Reports native Parquet writer SQL metrics to Spark's task-level
89+
* [[org.apache.spark.executor.OutputMetrics]] so the Spark UI Stages tab Output column shows
90+
* bytes and records written.
91+
*
92+
* Must be registered on the task thread before [[org.apache.comet.CometExecIterator]] so
93+
* Spark's completion listener stack invokes the iterator `close` (final SQL metric update)
94+
* before this listener runs.
95+
*/
96+
def reportNativeWriteOutputMetrics(ctx: TaskContext): Unit = {
97+
ctx.addTaskCompletionListener[Unit] { _ =>
98+
metrics.get("bytes_written").foreach { m =>
99+
ctx.taskMetrics().outputMetrics.setBytesWritten(m.value)
100+
}
101+
metrics.get("rows_written").foreach { m =>
102+
ctx.taskMetrics().outputMetrics.setRecordsWritten(m.value)
103+
}
104+
}
105+
}
106+
87107
/**
88108
* Gets a child node. Called from native.
89109
*/

spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.jdk.CollectionConverters._
2424
import org.apache.hadoop.fs.Path
2525
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
2626
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
27+
import org.apache.spark.TaskContext
2728
import org.apache.spark.internal.io.FileCommitProtocol
2829
import org.apache.spark.rdd.RDD
2930
import org.apache.spark.sql.catalyst.InternalRow
@@ -197,6 +198,9 @@ case class CometNativeWriteExec(
197198
}
198199

199200
val nativeMetrics = CometMetricNode.fromCometPlan(this)
201+
// Register before CometExecIterator so completion listeners run after iterator close
202+
// (Spark runs task completion callbacks in reverse registration order).
203+
Option(TaskContext.get()).foreach(nativeMetrics.reportNativeWriteOutputMetrics)
200204

201205
val size = modifiedNativeOp.getSerializedSize
202206
val planBytes = new Array[Byte](size)

spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,23 @@
1919

2020
package org.apache.spark.sql.comet
2121

22+
import java.io.File
23+
2224
import scala.collection.mutable
2325

24-
import org.apache.spark.SparkConf
26+
import org.apache.spark.{SparkConf, SparkContext}
2527
import org.apache.spark.executor.ShuffleReadMetrics
2628
import org.apache.spark.executor.ShuffleWriteMetrics
2729
import org.apache.spark.scheduler.SparkListener
30+
import org.apache.spark.scheduler.SparkListenerJobStart
2831
import org.apache.spark.scheduler.SparkListenerTaskEnd
2932
import org.apache.spark.sql.CometTestBase
3033
import org.apache.spark.sql.comet.execution.shuffle.CometNativeShuffle
3134
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
3235
import org.apache.spark.sql.execution.SparkPlan
3336
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
37+
import org.apache.spark.sql.execution.command.DataWritingCommandExec
38+
import org.apache.spark.sql.internal.SQLConf
3439

3540
import org.apache.comet.CometConf
3641

@@ -100,6 +105,91 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper {
100105
}
101106
}
102107

108+
test("native parquet write reports task-level output metrics") {
109+
withParquetTable((0 until 5000).map(i => (i, (i + 1).toLong)), "tbl") {
110+
withTempPath { dir =>
111+
val outPath = new File(dir, "written").getAbsolutePath
112+
val expectedRows = 5000L
113+
val outputBytes = mutable.ArrayBuffer.empty[Long]
114+
val outputRecords = mutable.ArrayBuffer.empty[Long]
115+
val targetStageIds = mutable.HashSet.empty[Int]
116+
val jobGroupId = s"native-write-metrics-${java.util.UUID.randomUUID().toString}"
117+
118+
val listener = new SparkListener {
119+
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
120+
val isTargetJob = Option(jobStart.properties)
121+
.flatMap(props => Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)))
122+
.contains(jobGroupId)
123+
if (isTargetJob) {
124+
targetStageIds.synchronized {
125+
targetStageIds ++= jobStart.stageInfos.map(_.stageId)
126+
}
127+
}
128+
}
129+
130+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
131+
val isTargetStage = targetStageIds.synchronized {
132+
targetStageIds.contains(taskEnd.stageId)
133+
}
134+
if (isTargetStage) {
135+
val om = taskEnd.taskMetrics.outputMetrics
136+
if (om.bytesWritten > 0) {
137+
outputBytes.synchronized {
138+
outputBytes += om.bytesWritten
139+
outputRecords += om.recordsWritten
140+
}
141+
}
142+
}
143+
}
144+
}
145+
spark.sparkContext.addSparkListener(listener)
146+
147+
try {
148+
spark.sparkContext.listenerBus.waitUntilEmpty()
149+
150+
withSQLConf(
151+
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
152+
CometConf.COMET_EXEC_ENABLED.key -> "true",
153+
CometConf.getOperatorAllowIncompatConfigKey(
154+
classOf[DataWritingCommandExec]) -> "true",
155+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax") {
156+
spark.sparkContext.setJobGroup(jobGroupId, "native parquet write output metrics")
157+
try {
158+
sql("SELECT * FROM tbl").write.parquet(outPath)
159+
} finally {
160+
spark.sparkContext.clearJobGroup()
161+
}
162+
}
163+
164+
spark.sparkContext.listenerBus.waitUntilEmpty()
165+
166+
assert(outputBytes.nonEmpty, "No task reported outputMetrics.bytesWritten")
167+
val totalOutputBytes = outputBytes.sum
168+
val totalOutputRecords = outputRecords.sum
169+
170+
assert(
171+
totalOutputRecords == expectedRows,
172+
s"recordsWritten mismatch: metrics=$totalOutputRecords, expected=$expectedRows")
173+
174+
val outputDir = new File(outPath)
175+
val fileBytes = Option(outputDir.listFiles())
176+
.getOrElse(Array.empty)
177+
.filter(f => f.isFile && f.getName.startsWith("part-"))
178+
.map(_.length())
179+
.sum
180+
181+
assert(fileBytes > 0L, s"Expected written parquet bytes should be > 0, got $fileBytes")
182+
val ratio = totalOutputBytes.toDouble / fileBytes.toDouble
183+
assert(
184+
ratio >= 0.7 && ratio <= 1.3,
185+
s"bytesWritten ratio out of range: metrics=$totalOutputBytes, files=$fileBytes, ratio=$ratio")
186+
} finally {
187+
spark.sparkContext.removeSparkListener(listener)
188+
}
189+
}
190+
}
191+
}
192+
103193
test("native_datafusion scan reports task-level input metrics matching Spark") {
104194
val totalRows = 10000
105195
withTempPath { dir =>

0 commit comments

Comments
 (0)