diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqlExpressions/SqliteAggregateFunctionExpression.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqlExpressions/SqliteAggregateFunctionExpression.cs new file mode 100644 index 00000000000..c79a96cbf2f --- /dev/null +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqlExpressions/SqliteAggregateFunctionExpression.cs @@ -0,0 +1,218 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +// ReSharper disable once CheckNamespace +namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SqliteAggregateFunctionExpression : SqlExpression +{ + private static ConstructorInfo? _quotingConstructor; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SqliteAggregateFunctionExpression( + string name, + IReadOnlyList arguments, + IReadOnlyList orderings, + bool nullable, + IEnumerable argumentsPropagateNullability, + Type type, + RelationalTypeMapping? typeMapping) + : base(type, typeMapping) + { + Name = name; + Arguments = arguments.ToList(); + Orderings = orderings; + IsNullable = nullable; + ArgumentsPropagateNullability = argumentsPropagateNullability.ToList(); + } + + /// + /// The name of the aggregate SQL function, e.g. group_concat. + /// + public virtual string Name { get; } + + /// + /// The arguments passed to the aggregate function. + /// + public virtual IReadOnlyList Arguments { get; } + + /// + /// The orderings applied to the aggregated input, rendered inside the function call as + /// group_concat(value, separator ORDER BY ...). + /// + public virtual IReadOnlyList Orderings { get; } + + /// + /// Whether the expression is nullable. + /// + public virtual bool IsNullable { get; } + + /// + /// For each argument, whether a value propagates to a result. + /// + public virtual IReadOnlyList ArgumentsPropagateNullability { get; } + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + SqlExpression[]? arguments = null; + for (var i = 0; i < Arguments.Count; i++) + { + var visitedArgument = (SqlExpression)visitor.Visit(Arguments[i]); + if (visitedArgument != Arguments[i] && arguments is null) + { + arguments = new SqlExpression[Arguments.Count]; + + for (var j = 0; j < i; j++) + { + arguments[j] = Arguments[j]; + } + } + + if (arguments is not null) + { + arguments[i] = visitedArgument; + } + } + + OrderingExpression[]? orderings = null; + for (var i = 0; i < Orderings.Count; i++) + { + var visitedOrdering = (OrderingExpression)visitor.Visit(Orderings[i]); + if (visitedOrdering != Orderings[i] && orderings is null) + { + orderings = new OrderingExpression[Orderings.Count]; + + for (var j = 0; j < i; j++) + { + orderings[j] = Orderings[j]; + } + } + + if (orderings is not null) + { + orderings[i] = visitedOrdering; + } + } + + return arguments is not null || orderings is not null + ? new SqliteAggregateFunctionExpression( + Name, + arguments ?? Arguments, + orderings ?? Orderings, + IsNullable, + ArgumentsPropagateNullability, + Type, + TypeMapping) + : this; + } + + /// + /// Applies the given type mapping, returning a new expression. + /// + public virtual SqliteAggregateFunctionExpression ApplyTypeMapping(RelationalTypeMapping? typeMapping) + => new( + Name, + Arguments, + Orderings, + IsNullable, + ArgumentsPropagateNullability, + Type, + typeMapping ?? TypeMapping); + + /// + /// Returns a new expression with the given arguments and orderings, or this instance if nothing changed. + /// + public virtual SqliteAggregateFunctionExpression Update( + IReadOnlyList arguments, + IReadOnlyList orderings) + => (ReferenceEquals(arguments, Arguments) || arguments.SequenceEqual(Arguments)) + && (ReferenceEquals(orderings, Orderings) || orderings.SequenceEqual(Orderings)) + ? this + : new SqliteAggregateFunctionExpression( + Name, + arguments, + orderings, + IsNullable, + ArgumentsPropagateNullability, + Type, + TypeMapping); + + /// + public override Expression Quote() + => New( + _quotingConstructor ??= typeof(SqliteAggregateFunctionExpression).GetConstructor( + [ + typeof(string), typeof(IReadOnlyList), typeof(IReadOnlyList), typeof(bool), + typeof(IEnumerable), typeof(Type), typeof(RelationalTypeMapping) + ])!, + Constant(Name), + NewArrayInit(typeof(SqlExpression), initializers: Arguments.Select(a => a.Quote())), + NewArrayInit(typeof(OrderingExpression), Orderings.Select(o => o.Quote())), + Constant(IsNullable), + NewArrayInit(typeof(bool), initializers: ArgumentsPropagateNullability.Select(n => Constant(n))), + Constant(Type), + RelationalExpressionQuotingUtilities.QuoteTypeMapping(TypeMapping)); + + /// + protected override void Print(ExpressionPrinter expressionPrinter) + { + expressionPrinter.Append(Name); + + expressionPrinter.Append("("); + expressionPrinter.VisitCollection(Arguments); + + if (Orderings.Count > 0) + { + expressionPrinter.Append(" ORDER BY "); + expressionPrinter.VisitCollection(Orderings); + } + + expressionPrinter.Append(")"); + } + + /// + public override bool Equals(object? obj) + => obj is SqliteAggregateFunctionExpression sqliteAggregateFunctionExpression && Equals(sqliteAggregateFunctionExpression); + + private bool Equals(SqliteAggregateFunctionExpression? other) + => ReferenceEquals(this, other) + || other is not null + && base.Equals(other) + && Name == other.Name + && Arguments.SequenceEqual(other.Arguments) + && Orderings.SequenceEqual(other.Orderings); + + /// + public override int GetHashCode() + { + var hash = new HashCode(); + hash.Add(base.GetHashCode()); + hash.Add(Name); + + for (var i = 0; i < Arguments.Count; i++) + { + hash.Add(Arguments[i]); + } + + for (var i = 0; i < Orderings.Count; i++) + { + hash.Add(Orderings[i]); + } + + return hash.ToHashCode(); + } +} diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQuerySqlGenerator.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQuerySqlGenerator.cs index 26c92c1f900..176c694c7f1 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteQuerySqlGenerator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteQuerySqlGenerator.cs @@ -36,6 +36,10 @@ protected override Expression VisitExtension(Expression extensionExpression) GenerateJsonEach(jsonEachExpression); return extensionExpression; + case SqliteAggregateFunctionExpression aggregateFunctionExpression: + GenerateAggregateFunction(aggregateFunctionExpression); + return extensionExpression; + default: return base.VisitExtension(extensionExpression); } @@ -174,6 +178,40 @@ private void GenerateRegexp(RegexpExpression regexpExpression, bool negated = fa Visit(regexpExpression.Pattern); } + private void GenerateAggregateFunction(SqliteAggregateFunctionExpression aggregateFunctionExpression) + { + Sql.Append(aggregateFunctionExpression.Name).Append("("); + + for (var i = 0; i < aggregateFunctionExpression.Arguments.Count; i++) + { + if (i > 0) + { + Sql.Append(", "); + } + + Visit(aggregateFunctionExpression.Arguments[i]); + } + + // Unlike SQL Server's "WITHIN GROUP (ORDER BY ...)", SQLite renders the ordering inside the function + // parentheses: group_concat(value, separator ORDER BY ...). Supported since SQLite 3.44.0. + if (aggregateFunctionExpression.Orderings.Count > 0) + { + Sql.Append(" ORDER BY "); + + for (var i = 0; i < aggregateFunctionExpression.Orderings.Count; i++) + { + if (i > 0) + { + Sql.Append(", "); + } + + Visit(aggregateFunctionExpression.Orderings[i]); + } + } + + Sql.Append(")"); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlNullabilityProcessor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlNullabilityProcessor.cs index 470dd3af33c..70a39db78bc 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlNullabilityProcessor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlNullabilityProcessor.cs @@ -41,6 +41,8 @@ protected override SqlExpression VisitCustomSqlExpression( { GlobExpression globExpression => VisitGlob(globExpression, allowOptimizedExpansion, out nullable), RegexpExpression regexpExpression => VisitRegexp(regexpExpression, allowOptimizedExpansion, out nullable), + SqliteAggregateFunctionExpression aggregateFunctionExpression + => VisitAggregateFunction(aggregateFunctionExpression, allowOptimizedExpansion, out nullable), _ => base.VisitCustomSqlExpression(sqlExpression, allowOptimizedExpansion, out nullable) }; @@ -84,6 +86,67 @@ protected virtual SqlExpression VisitRegexp( return regexpExpression.Update(match, pattern); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected virtual SqlExpression VisitAggregateFunction( + SqliteAggregateFunctionExpression aggregateFunctionExpression, + bool allowOptimizedExpansion, + out bool nullable) + { + nullable = aggregateFunctionExpression.IsNullable; + + SqlExpression[]? arguments = null; + for (var i = 0; i < aggregateFunctionExpression.Arguments.Count; i++) + { + var visitedArgument = Visit(aggregateFunctionExpression.Arguments[i], out _); + if (visitedArgument != aggregateFunctionExpression.Arguments[i] && arguments is null) + { + arguments = new SqlExpression[aggregateFunctionExpression.Arguments.Count]; + + for (var j = 0; j < i; j++) + { + arguments[j] = aggregateFunctionExpression.Arguments[j]; + } + } + + if (arguments is not null) + { + arguments[i] = visitedArgument; + } + } + + OrderingExpression[]? orderings = null; + for (var i = 0; i < aggregateFunctionExpression.Orderings.Count; i++) + { + var ordering = aggregateFunctionExpression.Orderings[i]; + var visitedOrdering = ordering.Update(Visit(ordering.Expression, out _)); + if (visitedOrdering != aggregateFunctionExpression.Orderings[i] && orderings is null) + { + orderings = new OrderingExpression[aggregateFunctionExpression.Orderings.Count]; + + for (var j = 0; j < i; j++) + { + orderings[j] = aggregateFunctionExpression.Orderings[j]; + } + } + + if (orderings is not null) + { + orderings[i] = visitedOrdering; + } + } + + return arguments is not null || orderings is not null + ? aggregateFunctionExpression.Update( + arguments ?? aggregateFunctionExpression.Arguments, + orderings ?? aggregateFunctionExpression.Orderings) + : aggregateFunctionExpression; + } + /// protected override SqlExpression VisitSqlFunction( SqlFunctionExpression sqlFunctionExpression, diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringAggregateMethodTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringAggregateMethodTranslator.cs index 1f048b447ea..925b955b1b9 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringAggregateMethodTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteStringAggregateMethodTranslator.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.Data.Sqlite; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; // ReSharper disable once CheckNamespace @@ -14,6 +15,10 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal; /// public class SqliteStringAggregateMethodTranslator(ISqlExpressionFactory sqlExpressionFactory) : IAggregateMethodCallTranslator { + // group_concat supports an in-function ORDER BY clause since SQLite 3.44.0. + private readonly bool _isOrderedAggregateSupported + = new Version(new SqliteConnection().ServerVersion) >= new Version(3, 44); + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -47,9 +52,12 @@ public class SqliteStringAggregateMethodTranslator(ISqlExpressionFactory sqlExpr return null; } - // SQLite does not support input ordering on aggregate methods. Since ordering matters very much for translating, if the user - // specified an ordering we refuse to translate (but to error than to ignore in this case). - if (source.Orderings.Count > 0) + // SQLite's group_concat() accepts only a single argument when DISTINCT is used, so it cannot be combined + // with the separator that string.Join/Concat always supply ("DISTINCT aggregates must have exactly one + // argument"). In-aggregate ORDER BY additionally requires SQLite 3.44.0. Fall back to client evaluation + // rather than emit SQL that fails at execution time. + if (source.IsDistinct + || (source.Orderings.Count > 0 && !_isOrderedAggregateSupported)) { return null; } @@ -75,17 +83,33 @@ public class SqliteStringAggregateMethodTranslator(ISqlExpressionFactory sqlExpr sqlExpression = new DistinctExpression(sqlExpression); } - // group_concat returns null when there are no rows (or non-null values), but string.Join returns an empty string. - return sqlExpressionFactory.Coalesce( - sqlExpressionFactory.Function( + var functionArguments = new[] + { + sqlExpression, + sqlExpressionFactory.ApplyTypeMapping(separator, sqlExpression.TypeMapping) + }; + + // SQLite supports ORDER BY inside aggregate functions since 3.44.0: group_concat(value, separator ORDER BY ...). + // When the user specified an ordering we emit our custom expression that renders it; otherwise a plain function call. + SqlExpression aggregate = source.Orderings.Count == 0 + ? sqlExpressionFactory.Function( "group_concat", - [ - sqlExpression, - sqlExpressionFactory.ApplyTypeMapping(separator, sqlExpression.TypeMapping) - ], + functionArguments, nullable: true, argumentsPropagateNullability: Statics.FalseArrays[2], - typeof(string)), + typeof(string)) + : new SqliteAggregateFunctionExpression( + "group_concat", + functionArguments, + source.Orderings, + nullable: true, + argumentsPropagateNullability: Statics.FalseArrays[2], + typeof(string), + sqlExpression.TypeMapping); + + // group_concat returns null when there are no rows (or non-null values), but string.Join returns an empty string. + return sqlExpressionFactory.Coalesce( + aggregate, sqlExpressionFactory.Constant(string.Empty, typeof(string)), sqlExpression.TypeMapping); } diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/Translations/StringTranslationsSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/Translations/StringTranslationsSqliteTest.cs index 9fb628170f0..abd69dfb756 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/Translations/StringTranslationsSqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/Translations/StringTranslationsSqliteTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Text.RegularExpressions; +using Microsoft.EntityFrameworkCore.Sqlite.Internal; using Microsoft.EntityFrameworkCore.TestModels.BasicTypesModel; namespace Microsoft.EntityFrameworkCore.Query.Translations; @@ -1418,25 +1419,38 @@ GROUP BY "b"."Int" public override async Task Join_with_ordering() { - // SQLite does not support input ordering on aggregate methods; the below does client evaluation. + // group_concat with an ORDER BY clause requires SQLite 3.44.0. + Assert.SkipUnless(SqliteTestEnvironment.VersionAtLeast3_44, "Requires SQLite 3.44.0 for ORDER BY in aggregate functions."); + await base.Join_with_ordering(); AssertSql( """ -SELECT "b1"."Int", "b0"."String", "b0"."Id" -FROM ( - SELECT "b"."Int" - FROM "BasicTypesEntities" AS "b" - GROUP BY "b"."Int" -) AS "b1" -LEFT JOIN "BasicTypesEntities" AS "b0" ON "b1"."Int" = "b0"."Int" -ORDER BY "b1"."Int", "b0"."Id" DESC +SELECT "b"."Int" AS "Key", COALESCE(group_concat("b"."String", '|' ORDER BY "b"."Id" DESC), '') AS "Strings" +FROM "BasicTypesEntities" AS "b" +GROUP BY "b"."Int" """); } public override Task Join_non_aggregate() => AssertTranslationFailed(() => base.Join_non_aggregate()); + [Fact] + public virtual async Task Join_with_distinct() + { + // SQLite's group_concat() accepts only a single argument when DISTINCT is used, so it cannot be combined + // with the separator that string.Join always supplies ("DISTINCT aggregates must have exactly one argument"). + // We therefore don't translate it; the query then falls back to APPLY, which SQLite does not support. + Assert.Equal( + SqliteStrings.ApplyNotSupported, + (await Assert.ThrowsAsync( + () => AssertQuery( + ss => ss.Set() + .GroupBy(c => c.Int) + .Select(g => new { g.Key, Strings = string.Join("|", g.Select(e => e.String).Distinct()) }), + elementSorter: x => x.Key))).Message); + } + #endregion Join #region Concatenation diff --git a/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqliteTestEnvironment.cs b/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqliteTestEnvironment.cs index 175a1467091..b544643ca7e 100644 --- a/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqliteTestEnvironment.cs +++ b/test/EFCore.Sqlite.FunctionalTests/TestUtilities/SqliteTestEnvironment.cs @@ -30,4 +30,10 @@ private static readonly Lazy CurrentVersionLazy /// public static bool VersionAtLeast3_35 => CurrentVersionLazy.Value is { } v && v >= new Version(3, 35, 0); + + /// + /// SQLite version >= 3.44.0 (required for ORDER BY inside aggregate functions, e.g. group_concat). + /// + public static bool VersionAtLeast3_44 + => CurrentVersionLazy.Value is { } v && v >= new Version(3, 44, 0); }