Skip to content

Commit 1d49074

Browse files
authored
chore: Add Comet writer nested types test assertion (#3480)
1 parent e87e1a3 commit 1d49074

1 file changed

Lines changed: 18 additions & 12 deletions

File tree

spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -447,17 +447,7 @@ class CometParquetWriterSuite extends CometTestBase {
447447
}
448448
}
449449

450-
private def writeWithCometNativeWriteExec(
451-
inputPath: String,
452-
outputPath: String,
453-
num_partitions: Option[Int] = None): Option[SparkPlan] = {
454-
val df = spark.read.parquet(inputPath)
455-
456-
val plan = captureWritePlan(
457-
path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path),
458-
outputPath)
459-
460-
// Count CometNativeWriteExec instances in the plan
450+
private def assertHasCometNativeWriteExec(plan: SparkPlan): Unit = {
461451
var nativeWriteCount = 0
462452
plan.foreach {
463453
case _: CometNativeWriteExec =>
@@ -474,6 +464,19 @@ class CometParquetWriterSuite extends CometTestBase {
474464
assert(
475465
nativeWriteCount == 1,
476466
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${plan.treeString}")
467+
}
468+
469+
private def writeWithCometNativeWriteExec(
470+
inputPath: String,
471+
outputPath: String,
472+
num_partitions: Option[Int] = None): Option[SparkPlan] = {
473+
val df = spark.read.parquet(inputPath)
474+
475+
val plan = captureWritePlan(
476+
path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path),
477+
outputPath)
478+
479+
assertHasCometNativeWriteExec(plan)
477480

478481
Some(plan)
479482
}
@@ -524,7 +527,10 @@ class CometParquetWriterSuite extends CometTestBase {
524527
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax") {
525528

526529
val parquetDf = spark.read.parquet(inputPath)
527-
parquetDf.write.parquet(outputPath)
530+
531+
// Capture plan and verify CometNativeWriteExec is used
532+
val plan = captureWritePlan(path => parquetDf.write.parquet(path), outputPath)
533+
assertHasCometNativeWriteExec(plan)
528534
}
529535

530536
// Verify round-trip: read with Spark and Comet, compare results

0 commit comments

Comments
 (0)