diff --git a/src/main/java/org/openrewrite/staticanalysis/ReplaceStringBuilderWithString.java b/src/main/java/org/openrewrite/staticanalysis/ReplaceStringBuilderWithString.java index 964beea3e..2879b22f8 100644 --- a/src/main/java/org/openrewrite/staticanalysis/ReplaceStringBuilderWithString.java +++ b/src/main/java/org/openrewrite/staticanalysis/ReplaceStringBuilderWithString.java @@ -16,9 +16,10 @@ package org.openrewrite.staticanalysis; import org.openrewrite.*; +import org.openrewrite.internal.ListUtils; +import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.PartProvider; import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.*; import org.openrewrite.marker.Markers; @@ -33,8 +34,6 @@ public class ReplaceStringBuilderWithString extends Recipe { private static final MethodMatcher STRING_BUILDER_APPEND = new MethodMatcher("java.lang.StringBuilder append(..)"); private static final MethodMatcher STRING_BUILDER_TO_STRING = new MethodMatcher("java.lang.StringBuilder toString()"); - private static J.Parentheses parenthesesTemplate; - private static J.MethodInvocation stringValueOfTemplate; @Override public String getDisplayName() { @@ -55,149 +54,118 @@ public Duration getEstimatedEffortPerOccurrence() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(Preconditions.and(new UsesMethod<>(STRING_BUILDER_APPEND), new UsesMethod<>(STRING_BUILDER_TO_STRING)), new JavaVisitor() { - @Override - public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); - - if (STRING_BUILDER_TO_STRING.matches(method)) { - List methodCallsChain = new ArrayList<>(); - List arguments = new ArrayList<>(); - boolean isFlattenable = flatMethodInvocationChain(method, methodCallsChain, arguments); - if (!isFlattenable) { - return m; - } - - Collections.reverse(arguments); - adjustExpressions(arguments); - if (arguments.isEmpty()) { - return m; - } + return Preconditions.check( + Preconditions.and( + new UsesMethod<>(STRING_BUILDER_APPEND), + new UsesMethod<>(STRING_BUILDER_TO_STRING)), + new StringBuilderToAppendVisitor() + ); + } - Expression additive = ChainStringBuilderAppendCalls.additiveExpression(arguments) - .withPrefix(method.getPrefix()); + private static class StringBuilderToAppendVisitor extends JavaVisitor { + @Override + public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); + + if (STRING_BUILDER_TO_STRING.matches(method)) { + List methodCallsChain = new ArrayList<>(); + List arguments = new ArrayList<>(); + boolean isFlattenable = flatMethodInvocationChain(method, methodCallsChain, arguments); + if (!isFlattenable || arguments.isEmpty()) { + return m; + } - if (isAMethodSelect(method)) { - additive = wrapExpression(additive); - } + Collections.reverse(arguments); + arguments = adjustExpressions(method, arguments); - return additive; + Expression additive = ChainStringBuilderAppendCalls.additiveExpression(arguments).withPrefix(method.getPrefix()); + if (isAMethodSelect(method)) { + additive = new J.Parentheses<>(randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build(additive)); } - return m; - } - // Check if a method call is a select of another method call - private boolean isAMethodSelect(J.MethodInvocation method) { - Cursor parent = getCursor().getParent(2); // 2 means skip right padded cursor - if (parent == null || !(parent.getValue() instanceof J.MethodInvocation)) { - return false; - } - return ((J.MethodInvocation) parent.getValue()).getSelect() == method; + return additive; } - }); - } + return m; + } - private J.Literal toStringLiteral(J.Literal input) { - if (input.getType() == JavaType.Primitive.String) { - return input; + // Check if a method call is a select of another method call + private boolean isAMethodSelect(J.MethodInvocation method) { + Cursor parent = getCursor().getParent(2); // 2 means skip right padded cursor + if (parent == null || !(parent.getValue() instanceof J.MethodInvocation)) { + return false; + } + return ((J.MethodInvocation) parent.getValue()).getSelect() == method; } - String value = input.getValueSource(); - return new J.Literal(randomId(), Space.EMPTY, Markers.EMPTY, value, - "\"" + value + "\"", null, JavaType.Primitive.String); - } + private J.Literal toStringLiteral(J.Literal input) { + if (input.getType() == JavaType.Primitive.String) { + return input; + } - private void adjustExpressions(List arguments) { - for (int i = 0; i < arguments.size(); i++) { - if (i == 0) { - // the first expression must be a String type to support case like `new StringBuilder().append(1)` - if (!TypeUtils.isString(arguments.get(0).getType())) { - if (arguments.get(0) instanceof J.Literal) { - // wrap by "" - arguments.set(0, toStringLiteral((J.Literal) arguments.get(0))); - } else { - J.MethodInvocation stringValueOf = getStringValueOfMethodInvocationTemplate() - .withArguments(Collections.singletonList(arguments.get(0))) - .withPrefix(arguments.get(0).getPrefix()); - arguments.set(0, stringValueOf); + String value = input.getValueSource(); + return new J.Literal(randomId(), Space.EMPTY, Markers.EMPTY, value, + "\"" + value + "\"", null, JavaType.Primitive.String); + } + + private List adjustExpressions(J.MethodInvocation method, List arguments) { + return ListUtils.map(arguments, (i, arg) -> { + if (i == 0) { + if (!TypeUtils.isString(arg.getType())) { + if (arg instanceof J.Literal) { + return toStringLiteral((J.Literal) arg); + } else { + return JavaTemplate.builder("String.valueOf(#{any()})").build() + .apply(getCursor(), method.getCoordinates().replace(), arg) + .withPrefix(arg.getPrefix()); + } } + } else if (!(arg instanceof J.Identifier || arg instanceof J.Literal || arg instanceof J.MethodInvocation)) { + return new J.Parentheses<>(randomId(), Space.EMPTY, Markers.EMPTY, JRightPadded.build(arg)); } - } else { - // wrap by parentheses to support case like `.append(1+2)` - Expression arg = arguments.get(i); - if (!(arg instanceof J.Identifier || arg instanceof J.Literal || arg instanceof J.MethodInvocation)) { - arguments.set(i, wrapExpression(arg)); - } - } + return arg; + }); } - } - /** - * Return true if the method calls chain is like "new StringBuilder().append("A")....append("B");" - * - * @param method a StringBuilder.toString() method call - * @param methodChain output methods chain - * @param arguments output expression list to be chained by '+'. - */ - private static boolean flatMethodInvocationChain(J.MethodInvocation method, - List methodChain, - List arguments - ) { - Expression select = method.getSelect(); - while (select != null) { - methodChain.add(select); - if (!(select instanceof J.MethodInvocation)) { - break; - } + /** + * Return true if the method calls chain is like "new StringBuilder().append("A")....append("B");" + * + * @param method a StringBuilder.toString() method call + * @param methodChain output methods chain + * @param arguments output expression list to be chained by '+'. + */ + private boolean flatMethodInvocationChain(J.MethodInvocation method, List methodChain, List arguments) { + Expression select = method.getSelect(); + while (select != null) { + methodChain.add(select); + if (!(select instanceof J.MethodInvocation)) { + break; + } - J.MethodInvocation selectMethod = (J.MethodInvocation) select; - select = selectMethod.getSelect(); + J.MethodInvocation selectMethod = (J.MethodInvocation) select; + select = selectMethod.getSelect(); - if (!STRING_BUILDER_APPEND.matches(selectMethod)) { - return false; - } + if (!STRING_BUILDER_APPEND.matches(selectMethod)) { + return false; + } - List args = selectMethod.getArguments(); - if (args.size() != 1) { - return false; - } else { - arguments.add(args.get(0)); + List args = selectMethod.getArguments(); + if (args.size() != 1) { + return false; + } else { + arguments.add(args.get(0)); + } } - } - if (select instanceof J.NewClass && - ((J.NewClass) select).getClazz() != null && - TypeUtils.isOfClassType(((J.NewClass) select).getClazz().getType(), "java.lang.StringBuilder")) { - J.NewClass nc = (J.NewClass) select; - if (nc.getArguments().size() == 1 && TypeUtils.isString(nc.getArguments().get(0).getType())) { - arguments.add(nc.getArguments().get(0)); + if (select instanceof J.NewClass && + ((J.NewClass) select).getClazz() != null && + TypeUtils.isOfClassType(((J.NewClass) select).getClazz().getType(), "java.lang.StringBuilder")) { + J.NewClass nc = (J.NewClass) select; + if (nc.getArguments().size() == 1 && TypeUtils.isString(nc.getArguments().get(0).getType())) { + arguments.add(nc.getArguments().get(0)); + } + return true; } - return true; + return false; } - return false; - } - - public static J.Parentheses getParenthesesTemplate() { - if (parenthesesTemplate == null) { - parenthesesTemplate = PartProvider.buildPart("class B { void foo() { (\"A\" + \"B\").length(); } } ", J.Parentheses.class); - } - return parenthesesTemplate; - } - - public static J.MethodInvocation getStringValueOfMethodInvocationTemplate() { - if (stringValueOfTemplate == null) { - stringValueOfTemplate = PartProvider.buildPart("class C {\n" + - " void foo() {\n" + - " Object obj = 1 + 2;\n" + - " String.valueOf(obj);\n" + - " }\n" + - "}", - J.MethodInvocation.class); - } - return stringValueOfTemplate; - } - - public static J.Parentheses wrapExpression(Expression exp) { - return getParenthesesTemplate().withTree(exp).withPrefix(exp.getPrefix()); } }