@@ -383,7 +383,9 @@ class CometParquetWriterSuite extends CometTestBase {
383383 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
384384 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
385385 withTempPath { path =>
386- spark.range(100 ).repartition(10 ).where(" id = 50" ).write.parquet(path.toString)
386+ assertNativeWriter {
387+ spark.range(100 ).repartition(10 ).where(" id = 50" ).write.parquet(path.toString)
388+ }
387389 val partFiles = path
388390 .listFiles()
389391 .filter(f => f.isFile && ! f.getName.startsWith(" ." ) && ! f.getName.startsWith(" _" ))
@@ -400,7 +402,9 @@ class CometParquetWriterSuite extends CometTestBase {
400402 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
401403 withTable(" t1" , " t2" ) {
402404 sql(" CREATE TABLE t1(i CHAR(5), c VARCHAR(4)) USING parquet" )
403- sql(" CREATE TABLE t2 USING parquet AS SELECT * FROM t1" )
405+ assertNativeWriter {
406+ sql(" CREATE TABLE t2 USING parquet AS SELECT * FROM t1" )
407+ }
404408 checkAnswer(
405409 sql(" desc t2" ).selectExpr(" data_type" ).where(" data_type like '%char%'" ),
406410 Seq (Row (" char(5)" ), Row (" varchar(4)" )))
@@ -417,7 +421,9 @@ class CometParquetWriterSuite extends CometTestBase {
417421 withTable(" t1" , " t2" ) {
418422 sql(" CREATE TABLE t1(col CHAR(5)) USING parquet" )
419423 withSQLConf(SQLConf .CHAR_AS_VARCHAR .key -> " true" ) {
420- sql(" CREATE TABLE t2 USING parquet AS SELECT * FROM t1" )
424+ assertNativeWriter {
425+ sql(" CREATE TABLE t2 USING parquet AS SELECT * FROM t1" )
426+ }
421427 checkAnswer(
422428 sql(" desc t2" ).selectExpr(" data_type" ).where(" data_type like '%char%'" ),
423429 Seq (Row (" varchar(5)" )))
@@ -436,11 +442,17 @@ class CometParquetWriterSuite extends CometTestBase {
436442 val path = dir.toURI.getPath
437443 withTable(" tab1" , " tab2" ) {
438444 sql(s """ create table tab1 (a int) using parquet location ' $path' """ )
439- sql(" insert into tab1 values(1)" )
445+ assertNativeWriter {
446+ sql(" insert into tab1 values(1)" )
447+ }
440448 checkAnswer(sql(" select * from tab1" ), Seq (Row (1 )))
441449 sql(" create table tab2 (a int) using parquet" )
442- sql(" insert into tab2 values(2)" )
443- sql(s """ insert overwrite local directory ' $path' using parquet select * from tab2 """ )
450+ assertNativeWriter {
451+ sql(" insert into tab2 values(2)" )
452+ }
453+ assertNativeWriter {
454+ sql(s """ insert overwrite local directory ' $path' using parquet select * from tab2 """ )
455+ }
444456 sql(" refresh table tab1" )
445457 checkAnswer(sql(" select * from tab1" ), Seq (Row (2 )))
446458 }
@@ -456,7 +468,9 @@ class CometParquetWriterSuite extends CometTestBase {
456468 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
457469 withTable(" t" ) {
458470 sql(" create table t(i boolean, s bigint) using parquet" )
459- sql(" insert into t(i) values(true)" )
471+ assertNativeWriter {
472+ sql(" insert into t(i) values(true)" )
473+ }
460474 checkAnswer(spark.table(" t" ), Row (true , null ))
461475 }
462476 }
@@ -471,7 +485,9 @@ class CometParquetWriterSuite extends CometTestBase {
471485 withTable(" t" ) {
472486 sql(" create table t(i boolean) using parquet" )
473487 sql(" alter table t add column s string default concat('abc', 'def')" )
474- sql(" insert into t values(true, default)" )
488+ assertNativeWriter {
489+ sql(" insert into t values(true, default)" )
490+ }
475491 checkAnswer(spark.table(" t" ), Row (true , " abcdef" ))
476492 }
477493 }
@@ -485,9 +501,13 @@ class CometParquetWriterSuite extends CometTestBase {
485501 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
486502 withTable(" t1" , " t2" ) {
487503 sql(" create table t1(i boolean, s bigint default 42) using parquet" )
488- sql(" insert into t1 values (true, 41), (false, default)" )
504+ assertNativeWriter {
505+ sql(" insert into t1 values (true, 41), (false, default)" )
506+ }
489507 sql(" create table t2(i boolean, s bigint) using parquet" )
490- sql(" insert into t2 select * from t1 order by s" )
508+ assertNativeWriter {
509+ sql(" insert into t2 select * from t1 order by s" )
510+ }
491511 checkAnswer(spark.table(" t2" ), Seq (Row (true , 41 ), Row (false , 42 )))
492512 }
493513 }
@@ -502,10 +522,14 @@ class CometParquetWriterSuite extends CometTestBase {
502522 withTable(" tbl" , " tbl2" ) {
503523 withView(" view1" ) {
504524 val df = spark.range(10 ).toDF(" id" )
505- df.write.format(" parquet" ).saveAsTable(" tbl" )
525+ assertNativeWriter {
526+ df.write.format(" parquet" ).saveAsTable(" tbl" )
527+ }
506528 spark.sql(" CREATE VIEW view1 AS SELECT id FROM tbl" )
507529 spark.sql(" CREATE TABLE tbl2(ID long) USING parquet" )
508- spark.sql(" INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1" )
530+ assertNativeWriter {
531+ spark.sql(" INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1" )
532+ }
509533 checkAnswer(spark.table(" tbl2" ), (0 until 10 ).map(Row (_)))
510534 }
511535 }
@@ -519,11 +543,13 @@ class CometParquetWriterSuite extends CometTestBase {
519543 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
520544 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
521545 withTempPath { dir =>
522- spark
523- .range(1 )
524- .selectExpr(" current_timestamp() as ts" )
525- .write
526- .parquet(dir.toString + " /spark" )
546+ assertNativeWriter {
547+ spark
548+ .range(1 )
549+ .selectExpr(" current_timestamp() as ts" )
550+ .write
551+ .parquet(dir.toString + " /spark" )
552+ }
527553 val result = spark.read.parquet(dir.toString + " /spark" ).collect()
528554 assert(result.length == 1 )
529555 }
@@ -539,8 +565,12 @@ class CometParquetWriterSuite extends CometTestBase {
539565 withTable(" tab1" , " tab2" ) {
540566 sql(""" CREATE TABLE tab1 (s struct<a: string, b: string>) USING parquet""" )
541567 sql(""" CREATE TABLE tab2 (s struct<c: string, d: string>) USING parquet""" )
542- sql(" INSERT INTO tab1 VALUES (named_struct('a', 'x', 'b', 'y'))" )
543- sql(" INSERT INTO tab2 SELECT * FROM tab1" )
568+ assertNativeWriter {
569+ sql(" INSERT INTO tab1 VALUES (named_struct('a', 'x', 'b', 'y'))" )
570+ }
571+ assertNativeWriter {
572+ sql(" INSERT INTO tab2 SELECT * FROM tab1" )
573+ }
544574 checkAnswer(spark.table(" tab2" ), Row (Row (" x" , " y" )))
545575 }
546576 }
@@ -553,7 +583,9 @@ class CometParquetWriterSuite extends CometTestBase {
553583 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
554584 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
555585 withTempPath { dir =>
556- spark.range(1 ).repartition(1 ).write.parquet(dir.getAbsolutePath)
586+ assertNativeWriter {
587+ spark.range(1 ).repartition(1 ).write.parquet(dir.getAbsolutePath)
588+ }
557589 val files = dir.listFiles().filter(_.getName.endsWith(" .parquet" ))
558590 assert(files.nonEmpty, " Expected parquet files to be written" )
559591 }
@@ -568,7 +600,9 @@ class CometParquetWriterSuite extends CometTestBase {
568600 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
569601 withTempDir { dir =>
570602 val path = dir.getCanonicalPath
571- spark.range(10 ).repartition(10 ).write.mode(" overwrite" ).parquet(path)
603+ assertNativeWriter {
604+ spark.range(10 ).repartition(10 ).write.mode(" overwrite" ).parquet(path)
605+ }
572606 val files = new File (path).listFiles().filter(_.getName.startsWith(" part-" ))
573607 assert(files.length > 0 , " Expected part files to be written" )
574608 }
@@ -584,7 +618,9 @@ class CometParquetWriterSuite extends CometTestBase {
584618 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
585619 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
586620 withTable(" t" ) {
587- sql(" CREATE TABLE t USING parquet AS SELECT 1 AS c UNION ALL SELECT 2" )
621+ assertNativeWriter {
622+ sql(" CREATE TABLE t USING parquet AS SELECT 1 AS c UNION ALL SELECT 2" )
623+ }
588624 checkAnswer(spark.table(" t" ), Seq (Row (1 ), Row (2 )))
589625 }
590626 }
@@ -600,7 +636,9 @@ class CometParquetWriterSuite extends CometTestBase {
600636 withTable(" t1" , " t2" ) {
601637 sql(" CREATE TABLE t1(a INT) USING parquet" )
602638 sql(" CREATE TABLE t2(a INT) USING parquet" )
603- sql(" FROM (SELECT 1 AS a) src INSERT INTO t1 SELECT a INSERT INTO t2 SELECT a" )
639+ assertNativeWriter {
640+ sql(" FROM (SELECT 1 AS a) src INSERT INTO t1 SELECT a INSERT INTO t2 SELECT a" )
641+ }
604642 checkAnswer(spark.table(" t1" ), Row (1 ))
605643 checkAnswer(spark.table(" t2" ), Row (1 ))
606644 }
@@ -626,15 +664,72 @@ class CometParquetWriterSuite extends CometTestBase {
626664 }
627665
628666 /**
629- * Captures the execution plan during a write operation.
667+ * Executes a code block and asserts that CometNativeWriteExec is in the write plan.
668+ * This is used for verifying native writer is called in SQL commands.
630669 *
631- * @param writeOp
632- * The write operation to execute (takes output path as parameter)
633- * @param outputPath
634- * The path to write to
635- * @return
636- * The captured execution plan
670+ * @param block
671+ * The code block to execute (should contain write operations)
637672 */
673+ private def assertNativeWriter (block : => Unit ): Unit = {
674+ var capturedPlan : Option [QueryExecution ] = None
675+
676+ val listener = new org.apache.spark.sql.util.QueryExecutionListener {
677+ override def onSuccess (funcName : String , qe : QueryExecution , durationNs : Long ): Unit = {
678+ if (funcName == " save" || funcName.contains(" command" )) {
679+ capturedPlan = Some (qe)
680+ }
681+ }
682+
683+ override def onFailure (
684+ funcName : String ,
685+ qe : QueryExecution ,
686+ exception : Exception ): Unit = {}
687+ }
688+
689+ spark.listenerManager.register(listener)
690+
691+ try {
692+ block
693+
694+ // Wait for listener to be called with timeout
695+ val maxWaitTimeMs = 15000
696+ val checkIntervalMs = 100
697+ val maxIterations = maxWaitTimeMs / checkIntervalMs
698+ var iterations = 0
699+
700+ while (capturedPlan.isEmpty && iterations < maxIterations) {
701+ Thread .sleep(checkIntervalMs)
702+ iterations += 1
703+ }
704+
705+ assert(
706+ capturedPlan.isDefined,
707+ s " Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured " )
708+
709+ val plan = stripAQEPlan(capturedPlan.get.executedPlan)
710+
711+ // Count CometNativeWriteExec instances in the plan
712+ var nativeWriteCount = 0
713+ plan.foreach {
714+ case _ : CometNativeWriteExec =>
715+ nativeWriteCount += 1
716+ case d : DataWritingCommandExec =>
717+ d.child.foreach {
718+ case _ : CometNativeWriteExec =>
719+ nativeWriteCount += 1
720+ case _ =>
721+ }
722+ case _ =>
723+ }
724+
725+ assert(
726+ nativeWriteCount == 1 ,
727+ s " Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount: \n ${plan.treeString}" )
728+ } finally {
729+ spark.listenerManager.unregister(listener)
730+ }
731+ }
732+
638733 private def captureWritePlan (writeOp : String => Unit , outputPath : String ): SparkPlan = {
639734 var capturedPlan : Option [QueryExecution ] = None
640735
0 commit comments