From a5f219d500937fa44b1e9568d105997cba99e2d0 Mon Sep 17 00:00:00 2001 From: Christopher Jolly Date: Sun, 21 Dec 2025 20:30:02 +0800 Subject: [PATCH] Add translation for string methods with char arguments Enable SQL translation for string.IndexOf, Replace, StartsWith, EndsWith, and Contains when called with char arguments. Update translators and type mapping to support char overloads, and implement corresponding tests to verify correct SQL generation. --- .../Internal/NpgsqlStringMethodTranslator.cs | 10 +- .../NpgsqlSqlTranslatingExpressionVisitor.cs | 57 +++++++++- .../StringTranslationsNpgsqlTest.cs | 101 ++++++++++++++---- 3 files changed, 143 insertions(+), 25 deletions(-) diff --git a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs index d56f63380..e6ea5c1f3 100644 --- a/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs +++ b/src/EFCore.PG/Query/ExpressionTranslators/Internal/NpgsqlStringMethodTranslator.cs @@ -39,6 +39,9 @@ public class NpgsqlStringMethodTranslator : IMethodCallTranslator private static readonly MethodInfo Replace = typeof(string).GetRuntimeMethod( nameof(string.Replace), [typeof(string), typeof(string)])!; + private static readonly MethodInfo Replace_Char = typeof(string).GetRuntimeMethod( + nameof(string.Replace), [typeof(char), typeof(char)])!; + private static readonly MethodInfo Substring = typeof(string).GetTypeInfo().GetDeclaredMethods(nameof(string.Substring)) .Single(m => m.GetParameters().Length == 1); @@ -204,7 +207,7 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I { var argument = arguments[0]; var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, argument); - + argument = _sqlExpressionFactory.ApplyTypeMapping(argument, argument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); return _sqlExpressionFactory.Subtract( _sqlExpressionFactory.Function( "strpos", @@ -218,12 +221,15 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I _sqlExpressionFactory.Constant(1)); } - if (method == Replace) + if (method == Replace || method == Replace_Char) { var oldValue = arguments[0]; var newValue = arguments[1]; var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, oldValue, newValue); + oldValue = _sqlExpressionFactory.ApplyTypeMapping(oldValue, oldValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); + newValue = _sqlExpressionFactory.ApplyTypeMapping(newValue, newValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping); + return _sqlExpressionFactory.Function( "replace", [ diff --git a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs index 7211b966c..e1d7ce2ea 100644 --- a/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.PG/Query/Internal/NpgsqlSqlTranslatingExpressionVisitor.cs @@ -45,12 +45,21 @@ public class NpgsqlSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExp private static readonly MethodInfo StringStartsWithMethod = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string)])!; + private static readonly MethodInfo StringStartsWithMethodChar + = typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(char)])!; + private static readonly MethodInfo StringEndsWithMethod = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string)])!; + private static readonly MethodInfo StringEndsWithMethodChar + = typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(char)])!; + private static readonly MethodInfo StringContainsMethod = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!; + private static readonly MethodInfo StringContainsMethodChar + = typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char)])!; + private static readonly MethodInfo EscapeLikePatternParameterMethod = typeof(NpgsqlSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ConstructLikePatternParameter))!; @@ -405,21 +414,21 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp return TranslateCubeToSubset(sqlCubeInstance, sqlIndexes) ?? QueryCompilationContext.NotTranslatedExpression; } - if (method == StringStartsWithMethod + if ((method == StringStartsWithMethod || method == StringStartsWithMethodChar) && TryTranslateStartsEndsWithContains( methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.StartsWith, out var translation1)) { return translation1; } - if (method == StringEndsWithMethod + if ((method == StringEndsWithMethod || method == StringEndsWithMethodChar) && TryTranslateStartsEndsWithContains( methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.EndsWith, out var translation2)) { return translation2; } - if (method == StringContainsMethod + if ((method == StringContainsMethod || method == StringContainsMethodChar) && TryTranslateStartsEndsWithContains( methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.Contains, out var translation3)) { @@ -719,6 +728,32 @@ private bool TryTranslateStartsEndsWithContains( _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) })), + char s when !IsLikeWildChar(s) + => _sqlExpressionFactory.Like( + translatedInstance, + _sqlExpressionFactory.Constant( + methodType switch + { + StartsEndsWithContains.StartsWith => s + "%", + StartsEndsWithContains.EndsWith => "%" + s, + StartsEndsWithContains.Contains => $"%{s}%", + + _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) + })), + + char s => _sqlExpressionFactory.Like( + translatedInstance, + _sqlExpressionFactory.Constant( + methodType switch + { + StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%", + StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s, + StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%", + + _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) + }), + _sqlExpressionFactory.Constant(LikeEscapeChar)), + _ => throw new UnreachableException() }; @@ -834,6 +869,22 @@ private bool TryTranslateStartsEndsWithContains( _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) }, + char s when !IsLikeWildChar(s) => methodType switch + { + StartsEndsWithContains.StartsWith => s + "%", + StartsEndsWithContains.EndsWith => "%" + s, + StartsEndsWithContains.Contains => $"%{s}%", + _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) + }, + + char s => methodType switch + { + StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%", + StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s, + StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%", + _ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null) + }, + _ => throw new UnreachableException() }; diff --git a/test/EFCore.PG.FunctionalTests/Query/Translations/StringTranslationsNpgsqlTest.cs b/test/EFCore.PG.FunctionalTests/Query/Translations/StringTranslationsNpgsqlTest.cs index 2ebc6fb97..f28419be4 100644 --- a/test/EFCore.PG.FunctionalTests/Query/Translations/StringTranslationsNpgsqlTest.cs +++ b/test/EFCore.PG.FunctionalTests/Query/Translations/StringTranslationsNpgsqlTest.cs @@ -131,9 +131,17 @@ WHERE strpos(b."String", 'eattl') - 1 <> -1 """); } - // TODO: #3547 - public override Task IndexOf_Char() - => Assert.ThrowsAsync(() => base.IndexOf_Char()); + public override async Task IndexOf_Char() + { + await base.IndexOf_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE strpos(b."String", 'e') - 1 <> -1 +"""); + } public override async Task IndexOf_with_empty_string() { @@ -231,9 +239,17 @@ WHERE replace(b."String", 'Sea', 'Rea') = 'Reattle' """); } - // TODO: #3547 - public override Task Replace_Char() - => AssertTranslationFailed(() => base.Replace_Char()); + public override async Task Replace_Char() + { + await base.Replace_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE replace(b."String", 'S', 'R') = 'Reattle' +"""); + } public override async Task Replace_with_empty_string() { @@ -429,9 +445,17 @@ WHERE b."String" LIKE 'Se%' """); } - // TODO: #3547 - public override Task StartsWith_Literal_Char() - => AssertTranslationFailed(() => base.StartsWith_Literal_Char()); + public override async Task StartsWith_Literal_Char() + { + await base.StartsWith_Literal_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE 'S%' +"""); + } public override async Task StartsWith_Parameter() { @@ -447,8 +471,19 @@ WHERE b."String" LIKE @pattern_startswith """); } - public override Task StartsWith_Parameter_Char() - => AssertTranslationFailed(() => base.StartsWith_Parameter_Char()); + public override async Task StartsWith_Parameter_Char() + { + await base.StartsWith_Parameter_Char(); + + AssertSql( + """ +@pattern_startswith='S%' + +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE @pattern_startswith +"""); + } public override async Task StartsWith_Column() { @@ -499,9 +534,17 @@ WHERE b."String" LIKE '%le' """); } - // TODO: #3547 - public override Task EndsWith_Literal_Char() - => AssertTranslationFailed(() => base.EndsWith_Literal_Char()); + public override async Task EndsWith_Literal_Char() + { + await base.EndsWith_Literal_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE '%e' +"""); + } public override async Task EndsWith_Parameter() { @@ -517,9 +560,19 @@ WHERE b."String" LIKE @pattern_endswith """); } - // TODO: #3547 - public override Task EndsWith_Parameter_Char() - => AssertTranslationFailed(() => base.EndsWith_Parameter_Char()); + public override async Task EndsWith_Parameter_Char() + { + await base.EndsWith_Parameter_Char(); + + AssertSql( + """ +@pattern_endswith='%e' + +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE @pattern_endswith +"""); + } public override async Task EndsWith_Column() { @@ -575,9 +628,17 @@ WHERE b."String" LIKE '%eattl%' """); } - // TODO: #3547 - public override Task Contains_Literal_Char() - => AssertTranslationFailed(() => base.Contains_Literal_Char()); + public override async Task Contains_Literal_Char() + { + await base.Contains_Literal_Char(); + + AssertSql( + """ +SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan" +FROM "BasicTypesEntities" AS b +WHERE b."String" LIKE '%e%' +"""); + } public override async Task Contains_Column() {