diff --git a/Meziantou.FluentAssertionsAnalyzers.Tests/NUnit4ToFluentAssertionsAnalyzerUnitTests.cs b/Meziantou.FluentAssertionsAnalyzers.Tests/NUnit4ToFluentAssertionsAnalyzerUnitTests.cs index 33d6f42..6e1a87a 100644 --- a/Meziantou.FluentAssertionsAnalyzers.Tests/NUnit4ToFluentAssertionsAnalyzerUnitTests.cs +++ b/Meziantou.FluentAssertionsAnalyzers.Tests/NUnit4ToFluentAssertionsAnalyzerUnitTests.cs @@ -564,4 +564,68 @@ public void MyTest() } """); } + + [Theory] + [InlineData(@"ClassicAssert.AreEqual(test, 6)", @"test.Should().Be(6)")] + [InlineData(@"ClassicAssert.AreNotEqual(test, 6)", @"test.Should().NotBe(6)")] + [InlineData(@"ClassicAssert.AreEqual(test, ""test"")", @"test.Should().Be(""test"")")] + [InlineData(@"ClassicAssert.AreNotEqual(test, ""test"")", @"test.Should().NotBe(""test"")")] + public Task Assert_Inverted_Asserts(string code, string fix) + { + return Assert( + $$""" + using NUnit.Framework.Legacy; + + class Test + { + public void MyTest(object test) + { + [|{{code}}|]; + } + } + """, + $$""" + using FluentAssertions; + using NUnit.Framework.Legacy; + + class Test + { + public void MyTest(object test) + { + {{fix}}; + } + } + """); + } + + [Theory] + [InlineData(@"ClassicAssert.AreEqual(test, 6.0, delta: 2.0)", @"test.Should().BeApproximately(6.0, 2.0)")] + [InlineData(@"ClassicAssert.AreEqual(test, 6.0)", @"test.Should().Be(6.0)")] + public Task Assert_Inverted_Asserts_double(string code, string fix) + { + return Assert( + $$""" + using NUnit.Framework.Legacy; + + class Test + { + public void MyTest(double test) + { + [|{{code}}|]; + } + } + """, + $$""" + using FluentAssertions; + using NUnit.Framework.Legacy; + + class Test + { + public void MyTest(double test) + { + {{fix}}; + } + } + """); + } } diff --git a/Meziantou.FluentAssertionsAnalyzers/NunitAssertAnalyzerCodeFixProvider.cs b/Meziantou.FluentAssertionsAnalyzers/NunitAssertAnalyzerCodeFixProvider.cs index ba81352..18b4f49 100644 --- a/Meziantou.FluentAssertionsAnalyzers/NunitAssertAnalyzerCodeFixProvider.cs +++ b/Meziantou.FluentAssertionsAnalyzers/NunitAssertAnalyzerCodeFixProvider.cs @@ -178,18 +178,17 @@ private static async Task Rewrite(Document document, SyntaxNode nodeTo { if (methodName is "AreEqual") { - if (method.Parameters[0].Type.SpecialType == SpecialType.System_Double) - { - result = rewrite.UsingShould(arguments[1], "BeApproximately", ArgumentList(arguments[0], arguments.Skip(2))); - } - else - { - result = rewrite.UsingShould(arguments[1], "Be", ArgumentList(arguments[0], arguments.Skip(2))); - } + var (left, right) = GetLeftRight(arguments, semanticModel, cancellationToken); + + var useBeApproximately = semanticModel.GetTypeInfo(left.Expression, cancellationToken).Type?.SpecialType == SpecialType.System_Double + && arguments.FirstOrDefault(x => x.NameColon?.Name.Identifier.ValueText is "delta") is not null; + + result = rewrite.UsingShould(right, useBeApproximately ? "BeApproximately" : "Be", ArgumentList(left, arguments.Skip(2))); } else if (methodName is "AreNotEqual") { - result = rewrite.UsingShould(arguments[1], "NotBe", ArgumentList(arguments[0], arguments.Skip(2))); + var (left, right) = GetLeftRight(arguments, semanticModel, cancellationToken); + result = rewrite.UsingShould(right, "NotBe", ArgumentList(left, arguments.Skip(2))); } else if (methodName is "AreSame") { @@ -855,6 +854,28 @@ bool IsGenericMethod(out ITypeSymbol typeArgument, ITypeSymbol root, params stri return document; } + private static (ArgumentSyntax left, ArgumentSyntax right) GetLeftRight(SeparatedSyntaxList arguments, SemanticModel semanticModel, CancellationToken cancellationToken) + { + var left = arguments[0]; + var right = arguments[1]; + var leftValue = semanticModel.GetConstantValue(left.Expression, cancellationToken); + var rightValue = semanticModel.GetConstantValue(right.Expression, cancellationToken); + + // Don't invert if both are constant + if (leftValue.HasValue && rightValue.HasValue) + { + return (left, right); + } + + // Invert if right is constant + if (rightValue.HasValue) + { + return (right, left); + } + + return (left, right); + } + private static ITypeSymbol GetConstantTypeValue(SemanticModel semanticModel, ExpressionSyntax expression, CancellationToken cancellationToken) { var operation = semanticModel.GetOperation(expression, cancellationToken);