Skip to content

Commit 570521f

Browse files
committed
enable_spark_tests_comet_native_writer_fix_spark_rebase_main
1 parent 370d98d commit 570521f

1 file changed

Lines changed: 33 additions & 125 deletions

File tree

spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala

Lines changed: 33 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
package org.apache.comet.parquet
2121

2222
import java.io.File
23+
2324
import scala.util.Random
25+
2426
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
2527
import org.apache.spark.sql.comet.{CometBatchScanExec, CometNativeScanExec, CometNativeWriteExec, CometScanExec}
2628
import org.apache.spark.sql.execution.{FileSourceScanExec, QueryExecution, SparkPlan}
2729
import org.apache.spark.sql.execution.command.DataWritingCommandExec
2830
import org.apache.spark.sql.internal.SQLConf
2931
import org.apache.spark.sql.types.StructType
32+
3033
import org.apache.comet.CometConf
3134
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
3235
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, SchemaGenOptions}
@@ -383,9 +386,7 @@ class CometParquetWriterSuite extends CometTestBase {
383386
CometConf.COMET_EXEC_ENABLED.key -> "true",
384387
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
385388
withTempPath { path =>
386-
assertNativeWriter {
387-
spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString)
388-
}
389+
spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString)
389390
val partFiles = path
390391
.listFiles()
391392
.filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
@@ -402,9 +403,7 @@ class CometParquetWriterSuite extends CometTestBase {
402403
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
403404
withTable("t1", "t2") {
404405
sql("CREATE TABLE t1(i CHAR(5), c VARCHAR(4)) USING parquet")
405-
assertNativeWriter {
406-
sql("CREATE TABLE t2 USING parquet AS SELECT * FROM t1")
407-
}
406+
sql("CREATE TABLE t2 USING parquet AS SELECT * FROM t1")
408407
checkAnswer(
409408
sql("desc t2").selectExpr("data_type").where("data_type like '%char%'"),
410409
Seq(Row("char(5)"), Row("varchar(4)")))
@@ -421,9 +420,7 @@ class CometParquetWriterSuite extends CometTestBase {
421420
withTable("t1", "t2") {
422421
sql("CREATE TABLE t1(col CHAR(5)) USING parquet")
423422
withSQLConf(SQLConf.CHAR_AS_VARCHAR.key -> "true") {
424-
assertNativeWriter {
425-
sql("CREATE TABLE t2 USING parquet AS SELECT * FROM t1")
426-
}
423+
sql("CREATE TABLE t2 USING parquet AS SELECT * FROM t1")
427424
checkAnswer(
428425
sql("desc t2").selectExpr("data_type").where("data_type like '%char%'"),
429426
Seq(Row("varchar(5)")))
@@ -442,17 +439,11 @@ class CometParquetWriterSuite extends CometTestBase {
442439
val path = dir.toURI.getPath
443440
withTable("tab1", "tab2") {
444441
sql(s"""create table tab1 (a int) using parquet location '$path'""")
445-
assertNativeWriter {
446-
sql("insert into tab1 values(1)")
447-
}
442+
sql("insert into tab1 values(1)")
448443
checkAnswer(sql("select * from tab1"), Seq(Row(1)))
449444
sql("create table tab2 (a int) using parquet")
450-
assertNativeWriter {
451-
sql("insert into tab2 values(2)")
452-
}
453-
assertNativeWriter {
454-
sql(s"""insert overwrite local directory '$path' using parquet select * from tab2""")
455-
}
445+
sql("insert into tab2 values(2)")
446+
sql(s"""insert overwrite local directory '$path' using parquet select * from tab2""")
456447
sql("refresh table tab1")
457448
checkAnswer(sql("select * from tab1"), Seq(Row(2)))
458449
}
@@ -468,9 +459,7 @@ class CometParquetWriterSuite extends CometTestBase {
468459
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
469460
withTable("t") {
470461
sql("create table t(i boolean, s bigint) using parquet")
471-
assertNativeWriter {
472-
sql("insert into t(i) values(true)")
473-
}
462+
sql("insert into t(i) values(true)")
474463
checkAnswer(spark.table("t"), Row(true, null))
475464
}
476465
}
@@ -485,9 +474,7 @@ class CometParquetWriterSuite extends CometTestBase {
485474
withTable("t") {
486475
sql("create table t(i boolean) using parquet")
487476
sql("alter table t add column s string default concat('abc', 'def')")
488-
assertNativeWriter {
489-
sql("insert into t values(true, default)")
490-
}
477+
sql("insert into t values(true, default)")
491478
checkAnswer(spark.table("t"), Row(true, "abcdef"))
492479
}
493480
}
@@ -501,13 +488,9 @@ class CometParquetWriterSuite extends CometTestBase {
501488
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
502489
withTable("t1", "t2") {
503490
sql("create table t1(i boolean, s bigint default 42) using parquet")
504-
assertNativeWriter {
505-
sql("insert into t1 values (true, 41), (false, default)")
506-
}
491+
sql("insert into t1 values (true, 41), (false, default)")
507492
sql("create table t2(i boolean, s bigint) using parquet")
508-
assertNativeWriter {
509-
sql("insert into t2 select * from t1 order by s")
510-
}
493+
sql("insert into t2 select * from t1 order by s")
511494
checkAnswer(spark.table("t2"), Seq(Row(true, 41), Row(false, 42)))
512495
}
513496
}
@@ -522,14 +505,10 @@ class CometParquetWriterSuite extends CometTestBase {
522505
withTable("tbl", "tbl2") {
523506
withView("view1") {
524507
val df = spark.range(10).toDF("id")
525-
assertNativeWriter {
526-
df.write.format("parquet").saveAsTable("tbl")
527-
}
508+
df.write.format("parquet").saveAsTable("tbl")
528509
spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl")
529510
spark.sql("CREATE TABLE tbl2(ID long) USING parquet")
530-
assertNativeWriter {
531-
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
532-
}
511+
spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1")
533512
checkAnswer(spark.table("tbl2"), (0 until 10).map(Row(_)))
534513
}
535514
}
@@ -543,13 +522,11 @@ class CometParquetWriterSuite extends CometTestBase {
543522
CometConf.COMET_EXEC_ENABLED.key -> "true",
544523
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
545524
withTempPath { dir =>
546-
assertNativeWriter {
547-
spark
548-
.range(1)
549-
.selectExpr("current_timestamp() as ts")
550-
.write
551-
.parquet(dir.toString + "/spark")
552-
}
525+
spark
526+
.range(1)
527+
.selectExpr("current_timestamp() as ts")
528+
.write
529+
.parquet(dir.toString + "/spark")
553530
val result = spark.read.parquet(dir.toString + "/spark").collect()
554531
assert(result.length == 1)
555532
}
@@ -565,12 +542,8 @@ class CometParquetWriterSuite extends CometTestBase {
565542
withTable("tab1", "tab2") {
566543
sql("""CREATE TABLE tab1 (s struct<a: string, b: string>) USING parquet""")
567544
sql("""CREATE TABLE tab2 (s struct<c: string, d: string>) USING parquet""")
568-
assertNativeWriter {
569-
sql("INSERT INTO tab1 VALUES (named_struct('a', 'x', 'b', 'y'))")
570-
}
571-
assertNativeWriter {
572-
sql("INSERT INTO tab2 SELECT * FROM tab1")
573-
}
545+
sql("INSERT INTO tab1 VALUES (named_struct('a', 'x', 'b', 'y'))")
546+
sql("INSERT INTO tab2 SELECT * FROM tab1")
574547
checkAnswer(spark.table("tab2"), Row(Row("x", "y")))
575548
}
576549
}
@@ -583,9 +556,7 @@ class CometParquetWriterSuite extends CometTestBase {
583556
CometConf.COMET_EXEC_ENABLED.key -> "true",
584557
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
585558
withTempPath { dir =>
586-
assertNativeWriter {
587-
spark.range(1).repartition(1).write.parquet(dir.getAbsolutePath)
588-
}
559+
spark.range(1).repartition(1).write.parquet(dir.getAbsolutePath)
589560
val files = dir.listFiles().filter(_.getName.endsWith(".parquet"))
590561
assert(files.nonEmpty, "Expected parquet files to be written")
591562
}
@@ -600,9 +571,7 @@ class CometParquetWriterSuite extends CometTestBase {
600571
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
601572
withTempDir { dir =>
602573
val path = dir.getCanonicalPath
603-
assertNativeWriter {
604-
spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
605-
}
574+
spark.range(10).repartition(10).write.mode("overwrite").parquet(path)
606575
val files = new File(path).listFiles().filter(_.getName.startsWith("part-"))
607576
assert(files.length > 0, "Expected part files to be written")
608577
}
@@ -618,9 +587,7 @@ class CometParquetWriterSuite extends CometTestBase {
618587
CometConf.COMET_EXEC_ENABLED.key -> "true",
619588
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") {
620589
withTable("t") {
621-
assertNativeWriter {
622-
sql("CREATE TABLE t USING parquet AS SELECT 1 AS c UNION ALL SELECT 2")
623-
}
590+
sql("CREATE TABLE t USING parquet AS SELECT 1 AS c UNION ALL SELECT 2")
624591
checkAnswer(spark.table("t"), Seq(Row(1), Row(2)))
625592
}
626593
}
@@ -636,9 +603,7 @@ class CometParquetWriterSuite extends CometTestBase {
636603
withTable("t1", "t2") {
637604
sql("CREATE TABLE t1(a INT) USING parquet")
638605
sql("CREATE TABLE t2(a INT) USING parquet")
639-
assertNativeWriter {
640-
sql("FROM (SELECT 1 AS a) src INSERT INTO t1 SELECT a INSERT INTO t2 SELECT a")
641-
}
606+
sql("FROM (SELECT 1 AS a) src INSERT INTO t1 SELECT a INSERT INTO t2 SELECT a")
642607
checkAnswer(spark.table("t1"), Row(1))
643608
checkAnswer(spark.table("t2"), Row(1))
644609
}
@@ -664,72 +629,15 @@ class CometParquetWriterSuite extends CometTestBase {
664629
}
665630

666631
/**
667-
* Executes a code block and asserts that CometNativeWriteExec is in the write plan.
668-
* This is used for verifying native writer is called in SQL commands.
632+
* Captures the execution plan during a write operation.
669633
*
670-
* @param block
671-
* The code block to execute (should contain write operations)
634+
* @param writeOp
635+
* The write operation to execute (takes output path as parameter)
636+
* @param outputPath
637+
* The path to write to
638+
* @return
639+
* The captured execution plan
672640
*/
673-
private def assertNativeWriter(block: => Unit): Unit = {
674-
var capturedPlan: Option[QueryExecution] = None
675-
676-
val listener = new org.apache.spark.sql.util.QueryExecutionListener {
677-
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
678-
if (funcName == "save" || funcName.contains("command")) {
679-
capturedPlan = Some(qe)
680-
}
681-
}
682-
683-
override def onFailure(
684-
funcName: String,
685-
qe: QueryExecution,
686-
exception: Exception): Unit = {}
687-
}
688-
689-
spark.listenerManager.register(listener)
690-
691-
try {
692-
block
693-
694-
// Wait for listener to be called with timeout
695-
val maxWaitTimeMs = 15000
696-
val checkIntervalMs = 100
697-
val maxIterations = maxWaitTimeMs / checkIntervalMs
698-
var iterations = 0
699-
700-
while (capturedPlan.isEmpty && iterations < maxIterations) {
701-
Thread.sleep(checkIntervalMs)
702-
iterations += 1
703-
}
704-
705-
assert(
706-
capturedPlan.isDefined,
707-
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")
708-
709-
val plan = stripAQEPlan(capturedPlan.get.executedPlan)
710-
711-
// Count CometNativeWriteExec instances in the plan
712-
var nativeWriteCount = 0
713-
plan.foreach {
714-
case _: CometNativeWriteExec =>
715-
nativeWriteCount += 1
716-
case d: DataWritingCommandExec =>
717-
d.child.foreach {
718-
case _: CometNativeWriteExec =>
719-
nativeWriteCount += 1
720-
case _ =>
721-
}
722-
case _ =>
723-
}
724-
725-
assert(
726-
nativeWriteCount == 1,
727-
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${plan.treeString}")
728-
} finally {
729-
spark.listenerManager.unregister(listener)
730-
}
731-
}
732-
733641
private def captureWritePlan(writeOp: String => Unit, outputPath: String): SparkPlan = {
734642
var capturedPlan: Option[QueryExecution] = None
735643

0 commit comments

Comments
 (0)