Skip to content

Commit f089295

Browse files
authored
fix: fall back to Spark for shuffle/sort/aggregate on non-default collated strings [Spark 4] (#4035)
1 parent beecc2d commit f089295

8 files changed

Lines changed: 111 additions & 61 deletions

File tree

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ jobs:
351351
- name: "sql"
352352
value: |
353353
org.apache.spark.sql.CometToPrettyStringSuite
354+
org.apache.spark.sql.CometCollationSuite
354355
fail-fast: false
355356
name: ${{ matrix.profile.name }}/${{ matrix.profile.scan_impl }} [${{ matrix.suite.name }}]
356357
runs-on: ${{ github.repository_owner == 'apache' && format('runs-on={0},family=m8a+m7a+c8a,cpu=16,image=ubuntu24-full-x64,extras=s3-cache,disk=large,tag=datafusion-comet', github.run_id) || 'ubuntu-latest' }}

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ jobs:
227227
- name: "sql"
228228
value: |
229229
org.apache.spark.sql.CometToPrettyStringSuite
230+
org.apache.spark.sql.CometCollationSuite
230231
231232
fail-fast: false
232233
name: ${{ matrix.os }}/${{ matrix.profile.name }} [${{ matrix.suite.name }}]

common/src/main/spark-4.0/org/apache/comet/shims/CometTypeShim.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@
1919

2020
package org.apache.comet.shims
2121

22-
import org.apache.spark.sql.internal.types.StringTypeWithCollation
23-
import org.apache.spark.sql.types.DataType
22+
import org.apache.spark.sql.types.{DataType, StringType}
2423

2524
trait CometTypeShim {
26-
def isStringCollationType(dt: DataType): Boolean = dt.isInstanceOf[StringTypeWithCollation]
25+
// A `StringType` carries collation metadata in Spark 4.0. Only non-default (non-UTF8_BINARY)
26+
// collations have semantics Comet's byte-level hashing/sorting/equality cannot honor. The
27+
// default `StringType` object is `StringType(UTF8_BINARY_COLLATION_ID)`, so comparing
28+
// `collationId` against that instance's id picks out non-default collations without needing
29+
// `private[sql]` helpers on `StringType`.
30+
def isStringCollationType(dt: DataType): Boolean = dt match {
31+
case st: StringType => st.collationId != StringType.collationId
32+
case _ => false
33+
}
2734
}

dev/diffs/4.0.1.diff

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -150,26 +150,6 @@ index 4410fe50912..43bcce2a038 100644
150150
case _ => Map[String, String]()
151151
}
152152
val childrenInfo = children.flatMap {
153-
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out
154-
index 7aca17dcb25..8afeb3b4a2f 100644
155-
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out
156-
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/listagg-collations.sql.out
157-
@@ -64,15 +64,6 @@ WithCTE
158-
+- CTERelationRef xxxx, true, [c1#x], false, false
159-
160-
161-
--- !query
162-
-SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1)
163-
--- !query analysis
164-
-Aggregate [lower(listagg(distinct collate(c1#x, utf8_lcase), null, collate(c1#x, utf8_lcase) ASC NULLS FIRST, 0, 0)) AS lower(listagg(DISTINCT collate(c1, utf8_lcase), NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST))#x]
165-
-+- SubqueryAlias t
166-
- +- Project [col1#x AS c1#x]
167-
- +- LocalRelation [col1#x]
168-
-
169-
-
170-
-- !query
171-
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t
172-
-- !query analysis
173153
diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql
174154
index 17815ed5dde..baad440b1ce 100644
175155
--- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql
@@ -230,21 +210,6 @@ index 698ca009b4f..57d774a3617 100644
230210

231211
-- Test tables
232212
CREATE table explain_temp1 (key int, val int) USING PARQUET;
233-
diff --git a/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
234-
index aa3d02dc2fb..c4f878d9908 100644
235-
--- a/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
236-
+++ b/sql/core/src/test/resources/sql-tests/inputs/listagg-collations.sql
237-
@@ -5,7 +5,9 @@ WITH t(c1) AS (SELECT listagg(col1) WITHIN GROUP (ORDER BY col1) FROM (VALUES ('
238-
-- Test cases with utf8_lcase. Lower expression added for determinism
239-
SELECT lower(listagg(c1) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1);
240-
WITH t(c1) AS (SELECT lower(listagg(DISTINCT col1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('A'), ('b'), ('B'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'b') FROM t;
241-
-SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1);
242-
+-- TODO https://github.com/apache/datafusion-comet/issues/1947
243-
+-- TODO fix Comet for this query
244-
+-- SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1);
245-
-- Test cases with unicode_rtrim.
246-
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t;
247-
WITH t(c1) AS (SELECT listagg(col1) WITHIN GROUP (ORDER BY col1 COLLATE unicode_rtrim) FROM (VALUES ('abc '), ('abc\n'), ('abc'), ('x'))) SELECT replace(replace(c1, ' ', ''), '\n', '$') FROM t;
248213
diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
249214
index 41fd4de2a09..162d5a817b6 100644
250215
--- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/aggregates_part3.sql
@@ -367,25 +332,6 @@ index 21a3ce1e122..f4762ab98f0 100644
367332
SET spark.sql.ansi.enabled = false;
368333

369334
-- In COMPENSATION views get invalidated if the type can't cast
370-
diff --git a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out
371-
index 1f8c5822e7d..b7de4e28813 100644
372-
--- a/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out
373-
+++ b/sql/core/src/test/resources/sql-tests/results/listagg-collations.sql.out
374-
@@ -40,14 +40,6 @@ struct<len(c1):int,regexp_count(c1, a):int,regexp_count(c1, b):int>
375-
2 1 1
376-
377-
378-
--- !query
379-
-SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1)
380-
--- !query schema
381-
-struct<lower(listagg(DISTINCT collate(c1, utf8_lcase), NULL) WITHIN GROUP (ORDER BY collate(c1, utf8_lcase) ASC NULLS FIRST)):string collate UTF8_LCASE>
382-
--- !query output
383-
-ab
384-
-
385-
-
386-
-- !query
387-
WITH t(c1) AS (SELECT replace(listagg(DISTINCT col1 COLLATE unicode_rtrim) COLLATE utf8_binary, ' ', '') FROM (VALUES ('xbc '), ('xbc '), ('a'), ('xbc'))) SELECT len(c1), regexp_count(c1, 'a'), regexp_count(c1, 'xbc') FROM t
388-
-- !query schema
389335
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
390336
index 0f42502f1d9..e9ff802141f 100644
391337
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc}
3939
import org.apache.comet.serde.Types.{DataType => ProtoDataType}
4040
import org.apache.comet.serde.Types.DataType._
4141
import org.apache.comet.serde.literals.CometLiteral
42-
import org.apache.comet.shims.CometExprShim
42+
import org.apache.comet.shims.{CometExprShim, CometTypeShim}
4343

4444
/**
4545
* An utility object for query plan and expression serialization.
4646
*/
47-
object QueryPlanSerde extends Logging with CometExprShim {
47+
object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
4848

4949
private[comet] val arrayExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
5050
classOf[ArrayAppend] -> CometArrayAppend,
@@ -805,6 +805,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
805805
// scalastyle:on
806806
def supportedScalarSortElementType(dt: DataType): Boolean = {
807807
dt match {
808+
// Collated strings require collation-aware ordering; Comet only compares raw bytes.
809+
case st: StringType if isStringCollationType(st) => false
808810
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
809811
_: DoubleType | _: DecimalType | _: DateType | _: TimestampType | _: TimestampNTZType |
810812
_: BooleanType | _: BinaryType | _: StringType =>

spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MOD
5353
import org.apache.comet.CometSparkSessionExtensions.{hasExplainInfo, isCometShuffleManagerEnabled, withInfos}
5454
import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported}
5555
import org.apache.comet.serde.operator.CometSink
56-
import org.apache.comet.shims.ShimCometShuffleExchangeExec
56+
import org.apache.comet.shims.{CometTypeShim, ShimCometShuffleExchangeExec}
5757

5858
/**
5959
* Performs a shuffle that will result in the desired partitioning.
@@ -219,6 +219,7 @@ case class CometShuffleExchangeExec(
219219
object CometShuffleExchangeExec
220220
extends CometSink[ShuffleExchangeExec]
221221
with ShimCometShuffleExchangeExec
222+
with CometTypeShim
222223
with SQLConfHelper {
223224

224225
override def getSupportLevel(op: ShuffleExchangeExec): SupportLevel = {
@@ -316,6 +317,9 @@ object CometShuffleExchangeExec
316317
* hashing complex types, see hash_funcs/utils.rs
317318
*/
318319
def supportedHashPartitioningDataType(dt: DataType): Boolean = dt match {
320+
// Collated strings require collation-aware hashing; Comet only hashes raw bytes,
321+
// which would misroute rows that compare equal under the collation.
322+
case st: StringType if isStringCollationType(st) => false
319323
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
320324
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
321325
_: TimestampNTZType | _: DateType =>
@@ -338,6 +342,8 @@ object CometShuffleExchangeExec
338342
* complex types.
339343
*/
340344
def supportedRangePartitioningDataType(dt: DataType): Boolean = dt match {
345+
// Collated strings require collation-aware ordering; Comet only compares raw bytes.
346+
case st: StringType if isStringCollationType(st) => false
341347
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
342348
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
343349
_: TimestampNTZType | _: DecimalType | _: DateType =>
@@ -498,6 +504,11 @@ object CometShuffleExchangeExec
498504
reasons += s"unsupported hash partitioning expression: $expr"
499505
}
500506
}
507+
for (dt <- expressions.map(_.dataType).distinct) {
508+
if (isStringCollationType(dt)) {
509+
reasons += s"unsupported hash partitioning data type for columnar shuffle: $dt"
510+
}
511+
}
501512
case SinglePartition =>
502513
// we already checked that the input types are supported
503514
case RoundRobinPartitioning(_) =>
@@ -508,6 +519,11 @@ object CometShuffleExchangeExec
508519
reasons += s"unsupported range partitioning sort order: $o"
509520
}
510521
}
522+
for (dt <- orderings.map(_.dataType).distinct) {
523+
if (isStringCollationType(dt)) {
524+
reasons += s"unsupported range partitioning data type for columnar shuffle: $dt"
525+
}
526+
}
511527
case _ =>
512528
reasons +=
513529
s"unsupported Spark partitioning for columnar shuffle: ${partitioning.getClass.getName}"

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, with
5656
import org.apache.comet.parquet.CometParquetUtils
5757
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported}
5858
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator}
59-
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, supportedSortType}
59+
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, isStringCollationType, supportedSortType}
6060
import org.apache.comet.serde.operator.CometSink
6161

6262
/**
@@ -1386,6 +1386,14 @@ trait CometBaseAggregate {
13861386
return None
13871387
}
13881388

1389+
if (groupingExpressions.exists(expr => isStringCollationType(expr.dataType))) {
1390+
// Collation-aware grouping requires collation-aware hashing/equality; Comet only
1391+
// compares raw bytes, which would put rows that compare equal under the collation
1392+
// into different groups.
1393+
withInfo(aggregate, "Grouping on non-default collated strings is not supported")
1394+
return None
1395+
}
1396+
13891397
val groupingExprsWithInput =
13901398
groupingExpressions.map(expr => expr.name -> exprToProto(expr, child.output))
13911399

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql
21+
22+
class CometCollationSuite extends CometTestBase {
23+
24+
// Queries that group, sort, or shuffle on a non-default collated string must fall back to
25+
// Spark because Comet's shuffle/sort/aggregate compare raw bytes rather than collation-aware
26+
// keys. The shuffle-exchange rule is the primary line of defense (see #1947), so these tests
27+
// pin down the fallback reason it emits.
28+
private val hashShuffleCollationReason =
29+
"unsupported hash partitioning data type for columnar shuffle"
30+
private val rangeShuffleCollationReason =
31+
"unsupported range partitioning data type for columnar shuffle"
32+
33+
test("listagg DISTINCT with utf8_lcase collation (issue #1947)") {
34+
checkSparkAnswerAndFallbackReason(
35+
"SELECT lower(listagg(DISTINCT c1 COLLATE utf8_lcase) " +
36+
"WITHIN GROUP (ORDER BY c1 COLLATE utf8_lcase)) " +
37+
"FROM (VALUES ('a'), ('B'), ('b'), ('A')) AS t(c1)",
38+
hashShuffleCollationReason)
39+
}
40+
41+
test("DISTINCT on utf8_lcase collated string groups case-insensitively") {
42+
checkSparkAnswerAndFallbackReason(
43+
"SELECT DISTINCT c1 COLLATE utf8_lcase AS c " +
44+
"FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) ORDER BY c",
45+
hashShuffleCollationReason)
46+
}
47+
48+
test("GROUP BY utf8_lcase collated string groups case-insensitively") {
49+
checkSparkAnswerAndFallbackReason(
50+
"SELECT lower(c1 COLLATE utf8_lcase) AS k, count(*) " +
51+
"FROM (VALUES ('a'), ('A'), ('b'), ('B')) AS t(c1) " +
52+
"GROUP BY c1 COLLATE utf8_lcase ORDER BY k",
53+
hashShuffleCollationReason)
54+
}
55+
56+
test("ORDER BY utf8_lcase collated string sorts case-insensitively") {
57+
checkSparkAnswerAndFallbackReason(
58+
"SELECT c1 COLLATE utf8_lcase AS c " +
59+
"FROM (VALUES ('A'), ('b'), ('a'), ('B')) AS t(c1) ORDER BY c",
60+
rangeShuffleCollationReason)
61+
}
62+
63+
test("default UTF8_BINARY string still runs through Comet") {
64+
// Sanity check that the collation fallback does not over-block the default string type.
65+
withParquetTable(Seq(("a", 1), ("b", 2), ("a", 3)), "tbl") {
66+
checkSparkAnswerAndOperator("SELECT DISTINCT _1 FROM tbl ORDER BY _1")
67+
}
68+
}
69+
}

0 commit comments

Comments
 (0)