diff --git a/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java b/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java index 3ac00568b..1e2845c0f 100644 --- a/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java +++ b/src/main/java/org/openrewrite/kotlin/internal/KotlinTreeParserVisitor.java @@ -32,6 +32,7 @@ import org.jetbrains.kotlin.fir.expressions.FirFunctionCall; import org.jetbrains.kotlin.fir.expressions.FirStringConcatenationCall; import org.jetbrains.kotlin.fir.references.FirResolvedCallableReference; +import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference; import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol; import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol; import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol; @@ -3313,6 +3314,9 @@ private JavaType.Method methodInvocationType(PsiElement psi) { if (firElement instanceof FirFunctionCall) { return psiElementAssociations.getTypeMapping().methodInvocationType((FirFunctionCall) firElement, psiElementAssociations.getFile().getSymbol()); } + if (firElement instanceof FirResolvedNamedReference) { + throw new UnsupportedOperationException("FIXME"); + } return null; } diff --git a/src/main/kotlin/org/openrewrite/kotlin/internal/PsiElementAssociations.kt b/src/main/kotlin/org/openrewrite/kotlin/internal/PsiElementAssociations.kt index c2b45f643..47c342da4 100644 --- a/src/main/kotlin/org/openrewrite/kotlin/internal/PsiElementAssociations.kt +++ b/src/main/kotlin/org/openrewrite/kotlin/internal/PsiElementAssociations.kt @@ -21,10 +21,7 @@ import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.fir.FirElement import org.jetbrains.kotlin.fir.declarations.FirDeclaration import org.jetbrains.kotlin.fir.declarations.FirFile -import org.jetbrains.kotlin.fir.expressions.FirExpression -import org.jetbrains.kotlin.fir.expressions.FirFunctionCall -import org.jetbrains.kotlin.fir.expressions.FirResolvedQualifier -import org.jetbrains.kotlin.fir.expressions.FirReturnExpression +import org.jetbrains.kotlin.fir.expressions.* import org.jetbrains.kotlin.fir.expressions.impl.FirElseIfTrueCondition import org.jetbrains.kotlin.fir.expressions.impl.FirSingleExpressionBlock import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference @@ -36,8 +33,11 @@ import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol import org.jetbrains.kotlin.fir.types.FirResolvedTypeRef import org.jetbrains.kotlin.fir.visitors.FirDefaultVisitor import org.jetbrains.kotlin.psi +import org.jetbrains.kotlin.psi.KtArrayAccessExpression import org.jetbrains.kotlin.psi.KtDeclaration import org.jetbrains.kotlin.psi.KtExpression +import org.jetbrains.kotlin.psi.KtPostfixExpression +import org.jetbrains.kotlin.psi.KtPrefixExpression import org.openrewrite.java.tree.JavaType import org.openrewrite.kotlin.KotlinTypeMapping @@ -157,10 +157,6 @@ class PsiElementAssociations(val typeMapping: KotlinTypeMapping, val file: FirFi var p = psi while (p != null && !elementMap.containsKey(p)) { p = p.parent - // don't skip KtDotQualifiedExpression for field access -// if (p is KtDotQualifiedExpression) { -// return null -// } } if (p == null) { @@ -171,8 +167,16 @@ class PsiElementAssociations(val typeMapping: KotlinTypeMapping, val file: FirFi val directFirInfos = allFirInfos.filter { filter.invoke(it.fir) } return if (directFirInfos.isNotEmpty()) directFirInfos[0].fir - else if (allFirInfos.isNotEmpty()) - allFirInfos[0].fir + else if (allFirInfos.isNotEmpty()) { + return when (psi) { + is KtPrefixExpression -> allFirInfos.first { it.fir is FirVariableAssignment }.fir + is KtPostfixExpression -> allFirInfos.first { it.fir is FirResolvedTypeRef }.fir + is KtArrayAccessExpression -> allFirInfos.first { it.fir is FirResolvedNamedReference && it.fir.name.asString() == "get" }.fir + else -> { + allFirInfos[0].fir + } + } + } else null } diff --git a/src/test/java/org/openrewrite/kotlin/tree/AnnotationTest.java b/src/test/java/org/openrewrite/kotlin/tree/AnnotationTest.java index 14e87b988..1579e5590 100644 --- a/src/test/java/org/openrewrite/kotlin/tree/AnnotationTest.java +++ b/src/test/java/org/openrewrite/kotlin/tree/AnnotationTest.java @@ -383,10 +383,8 @@ void lastAnnotations() { kotlin( """ annotation class A - annotation class B - @A - internal @B class Foo + internal @A class Foo """, spec -> spec.afterRecipe(cu -> { Optional s = cu.getStatements().stream() diff --git a/src/test/java/org/openrewrite/kotlin/tree/ClassDeclarationTest.java b/src/test/java/org/openrewrite/kotlin/tree/ClassDeclarationTest.java index 8cef5d29a..b1468a61f 100644 --- a/src/test/java/org/openrewrite/kotlin/tree/ClassDeclarationTest.java +++ b/src/test/java/org/openrewrite/kotlin/tree/ClassDeclarationTest.java @@ -136,7 +136,11 @@ interface C { class Inner { } } - """ + """, spec -> spec.afterRecipe(cu -> { + assertThat(cu.getStatements().stream() + .anyMatch(it -> it instanceof J.ClassDeclaration && + ((J.ClassDeclaration) it).getKind() == J.ClassDeclaration.Kind.Type.Interface)).isEqualTo(true); + }) ) ); } @@ -151,14 +155,24 @@ void modifierOrdering() { @Test void annotationClass() { rewriteRun( - kotlin("annotation class A") - ); + kotlin("annotation class A", + spec -> spec.afterRecipe(cu -> { + assertThat(cu.getStatements().stream() + .anyMatch(it -> it instanceof J.ClassDeclaration && + ((J.ClassDeclaration) it).getKind() == J.ClassDeclaration.Kind.Type.Annotation)).isEqualTo(true); + })) + ); } @Test void enumClass() { rewriteRun( - kotlin("enum class A") + kotlin("enum class A", + spec -> spec.afterRecipe(cu -> { + assertThat(cu.getStatements().stream() + .anyMatch(it -> it instanceof J.ClassDeclaration && + ((J.ClassDeclaration) it).getKind() == J.ClassDeclaration.Kind.Type.Enum)).isEqualTo(true); + })) ); }