diff --git a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java index 420d4eacb..c346a5ff2 100644 --- a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java +++ b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java @@ -17,7 +17,6 @@ import lombok.EqualsAndHashCode; import lombok.Value; -import org.jspecify.annotations.Nullable; import org.openrewrite.Tree; import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.MethodMatcher; @@ -27,11 +26,29 @@ import static java.util.Collections.singletonList; +/** + * A visitor that identifies and addresses potential issues related to + * the use of {@code equals} methods in Java, particularly to avoid + * null pointer exceptions when comparing strings. + *

+ * This visitor looks for method invocations of {@code equals}, + * {@code equalsIgnoreCase}, {@code compareTo}, and {@code contentEquals}, + * and performs optimizations to ensure null checks are correctly applied. + *

+ * For more details, refer to the PMD best practices: + * Literals First in Comparisons + * + * @param

The type of the parent context used for visiting the AST. + */ @Value @EqualsAndHashCode(callSuper = false) public class EqualsAvoidsNullVisitor

extends JavaVisitor

{ - private static final MethodMatcher STRING_EQUALS = new MethodMatcher("String equals(java.lang.Object)"); - private static final MethodMatcher STRING_EQUALS_IGNORE_CASE = new MethodMatcher("String equalsIgnoreCase(java.lang.String)"); + + MethodMatcher EQUALS = new MethodMatcher("java.lang.String equals(java.lang.Object)"); + MethodMatcher EQUALS_IGNORE_CASE = new MethodMatcher("java.lang.String equalsIgnoreCase(java.lang.String)"); + MethodMatcher COMPARE_TO = new MethodMatcher("java.lang.String compareTo(java.lang.String)"); + MethodMatcher COMPARE_TO_IGNORE_CASE = new MethodMatcher("java.lang.String compareToIgnoreCase(java.lang.String)"); + MethodMatcher CONTENT_EQUALS = new MethodMatcher("java.lang.String contentEquals(java.lang.CharSequence)"); EqualsAvoidsNullStyle style; @@ -45,22 +62,28 @@ public J visitMethodInvocation(J.MethodInvocation method, P p) { if (m.getSelect() == null) { return m; } - - if ((STRING_EQUALS.matches(m) || (!Boolean.TRUE.equals(style.getIgnoreEqualsIgnoreCase()) && STRING_EQUALS_IGNORE_CASE.matches(m))) && - m.getArguments().get(0) instanceof J.Literal && - !(m.getSelect() instanceof J.Literal)) { - Tree parent = getCursor().getParentTreeCursor().getValue(); + if (!(m.getSelect() instanceof J.Literal) + && m.getArguments().get(0) instanceof J.Literal + && (EQUALS.matches(m) + || !style.getIgnoreEqualsIgnoreCase() + && EQUALS_IGNORE_CASE.matches(m) + || COMPARE_TO.matches(m) + || COMPARE_TO_IGNORE_CASE.matches(m) + || CONTENT_EQUALS.matches(m))) { + final Object parent = getCursor().getParentTreeCursor().getValue(); if (parent instanceof J.Binary) { - J.Binary binary = (J.Binary) parent; - if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof J.Binary) { - J.Binary potentialNullCheck = (J.Binary) binary.getLeft(); - if ((isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), m.getSelect())) || - (isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), m.getSelect()))) { + final J.Binary binary = (J.Binary) parent; + if (binary.getLeft() instanceof J.Binary + && binary.getOperator() == J.Binary.Type.And) { + final J.Binary potentialNullCheck = (J.Binary) binary.getLeft(); + if (isNullLiteral(potentialNullCheck.getLeft()) + && matchesSelect(potentialNullCheck.getRight(), m.getSelect()) + || isNullLiteral(potentialNullCheck.getRight()) + && matchesSelect(potentialNullCheck.getLeft(), m.getSelect())) { doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary)); } } } - if (m.getArguments().get(0).getType() == JavaType.Primitive.Null) { return new J.Binary(Tree.randomId(), m.getPrefix(), Markers.EMPTY, m.getSelect(), @@ -68,11 +91,10 @@ public J visitMethodInvocation(J.MethodInvocation method, P p) { m.getArguments().get(0).withPrefix(Space.SINGLE_SPACE), JavaType.Primitive.Boolean); } else { - m = m.withSelect(((J.Literal) m.getArguments().get(0)).withPrefix(m.getSelect().getPrefix())) + return m.withSelect(m.getArguments().get(0).withPrefix(m.getSelect().getPrefix())) .withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY))); } } - return m; } @@ -88,14 +110,6 @@ private static class RemoveUnnecessaryNullCheck

extends JavaVisitor

{ private final J.Binary scope; boolean done; - @Override - public @Nullable J visit(@Nullable Tree tree, P p) { - if (done) { - return (J) tree; - } - return super.visit(tree, p); - } - public RemoveUnnecessaryNullCheck(J.Binary scope) { this.scope = scope; } @@ -106,7 +120,6 @@ public J visitBinary(J.Binary binary, P p) { done = true; return binary.getRight().withPrefix(Space.EMPTY); } - return super.visitBinary(binary, p); } } diff --git a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java index 68bb55236..a78507a3c 100644 --- a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java @@ -39,18 +39,24 @@ void invertConditional() { """ public class A { { - String s = null; - if(s.equals("test")) {} - if(s.equalsIgnoreCase("test")) {} + String s = "LiteralsFirstInComparisons"; + System.out.println(s.compareTo("LiteralsFirstInComparisons")); + System.out.println(s.compareToIgnoreCase("LiteralsFirstInComparisons")); + System.out.println(s.contentEquals("LiteralsFirstInComparisons")); + System.out.println(s.equals("LiteralsFirstInComparisons")); + System.out.println(s.equalsIgnoreCase("LiteralsFirstInComparisons")); } } """, """ public class A { { - String s = null; - if("test".equals(s)) {} - if("test".equalsIgnoreCase(s)) {} + String s = "LiteralsFirstInComparisons"; + System.out.println("LiteralsFirstInComparisons".compareTo(s)); + System.out.println("LiteralsFirstInComparisons".compareToIgnoreCase(s)); + System.out.println("LiteralsFirstInComparisons".contentEquals(s)); + System.out.println("LiteralsFirstInComparisons".equals(s)); + System.out.println("LiteralsFirstInComparisons".equalsIgnoreCase(s)); } } """ @@ -67,8 +73,8 @@ void removeUnnecessaryNullCheck() { public class A { { String s = null; - if(s != null && s.equals("test")) {} - if(null != s && s.equals("test")) {} + if(s != null && s.equals("LiteralsFirstInComparisons")) {} + if(null != s && s.equals("LiteralsFirstInComparisons")) {} } } """, @@ -76,8 +82,8 @@ public class A { public class A { { String s = null; - if("test".equals(s)) {} - if("test".equals(s)) {} + if("LiteralsFirstInComparisons".equals(s)) {} + if("LiteralsFirstInComparisons".equals(s)) {} } } """ @@ -88,8 +94,8 @@ public class A { @Test void nullLiteral() { rewriteRun( - //language=java - java(""" + //language=java + java(""" public class A { void foo(String s) { if(s.equals(null)) { @@ -97,8 +103,8 @@ void foo(String s) { } } """, - """ - + """ + public class A { void foo(String s) { if(s == null) {