@@ -447,17 +447,7 @@ class CometParquetWriterSuite extends CometTestBase {
447447 }
448448 }
449449
450- private def writeWithCometNativeWriteExec (
451- inputPath : String ,
452- outputPath : String ,
453- num_partitions : Option [Int ] = None ): Option [SparkPlan ] = {
454- val df = spark.read.parquet(inputPath)
455-
456- val plan = captureWritePlan(
457- path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path),
458- outputPath)
459-
460- // Count CometNativeWriteExec instances in the plan
450+ private def assertHasCometNativeWriteExec (plan : SparkPlan ): Unit = {
461451 var nativeWriteCount = 0
462452 plan.foreach {
463453 case _ : CometNativeWriteExec =>
@@ -474,6 +464,19 @@ class CometParquetWriterSuite extends CometTestBase {
474464 assert(
475465 nativeWriteCount == 1 ,
476466 s " Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount: \n ${plan.treeString}" )
467+ }
468+
469+ private def writeWithCometNativeWriteExec (
470+ inputPath : String ,
471+ outputPath : String ,
472+ num_partitions : Option [Int ] = None ): Option [SparkPlan ] = {
473+ val df = spark.read.parquet(inputPath)
474+
475+ val plan = captureWritePlan(
476+ path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path),
477+ outputPath)
478+
479+ assertHasCometNativeWriteExec(plan)
477480
478481 Some (plan)
479482 }
@@ -524,7 +527,10 @@ class CometParquetWriterSuite extends CometTestBase {
524527 SQLConf .SESSION_LOCAL_TIMEZONE .key -> " America/Halifax" ) {
525528
526529 val parquetDf = spark.read.parquet(inputPath)
527- parquetDf.write.parquet(outputPath)
530+
531+ // Capture plan and verify CometNativeWriteExec is used
532+ val plan = captureWritePlan(path => parquetDf.write.parquet(path), outputPath)
533+ assertHasCometNativeWriteExec(plan)
528534 }
529535
530536 // Verify round-trip: read with Spark and Comet, compare results
0 commit comments