diff --git a/rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitorTest.java b/rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitorTest.java index 4264fcf0986..6030207335c 100644 --- a/rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitorTest.java +++ b/rewrite-java-test/src/test/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitorTest.java @@ -733,4 +733,37 @@ void foo(String a) { rewriteRun(java(beforeJava, template.formatted(after))); } } + + @Issue("https://github.com/openrewrite/rewrite-feature-flags/issues/40") + @Test + void simplifyStringLiteralEqualsStringLiteral() { + rewriteRun( + java( + """ + class A { + { + String foo = "foo"; + if ("foo".equals("foo")) {} + if (foo.equals(foo)) {} + if (foo.equals("foo")) {} + if ("foo".equals(foo)) {} + if ("foo".equals("bar")) {} + } + } + """, + """ + class A { + { + String foo = "foo"; + if (true) {} + if (true) {} + if (foo.equals("foo")) {} + if ("foo".equals(foo)) {} + if (false) {} + } + } + """ + ) + ); + } } diff --git a/rewrite-java/src/main/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitor.java b/rewrite-java/src/main/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitor.java index 742fcca4a04..71bc99f825c 100644 --- a/rewrite-java/src/main/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitor.java +++ b/rewrite-java/src/main/java/org/openrewrite/java/cleanup/SimplifyBooleanExpressionVisitor.java @@ -240,6 +240,7 @@ private J.Binary.Type maybeNegate(J.Binary.Type operator) { } private final MethodMatcher isEmpty = new MethodMatcher("java.lang.String isEmpty()"); + private final MethodMatcher equals = new MethodMatcher("java.lang.String equals(java.lang.Object)"); @Override public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) { @@ -250,6 +251,15 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext execu select instanceof J.Literal && select.getType() == JavaType.Primitive.String) { return booleanLiteral(method, J.Literal.isLiteralValue(select, "")); + } else if (equals.matches(asMethod)) { + Expression arg = asMethod.getArguments().get(0); + if (arg instanceof J.Literal && select instanceof J.Literal) { + return booleanLiteral(method, ((J.Literal) select).getValue().equals(((J.Literal) arg).getValue())); + } else if (arg instanceof J.Identifier && select instanceof J.Identifier) { + return booleanLiteral(method, SemanticallyEqual.areEqual(select, arg)); + } if (arg instanceof J.FieldAccess && select instanceof J.FieldAccess) { + return booleanLiteral(method, SemanticallyEqual.areEqual(select, arg)); + } } return j; }