From a0befe7677128cc29cf7c5f0d3785e66e8b94d00 Mon Sep 17 00:00:00 2001 From: Patrick Date: Tue, 13 Sep 2022 09:36:26 -0700 Subject: [PATCH] Add display name and some polish --- .../java/spring/boot3/PreciseBeanType.java | 106 ++++++++++-------- 1 file changed, 57 insertions(+), 49 deletions(-) diff --git a/src/main/java/org/openrewrite/java/spring/boot3/PreciseBeanType.java b/src/main/java/org/openrewrite/java/spring/boot3/PreciseBeanType.java index d0679550c..2002a4d9a 100644 --- a/src/main/java/org/openrewrite/java/spring/boot3/PreciseBeanType.java +++ b/src/main/java/org/openrewrite/java/spring/boot3/PreciseBeanType.java @@ -33,12 +33,17 @@ public class PreciseBeanType extends Recipe { @Override public String getDisplayName() { + return "Bean methods should return concrete types"; + } + + @Override + public String getDescription() { return "Replace Bean method return types with concrete types being returned. This is required for Spring 6 AOT"; } @Override - protected TreeVisitor getSingleSourceApplicableTest() { - return new UsesType(BEAN); + protected UsesType getSingleSourceApplicableTest() { + return new UsesType<>(BEAN); } @Override @@ -47,48 +52,44 @@ protected TreeVisitor getVisitor() { @Override public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext executionContext) { J.MethodDeclaration m = super.visitMethodDeclaration(method, executionContext); - if (isBeanMethod(m)) { - Object o = getCursor().pollMessage(MSG_KEY); - if (o != null) { - if (!o.equals(method.getReturnTypeExpression().getType())) { - if (o instanceof JavaType.FullyQualified) { - JavaType.FullyQualified actualType = (JavaType.FullyQualified) o; - if (m.getReturnTypeExpression() instanceof J.Identifier) { - J.Identifier identifierReturnExpr = (J.Identifier) m.getReturnTypeExpression(); - maybeAddImport(actualType); - if (identifierReturnExpr.getType() instanceof JavaType.FullyQualified) { - maybeRemoveImport((JavaType.FullyQualified) identifierReturnExpr.getType()); - } - m = m.withReturnTypeExpression(identifierReturnExpr - .withType(actualType) - .withSimpleName(actualType.getClassName()) - ); - } else if (m.getReturnTypeExpression() instanceof J.ParameterizedType) { - J.ParameterizedType parameterizedType = (J.ParameterizedType) m.getReturnTypeExpression(); - maybeAddImport(actualType); - if (parameterizedType.getType() instanceof JavaType.FullyQualified) { - maybeRemoveImport((JavaType.FullyQualified) parameterizedType.getType()); - } - m = m.withReturnTypeExpression(parameterizedType - .withType(actualType) - .withClazz(TypeTree.build(actualType.getClassName()).withType(actualType)) - ); - } + Object o = getCursor().pollMessage(MSG_KEY); + if (o != null && (method.getReturnTypeExpression() != null && !o.equals(method.getReturnTypeExpression().getType())) && isBeanMethod(m)) { + if (o instanceof JavaType.FullyQualified) { + JavaType.FullyQualified actualType = (JavaType.FullyQualified) o; + if (m.getReturnTypeExpression() instanceof J.Identifier) { + J.Identifier identifierReturnExpr = (J.Identifier) m.getReturnTypeExpression(); + maybeAddImport(actualType); + if (identifierReturnExpr.getType() instanceof JavaType.FullyQualified) { + maybeRemoveImport((JavaType.FullyQualified) identifierReturnExpr.getType()); + } + m = m.withReturnTypeExpression(identifierReturnExpr + .withType(actualType) + .withSimpleName(actualType.getClassName()) + ); + } else if (m.getReturnTypeExpression() instanceof J.ParameterizedType) { + J.ParameterizedType parameterizedType = (J.ParameterizedType) m.getReturnTypeExpression(); + maybeAddImport(actualType); + if (parameterizedType.getType() instanceof JavaType.FullyQualified) { + maybeRemoveImport((JavaType.FullyQualified) parameterizedType.getType()); + } + m = m.withReturnTypeExpression(parameterizedType + .withType(actualType) + .withClazz(TypeTree.build(actualType.getClassName()).withType(actualType)) + ); + } - } else if (o instanceof JavaType.Array) { - JavaType.Array actualType = (JavaType.Array) o; - if (m.getReturnTypeExpression() instanceof J.ArrayType && actualType.getElemType() instanceof JavaType.FullyQualified) { - JavaType.FullyQualified actualElementType = (JavaType.FullyQualified) actualType.getElemType(); - J.ArrayType arrayType = (J.ArrayType) m.getReturnTypeExpression(); - maybeAddImport(actualElementType); - if (arrayType.getElementType() instanceof JavaType.FullyQualified) { - maybeRemoveImport((JavaType.FullyQualified) arrayType.getElementType()); - } - m = m.withReturnTypeExpression(arrayType - .withElementType(TypeTree.build(actualElementType.getClassName()).withType(actualType)) - ); - } + } else if (o instanceof JavaType.Array) { + JavaType.Array actualType = (JavaType.Array) o; + if (m.getReturnTypeExpression() instanceof J.ArrayType && actualType.getElemType() instanceof JavaType.FullyQualified) { + JavaType.FullyQualified actualElementType = (JavaType.FullyQualified) actualType.getElemType(); + J.ArrayType arrayType = (J.ArrayType) m.getReturnTypeExpression(); + maybeAddImport(actualElementType); + if (arrayType.getElementType() instanceof JavaType.FullyQualified) { + maybeRemoveImport((JavaType.FullyQualified) arrayType.getElementType()); } + m = m.withReturnTypeExpression(arrayType + .withElementType(TypeTree.build(actualElementType.getClassName()).withType(actualType)) + ); } } } @@ -96,17 +97,24 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex } private boolean isBeanMethod(J.MethodDeclaration m) { - return m.getLeadingAnnotations().stream().anyMatch(a -> TypeUtils.isOfClassType(a.getType(), BEAN)); + for (J.Annotation leadingAnnotation : m.getLeadingAnnotations()) { + if (TypeUtils.isOfClassType(leadingAnnotation.getType(), BEAN)) { + return true; + } + } + return false; } @Override public J.Return visitReturn(J.Return _return, ExecutionContext executionContext) { - Cursor methodCursor = getCursor(); - while (methodCursor != null && !(methodCursor.getValue() instanceof J.Lambda || methodCursor.getValue() instanceof J.MethodDeclaration)) { - methodCursor = methodCursor.getParent(); - } - if (methodCursor != null && methodCursor.getValue() instanceof J.MethodDeclaration) { - methodCursor.putMessage(MSG_KEY, _return.getExpression().getType()); + if (_return.getExpression() != null && _return.getExpression().getType() != null) { + Cursor methodCursor = getCursor(); + while (methodCursor != null && !(methodCursor.getValue() instanceof J.Lambda || methodCursor.getValue() instanceof J.MethodDeclaration)) { + methodCursor = methodCursor.getParent(); + } + if (methodCursor != null && methodCursor.getValue() instanceof J.MethodDeclaration) { + methodCursor.putMessage(MSG_KEY, _return.getExpression().getType()); + } } return super.visitReturn(_return, executionContext); }