Skip to content

Commit fec7c20

Browse files
committed
enable_spark_tests_comet_native_writer_fix_spark_rebase_main
1 parent 3481d43 commit fec7c20

1 file changed

Lines changed: 125 additions & 30 deletions

File tree

spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala

Lines changed: 125 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,9 @@ class CometParquetWriterSuite extends CometTestBase {
383383
CometConf.COMET_EXEC_ENABLED.key -> "true",
384384
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
385385
withTempPath { path =>
386-
spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString)
386+
assertNativeWriter {
387+
spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString)
388+
}
387389
val partFiles = path
388390
.listFiles()
389391
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
@@ -400,7 +402,9 @@ class CometParquetWriterSuite extends CometTestBase {
400402
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
401403
withTable("t1", "t2") {
402404
sql("CREATE TABLE t1(i CHAR(5), c VARCHAR(4)) USING parquet")
403-
sql("CREATE TABLE t2 USING parquet AS SELECT * FROM t1")
405+
assertNativeWriter {
406+
sql("CREATE TABLE t2 USING parquet AS SELECT * FROM t1")
407+
}
404408
checkAnswer(
405409
sql("desc t2").selectExpr("data_type").where("data_type like '%char%'"),
406410
Seq(Row("char(5)"), Row("varchar(4)")))
@@ -417,7 +421,9 @@ class CometParquetWriterSuite extends CometTestBase {
417421
withTable("t1", "t2") {
418422
sql("CREATE TABLE t1(col CHAR(5)) USING parquet")
419423
withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") {
420-
sql("CREATE TABLE t2 USING parquet AS SELECT * FROM t1")
424+
assertNativeWriter {
425+
sql("CREATE TABLE t2 USING parquet AS SELECT * FROM t1")
426+
}
421427
checkAnswer(
422428
sql("desc t2").selectExpr("data_type").where("data_type like '%char%'"),
423429
Seq(Row("varchar(5)")))
@@ -436,11 +442,17 @@ class CometParquetWriterSuite extends CometTestBase {
436442
val path = dir.toURI.getPath
437443
withTable("tab1", "tab2") {
438444
sql(s"""create table tab1 (a int) using parquet location '$path'""")
439-
sql("insert into tab1 values(1)")
445+
assertNativeWriter {
446+
sql("insert into tab1 values(1)")
447+
}
440448
checkAnswer(sql("select * from tab1"), Seq(Row(1)))
441449
sql("create table tab2 (a int) using parquet")
442-
sql("insert into tab2 values(2)")
443-
sql(s"""insert overwrite local directory '$path' using parquet select * from tab2""")
450+
assertNativeWriter {
451+
sql("insert into tab2 values(2)")
452+
}
453+
assertNativeWriter {
454+
sql(s"""insert overwrite local directory '$path' using parquet select * from tab2""")
455+
}
444456
sql("refresh table tab1")
445457
checkAnswer(sql("select * from tab1"), Seq(Row(2)))
446458
}
@@ -456,7 +468,9 @@ class CometParquetWriterSuite extends CometTestBase {
456468
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
457469
withTable("t") {
458470
sql("create table t(i boolean, s bigint) using parquet")
459-
sql("insert into t(i) values(true)")
471+
assertNativeWriter {
472+
sql("insert into t(i) values(true)")
473+
}
460474
checkAnswer(spark.table("t"), Row(true, null))
461475
}
462476
}
@@ -471,7 +485,9 @@ class CometParquetWriterSuite extends CometTestBase {
471485
withTable("t") {
472486
sql("create table t(i boolean) using parquet")
473487
sql("alter table t add column s string default concat('abc', 'def')")
474-
sql("insert into t values(true, default)")
488+
assertNativeWriter {
489+
sql("insert into t values(true, default)")
490+
}
475491
checkAnswer(spark.table("t"), Row(true, "abcdef"))
476492
}
477493
}
@@ -485,9 +501,13 @@ class CometParquetWriterSuite extends CometTestBase {
485501
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
486502
withTable("t1", "t2") {
487503
sql("create table t1(i boolean, s bigint default 42) using parquet")
488-
sql("insert into t1 values (true, 41), (false, default)")
504+
assertNativeWriter {
505+
sql("insert into t1 values (true, 41), (false, default)")
506+
}
489507
sql("create table t2(i boolean, s bigint) using parquet")
490-
sql("insert into t2 select * from t1 order by s")
508+
assertNativeWriter {
509+
sql("insert into t2 select * from t1 order by s")
510+
}
491511
checkAnswer(spark.table("t2"), Seq(Row(true, 41), Row(false, 42)))
492512
}
493513
}
@@ -502,10 +522,14 @@ class CometParquetWriterSuite extends CometTestBase {
502522
withTable("tbl", "tbl2") {
503523
withView("view1") {
504524
val df = spark.range(10).toDF("id")
505-
df.write.format("parquet").saveAsTable("tbl")
525+
assertNativeWriter {
526+
df.write.format("parquet").saveAsTable("tbl")
527+
}
506528
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
507529
spark.sql("CREATE TABLE tbl2(ID long) USING parquet")
508-
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
530+
assertNativeWriter {
531+
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
532+
}
509533
checkAnswer(spark.table("tbl2"), (0 until 10).map(Row(_)))
510534
}
511535
}
@@ -519,11 +543,13 @@ class CometParquetWriterSuite extends CometTestBase {
519543
CometConf.COMET_EXEC_ENABLED.key -> "true",
520544
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
521545
withTempPath { dir =>
522-
spark
523-
.range(1)
524-
.selectExpr("current_timestamp() as ts")
525-
.write
526-
.parquet(dir.toString + "/spark")
546+
assertNativeWriter {
547+
spark
548+
.range(1)
549+
.selectExpr("current_timestamp() as ts")
550+
.write
551+
.parquet(dir.toString + "/spark")
552+
}
527553
val result = spark.read.parquet(dir.toString + "/spark").collect()
528554
assert(result.length == 1)
529555
}
@@ -539,8 +565,12 @@ class CometParquetWriterSuite extends CometTestBase {
539565
withTable("tab1", "tab2") {
540566
sql("""CREATE TABLE tab1 (s struct<a: string, b: string>) USING parquet""")
541567
sql("""CREATE TABLE tab2 (s struct<c: string, d: string>) USING parquet""")
542-
sql("INSERT INTO tab1 VALUES (named_struct('a', 'x', 'b', 'y'))")
543-
sql("INSERT INTO tab2 SELECT * FROM tab1")
568+
assertNativeWriter {
569+
sql("INSERT INTO tab1 VALUES (named_struct('a', 'x', 'b', 'y'))")
570+
}
571+
assertNativeWriter {
572+
sql("INSERT INTO tab2 SELECT * FROM tab1")
573+
}
544574
checkAnswer(spark.table("tab2"), Row(Row("x", "y")))
545575
}
546576
}
@@ -553,7 +583,9 @@ class CometParquetWriterSuite extends CometTestBase {
553583
CometConf.COMET_EXEC_ENABLED.key -> "true",
554584
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
555585
withTempPath { dir =>
556-
spark.range(1).repartition(1).write.parquet(dir.getAbsolutePath)
586+
assertNativeWriter {
587+
spark.range(1).repartition(1).write.parquet(dir.getAbsolutePath)
588+
}
557589
val files = dir.listFiles().filter(_.getName.endsWith(".parquet"))
558590
assert(files.nonEmpty, "Expected parquet files to be written")
559591
}
@@ -568,7 +600,9 @@ class CometParquetWriterSuite extends CometTestBase {
568600
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
569601
withTempDir { dir =>
570602
val path = dir.getCanonicalPath
571-
spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
603+
assertNativeWriter {
604+
spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
605+
}
572606
val files = new File(path).listFiles().filter(_.getName.startsWith("part-"))
573607
assert(files.length > 0, "Expected part files to be written")
574608
}
@@ -584,7 +618,9 @@ class CometParquetWriterSuite extends CometTestBase {
584618
CometConf.COMET_EXEC_ENABLED.key -> "true",
585619
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
586620
withTable("t") {
587-
sql("CREATE TABLE t USING parquet AS SELECT 1 AS c UNION ALL SELECT 2")
621+
assertNativeWriter {
622+
sql("CREATE TABLE t USING parquet AS SELECT 1 AS c UNION ALL SELECT 2")
623+
}
588624
checkAnswer(spark.table("t"), Seq(Row(1), Row(2)))
589625
}
590626
}
@@ -600,7 +636,9 @@ class CometParquetWriterSuite extends CometTestBase {
600636
withTable("t1", "t2") {
601637
sql("CREATE TABLE t1(a INT) USING parquet")
602638
sql("CREATE TABLE t2(a INT) USING parquet")
603-
sql("FROM (SELECT 1 AS a) src INSERT INTO t1 SELECT a INSERT INTO t2 SELECT a")
639+
assertNativeWriter {
640+
sql("FROM (SELECT 1 AS a) src INSERT INTO t1 SELECT a INSERT INTO t2 SELECT a")
641+
}
604642
checkAnswer(spark.table("t1"), Row(1))
605643
checkAnswer(spark.table("t2"), Row(1))
606644
}
@@ -626,15 +664,72 @@ class CometParquetWriterSuite extends CometTestBase {
626664
}
627665

628666
/**
629-
* Captures the execution plan during a write operation.
667+
* Executes a code block and asserts that CometNativeWriteExec is in the write plan.
668+
* This is used for verifying native writer is called in SQL commands.
630669
*
631-
* @param writeOp
632-
* The write operation to execute (takes output path as parameter)
633-
* @param outputPath
634-
* The path to write to
635-
* @return
636-
* The captured execution plan
670+
* @param block
671+
* The code block to execute (should contain write operations)
637672
*/
673+
private def assertNativeWriter(block: => Unit): Unit = {
674+
var capturedPlan: Option[QueryExecution] = None
675+
676+
val listener = new org.apache.spark.sql.util.QueryExecutionListener {
677+
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
678+
if (funcName == "save" || funcName.contains("command")) {
679+
capturedPlan = Some(qe)
680+
}
681+
}
682+
683+
override def onFailure(
684+
funcName: String,
685+
qe: QueryExecution,
686+
exception: Exception): Unit = {}
687+
}
688+
689+
spark.listenerManager.register(listener)
690+
691+
try {
692+
block
693+
694+
// Wait for listener to be called with timeout
695+
val maxWaitTimeMs = 15000
696+
val checkIntervalMs = 100
697+
val maxIterations = maxWaitTimeMs / checkIntervalMs
698+
var iterations = 0
699+
700+
while (capturedPlan.isEmpty && iterations < maxIterations) {
701+
Thread.sleep(checkIntervalMs)
702+
iterations += 1
703+
}
704+
705+
assert(
706+
capturedPlan.isDefined,
707+
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")
708+
709+
val plan = stripAQEPlan(capturedPlan.get.executedPlan)
710+
711+
// Count CometNativeWriteExec instances in the plan
712+
var nativeWriteCount = 0
713+
plan.foreach {
714+
case _: CometNativeWriteExec =>
715+
nativeWriteCount += 1
716+
case d: DataWritingCommandExec =>
717+
d.child.foreach {
718+
case _: CometNativeWriteExec =>
719+
nativeWriteCount += 1
720+
case _ =>
721+
}
722+
case _ =>
723+
}
724+
725+
assert(
726+
nativeWriteCount == 1,
727+
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${plan.treeString}")
728+
} finally {
729+
spark.listenerManager.unregister(listener)
730+
}
731+
}
732+
638733
private def captureWritePlan(writeOp: String => Unit, outputPath: String): SparkPlan = {
639734
var capturedPlan: Option[QueryExecution] = None
640735

0 commit comments

Comments
 (0)