diff --git a/src/main/kotlin/org/openrewrite/kotlin/KotlinTypeMapping.kt b/src/main/kotlin/org/openrewrite/kotlin/KotlinTypeMapping.kt index d9831e1ad..f77ebc2e6 100644 --- a/src/main/kotlin/org/openrewrite/kotlin/KotlinTypeMapping.kt +++ b/src/main/kotlin/org/openrewrite/kotlin/KotlinTypeMapping.kt @@ -28,12 +28,15 @@ import org.jetbrains.kotlin.fir.declarations.* import org.jetbrains.kotlin.fir.declarations.impl.FirOuterClassTypeParameterRef import org.jetbrains.kotlin.fir.declarations.utils.isLocal import org.jetbrains.kotlin.fir.declarations.utils.modality +import org.jetbrains.kotlin.fir.declarations.utils.nameOrSpecialName import org.jetbrains.kotlin.fir.declarations.utils.visibility import org.jetbrains.kotlin.fir.expressions.* import org.jetbrains.kotlin.fir.java.declarations.FirJavaField import org.jetbrains.kotlin.fir.references.FirErrorNamedReference import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference import org.jetbrains.kotlin.fir.references.toResolvedBaseSymbol +import org.jetbrains.kotlin.fir.resolve.dfa.DfaInternals +import org.jetbrains.kotlin.fir.resolve.dfa.symbol import org.jetbrains.kotlin.fir.resolve.providers.toSymbol import org.jetbrains.kotlin.fir.resolve.toFirRegularClass import org.jetbrains.kotlin.fir.resolve.toSymbol @@ -611,7 +614,7 @@ class KotlinTypeMapping( return methodInvocationType(fir, signature) } - @OptIn(SymbolInternals::class) + @OptIn(SymbolInternals::class, DfaInternals::class) fun methodInvocationType(function: FirFunctionCall, signature: String): JavaType.Method? { val sym = function.calleeReference.toResolvedBaseSymbol() ?: return null val receiver = if (sym is FirFunctionSymbol<*>) sym.receiverParameter else null @@ -703,9 +706,12 @@ class KotlinTypeMapping( if (function.toResolvedCallableSymbol()?.receiverParameter != null) { paramTypes!!.add(type(function.toResolvedCallableSymbol()?.receiverParameter!!.typeRef)) } - for (param: FirExpression? in function.arguments) { - if (param != null) { - paramTypes!!.add(type(param.typeRef)) + for ((index, p) in (function.toResolvedCallableSymbol()?.fir as FirFunction).valueParameters.withIndex()) { + val t = type(p.returnTypeRef) + if (t !is GenericTypeVariable) { + paramTypes!!.add(t) + } else { + paramTypes!!.add(type((function.arguments[index]).typeRef)) } } method.unsafeSet( diff --git a/src/main/kotlin/org/openrewrite/kotlin/KotlinTypeSignatureBuilder.kt b/src/main/kotlin/org/openrewrite/kotlin/KotlinTypeSignatureBuilder.kt index f4a8c03c6..47f9322d2 100644 --- a/src/main/kotlin/org/openrewrite/kotlin/KotlinTypeSignatureBuilder.kt +++ b/src/main/kotlin/org/openrewrite/kotlin/KotlinTypeSignatureBuilder.kt @@ -20,6 +20,7 @@ import org.jetbrains.kotlin.fir.* import org.jetbrains.kotlin.fir.declarations.* import org.jetbrains.kotlin.fir.declarations.impl.FirOuterClassTypeParameterRef import org.jetbrains.kotlin.fir.declarations.utils.classId +import org.jetbrains.kotlin.fir.declarations.utils.nameOrSpecialName import org.jetbrains.kotlin.fir.expressions.* import org.jetbrains.kotlin.fir.references.FirErrorNamedReference import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference @@ -366,13 +367,19 @@ class KotlinTypeSignatureBuilder(private val firSession: FirSession, private val return sig.toString() } + @OptIn(SymbolInternals::class) private fun methodCallArgumentSignature(function: FirFunctionCall): String { val genericArgumentTypes = StringJoiner(",", "[", "]") if (function.toResolvedCallableSymbol()?.receiverParameter != null) { genericArgumentTypes.add(signature(function.toResolvedCallableSymbol()?.receiverParameter!!.typeRef)) } - for (p in function.arguments) { - genericArgumentTypes.add(signature(p.typeRef, function)) + for ((index, p) in (function.toResolvedCallableSymbol()?.fir as FirFunction).valueParameters.withIndex()) { + val sig = signature(p.returnTypeRef, function) + if (sig.startsWith("Generic{")) { + genericArgumentTypes.add(signature((function.arguments[index]).typeRef, function)) + } else { + genericArgumentTypes.add(sig) + } } return genericArgumentTypes.toString() } diff --git a/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java b/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java index 57bd94755..ec11c09e4 100644 --- a/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java +++ b/src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java @@ -382,7 +382,7 @@ void coneTypeProjection() { @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, AtomicBoolean found) { if (methodMatcher.matches(method)) { - assertThat(method.getMethodType().toString()).isEqualTo("kotlin.collections.MutableList{name=addAll,return=kotlin.Boolean,parameters=[kotlin.collections.List]}"); + assertThat(method.getMethodType().toString()).isEqualTo("kotlin.collections.MutableList{name=addAll,return=kotlin.Boolean,parameters=[kotlin.collections.Collection]}"); found.set(true); } return super.visitMethodInvocation(method, found); @@ -436,7 +436,7 @@ void genericIntersectionType() { public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, AtomicBoolean found) { if (methodMatcher.matches(method)) { assertThat(method.getMethodType().toString()) - .isEqualTo("kotlin.collections.CollectionsKt{name=listOf,return=kotlin.collections.List & java.io.Serializable}>>,parameters=[kotlin.Array & java.io.Serializable}>}>]}"); + .isEqualTo("kotlin.collections.CollectionsKt{name=listOf,return=kotlin.collections.List & java.io.Serializable}>>,parameters=[kotlin.Array]}"); found.set(true); } return super.visitMethodInvocation(method, found); @@ -462,7 +462,7 @@ void implicitInvoke() { @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, AtomicBoolean atomicBoolean) { if (matcher.matches(method)) { - assertThat(method.getMethodType().toString()).isEqualTo("kotlin.Function1, kotlin.Unit>{name=invoke,return=kotlin.Unit,parameters=[kotlin.collections.List]}"); + assertThat(method.getMethodType().toString()).isEqualTo("kotlin.Function1, kotlin.Unit>{name=invoke,return=kotlin.Unit,parameters=[kotlin.collections.Collection]}"); found.set(true); } return super.visitMethodInvocation(method, atomicBoolean); @@ -502,6 +502,33 @@ public J.FieldAccess visitFieldAccess(J.FieldAccess fieldAccess, AtomicBoolean f ); } + @Test + void println() { + //noinspection RemoveRedundantQualifierName + rewriteRun( + kotlin( + """ + fun method() { + println("foo") + } + """, spec -> spec.afterRecipe(cu -> { + AtomicBoolean found = new AtomicBoolean(false); + new KotlinIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, AtomicBoolean atomicBoolean) { + if ("println".equals(method.getSimpleName())) { + assertThat(method.getMethodType().toString()).isEqualTo("kotlin.io.ConsoleKt{name=println,return=kotlin.Unit,parameters=[kotlin.Any]}"); + found.set(true); + } + return super.visitMethodInvocation(method, atomicBoolean); + } + }.visit(cu, found); + assertThat(found.get()).isTrue(); + }) + ) + ); + } + @SuppressWarnings({"KotlinConstantConditions", "UnusedUnaryOperator", "RedundantExplicitType"}) @Test void whenExpression() {