2020package org .apache .comet .parquet
2121
2222import java .io .File
23+
2324import scala .util .Random
25+
2426import org .apache .spark .sql .{CometTestBase , DataFrame , Row }
2527import org .apache .spark .sql .comet .{CometBatchScanExec , CometNativeScanExec , CometNativeWriteExec , CometScanExec }
2628import org .apache .spark .sql .execution .{FileSourceScanExec , QueryExecution , SparkPlan }
2729import org .apache .spark .sql .execution .command .DataWritingCommandExec
2830import org .apache .spark .sql .internal .SQLConf
2931import org .apache .spark .sql .types .StructType
32+
3033import org .apache .comet .CometConf
3134import org .apache .comet .CometSparkSessionExtensions .isSpark40Plus
3235import org .apache .comet .testing .{DataGenOptions , FuzzDataGenerator , SchemaGenOptions }
@@ -383,9 +386,7 @@ class CometParquetWriterSuite extends CometTestBase {
383386 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
384387 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
385388 withTempPath { path =>
386- assertNativeWriter {
387- spark.range(100 ).repartition(10 ).where(" id = 50" ).write.parquet(path.toString)
388- }
389+ spark.range(100 ).repartition(10 ).where(" id = 50" ).write.parquet(path.toString)
389390 val partFiles = path
390391 .listFiles()
391392 .filter(f => f.isFile && ! f.getName.startsWith(" ." ) && ! f.getName.startsWith(" _" ))
@@ -402,9 +403,7 @@ class CometParquetWriterSuite extends CometTestBase {
402403 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
403404 withTable(" t1" , " t2" ) {
404405 sql(" CREATE TABLE t1(i CHAR(5), c VARCHAR(4)) USING parquet" )
405- assertNativeWriter {
406- sql(" CREATE TABLE t2 USING parquet AS SELECT * FROM t1" )
407- }
406+ sql(" CREATE TABLE t2 USING parquet AS SELECT * FROM t1" )
408407 checkAnswer(
409408 sql(" desc t2" ).selectExpr(" data_type" ).where(" data_type like '%char%'" ),
410409 Seq (Row (" char(5)" ), Row (" varchar(4)" )))
@@ -421,9 +420,7 @@ class CometParquetWriterSuite extends CometTestBase {
421420 withTable(" t1" , " t2" ) {
422421 sql(" CREATE TABLE t1(col CHAR(5)) USING parquet" )
423422 withSQLConf(SQLConf .CHAR_AS_VARCHAR .key -> " true" ) {
424- assertNativeWriter {
425- sql(" CREATE TABLE t2 USING parquet AS SELECT * FROM t1" )
426- }
423+ sql(" CREATE TABLE t2 USING parquet AS SELECT * FROM t1" )
427424 checkAnswer(
428425 sql(" desc t2" ).selectExpr(" data_type" ).where(" data_type like '%char%'" ),
429426 Seq (Row (" varchar(5)" )))
@@ -442,17 +439,11 @@ class CometParquetWriterSuite extends CometTestBase {
442439 val path = dir.toURI.getPath
443440 withTable(" tab1" , " tab2" ) {
444441 sql(s """ create table tab1 (a int) using parquet location ' $path' """ )
445- assertNativeWriter {
446- sql(" insert into tab1 values(1)" )
447- }
442+ sql(" insert into tab1 values(1)" )
448443 checkAnswer(sql(" select * from tab1" ), Seq (Row (1 )))
449444 sql(" create table tab2 (a int) using parquet" )
450- assertNativeWriter {
451- sql(" insert into tab2 values(2)" )
452- }
453- assertNativeWriter {
454- sql(s """ insert overwrite local directory ' $path' using parquet select * from tab2 """ )
455- }
445+ sql(" insert into tab2 values(2)" )
446+ sql(s """ insert overwrite local directory ' $path' using parquet select * from tab2 """ )
456447 sql(" refresh table tab1" )
457448 checkAnswer(sql(" select * from tab1" ), Seq (Row (2 )))
458449 }
@@ -468,9 +459,7 @@ class CometParquetWriterSuite extends CometTestBase {
468459 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
469460 withTable(" t" ) {
470461 sql(" create table t(i boolean, s bigint) using parquet" )
471- assertNativeWriter {
472- sql(" insert into t(i) values(true)" )
473- }
462+ sql(" insert into t(i) values(true)" )
474463 checkAnswer(spark.table(" t" ), Row (true , null ))
475464 }
476465 }
@@ -485,9 +474,7 @@ class CometParquetWriterSuite extends CometTestBase {
485474 withTable(" t" ) {
486475 sql(" create table t(i boolean) using parquet" )
487476 sql(" alter table t add column s string default concat('abc', 'def')" )
488- assertNativeWriter {
489- sql(" insert into t values(true, default)" )
490- }
477+ sql(" insert into t values(true, default)" )
491478 checkAnswer(spark.table(" t" ), Row (true , " abcdef" ))
492479 }
493480 }
@@ -501,13 +488,9 @@ class CometParquetWriterSuite extends CometTestBase {
501488 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
502489 withTable(" t1" , " t2" ) {
503490 sql(" create table t1(i boolean, s bigint default 42) using parquet" )
504- assertNativeWriter {
505- sql(" insert into t1 values (true, 41), (false, default)" )
506- }
491+ sql(" insert into t1 values (true, 41), (false, default)" )
507492 sql(" create table t2(i boolean, s bigint) using parquet" )
508- assertNativeWriter {
509- sql(" insert into t2 select * from t1 order by s" )
510- }
493+ sql(" insert into t2 select * from t1 order by s" )
511494 checkAnswer(spark.table(" t2" ), Seq (Row (true , 41 ), Row (false , 42 )))
512495 }
513496 }
@@ -522,14 +505,10 @@ class CometParquetWriterSuite extends CometTestBase {
522505 withTable(" tbl" , " tbl2" ) {
523506 withView(" view1" ) {
524507 val df = spark.range(10 ).toDF(" id" )
525- assertNativeWriter {
526- df.write.format(" parquet" ).saveAsTable(" tbl" )
527- }
508+ df.write.format(" parquet" ).saveAsTable(" tbl" )
528509 spark.sql(" CREATE VIEW view1 AS SELECT id FROM tbl" )
529510 spark.sql(" CREATE TABLE tbl2(ID long) USING parquet" )
530- assertNativeWriter {
531- spark.sql(" INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1" )
532- }
511+ spark.sql(" INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1" )
533512 checkAnswer(spark.table(" tbl2" ), (0 until 10 ).map(Row (_)))
534513 }
535514 }
@@ -543,13 +522,11 @@ class CometParquetWriterSuite extends CometTestBase {
543522 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
544523 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
545524 withTempPath { dir =>
546- assertNativeWriter {
547- spark
548- .range(1 )
549- .selectExpr(" current_timestamp() as ts" )
550- .write
551- .parquet(dir.toString + " /spark" )
552- }
525+ spark
526+ .range(1 )
527+ .selectExpr(" current_timestamp() as ts" )
528+ .write
529+ .parquet(dir.toString + " /spark" )
553530 val result = spark.read.parquet(dir.toString + " /spark" ).collect()
554531 assert(result.length == 1 )
555532 }
@@ -565,12 +542,8 @@ class CometParquetWriterSuite extends CometTestBase {
565542 withTable(" tab1" , " tab2" ) {
566543 sql(""" CREATE TABLE tab1 (s struct<a: string, b: string>) USING parquet""" )
567544 sql(""" CREATE TABLE tab2 (s struct<c: string, d: string>) USING parquet""" )
568- assertNativeWriter {
569- sql(" INSERT INTO tab1 VALUES (named_struct('a', 'x', 'b', 'y'))" )
570- }
571- assertNativeWriter {
572- sql(" INSERT INTO tab2 SELECT * FROM tab1" )
573- }
545+ sql(" INSERT INTO tab1 VALUES (named_struct('a', 'x', 'b', 'y'))" )
546+ sql(" INSERT INTO tab2 SELECT * FROM tab1" )
574547 checkAnswer(spark.table(" tab2" ), Row (Row (" x" , " y" )))
575548 }
576549 }
@@ -583,9 +556,7 @@ class CometParquetWriterSuite extends CometTestBase {
583556 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
584557 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
585558 withTempPath { dir =>
586- assertNativeWriter {
587- spark.range(1 ).repartition(1 ).write.parquet(dir.getAbsolutePath)
588- }
559+ spark.range(1 ).repartition(1 ).write.parquet(dir.getAbsolutePath)
589560 val files = dir.listFiles().filter(_.getName.endsWith(" .parquet" ))
590561 assert(files.nonEmpty, " Expected parquet files to be written" )
591562 }
@@ -600,9 +571,7 @@ class CometParquetWriterSuite extends CometTestBase {
600571 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
601572 withTempDir { dir =>
602573 val path = dir.getCanonicalPath
603- assertNativeWriter {
604- spark.range(10 ).repartition(10 ).write.mode(" overwrite" ).parquet(path)
605- }
574+ spark.range(10 ).repartition(10 ).write.mode(" overwrite" ).parquet(path)
606575 val files = new File (path).listFiles().filter(_.getName.startsWith(" part-" ))
607576 assert(files.length > 0 , " Expected part files to be written" )
608577 }
@@ -618,9 +587,7 @@ class CometParquetWriterSuite extends CometTestBase {
618587 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
619588 CometConf .getOperatorAllowIncompatConfigKey(classOf [DataWritingCommandExec ]) -> " true" ) {
620589 withTable(" t" ) {
621- assertNativeWriter {
622- sql(" CREATE TABLE t USING parquet AS SELECT 1 AS c UNION ALL SELECT 2" )
623- }
590+ sql(" CREATE TABLE t USING parquet AS SELECT 1 AS c UNION ALL SELECT 2" )
624591 checkAnswer(spark.table(" t" ), Seq (Row (1 ), Row (2 )))
625592 }
626593 }
@@ -636,9 +603,7 @@ class CometParquetWriterSuite extends CometTestBase {
636603 withTable(" t1" , " t2" ) {
637604 sql(" CREATE TABLE t1(a INT) USING parquet" )
638605 sql(" CREATE TABLE t2(a INT) USING parquet" )
639- assertNativeWriter {
640- sql(" FROM (SELECT 1 AS a) src INSERT INTO t1 SELECT a INSERT INTO t2 SELECT a" )
641- }
606+ sql(" FROM (SELECT 1 AS a) src INSERT INTO t1 SELECT a INSERT INTO t2 SELECT a" )
642607 checkAnswer(spark.table(" t1" ), Row (1 ))
643608 checkAnswer(spark.table(" t2" ), Row (1 ))
644609 }
@@ -664,72 +629,15 @@ class CometParquetWriterSuite extends CometTestBase {
664629 }
665630
666631 /**
667- * Executes a code block and asserts that CometNativeWriteExec is in the write plan.
668- * This is used for verifying native writer is called in SQL commands.
632+ * Captures the execution plan during a write operation.
669633 *
670- * @param block
671- * The code block to execute (should contain write operations)
634+ * @param writeOp
635+ * The write operation to execute (takes output path as parameter)
636+ * @param outputPath
637+ * The path to write to
638+ * @return
639+ * The captured execution plan
672640 */
673- private def assertNativeWriter (block : => Unit ): Unit = {
674- var capturedPlan : Option [QueryExecution ] = None
675-
676- val listener = new org.apache.spark.sql.util.QueryExecutionListener {
677- override def onSuccess (funcName : String , qe : QueryExecution , durationNs : Long ): Unit = {
678- if (funcName == " save" || funcName.contains(" command" )) {
679- capturedPlan = Some (qe)
680- }
681- }
682-
683- override def onFailure (
684- funcName : String ,
685- qe : QueryExecution ,
686- exception : Exception ): Unit = {}
687- }
688-
689- spark.listenerManager.register(listener)
690-
691- try {
692- block
693-
694- // Wait for listener to be called with timeout
695- val maxWaitTimeMs = 15000
696- val checkIntervalMs = 100
697- val maxIterations = maxWaitTimeMs / checkIntervalMs
698- var iterations = 0
699-
700- while (capturedPlan.isEmpty && iterations < maxIterations) {
701- Thread .sleep(checkIntervalMs)
702- iterations += 1
703- }
704-
705- assert(
706- capturedPlan.isDefined,
707- s " Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured " )
708-
709- val plan = stripAQEPlan(capturedPlan.get.executedPlan)
710-
711- // Count CometNativeWriteExec instances in the plan
712- var nativeWriteCount = 0
713- plan.foreach {
714- case _ : CometNativeWriteExec =>
715- nativeWriteCount += 1
716- case d : DataWritingCommandExec =>
717- d.child.foreach {
718- case _ : CometNativeWriteExec =>
719- nativeWriteCount += 1
720- case _ =>
721- }
722- case _ =>
723- }
724-
725- assert(
726- nativeWriteCount == 1 ,
727- s " Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount: \n ${plan.treeString}" )
728- } finally {
729- spark.listenerManager.unregister(listener)
730- }
731- }
732-
733641 private def captureWritePlan (writeOp : String => Unit , outputPath : String ): SparkPlan = {
734642 var capturedPlan : Option [QueryExecution ] = None
735643
0 commit comments