Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public class NpgsqlStringMethodTranslator : IMethodCallTranslator
private static readonly MethodInfo Replace = typeof(string).GetRuntimeMethod(
nameof(string.Replace), [typeof(string), typeof(string)])!;

private static readonly MethodInfo Replace_Char = typeof(string).GetRuntimeMethod(
nameof(string.Replace), [typeof(char), typeof(char)])!;

private static readonly MethodInfo Substring = typeof(string).GetTypeInfo().GetDeclaredMethods(nameof(string.Substring))
.Single(m => m.GetParameters().Length == 1);

Expand Down Expand Up @@ -204,7 +207,7 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I
{
var argument = arguments[0];
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, argument);

argument = _sqlExpressionFactory.ApplyTypeMapping(argument, argument.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping);
return _sqlExpressionFactory.Subtract(
_sqlExpressionFactory.Function(
"strpos",
Expand All @@ -218,12 +221,15 @@ public NpgsqlStringMethodTranslator(NpgsqlTypeMappingSource typeMappingSource, I
_sqlExpressionFactory.Constant(1));
}

if (method == Replace)
if (method == Replace || method == Replace_Char)
{
var oldValue = arguments[0];
var newValue = arguments[1];
var stringTypeMapping = ExpressionExtensions.InferTypeMapping(instance!, oldValue, newValue);

oldValue = _sqlExpressionFactory.ApplyTypeMapping(oldValue, oldValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping);
newValue = _sqlExpressionFactory.ApplyTypeMapping(newValue, newValue.Type == typeof(char) ? CharTypeMapping.Default : stringTypeMapping);

return _sqlExpressionFactory.Function(
"replace",
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,21 @@ public class NpgsqlSqlTranslatingExpressionVisitor : RelationalSqlTranslatingExp
private static readonly MethodInfo StringStartsWithMethod
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(string)])!;

private static readonly MethodInfo StringStartsWithMethodChar
= typeof(string).GetRuntimeMethod(nameof(string.StartsWith), [typeof(char)])!;

private static readonly MethodInfo StringEndsWithMethod
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(string)])!;

private static readonly MethodInfo StringEndsWithMethodChar
= typeof(string).GetRuntimeMethod(nameof(string.EndsWith), [typeof(char)])!;

private static readonly MethodInfo StringContainsMethod
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(string)])!;

private static readonly MethodInfo StringContainsMethodChar
= typeof(string).GetRuntimeMethod(nameof(string.Contains), [typeof(char)])!;

private static readonly MethodInfo EscapeLikePatternParameterMethod =
typeof(NpgsqlSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ConstructLikePatternParameter))!;

Expand Down Expand Up @@ -405,21 +414,21 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return TranslateCubeToSubset(sqlCubeInstance, sqlIndexes) ?? QueryCompilationContext.NotTranslatedExpression;
}

if (method == StringStartsWithMethod
if ((method == StringStartsWithMethod || method == StringStartsWithMethodChar)
&& TryTranslateStartsEndsWithContains(
methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.StartsWith, out var translation1))
{
return translation1;
}

if (method == StringEndsWithMethod
if ((method == StringEndsWithMethod || method == StringEndsWithMethodChar)
&& TryTranslateStartsEndsWithContains(
methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.EndsWith, out var translation2))
{
return translation2;
}

if (method == StringContainsMethod
if ((method == StringContainsMethod || method == StringContainsMethodChar)
&& TryTranslateStartsEndsWithContains(
methodCallExpression.Object!, methodCallExpression.Arguments[0], StartsEndsWithContains.Contains, out var translation3))
{
Expand Down Expand Up @@ -719,6 +728,32 @@ private bool TryTranslateStartsEndsWithContains(
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
})),

char s when !IsLikeWildChar(s)
=> _sqlExpressionFactory.Like(
translatedInstance,
_sqlExpressionFactory.Constant(
methodType switch
{
StartsEndsWithContains.StartsWith => s + "%",
StartsEndsWithContains.EndsWith => "%" + s,
StartsEndsWithContains.Contains => $"%{s}%",

_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
})),

char s => _sqlExpressionFactory.Like(
translatedInstance,
_sqlExpressionFactory.Constant(
methodType switch
{
StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%",
StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s,
StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%",

_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
}),
_sqlExpressionFactory.Constant(LikeEscapeChar)),

_ => throw new UnreachableException()
};

Expand Down Expand Up @@ -834,6 +869,22 @@ private bool TryTranslateStartsEndsWithContains(
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
},

char s when !IsLikeWildChar(s) => methodType switch
{
StartsEndsWithContains.StartsWith => s + "%",
StartsEndsWithContains.EndsWith => "%" + s,
StartsEndsWithContains.Contains => $"%{s}%",
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
},

char s => methodType switch
{
StartsEndsWithContains.StartsWith => LikeEscapeChar + s + "%",
StartsEndsWithContains.EndsWith => "%" + LikeEscapeChar + s,
StartsEndsWithContains.Contains => $"%{LikeEscapeChar}{s}%",
_ => throw new ArgumentOutOfRangeException(nameof(methodType), methodType, null)
},

_ => throw new UnreachableException()
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,17 @@ WHERE strpos(b."String", 'eattl') - 1 <> -1
""");
}

// TODO: #3547
public override Task IndexOf_Char()
=> Assert.ThrowsAsync<InvalidCastException>(() => base.IndexOf_Char());
public override async Task IndexOf_Char()
{
await base.IndexOf_Char();

AssertSql(
"""
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
FROM "BasicTypesEntities" AS b
WHERE strpos(b."String", 'e') - 1 <> -1
""");
}

public override async Task IndexOf_with_empty_string()
{
Expand Down Expand Up @@ -231,9 +239,17 @@ WHERE replace(b."String", 'Sea', 'Rea') = 'Reattle'
""");
}

// TODO: #3547
public override Task Replace_Char()
=> AssertTranslationFailed(() => base.Replace_Char());
public override async Task Replace_Char()
{
await base.Replace_Char();

AssertSql(
"""
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
FROM "BasicTypesEntities" AS b
WHERE replace(b."String", 'S', 'R') = 'Reattle'
""");
}

public override async Task Replace_with_empty_string()
{
Expand Down Expand Up @@ -429,9 +445,17 @@ WHERE b."String" LIKE 'Se%'
""");
}

// TODO: #3547
public override Task StartsWith_Literal_Char()
=> AssertTranslationFailed(() => base.StartsWith_Literal_Char());
public override async Task StartsWith_Literal_Char()
{
await base.StartsWith_Literal_Char();

AssertSql(
"""
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
FROM "BasicTypesEntities" AS b
WHERE b."String" LIKE 'S%'
""");
}

public override async Task StartsWith_Parameter()
{
Expand All @@ -447,8 +471,19 @@ WHERE b."String" LIKE @pattern_startswith
""");
}

public override Task StartsWith_Parameter_Char()
=> AssertTranslationFailed(() => base.StartsWith_Parameter_Char());
public override async Task StartsWith_Parameter_Char()
{
await base.StartsWith_Parameter_Char();

AssertSql(
"""
@pattern_startswith='S%'

SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
FROM "BasicTypesEntities" AS b
WHERE b."String" LIKE @pattern_startswith
""");
}

public override async Task StartsWith_Column()
{
Expand Down Expand Up @@ -499,9 +534,17 @@ WHERE b."String" LIKE '%le'
""");
}

// TODO: #3547
public override Task EndsWith_Literal_Char()
=> AssertTranslationFailed(() => base.EndsWith_Literal_Char());
public override async Task EndsWith_Literal_Char()
{
await base.EndsWith_Literal_Char();

AssertSql(
"""
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
FROM "BasicTypesEntities" AS b
WHERE b."String" LIKE '%e'
""");
}

public override async Task EndsWith_Parameter()
{
Expand All @@ -517,9 +560,19 @@ WHERE b."String" LIKE @pattern_endswith
""");
}

// TODO: #3547
public override Task EndsWith_Parameter_Char()
=> AssertTranslationFailed(() => base.EndsWith_Parameter_Char());
public override async Task EndsWith_Parameter_Char()
{
await base.EndsWith_Parameter_Char();

AssertSql(
"""
@pattern_endswith='%e'

SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
FROM "BasicTypesEntities" AS b
WHERE b."String" LIKE @pattern_endswith
""");
}

public override async Task EndsWith_Column()
{
Expand Down Expand Up @@ -575,9 +628,17 @@ WHERE b."String" LIKE '%eattl%'
""");
}

// TODO: #3547
public override Task Contains_Literal_Char()
=> AssertTranslationFailed(() => base.Contains_Literal_Char());
public override async Task Contains_Literal_Char()
{
await base.Contains_Literal_Char();

AssertSql(
"""
SELECT b."Id", b."Bool", b."Byte", b."ByteArray", b."DateOnly", b."DateTime", b."DateTimeOffset", b."Decimal", b."Double", b."Enum", b."FlagsEnum", b."Float", b."Guid", b."Int", b."Long", b."Short", b."String", b."TimeOnly", b."TimeSpan"
FROM "BasicTypesEntities" AS b
WHERE b."String" LIKE '%e%'
""");
}

public override async Task Contains_Column()
{
Expand Down