Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 315 additions & 0 deletions src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,19 @@ private enum CountCheckStatus
HasCount,
}

private enum LinqPredicateCheckStatus
{
Unknown,
Any,
Count,
WhereAny,
WhereCount,
Single,
SingleOrDefault,
WhereSingle,
WhereSingleOrDefault,
}

internal const string ProperAssertMethodNameKey = nameof(ProperAssertMethodNameKey);

/// <summary>
Expand Down Expand Up @@ -268,6 +281,56 @@ private static void AnalyzeInvocationOperation(OperationAnalysisContext context,
case "AreNotEqual":
AnalyzeAreEqualOrAreNotEqualInvocation(context, firstArgument, isAreEqualInvocation: false, objectTypeSymbol);
break;
case "IsNull":
AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: true);
break;

case "IsNotNull":
AnalyzeIsNullOrIsNotNullInvocation(context, firstArgument, isNullCheck: false);
break;
}
}

private static void AnalyzeIsNullOrIsNotNullInvocation(OperationAnalysisContext context, IOperation argument, bool isNullCheck)
{
RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation.");

// Check for Single/SingleOrDefault patterns
LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck(
argument,
out SyntaxNode? linqCollectionExpr,
out SyntaxNode? predicateExpr,
out _);

if (linqStatus is LinqPredicateCheckStatus.Single or
LinqPredicateCheckStatus.SingleOrDefault or
LinqPredicateCheckStatus.WhereSingle or
LinqPredicateCheckStatus.WhereSingleOrDefault &&
linqCollectionExpr != null)
{
// For Assert.IsNotNull(enumerable.Single[OrDefault](...)) -> Assert.ContainsSingle
// For Assert.IsNull(enumerable.Single[OrDefault](...)) -> Assert.DoesNotContain
string properAssertMethod = isNullCheck ? "DoesNotContain" : "ContainsSingle";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, predicateExpr != null ? CodeFixModeAddArgument : CodeFixModeSimple);

ImmutableArray<Location> additionalLocations = predicateExpr != null
? ImmutableArray.Create(
argument.Syntax.GetLocation(),
predicateExpr.GetLocation(),
linqCollectionExpr.GetLocation())
: ImmutableArray.Create(
argument.Syntax.GetLocation(),
linqCollectionExpr.GetLocation());

context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: additionalLocations,
properties: properties.ToImmutable(),
properAssertMethod,
isNullCheck ? "IsNull" : "IsNotNull"));
}
}

Expand Down Expand Up @@ -519,6 +582,146 @@ private static ComparisonCheckStatus RecognizeComparisonCheck(
return ComparisonCheckStatus.Unknown;
}

private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck(
IOperation operation,
out SyntaxNode? collectionExpression,
out SyntaxNode? predicateExpression,
out IOperation? countOperation)
{
collectionExpression = null;
predicateExpression = null;
countOperation = null;

// Check for enumerable.Any(predicate)
// Extension methods appear as: Instance=null, Arguments[0]=collection, Arguments[1]=predicate
if (operation is IInvocationOperation anyInvocation &&
anyInvocation.TargetMethod.Name == "Any" &&
anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
anyInvocation.Arguments.Length == 2)
{
collectionExpression = anyInvocation.Arguments[0].Value.Syntax;
predicateExpression = anyInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.Any;
}

// Check for enumerable.Count(predicate)
if (operation is IInvocationOperation countInvocation &&
countInvocation.TargetMethod.Name == "Count" &&
countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
countInvocation.Arguments.Length == 2)
{
collectionExpression = countInvocation.Arguments[0].Value.Syntax;
predicateExpression = countInvocation.Arguments[1].Value.Syntax;
countOperation = operation;
return LinqPredicateCheckStatus.Count;
}

// Check for enumerable.Where(predicate).Any()
if (operation is IInvocationOperation whereAnyInvocation &&
whereAnyInvocation.TargetMethod.Name == "Any" &&
whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereAnyInvocation.Arguments.Length == 1 &&
whereAnyInvocation.Arguments[0].Value is IInvocationOperation whereInvocation &&
whereInvocation.TargetMethod.Name == "Where" &&
whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation.Arguments.Length == 2)
{
collectionExpression = whereInvocation.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereAny;
}

// Check for enumerable.Where(predicate).Count()
if (operation is IInvocationOperation whereCountInvocation &&
whereCountInvocation.TargetMethod.Name == "Count" &&
whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereCountInvocation.Arguments.Length == 1 &&
whereCountInvocation.Arguments[0].Value is IInvocationOperation whereInvocation2 &&
whereInvocation2.TargetMethod.Name == "Where" &&
whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation2.Arguments.Length == 2)
{
collectionExpression = whereInvocation2.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation2.Arguments[1].Value.Syntax;
countOperation = operation;
return LinqPredicateCheckStatus.WhereCount;
}

// Check for enumerable.Where(predicate).Single()
if (operation is IInvocationOperation whereSingleInvocation &&
whereSingleInvocation.TargetMethod.Name == "Single" &&
whereSingleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereSingleInvocation.Arguments.Length == 1 &&
whereSingleInvocation.Arguments[0].Value is IInvocationOperation whereInvocation3 &&
whereInvocation3.TargetMethod.Name == "Where" &&
whereInvocation3.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation3.Arguments.Length == 2)
{
collectionExpression = whereInvocation3.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation3.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereSingle;
}

// Check for enumerable.Where(predicate).SingleOrDefault()
if (operation is IInvocationOperation whereSingleOrDefaultInvocation &&
whereSingleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" &&
whereSingleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereSingleOrDefaultInvocation.Arguments.Length == 1 &&
whereSingleOrDefaultInvocation.Arguments[0].Value is IInvocationOperation whereInvocation4 &&
whereInvocation4.TargetMethod.Name == "Where" &&
whereInvocation4.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation4.Arguments.Length == 2)
{
collectionExpression = whereInvocation4.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation4.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereSingleOrDefault;
}

// Check for enumerable.Single(predicate)
if (operation is IInvocationOperation singleInvocation &&
singleInvocation.TargetMethod.Name == "Single" &&
singleInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable")
{
if (singleInvocation.Arguments.Length == 2)
{
// Extension method with predicate
collectionExpression = singleInvocation.Arguments[0].Value.Syntax;
predicateExpression = singleInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.Single;
}
else if (singleInvocation.Arguments.Length == 1)
{
// Instance method or extension without predicate
collectionExpression = singleInvocation.Instance?.Syntax ?? singleInvocation.Arguments[0].Value.Syntax;
predicateExpression = null;
return LinqPredicateCheckStatus.Single;
}
}

// Check for enumerable.SingleOrDefault(predicate)
if (operation is IInvocationOperation singleOrDefaultInvocation &&
singleOrDefaultInvocation.TargetMethod.Name == "SingleOrDefault" &&
singleOrDefaultInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable")
{
if (singleOrDefaultInvocation.Arguments.Length == 2)
{
// Extension method with predicate
collectionExpression = singleOrDefaultInvocation.Arguments[0].Value.Syntax;
predicateExpression = singleOrDefaultInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.SingleOrDefault;
}
else if (singleOrDefaultInvocation.Arguments.Length == 1)
{
// Instance method or extension without predicate
collectionExpression = singleOrDefaultInvocation.Instance?.Syntax ?? singleOrDefaultInvocation.Arguments[0].Value.Syntax;
predicateExpression = null;
return LinqPredicateCheckStatus.SingleOrDefault;
}
}

return LinqPredicateCheckStatus.Unknown;
}

private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext context, IOperation conditionArgument, bool isTrueInvocation, INamedTypeSymbol objectTypeSymbol)
{
RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation.");
Expand Down Expand Up @@ -555,6 +758,36 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
return;
}

// Check for LINQ predicate patterns that suggest Contains/DoesNotContain
LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck(
conditionArgument,
out SyntaxNode? linqCollectionExpr,
out SyntaxNode? predicateExpr,
out _);

if (linqStatus != LinqPredicateCheckStatus.Unknown && linqCollectionExpr != null && predicateExpr != null)
{
// For Any() and Where().Any() patterns
if (linqStatus is LinqPredicateCheckStatus.Any or LinqPredicateCheckStatus.WhereAny)
{
string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);
context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
conditionArgument.Syntax.GetLocation(),
predicateExpr.GetLocation(),
linqCollectionExpr.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
isTrueInvocation ? "IsTrue" : "IsFalse"));
return;
}
}

// Check for string method patterns: myString.StartsWith/EndsWith/Contains(...)
StringMethodCheckStatus stringMethodStatus = RecognizeStringMethodCheck(conditionArgument, out SyntaxNode? stringExpr, out SyntaxNode? substringExpr);
if (stringMethodStatus != StringMethodCheckStatus.Unknown)
Expand Down Expand Up @@ -624,6 +857,54 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
return;
}

// Special-case: enumerable.Count(predicate) > 0 → Assert.Contains(predicate, enumerable)
if (conditionArgument is IBinaryOperation binaryOp &&
binaryOp.OperatorKind == BinaryOperatorKind.GreaterThan)
{
if (binaryOp.LeftOperand is IInvocationOperation countInvocation &&
binaryOp.RightOperand.ConstantValue.HasValue &&
binaryOp.RightOperand.ConstantValue.Value is int intValue &&
intValue == 0 &&
countInvocation.TargetMethod.Name == "Count")
{
SyntaxNode? countCollectionExpr = null;
SyntaxNode? countPredicateExpr = null;

if (countInvocation.Instance != null && countInvocation.Arguments.Length == 1)
{
countCollectionExpr = countInvocation.Instance.Syntax;
countPredicateExpr = countInvocation.Arguments[0].Value.Syntax;
}
else if (countInvocation.Instance == null && countInvocation.Arguments.Length == 2)
{
countCollectionExpr = countInvocation.Arguments[0].Value.Syntax;
countPredicateExpr = countInvocation.Arguments[1].Value.Syntax;
}

if (countCollectionExpr != null && countPredicateExpr != null)
{
string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);

context.ReportDiagnostic(
context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
conditionArgument.Syntax.GetLocation(),
countPredicateExpr.GetLocation(),
countCollectionExpr.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
isTrueInvocation ? "IsTrue" : "IsFalse"));

return;
}
}
}

// Check for comparison patterns: a > b, a >= b, a < b, a <= b
ComparisonCheckStatus comparisonStatus = RecognizeComparisonCheck(conditionArgument, out SyntaxNode? leftExpr, out SyntaxNode? rightExpr);
if (comparisonStatus != ComparisonCheckStatus.Unknown)
Expand Down Expand Up @@ -722,6 +1003,40 @@ private static void AnalyzeAreEqualOrAreNotEqualInvocation(OperationAnalysisCont
{
if (TryGetSecondArgumentValue((IInvocationOperation)context.Operation, out IOperation? actualArgumentValue))
{
// Check for LINQ predicate patterns that suggest ContainsSingle
LinqPredicateCheckStatus linqStatus2 = RecognizeLinqPredicateCheck(
actualArgumentValue!,
out SyntaxNode? linqCollectionExpr2,
out SyntaxNode? predicateExpr2,
out _);

if (isAreEqualInvocation &&
linqStatus2 is LinqPredicateCheckStatus.Count or LinqPredicateCheckStatus.WhereCount &&
linqCollectionExpr2 != null &&
predicateExpr2 != null &&
expectedArgument.ConstantValue.HasValue &&
expectedArgument.ConstantValue.Value is int expectedCountValue &&
expectedCountValue == 1)
{
// We have Assert.AreEqual(1, enumerable.Count(predicate))
// We want Assert.ContainsSingle(predicate, enumerable)
string properAssertMethod = "ContainsSingle";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);
context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
actualArgumentValue.Syntax.GetLocation(),
predicateExpr2.GetLocation(),
linqCollectionExpr2.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
"AreEqual"));
return;
}

// Check if we're comparing a count/length property
CountCheckStatus countStatus = RecognizeCountCheck(
expectedArgument,
Expand Down
3 changes: 2 additions & 1 deletion test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ private static void ValidateOutputIsNotMixed(IEnumerable<TestResult> testResults
Assert.Contains(methodName, message.Text);
Assert.Contains("TestInitialize", message.Text);
Assert.Contains("TestCleanup", message.Text);
Assert.IsFalse(shouldNotContain.Any(message.Text.Contains));
// Assert.IsFalse(shouldNotContain.Any(message.Text.Contains));
Assert.DoesNotContain(message.Text.Contains, shouldNotContain);
}

private static void ValidateInitializeAndCleanup(IEnumerable<TestResult> testResults, Func<TestResultMessage, bool> messageFilter)
Expand Down
Loading