diff --git a/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java b/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java index 9c371a8e7..470f12183 100644 --- a/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java +++ b/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java @@ -179,7 +179,8 @@ public J visitArrayAccessExpression(KtArrayAccessExpression expression, Executio Markers markers = Markers.EMPTY; Expression selectExpr = convertToExpression(requireNonNull(expression.getArrayExpression()).accept(this, data)); JRightPadded select = padRight(selectExpr, suffix(expression.getArrayExpression())); - J.Identifier name = createIdentifier("", Space.EMPTY, methodInvocationType(expression)); + JavaType.Method type = methodInvocationType(expression); + J.Identifier name = createIdentifier("", Space.EMPTY, type); markers = markers.addIfAbsent(new IndexedAccess(randomId())); List indexExpressions = expression.getIndexExpressions(); @@ -200,7 +201,7 @@ public J visitArrayAccessExpression(KtArrayAccessExpression expression, Executio null, name, args, - methodInvocationType(expression) + type ); } @@ -1049,14 +1050,14 @@ public J visitQualifiedExpression(KtQualifiedExpression expression, ExecutionCon .withPrefix(endFixPrefixAndInfix(expression)); } else { J.Identifier identifier = (J.Identifier) selector; - return new J.FieldAccess( + return mapType(new J.FieldAccess( randomId(), deepPrefix(expression), Markers.EMPTY, receiver, padLeft(suffix(expression.getReceiverExpression()), identifier), type(expression) - ); + )); } } @@ -1257,6 +1258,7 @@ public J visitTypeAlias(KtTypeAlias typeAlias, ExecutionContext data) { typeExpression = (TypeTree) typeAlias.getTypeParameterList().accept(this, data); if (typeExpression instanceof J.ParameterizedType) { + typeExpression = mapType(typeExpression); Space prefix = name.getPrefix(); typeExpression = ((J.ParameterizedType) typeExpression).withClazz(name.withPrefix(Space.EMPTY).withPrefix(prefix)); } @@ -1702,12 +1704,12 @@ public J visitAnnotation(KtAnnotation annotation, ExecutionContext data) { rp -> rp.withAfter(prefix(findFirstChild(annotation, anno -> anno.getNode().getElementType() == KtTokens.RBRACKET)))); } - return new J.Annotation(randomId(), + return mapType(new J.Annotation(randomId(), Space.EMPTY, Markers.EMPTY.addIfAbsent(new AnnotationUseSite(randomId(), suffix(annotation.getUseSiteTarget()), isImplicitBracket)), (NameTree) annotation.getUseSiteTarget().accept(this, data), JContainer.build(beforeLBracket, rpAnnotations, Markers.EMPTY) - ); + )); } @Override @@ -1763,14 +1765,14 @@ public J visitArgument(KtValueArgument argument, ExecutionContext data) { expr ); } - return new J.Assignment( + return mapType(new J.Assignment( randomId(), deepPrefix(argument), Markers.EMPTY, name, padLeft(suffix(argument.getArgumentName()), expr), type(argument.getArgumentExpression()) - ); + )); } else if (argument.isSpread()) { Expression j = (Expression) argument.getArgumentExpression().accept(this, data); return new K.SpreadArgument( @@ -1798,10 +1800,10 @@ public J visitBinaryExpression(KtBinaryExpression expression, ExecutionContext d Expression left = convertToExpression(expression.getLeft().accept(this, data)).withPrefix(Space.EMPTY); Expression right = convertToExpression((expression.getRight()).accept(this, data)) .withPrefix(prefix(expression.getRight())); - JavaType type = type(expression); // FIXME: expressions may map to many trees due to de-sugaring. + JavaType type = type(expression); // FIXME: This requires detection of infix overrides and operator overloads. if (javaBinaryType != null) { - return new J.Binary( + return mapType(new J.Binary( randomId(), deepPrefix(expression), Markers.EMPTY, @@ -1809,18 +1811,18 @@ public J visitBinaryExpression(KtBinaryExpression expression, ExecutionContext d padLeft(prefix(operationReference), javaBinaryType), right, type - ); + )); } else if (operationReference.getOperationSignTokenType() == KtTokens.EQ) { - return new J.Assignment( + return mapType(new J.Assignment( randomId(), deepPrefix(expression), Markers.EMPTY, left, padLeft(suffix(expression.getLeft()), right), type - ); + )); } else if (assignmentOperationType != null) { - return new J.AssignmentOperation( + return mapType(new J.AssignmentOperation( randomId(), deepPrefix(expression), Markers.EMPTY, @@ -1828,9 +1830,9 @@ public J visitBinaryExpression(KtBinaryExpression expression, ExecutionContext d padLeft(prefix(operationReference), assignmentOperationType), right, type - ); + )); } else if (kotlinBinaryType != null) { - return new K.Binary( + return mapType(new K.Binary( randomId(), deepPrefix(expression), Markers.EMPTY, @@ -1839,7 +1841,7 @@ public J visitBinaryExpression(KtBinaryExpression expression, ExecutionContext d right, Space.EMPTY, type - ); + )); } return mapFunctionCall(expression, data); @@ -1907,30 +1909,17 @@ public J visitCallExpression(KtCallExpression expression, ExecutionContext data) parameters.add(padRight(convertToExpression(ktTypeProjection.accept(this, data)), suffix(ktTypeProjection))); } - JavaType javaType = type(expression); - JavaType nameType = JavaType.Unknown.getInstance(); - JavaType pt = JavaType.Unknown.getInstance(); - if (javaType instanceof JavaType.Method) { - pt = ((JavaType.Method) javaType).getReturnType(); - } else if (javaType instanceof JavaType.Variable) { - pt = ((JavaType.Variable) javaType).getType(); - } else if (javaType instanceof JavaType.Parameterized) { - pt = javaType; - } - if (pt instanceof JavaType.Parameterized) { - nameType = ((JavaType.Parameterized) pt).getType(); - } - name = new J.ParameterizedType( + name = mapType(new J.ParameterizedType( randomId(), name.getPrefix(), Markers.EMPTY, - name.withType(nameType).withPrefix(Space.EMPTY), + name.withPrefix(Space.EMPTY), JContainer.build(prefix(expression.getTypeArgumentList()), parameters, Markers.EMPTY), - pt - ); + type(expression) + )); } - return new J.NewClass( + return mapType(new J.NewClass( randomId(), deepPrefix(expression), Markers.EMPTY, @@ -1940,7 +1929,7 @@ public J visitCallExpression(KtCallExpression expression, ExecutionContext data) mapValueArgumentsMaybeWithTrailingLambda(expression.getValueArgumentList(), expression.getValueArguments(), data), null, mt - ); + )); } else if (type == null || type == PsiElementAssociations.ExpressionType.METHOD_INVOCATION) { J j = expression.getCalleeExpression().accept(this, data); JRightPadded select = null; @@ -1960,29 +1949,28 @@ public J visitCallExpression(KtCallExpression expression, ExecutionContext data) args = args.withMarkers(args.getMarkers().addIfAbsent(new OmitParentheses(randomId()))); } - JavaType.Method methodType = methodInvocationType(expression); - return new J.MethodInvocation( + return mapType(new J.MethodInvocation( randomId(), deepPrefix(expression), Markers.EMPTY, select, typeParams, - name.withType(methodType), + name, args, - methodType - ); + methodInvocationType(expression) + )); } else if (type == PsiElementAssociations.ExpressionType.QUALIFIER) { TypeTree typeTree = (TypeTree) expression.getCalleeExpression().accept(this, data); JContainer typeParams = mapTypeArguments(expression.getTypeArgumentList(), data); - return new J.ParameterizedType( + return mapType(new J.ParameterizedType( randomId(), deepPrefix(expression), Markers.EMPTY, typeTree, typeParams, type(expression) - ); + )); } else { throw new UnsupportedOperationException("ExpressionType not found: " + expression.getCalleeExpression().getText()); } @@ -2348,15 +2336,16 @@ public J visitDotQualifiedExpression(KtDotQualifiedExpression expression, Execut J.ParameterizedType pt = (J.ParameterizedType) j; if (pt != null) { pt = pt.withClazz(pt.getClazz().withPrefix(prefix(callExpression))); - J.FieldAccess newName = new J.FieldAccess( + J.FieldAccess newName = mapType(new J.FieldAccess( randomId(), receiver.getPrefix(), Markers.EMPTY, receiver.withPrefix(Space.EMPTY), padLeft(suffix(expression.getReceiverExpression()), (J.Identifier) pt.getClazz()), pt.getType() - ); + )); pt = pt.withClazz(newName); + pt = mapType(pt); } return pt; @@ -2385,6 +2374,7 @@ public J visitDotQualifiedExpression(KtDotQualifiedExpression expression, Execut pt.getType() ); pt = pt.withClazz(newName); + pt = mapType(pt); cur = cur.withClazz(pt); } methodInvocation = cur; @@ -2392,14 +2382,14 @@ public J visitDotQualifiedExpression(KtDotQualifiedExpression expression, Execut J.Identifier id = (J.Identifier) cur.getClazz(); if (id != null) { id = id.withPrefix(prefix(callExpression)); - J.FieldAccess newName = new J.FieldAccess( + J.FieldAccess newName = mapType(new J.FieldAccess( randomId(), receiver.getPrefix(), Markers.EMPTY, receiver.withPrefix(Space.EMPTY), padLeft(suffix(expression.getReceiverExpression()), id), id.getType() - ); + )); methodInvocation = cur.withClazz(newName); } } @@ -2409,14 +2399,14 @@ public J visitDotQualifiedExpression(KtDotQualifiedExpression expression, Execut return methodInvocation; } else if (expression.getSelectorExpression() instanceof KtNameReferenceExpression) { // Maybe need to type check before creating a field access. - return new J.FieldAccess( + return mapType(new J.FieldAccess( randomId(), deepPrefix(expression), Markers.EMPTY, convertToExpression(expression.getReceiverExpression().accept(this, data).withPrefix(Space.EMPTY)), padLeft(suffix(expression.getReceiverExpression()), (J.Identifier) expression.getSelectorExpression().accept(this, data)), type(expression.getSelectorExpression()) - ); + )); } else { throw new UnsupportedOperationException("Unsupported dot qualified selector: " + expression.getSelectorExpression().getClass()); } @@ -2454,14 +2444,14 @@ public J visitImportDirective(KtImportDirective importDirective, ExecutionContex reference = reference.withPrefix(suffix(importPsi)); if (reference instanceof J.Identifier) { - reference = new J.FieldAccess( + reference = mapType(new J.FieldAccess( randomId(), suffix(importPsi), Markers.EMPTY, new J.Empty(randomId(), Space.EMPTY, Markers.EMPTY), padLeft(Space.EMPTY, (J.Identifier) reference), type(importDirective) - ); + )); } return new J.Import( @@ -2541,7 +2531,6 @@ private J visitNamedFunction0(KtNamedFunction function, ExecutionContext data) { if (ktParameters.isEmpty()) { params = JContainer.build(prefix(function.getValueParameterList()), singletonList(padRight(new J.Empty(randomId(), - // TODO: fix NPE. prefix(function.getValueParameterList().getRightParenthesis()), Markers.EMPTY), Space.EMPTY) @@ -2748,19 +2737,14 @@ public J visitPrefixExpression(KtPrefixExpression expression, ExecutionContext d // FIXME: Add detection of overloads and return the appropriate trees when it is not equivalent to a J.Unary. // Returning the base type only applies when the expression is equivalent to a J.Binary. JavaType javaType = type(expression); - if (javaType instanceof JavaType.Method) { - javaType = ((JavaType.Method) javaType).getReturnType(); - } else if (javaType instanceof JavaType.Variable) { - javaType = ((JavaType.Variable) javaType).getType(); - } - return new J.Unary( + return mapType(new J.Unary( randomId(), deepPrefix(expression), Markers.EMPTY, padLeft(prefix(expression.getOperationReference()), type), expression.getBaseExpression().accept(this, data).withPrefix(suffix(expression.getOperationReference())), javaType - ); + )); } @Override @@ -2776,12 +2760,11 @@ public J visitPostfixExpression(KtPostfixExpression expression, ExecutionContext J j = convertToExpression(requireNonNull(expression.getBaseExpression()).accept(this, data)); IElementType referencedNameElementType = expression.getOperationReference().getReferencedNameElementType(); if (referencedNameElementType == KtTokens.EXCLEXCL) { - // j = j.withMarkers(j.getMarkers().addIfAbsent(new CheckNotNull(randomId(), prefix(expression.getOperationReference())))); - j = new K.Unary(randomId(), deepPrefix(expression), Markers.EMPTY, padLeft(prefix(expression.getOperationReference()), K.Unary.Type.NotNull), (Expression) j, type); + j = mapType(new K.Unary(randomId(), deepPrefix(expression), Markers.EMPTY, padLeft(prefix(expression.getOperationReference()), K.Unary.Type.NotNull), (Expression) j, type)); } else if (referencedNameElementType == KtTokens.PLUSPLUS) { - j = new J.Unary(randomId(), deepPrefix(expression), Markers.EMPTY, padLeft(prefix(expression.getOperationReference()), J.Unary.Type.PostIncrement), (Expression) j, type); + j = mapType(new J.Unary(randomId(), deepPrefix(expression), Markers.EMPTY, padLeft(prefix(expression.getOperationReference()), J.Unary.Type.PostIncrement), (Expression) j, type)); } else if (referencedNameElementType == KtTokens.MINUSMINUS) { - j = new J.Unary(randomId(), deepPrefix(expression), Markers.EMPTY, padLeft(prefix(expression.getOperationReference()), J.Unary.Type.PostDecrement), (Expression) j, type); + j = mapType(new J.Unary(randomId(), deepPrefix(expression), Markers.EMPTY, padLeft(prefix(expression.getOperationReference()), J.Unary.Type.PostDecrement), (Expression) j, type)); } else { throw new UnsupportedOperationException("TODO"); } @@ -3146,7 +3129,6 @@ public J visitTypeReference(KtTypeReference typeReference, ExecutionContext data @Override public J visitUserType(KtUserType type, ExecutionContext data) { - // FIXME: must be mapped through parent element. I.E. functionType. J.Identifier name = (J.Identifier) requireNonNull(type.getReferenceExpression()).accept(this, data); if (type.getFirstChild() == type.getReferenceExpression()) { @@ -3157,7 +3139,7 @@ public J visitUserType(KtUserType type, ExecutionContext data) { if (type.getQualifier() != null) { Expression select = convertToExpression(type.getQualifier().accept(this, data)).withPrefix(prefix(type.getQualifier())); - nameTree = new J.FieldAccess(randomId(), Space.EMPTY, Markers.EMPTY, select, padLeft(suffix(type.getQualifier()), name), null); + nameTree = mapType(new J.FieldAccess(randomId(), Space.EMPTY, Markers.EMPTY, select, padLeft(suffix(type.getQualifier()), name), null)); } if (type.getTypeArgumentList() != null) { @@ -3171,14 +3153,14 @@ public J visitUserType(KtUserType type, ExecutionContext data) { } JavaType.Parameterized pt = (JavaType.Parameterized) javaType; - return new J.ParameterizedType( + return mapType(new J.ParameterizedType( randomId(), Space.EMPTY, Markers.EMPTY, - pt == null ? nameTree.withType(JavaType.Unknown.getInstance()) : nameTree.withType(pt.getType()), + nameTree, args, pt == null ? JavaType.Unknown.getInstance() : pt - ); + )); } return nameTree; @@ -3250,9 +3232,9 @@ else if (elementType == KtTokens.MUL) else if (elementType == KtTokens.DIV) return J.Binary.Type.Division; else if (elementType == KtTokens.EQEQ) - return J.Binary.Type.Equal; // TODO should this not be mapped to `Object#equals(Object)`? + return J.Binary.Type.Equal; else if (elementType == KtTokens.EXCLEQ) - return J.Binary.Type.NotEqual; // TODO should this not be mapped to `!Object#equals(Object)`? + return J.Binary.Type.NotEqual; else if (elementType == KtTokens.GT) return J.Binary.Type.GreaterThan; else if (elementType == KtTokens.GTEQ) @@ -3325,7 +3307,6 @@ private J.ControlParentheses buildIfCondition(KtIfExpression express return new J.ControlParentheses<>(randomId(), prefix(expression.getLeftParenthesis()), Markers.EMPTY, - // TODO: fix NPE. padRight(convertToExpression(requireNonNull(expression.getCondition()).accept(this, executionContext)) .withPrefix(suffix(expression.getLeftParenthesis())), prefix(expression.getRightParenthesis())) @@ -3357,25 +3338,146 @@ private J.If.Else buildIfElsePart(KtIfExpression expression) { /*==================================================================== * Type related methods * ====================================================================*/ - private T mapType(T tree) { - return mapType(tree, null); - } - @SuppressWarnings("unchecked") - private T mapType(T tree, @Nullable JavaType type) { + private T mapType(T tree) { + // TODO: polish, prevent unnecessary casts T updated = tree; /* Java trees */ if (updated instanceof J.Annotation) { - J.Annotation a = (J.Annotation) updated; - if (a.getAnnotationType() instanceof J.Identifier && a.getAnnotationType().getType() instanceof JavaType.Method) { - a = a.withAnnotationType(((J.Identifier) a.getAnnotationType()).withType(((JavaType.Method) a.getAnnotationType().getType()).getReturnType())); + if (isNotFullyQualified(((J.Annotation) updated).getAnnotationType().getType())) { + J.Annotation a = (J.Annotation) updated; + if (a.getAnnotationType().getType() instanceof JavaType.Method) { + a = a.withAnnotationType(a.getAnnotationType().withType(((JavaType.Method) a.getAnnotationType().getType()).getReturnType())); + } else { + // Type association error, add marker for use in data table. + } + updated = (T) a; + } + } else if (updated instanceof J.Assignment) { + if (isNotFullyQualified(((J.Assignment) updated).getType())) { + J.Assignment a = (J.Assignment) updated; + if (a.getType() instanceof JavaType.Method) { + a = a.withType(((JavaType.Method) a.getType()).getReturnType()); + } else if (a.getType() instanceof JavaType.Variable) { + a = a.withType(((JavaType.Variable) a.getType()).getType()); + } else { + // Type association error, add marker for use in data table. + } + updated = (T) a; + } + } else if (updated instanceof J.AssignmentOperation) { + if (isNotFullyQualified(((J.AssignmentOperation) updated).getType())) { + J.AssignmentOperation a = (J.AssignmentOperation) updated; + if (a.getType() instanceof JavaType.Method) { + a = a.withType(((JavaType.Method) a.getType()).getReturnType()); + } else if (a.getType() instanceof JavaType.Variable) { + a = a.withType(((JavaType.Variable) a.getType()).getType()); + } else { + // Type association error, add marker for use in data table. + } + updated = (T) a; + } + } else if (updated instanceof J.Binary && ((J.Binary) updated).getType() != null) { + if (isNotFullyQualified(((J.Binary) updated).getType())) { + J.Binary b = (J.Binary) updated; + if (b.getType() instanceof JavaType.Method) { + b = b.withType(((JavaType.Method) b.getType()).getReturnType()); + } else { + // Type association error, add marker for use in data table. + } + updated = (T) b; + } + } else if (updated instanceof J.FieldAccess) { + if (isNotFullyQualified(((J.FieldAccess) updated).getType())) { + J.FieldAccess f = (J.FieldAccess) updated; + if (f.getType() instanceof JavaType.Method) { + f = f.withType(((JavaType.Method) f.getType()).getReturnType()); + } else if (f.getType() instanceof JavaType.Variable) { + f = f.withType(((JavaType.Variable) f.getType()).getType()); + } else { + // Type association error, add marker for use in data table. + } + updated = (T) f; + } + + if (((J.FieldAccess) updated).getTarget() instanceof J.Identifier && + isNotFullyQualified((((J.FieldAccess) updated).getTarget()).getType())) { + // Type association error, add marker for use in data table. + } + } else if (updated instanceof J.MethodDeclaration) { + if (!(((J.MethodDeclaration) updated).getName().getType() instanceof JavaType.Method)) { + J.MethodDeclaration m = (J.MethodDeclaration) updated; + m = m.withName(m.getName().withType(m.getMethodType())); + updated = (T) m; + } + } else if (updated instanceof J.MethodInvocation) { + if (!(((J.MethodInvocation) updated).getName().getType() instanceof JavaType.Method)) { + J.MethodInvocation m = (J.MethodInvocation) updated; + m = m.withName(m.getName().withType(m.getMethodType())); + updated = (T) m; + } + } else if (updated instanceof J.NewClass) { + J.NewClass n = (J.NewClass) updated; + if (n.getClazz() != null && n.getClazz() instanceof J.Identifier && + n.getClazz().getType() instanceof JavaType.Parameterized) { + J.Identifier clazz = (J.Identifier) n.getClazz(); + n = n.withClazz(clazz.withType(((JavaType.Parameterized) clazz.getType()).getType())); + } + updated = (T) n; + } else if (updated instanceof J.ParameterizedType) { + J.ParameterizedType p = (J.ParameterizedType) updated; + if (p.getType() != null && !(p.getType() instanceof JavaType.Parameterized)) { + if (p.getType() instanceof JavaType.Method) { + if (((JavaType.Method) p.getType()).getReturnType() instanceof JavaType.Parameterized) { + p = p.withType(((JavaType.Method) p.getType()).getReturnType()); + } else { + // Type association error, add marker for use in data table. + } + } else { + // Type association error, add marker for use in data table. + } + } + if (p.getClazz() != null && p.getClazz().getType() instanceof JavaType.Parameterized) { + p = p.withClazz(p.getClazz().withType(((JavaType.Parameterized) p.getClazz().getType()).getType())); + } + updated = (T) p; + } else if (updated instanceof J.Unary) { + if (isNotFullyQualified(((J.Unary) updated).getType())) { + J.Unary u = (J.Unary) updated; + if (u.getType() instanceof JavaType.Method) { + u = u.withType(((JavaType.Method) u.getType()).getReturnType()); + } else { + // Type association error, add marker for use in data table. + } + updated = (T) u; } - updated = (T) a; } /* Kotlin trees */ + else if (updated instanceof K.Binary) { + if (isNotFullyQualified(((K.Binary) updated).getType())) { + K.Binary b = (K.Binary) updated; + if (b.getType() instanceof JavaType.Method) { + b = b.withType(((JavaType.Method) b.getType()).getReturnType()); + } else { + // Type association error, add marker for use in data table. + } + updated = (T) b; + } + } else if (updated instanceof K.Unary) { + if (isNotFullyQualified(((K.Unary) updated).getType())) { + // Type association error, add marker for use in data table. + } + } else { + throw new UnsupportedOperationException("Unsupported mapped type: " + updated.getClass().getName()); + } + return updated; } + private boolean isNotFullyQualified(@Nullable JavaType type) { + return type != null && !(type instanceof JavaType.FullyQualified); + } + @Nullable private JavaType type(@Nullable KtElement psi) { if (psi == null) { diff --git a/src/main/kotlin/org/openrewrite/kotlin/internal/PsiElementAssociations.kt b/src/main/kotlin/org/openrewrite/kotlin/internal/PsiElementAssociations.kt index e7c826b68..0ec5bd334 100644 --- a/src/main/kotlin/org/openrewrite/kotlin/internal/PsiElementAssociations.kt +++ b/src/main/kotlin/org/openrewrite/kotlin/internal/PsiElementAssociations.kt @@ -164,7 +164,7 @@ class PsiElementAssociations(val typeMapping: KotlinTypeMapping, val file: FirFi p = p.parent } - if (p == null || p is KtPackageDirective) { + if (p == null || p is KtPackageDirective || psi is KtAnnotationUseSiteTarget) { return null } diff --git a/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java b/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java index df0142e06..6fcee1fd4 100644 --- a/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java +++ b/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java @@ -40,7 +40,6 @@ import static java.util.Collections.singletonList; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.in; import static org.openrewrite.ExecutionContext.REQUIRE_PRINT_EQUALS_INPUT; import static org.openrewrite.java.tree.JavaType.GenericTypeVariable.Variance.*; import static org.openrewrite.kotlin.Assertions.kotlin; @@ -574,7 +573,7 @@ public K.When visitWhen(K.When when, AtomicBoolean found) { "n++~kotlin.Int", "--n~kotlin.Int", "n += a~kotlin.Int", - "n = a + b~kotlin.Int{name=plus,return=kotlin.Int,parameters=[kotlin.Int]}" + "n = a + b~kotlin.Int" }, delimiter = '~') void operatorOverload(String p1, String p2) { rewriteRun( @@ -607,7 +606,7 @@ public J.Unary visitUnary(J.Unary unary, AtomicBoolean b) { @Override public J.Binary visitBinary(J.Binary binary, AtomicBoolean b) { - JavaType.Method mt = (JavaType.Method) binary.getType(); + JavaType.Class mt = (JavaType.Class) binary.getType(); if (p2.equals(mt.toString())) { found.set(true); } @@ -686,35 +685,114 @@ operator fun contains(element: Int): Boolean { val b = 1 !in listOf(2) val a = 1 !in Foo() } + """, spec -> spec.afterRecipe(cu -> new KotlinIsoVisitor() { + @Override + public K.Binary visitBinary(K.Binary binary, Integer integer) { + JavaType type = binary.getType(); + assertThat(type).isInstanceOf(JavaType.Class.class); + assertThat(((JavaType.Class) type).getFullyQualifiedName()).isEqualTo("kotlin.Boolean"); + return binary; + } + }.visit(cu, 0)) + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-kotlin/issues/464") + @Test + void fieldAccessOnSuperType() { + rewriteRun( + kotlin( + """ + open class A { + val id: Int = 0 + } + class B : A() { + fun get(): Int { + return super.id + } + } """, spec -> spec.afterRecipe(cu -> { - MethodMatcher kotlinCollection = new MethodMatcher("kotlin.collections.List contains(..)"); - AtomicBoolean kotlinCollectionFound = new AtomicBoolean(false); - MethodMatcher operatorOverload = new MethodMatcher("Foo contains(..)"); - AtomicBoolean operatorOverloadFound = new AtomicBoolean(false); + AtomicBoolean found = new AtomicBoolean(false); new KotlinIsoVisitor() { @Override - public K.Binary visitBinary(K.Binary binary, Integer integer) { - JavaType.Method methodType = (JavaType.Method) binary.getType(); - if (kotlinCollection.matches(methodType)) { - assertThat(methodType.toString()) - .isEqualTo("kotlin.collections.List{name=contains,return=kotlin.Boolean,parameters=[kotlin.Int]}"); - kotlinCollectionFound.set(true); - } - if (operatorOverload.matches(methodType)) { - assertThat(methodType.toString()) - .isEqualTo("Foo{name=contains,return=kotlin.Boolean,parameters=[kotlin.Int]}"); - operatorOverloadFound.set(true); - } - return binary; + public J.FieldAccess visitFieldAccess(J.FieldAccess fieldAccess, Integer integer) { + assertThat(fieldAccess.getType().toString()).isEqualTo("kotlin.Int"); + found.set(true); + return super.visitFieldAccess(fieldAccess, integer); + } + }.visit(cu, 0); + assertThat(found.get()).isTrue(); + }) + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-kotlin/issues/464") + @Test + void parameterizedType() { + rewriteRun( + kotlin( + """ + import java.util.ArrayList + + class Foo { + val l: ArrayList = ArrayList() + } + """, spec -> spec.afterRecipe(cu -> { + AtomicBoolean found = new AtomicBoolean(false); + new KotlinIsoVisitor() { + @Override + public J.ParameterizedType visitParameterizedType(J.ParameterizedType type, Integer integer) { + assertThat(type.getType().toString()).isEqualTo("java.util.ArrayList"); + assertThat(type.getClazz().getType().toString()).isEqualTo("java.util.ArrayList"); + found.set(true); + return super.visitParameterizedType(type, integer); } }.visit(cu, 0); - assertThat(kotlinCollectionFound.get()).isTrue(); - assertThat(operatorOverloadFound.get()).isTrue(); + assertThat(found.get()).isTrue(); }) ) ); } + @SuppressWarnings("UnusedReceiverParameter") + @Issue("https://github.com/openrewrite/rewrite-kotlin/issues/464") + @Test + void parameterizedReceiver() { + rewriteRun( + kotlin( + """ + class SomeParameterized + val SomeParameterized < Int > . receivedMember : Int + get ( ) = 42 + """, spec -> spec.afterRecipe(cu -> new KotlinIsoVisitor() { + @Override + public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, Integer integer) { + assertThat(classDecl.getType().toString()).isEqualTo("SomeParameterized"); + assertThat(classDecl.getName().getType().toString()).isEqualTo("SomeParameterized"); + return super.visitClassDeclaration(classDecl, integer); + } + + @Override + public K.Property visitProperty(K.Property property, Integer integer) { + assertThat(property.getReceiver().getType().toString()).isEqualTo("SomeParameterized"); + assertThat(((J.ParameterizedType) property.getReceiver()).getClazz().getType().toString()).isEqualTo("SomeParameterized"); + return super.visitProperty(property, integer); + } + + @Override + public J.VariableDeclarations.NamedVariable visitVariable(J.VariableDeclarations.NamedVariable variable, Integer integer) { + assertThat(variable.getVariableType().toString()).isEqualTo("openRewriteFile0Kt{name=receivedMember,type=kotlin.Int}"); + assertThat(variable.getName().getType().toString()).isEqualTo("kotlin.Int"); + assertThat(variable.getName().getFieldType().toString()).isEqualTo("openRewriteFile0Kt{name=receivedMember,type=kotlin.Int}"); + return super.visitVariable(variable, integer); + } + }.visit(cu, 0)) + ) + ); + } + @Test void destructs() { rewriteRun( @@ -736,7 +814,7 @@ public K.DestructuringDeclaration visitDestructuringDeclaration(K.DestructuringD @Override public J.NewClass visitNewClass(J.NewClass newClass, AtomicBoolean found) { if ("Triple".equals(((J.Identifier) newClass.getClazz()).getSimpleName())) { - assertThat(newClass.getClazz().getType().toString()).isEqualTo("kotlin.Triple"); + assertThat(newClass.getClazz().getType().toString()).isEqualTo("kotlin.Triple"); assertThat(newClass.getConstructorType().toString()).isEqualTo("kotlin.Triple{name=,return=kotlin.Triple,parameters=[kotlin.Int,kotlin.Int,kotlin.Int]}"); } return super.visitNewClass(newClass, found);