diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/PartialEvaluator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/PartialEvaluator.cs index 171bdf58cc8..c5554336ac3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/PartialEvaluator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/PartialEvaluator.cs @@ -75,60 +75,137 @@ public override Expression Visit(Expression expression) protected override Expression VisitBinary(BinaryExpression node) { - if (node.NodeType == ExpressionType.AndAlso) + var leftExpression = node.Left; + var rightExpression = node.Right; + + if (leftExpression.Type == typeof(bool) && rightExpression.Type == typeof(bool)) { - var leftExpression = Visit(node.Left); - if (leftExpression is ConstantExpression constantLeftExpression ) + if (node.NodeType == ExpressionType.AndAlso) { - var value = (bool)constantLeftExpression.Value; - return value ? Visit(node.Right) : Expression.Constant(false); + leftExpression = Visit(leftExpression); + if (IsConstant(leftExpression, out var leftValue)) + { + // true && Q => Q + // false && Q => false + return leftValue ? Visit(rightExpression) : Expression.Constant(false); + } + + rightExpression = Visit(rightExpression); + if (IsConstant(rightExpression, out var rightValue)) + { + // P && true => P + // P && false => false + return rightValue ? leftExpression : Expression.Constant(false); + } + + return node.Update(leftExpression, conversion: null, rightExpression); } - var rightExpression = Visit(node.Right); - if (rightExpression is ConstantExpression constantRightExpression) + if (node.NodeType == ExpressionType.OrElse) { - var value = (bool)constantRightExpression.Value; - return value ? leftExpression : Expression.Constant(false); + leftExpression = Visit(leftExpression); + if (IsConstant(leftExpression, out var leftValue)) + { + // true || Q => true + // false || Q => Q + return leftValue ? Expression.Constant(true) : Visit(rightExpression); + } + + rightExpression = Visit(rightExpression); + if (IsConstant(rightExpression, out var rightValue)) + { + // P || true => true + // P || false => P + return rightValue ? Expression.Constant(true) : leftExpression; + } + + return node.Update(leftExpression, conversion: null, rightExpression); } + } + + return base.VisitBinary(node); + } - return node.Update(leftExpression, conversion: null, rightExpression); + protected override Expression VisitConditional(ConditionalExpression node) + { + var test = Visit(node.Test); + + if (IsConstant(test, out var testValue)) + { + // true ? A : B => A + // false ? A : B => B + return testValue ? Visit(node.IfTrue) : Visit(node.IfFalse); } - if (node.NodeType == ExpressionType.OrElse) + var ifTrue = Visit(node.IfTrue); + var ifFalse = Visit(node.IfFalse); + + if (BothAreConstant(ifTrue, ifFalse, out var ifTrueValue, out var ifFalseValue)) { - var leftExpression = Visit(node.Left); - if (leftExpression is ConstantExpression constantLeftExpression) + return (ifTrueValue, ifFalseValue) switch { - var value = (bool)constantLeftExpression.Value; - return value ? Expression.Constant(true) : Visit(node.Right); - } + (false, false) => Expression.Constant(false), // T ? false : false => false + (false, true) => Expression.Not(test), // T ? false : true => !T + (true, false) => test, // T ? true : false => T + (true, true) => Expression.Constant(true) // T ? true : true => true + }; + } + else if (IsConstant(ifTrue, out ifTrueValue)) + { + // T ? true : Q => T || Q + // T ? false : Q => !T && Q + return ifTrueValue + ? Visit(Expression.OrElse(test, ifFalse)) + : Visit(Expression.AndAlso(Expression.Not(test), ifFalse)); + } + else if (IsConstant(ifFalse, out ifFalseValue)) + { + // T ? P : true => !T || P + // T ? P : false => T && P + return ifFalseValue + ? Visit(Expression.OrElse(Expression.Not(test), ifTrue)) + : Visit(Expression.AndAlso(test, ifTrue)); + } - var rightExpression = Visit(node.Right); - if (rightExpression is ConstantExpression constantRightExpression) + return node.Update(test, ifTrue, ifFalse); + } + + protected override Expression VisitUnary(UnaryExpression node) + { + var operand = Visit(node.Operand); + + if (node.Type == typeof(bool) && + node.NodeType == ExpressionType.Not) + { + if (operand is UnaryExpression innerUnaryExpressionOperand && + innerUnaryExpressionOperand.NodeType == ExpressionType.Not) { - var value = (bool)constantRightExpression.Value; - return value ? Expression.Constant(true) : leftExpression; + // !!P => P + return innerUnaryExpressionOperand.Operand; } - - return node.Update(leftExpression, conversion: null, rightExpression); } - return base.VisitBinary(node); + return node.Update(operand); } - protected override Expression VisitConditional(ConditionalExpression node) + // private methods + private bool BothAreConstant(Expression expression1, Expression expression2, out T constantValue1, out T constantValue2) { - var test = Visit(node.Test); - if (test is ConstantExpression constantTestExpression) + if (expression1 is ConstantExpression constantExpression1 && + expression2 is ConstantExpression constantExpression2 && + constantExpression1.Type == typeof(T) && + constantExpression2.Type == typeof(T)) { - var value = (bool)constantTestExpression.Value; - return value ? Visit(node.IfTrue) : Visit(node.IfFalse); + constantValue1 = (T)constantExpression1.Value; + constantValue2 = (T)constantExpression2.Value; + return true; } - return node.Update(test, Visit(node.IfTrue), Visit(node.IfFalse)); + constantValue1 = default; + constantValue2 = default; + return false; } - // private methods private Expression Evaluate(Expression expression) { if (expression.NodeType == ExpressionType.Constant) @@ -139,6 +216,19 @@ private Expression Evaluate(Expression expression) Delegate fn = lambda.Compile(); return Expression.Constant(fn.DynamicInvoke(null), expression.Type); } + + private bool IsConstant(Expression expression, out T constantValue) + { + if (expression is ConstantExpression constantExpression1 && + constantExpression1.Type == typeof(T)) + { + constantValue = (T)constantExpression1.Value; + return true; + } + + constantValue = default; + return false; + } } private class Nominator : ExpressionVisitor diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4337Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4337Tests.cs index 8c066109d33..ab2d9f45621 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4337Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp4337Tests.cs @@ -32,10 +32,10 @@ public class CSharp4337Tests : LinqIntegrationTest { private static (Expression>> Projection, string ExpectedStage, bool[] ExpectedResults)[] __predicate_should_use_correct_representation_test_cases = new (Expression>> Projection, string ExpectedStage, bool[] ExpectedResults)[] { - (d => new R { N = d.Id, V = d.I1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : ['$I1', 1] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }), - (d => new R { N = d.Id, V = d.S1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : ['$S1', 'E1'] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }), - (d => new R { N = d.Id, V = E.E1 == d.I1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : [1, '$I1'] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }), - (d => new R { N = d.Id, V = E.E1 == d.S1 ? true : false }, "{ $project : { N : '$_id', V : { $cond : { if : { $eq : ['E1', '$S1'] }, then : true, else : false } }, _id : 0 } }", new[] { true, false }) + (d => new R { N = d.Id, V = d.I1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : ['$I1', 1] }, _id : 0 } }", new[] { true, false }), + (d => new R { N = d.Id, V = d.S1 == E.E1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : ['$S1', 'E1'] }, _id : 0 } }", new[] { true, false }), + (d => new R { N = d.Id, V = E.E1 == d.I1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : [1, '$I1'] }, _id : 0 } }", new[] { true, false }), + (d => new R { N = d.Id, V = E.E1 == d.S1 ? true : false }, "{ $project : { N : '$_id', V : { $eq : ['E1', '$S1'] }, _id : 0 } }", new[] { true, false }) }; public CSharp4337Tests(ClassFixture fixture) diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5628Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5628Tests.cs new file mode 100644 index 00000000000..d4159481ccd --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5628Tests.cs @@ -0,0 +1,190 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using MongoDB.Driver.TestHelpers; +using FluentAssertions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5628Tests : LinqIntegrationTest +{ + public CSharp5628Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Theory] + [InlineData(1, "{ $project : { _v : { $eq : ['$P', true] }, _id : 0 } }", new bool[] { false, false, true, true })] + [InlineData(2, "{ $project : { _v : { $ne : ['$P', false] }, _id : 0 } }", new bool[] { false, false, true, true })] + [InlineData(3, "{ $project : { _v : { $eq : ['$P', false] }, _id : 0 } }", new bool[] { true, true, false, false })] + [InlineData(4, "{ $project : { _v : { $ne : ['$P', true] }, _id : 0 } }", new bool[] { true, true, false, false })] + [InlineData(5, "{ $project : { _v : { $or : ['$P', '$Q'] }, _id : 0 } }", new bool[] { false, true, true, true })] + [InlineData(6, "{ $project : { _v : { $and : ['$P', '$Q'] }, _id : 0 } }", new bool[] { false, false, false, true })] + [InlineData(7, "{ $project : { _v : { $or : [{ $not : '$P' }, '$Q'] }, _id : 0 } }", new bool[] { true, true, false, true })] + [InlineData(8, "{ $project : { _v : { $and : [{ $not : '$P' }, '$Q'] }, _id : 0 } }", new bool[] { false, true, false, false })] + [InlineData(9, "{ $project : { _v : '$P', _id : 0 } }", new bool[] { false, false, true, true })] + [InlineData(10, "{ $project : { _v : { $not : '$P' }, _id : 0 } }", new bool[] { true, true, false, false })] + [InlineData(11, "{ $project : { _v : '$P', _id : 0 } }", new bool[] { false, false, true, true })] + [InlineData(12, "{ $project : { _v : '$P', _id : 0 } }", new bool[] { false, false, true, true })] + [InlineData(13, "{ $project : { _v : '$P', _id : 0 } }", new bool[] { false, false, true, true })] + + [InlineData(14, "{ $project : { _v : { $literal : false }, _id : 0 } }", new bool[] { false, false, false, false })] + [InlineData(15, "{ $project : { _v : { $literal : true }, _id : 0 } }", new bool[] { true, true, true, true })] + [InlineData(16, "{ $project : { _v : { $literal : true }, _id : 0 } }", new bool[] { true, true, true, true })] + [InlineData(17, "{ $project : { _v : { $literal : false }, _id : 0 } }", new bool[] { false, false, false, false })] + + [InlineData(18, "{ $project : { _v : { $ne : ['$X', '$Y'] }, _id : 0 } }", new bool[] { false, true, true, false })] + [InlineData(19, "{ $project : { _v : { $eq : ['$X', '$Y'] }, _id : 0 } }", new bool[] { true, false, false, true })] + [InlineData(20, "{ $project : { _v : { $not : { $lt : ['$X', '$Y'] } }, _id : 0 } }", new bool[] { true, false, true, true })] + [InlineData(21, "{ $project : { _v : { $not : { $gt : ['$X', '$Y'] } }, _id : 0 } }", new bool[] { true, true, false, true })] + [InlineData(22, "{ $project : { _v : { $not : { $lte : ['$X', '$Y'] } }, _id : 0 } }", new bool[] { false, false, true, false })] + [InlineData(23, "{ $project : { _v : { $not : { $gte : ['$X', '$Y'] } }, _id : 0 } }", new bool[] { false, true, false, false })] + public void Select_simplifications_should_work(int testCase, string expectedStage, bool[] expectedResults) + { + var collection = Fixture.Collection; + + // see: https://codeql.github.com/codeql-query-help/csharp/cs-simplifiable-boolean-expression/#recommendation + // not all simplifications listed there are safe for a database (because of possibly missing fields or tri-valued logic) + var queryable = testCase switch + { + 1 => collection.AsQueryable().Select(x => x.P == true), // not safe + 2 => collection.AsQueryable().Select(x => x.P != false), // not safe + 3 => collection.AsQueryable().Select(x => x.P == false), // not safe + 4 => collection.AsQueryable().Select(x => x.P != true), // not safe + 5 => collection.AsQueryable().Select(x => x.P ? true : x.Q), + 6 => collection.AsQueryable().Select(x => x.P ? x.Q : false), + 7 => collection.AsQueryable().Select(x => x.P ? x.Q : true), + 8 => collection.AsQueryable().Select(x => x.P ? false : x.Q), + 9 => collection.AsQueryable().Select(x => x.P ? true : false), + 10 => collection.AsQueryable().Select(x => x.P ? false : true), + 11 => collection.AsQueryable().Select(x => !!x.P), + 12 => collection.AsQueryable().Select(x => x.P && true), + 13 => collection.AsQueryable().Select(x => x.P || false), + + 14 => collection.AsQueryable().Select(x => x.P && false), + 15 => collection.AsQueryable().Select(x => x.P || true), + 16 => collection.AsQueryable().Select(x => x.P ? true : true), + 17 => collection.AsQueryable().Select(x => x.P ? false : false), + + 18 => collection.AsQueryable().Select(x => !(x.X == x.Y)), + 19 => collection.AsQueryable().Select(x => !(x.X != x.Y)), + 20 => collection.AsQueryable().Select(x => !(x.X < x.Y)), // not safe + 21 => collection.AsQueryable().Select(x => !(x.X > x.Y)), // not safe + 22 => collection.AsQueryable().Select(x => !(x.X <= x.Y)), // not safe + 23 => collection.AsQueryable().Select(x => !(x.X >= x.Y)), // not safe + _ => throw new ArgumentException($"Invalid test case: {testCase}") + }; + + var stages = Translate(collection, queryable); + AssertStages(stages, expectedStage); + + var results = queryable.ToList(); + results.Should().Equal(expectedResults); + } + + [Theory] + [InlineData(1, "{ $match : { P : true } }", new int[] { 3, 4 })] + [InlineData(2, "{ $match : { P : { $ne : false } } }", new int[] { 3, 4 })] + [InlineData(3, "{ $match : { P : false } }", new int[] { 1, 2, })] + [InlineData(4, "{ $match : { P : { $ne : true } } }", new int[] { 1, 2 })] + [InlineData(5, "{ $match : { $or : [{ P : true }, { Q : true }] } }", new int[] { 2, 3, 4 })] + [InlineData(6, "{ $match : { P : true, Q : true } }", new int[] { 4 })] + [InlineData(7, "{ $match : { $or : [{ P : { $ne : true } }, { Q : true }] } }", new int[] { 1, 2, 4 })] + [InlineData(8, "{ $match : { P : { $ne : true }, Q : true } }", new int[] { 2 })] + [InlineData(9, "{ $match : { P : true } }", new int[] { 3, 4 })] + [InlineData(10, "{ $match : { P : { $ne : true } } }", new int[] { 1, 2 })] + [InlineData(11, "{ $match : { P : true } }", new int[] { 3, 4 })] + [InlineData(12, "{ $match : { P : true } }", new int[] { 3, 4 })] + [InlineData(13, "{ $match : { P : true } }", new int[] { 3, 4 })] + + [InlineData(14, "{ $match : { _id : { $type : -1 } } }", new int[] { })] + [InlineData(15, null, new int[] { 1, 2, 3, 4 })] + [InlineData(16, null, new int[] { 1, 2, 3, 4 })] + [InlineData(17, "{ $match : { _id : { $type : -1 } } }", new int[] { })] + + [InlineData(18, "{ $match : { $nor : [{ $expr : { $eq : ['$X', '$Y'] } }] } }", new int[] { 2, 3 })] + [InlineData(19, "{ $match : { $nor : [{ $expr : { $ne : ['$X', '$Y'] } }] } }", new int[] { 1, 4 })] + [InlineData(20, "{ $match : { $nor : [{ $expr : { $lt : ['$X', '$Y'] } }] } }", new int[] { 1, 3, 4 })] + [InlineData(21, "{ $match : { $nor : [{ $expr : { $gt : ['$X', '$Y'] } }] } }", new int[] { 1, 2, 4 })] + [InlineData(22, "{ $match : { $nor : [{ $expr : { $lte : ['$X', '$Y'] } }] } }", new int[] { 3 })] + [InlineData(23, "{ $match : { $nor : [{ $expr : { $gte : ['$X', '$Y'] } }] } }", new int[] { 2 })] + public void Where_simplifications_should_work(int testCase, string expectedStage, int[] expectedIds) + { + var collection = Fixture.Collection; + + // see: https://codeql.github.com/codeql-query-help/csharp/cs-simplifiable-boolean-expression/#recommendation + // not all simplifications listed there are safe for a database (because of possibly missing fields or tri-valued logic) + var queryable = testCase switch + { + 1 => collection.AsQueryable().Where(x => x.P == true), // not safe + 2 => collection.AsQueryable().Where(x => x.P != false), // not safe + 3 => collection.AsQueryable().Where(x => x.P == false), // not safe + 4 => collection.AsQueryable().Where(x => x.P != true), // not safe + 5 => collection.AsQueryable().Where(x => x.P ? true : x.Q), + 6 => collection.AsQueryable().Where(x => x.P ? x.Q : false), + 7 => collection.AsQueryable().Where(x => x.P ? x.Q : true), + 8 => collection.AsQueryable().Where(x => x.P ? false : x.Q), + 9 => collection.AsQueryable().Where(x => x.P ? true : false), + 10 => collection.AsQueryable().Where(x => x.P ? false : true), + 11 => collection.AsQueryable().Where(x => !!x.P), + 12 => collection.AsQueryable().Where(x => x.P && true), + 13 => collection.AsQueryable().Where(x => x.P || false), + + 14 => collection.AsQueryable().Where(x => x.P && false), + 15 => collection.AsQueryable().Where(x => x.P || true), + 16 => collection.AsQueryable().Where(x => x.P ? true : true), + 17 => collection.AsQueryable().Where(x => x.P ? false : false), + + 18 => collection.AsQueryable().Where(x => !(x.X == x.Y)), + 19 => collection.AsQueryable().Where(x => !(x.X != x.Y)), + 20 => collection.AsQueryable().Where(x => !(x.X < x.Y)), // not safe + 21 => collection.AsQueryable().Where(x => !(x.X > x.Y)), // not safe + 22 => collection.AsQueryable().Where(x => !(x.X <= x.Y)), // not safe + 23 => collection.AsQueryable().Where(x => !(x.X >= x.Y)), // not safe + _ => throw new ArgumentException($"Invalid test case: {testCase}") + }; + + var stages = Translate(collection, queryable); + string[] expectedStages = expectedStage == null ? [] : [expectedStage]; + AssertStages(stages, expectedStages); + + var results = queryable.ToList(); + results.Select(r => r.Id).Should().Equal(expectedIds); + } + + public class C + { + public int Id { get; set; } + public bool P { get; set; } + public bool Q { get; set; } + public int X { get; set; } + public int Y { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new C { Id = 1, P = false, Q = false, X = 1, Y = 1 }, + new C { Id = 2, P = false, Q = true, X = 1, Y = 2 }, + new C { Id = 3, P = true, Q = false, X = 2, Y = 1 }, + new C { Id = 4, P = true, Q = true, X = 2, Y = 2 } + ]; + } +}