diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs index d56f63380..e6ea5c1f3 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs @@ -39,6 +39,9 @@ public class NpgsqlStringMethodTranslator : IMethodCallTranslator private static readonly MethodInfo Replace = typeof(string).GetRuntimeMethod( nameof(string.Replace), [typeof(string), typeof(string)])!; + private static readonly MethodInfo Replace_Char = typeof(string).GetRuntimeMethod( + nameof(string.Replace), [typeof(char), typeof(char)])!; + private static readonly MethodInfo Substring = typeof(string).GetTypeInfo().GetDeclaredMethods(nameof(string.Substring)) .Single(m => m.GetParameters().Length == 1); @@ -204,7 +207,7 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I { var argument = arguments[0]; var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, argument); - + argument = _sqlExpressionFactory.ApplyTypeMapping(argument, argument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); return _sqlExpressionFactory.Subtract( _sqlExpressionFactory.Function( "strpos", @@ -218,12 +221,15 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I _sqlExpressionFactory.Constant(1)); } - if (method == Replace) + if (method == Replace || method == Replace_Char) { var oldValue = arguments[0]; var newValue = arguments[1]; var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, oldValue, newValue); + oldValue = _sqlExpressionFactory.ApplyTypeMapping(oldValue, oldValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); + newValue = _sqlExpressionFactory.ApplyTypeMapping(newValue, newValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); + return _sqlExpressionFactory.Function( "replace", [ diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs index 7211b966c..e1d7ce2ea 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -45,12 +45,21 @@ public class NpgsqlSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExp private static readonly MethodInfo StringStartsWithMethod = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string)])!; + private static readonly MethodInfo StringStartsWithMethodChar + = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(char)])!; + private static readonly MethodInfo StringEndsWithMethod = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string)])!; + private static readonly MethodInfo StringEndsWithMethodChar + = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(char)])!; + private static readonly MethodInfo StringContainsMethod = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!; + private static readonly MethodInfo StringContainsMethodChar + = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char)])!; + private static readonly MethodInfo EscapeLikePatternParameterMethod = typeof(NpgsqlSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ConstructLikePatternParameter))!; @@ -405,21 +414,21 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return TranslateCubeToSubset(sqlCubeInstance, sqlIndexes) ?? QueryCompilationContext.NotTranslatedExpression; } - if (method == StringStartsWithMethod + if ((method == StringStartsWithMethod || method == StringStartsWithMethodChar) && TryTranslateStartsEndsWithContains( methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.StartsWith, out var translation1)) { return translation1; } - if (method == StringEndsWithMethod + if ((method == StringEndsWithMethod || method == StringEndsWithMethodChar) && TryTranslateStartsEndsWithContains( methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.EndsWith, out var translation2)) { return translation2; } - if (method == StringContainsMethod + if ((method == StringContainsMethod || method == StringContainsMethodChar) && TryTranslateStartsEndsWithContains( methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.Contains, out var translation3)) { @@ -719,6 +728,32 @@ private bool TryTranslateStartsEndsWithContains( _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) })), + char s when !IsLikeWildChar(s) + => _sqlExpressionFactory.Like( + translatedInstance, + _sqlExpressionFactory.Constant( + methodType switch + { + StartsEndsWithContains.StartsWith => s + "%", + StartsEndsWithContains.EndsWith => "%" + s, + StartsEndsWithContains.Contains => $"%{s}%", + + _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) + })), + + char s => _sqlExpressionFactory.Like( + translatedInstance, + _sqlExpressionFactory.Constant( + methodType switch + { + StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%", + StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s, + StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%", + + _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) + }), + _sqlExpressionFactory.Constant(LikeEscapeChar)), + _ => throw new UnreachableException() }; @@ -834,6 +869,22 @@ private bool TryTranslateStartsEndsWithContains( _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) }, + char s when !IsLikeWildChar(s) => methodType switch + { + StartsEndsWithContains.StartsWith => s + "%", + StartsEndsWithContains.EndsWith => "%" + s, + StartsEndsWithContains.Contains => $"%{s}%", + _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) + }, + + char s => methodType switch + { + StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%", + StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s, + StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%", + _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) + }, + _ => throw new UnreachableException() }; diff --git a/test/EFCore.PG.FunctionalTests/Query/Translations/StringTranslationsNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/Translations/StringTranslationsNpgsqlTest.cs index 2ebc6fb97..f28419be4 100644 --- a/test/EFCore.PG.FunctionalTests/Query/Translations/StringTranslationsNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/Translations/StringTranslationsNpgsqlTest.cs @@ -131,9 +131,17 @@ WHERE strpos(b."String", 'eattl') - 1 <> -1 """); } - // TODO: #3547 - public override Task IndexOf_Char() - => Assert.ThrowsAsync(() => base.IndexOf_Char()); + public override async Task IndexOf_Char() + { + await base.IndexOf_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE strpos(b."String", 'e') - 1 <> -1 +"""); + } public override async Task IndexOf_with_empty_string() { @@ -231,9 +239,17 @@ WHERE replace(b."String", 'Sea', 'Rea') = 'Reattle' """); } - // TODO: #3547 - public override Task Replace_Char() - => AssertTranslationFailed(() => base.Replace_Char()); + public override async Task Replace_Char() + { + await base.Replace_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE replace(b."String", 'S', 'R') = 'Reattle' +"""); + } public override async Task Replace_with_empty_string() { @@ -429,9 +445,17 @@ WHERE b."String" LIKE 'Se%' """); } - // TODO: #3547 - public override Task StartsWith_Literal_Char() - => AssertTranslationFailed(() => base.StartsWith_Literal_Char()); + public override async Task StartsWith_Literal_Char() + { + await base.StartsWith_Literal_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE 'S%' +"""); + } public override async Task StartsWith_Parameter() { @@ -447,8 +471,19 @@ WHERE b."String" LIKE @pattern_startswith """); } - public override Task StartsWith_Parameter_Char() - => AssertTranslationFailed(() => base.StartsWith_Parameter_Char()); + public override async Task StartsWith_Parameter_Char() + { + await base.StartsWith_Parameter_Char(); + + AssertSql( + """ +@pattern_startswith='S%' + +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE @pattern_startswith +"""); + } public override async Task StartsWith_Column() { @@ -499,9 +534,17 @@ WHERE b."String" LIKE '%le' """); } - // TODO: #3547 - public override Task EndsWith_Literal_Char() - => AssertTranslationFailed(() => base.EndsWith_Literal_Char()); + public override async Task EndsWith_Literal_Char() + { + await base.EndsWith_Literal_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE '%e' +"""); + } public override async Task EndsWith_Parameter() { @@ -517,9 +560,19 @@ WHERE b."String" LIKE @pattern_endswith """); } - // TODO: #3547 - public override Task EndsWith_Parameter_Char() - => AssertTranslationFailed(() => base.EndsWith_Parameter_Char()); + public override async Task EndsWith_Parameter_Char() + { + await base.EndsWith_Parameter_Char(); + + AssertSql( + """ +@pattern_endswith='%e' + +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE @pattern_endswith +"""); + } public override async Task EndsWith_Column() { @@ -575,9 +628,17 @@ WHERE b."String" LIKE '%eattl%' """); } - // TODO: #3547 - public override Task Contains_Literal_Char() - => AssertTranslationFailed(() => base.Contains_Literal_Char()); + public override async Task Contains_Literal_Char() + { + await base.Contains_Literal_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE '%e%' +"""); + } public override async Task Contains_Column() {