Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various type mapping fixes. #333

Merged
merged 1 commit into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions src/main/java/org/openrewrite/kotlin/KotlinTypeGoat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,16 @@ import java.lang.Object
// TODO: FIX ME. Files needs to declare fields and methods to assert type mapping.
@AnnotationWithRuntimeRetention
@AnnotationWithSourceRetention
abstract class KotlinTypeGoat<T, S> {
//abstract class KotlinTypeGoat<T, S> where S: PT<S>, S: C {
abstract class KotlinTypeGoat<T, S> where S: PT<S>, S: C {
val parameterizedField: PT<TypeA> = object : PT<TypeA> {}

val field: Int = 10

val gettableField: Int
get() = 10

var settableField : String = ""
set ( value ) {
var field: Int = 10
get() = field
set(value) {
field = value
}

// abstract class InheritedKotlinTypeGoat<T, U> : KotlinTypeGoat<T, U>() where U : PT<U>, U : C
abstract class InheritedKotlinTypeGoat<T, U> : KotlinTypeGoat<T, U>() where U : PT<U>, U : C

enum class EnumTypeA {
FOO, BAR(),
Expand Down Expand Up @@ -70,11 +65,11 @@ abstract class KotlinTypeGoat<T, S> {
abstract fun inner(n: C.Inner)
abstract fun enumTypeA(n: EnumTypeA)
abstract fun enumTypeB(n: EnumTypeB)
// abstract fun <U> inheritedJavaTypeGoat(n: InheritedKotlinTypeGoat<T, U>): InheritedKotlinTypeGoat<T, U> where U : PT<U>, U : C
// abstract fun <U> genericIntersection(n: U): U where U : TypeA, U : PT<U>, U : C
abstract fun <U> inheritedKotlinTypeGoat(n: InheritedKotlinTypeGoat<T, U>): InheritedKotlinTypeGoat<T, U> where U : PT<U>, U : C
abstract fun <U> genericIntersection(n: U): U where U : TypeA, U : PT<U>, U : C
abstract fun genericT(n: T): T // remove after signatures are common.

// abstract fun <U> recursiveIntersection(n: U) where U : KotlinTypeGoat.Extension<U>, U : Intersection<U>
abstract fun <U> recursiveIntersection(n: U) where U : Extension<U>, U : Intersection<U>

abstract fun javaType(n: Object)
}
Expand All @@ -85,9 +80,9 @@ interface C {

interface PT<T>

//internal interface Intersection<T> where T : KotlinTypeGoat.Extension<T>, T : Intersection<T> {
// val intersectionType: T
//}
internal interface Intersection<T> where T : KotlinTypeGoat.Extension<T>, T : Intersection<T> {
val intersectionType: T
}

@Retention(AnnotationRetention.SOURCE)
internal annotation class AnnotationWithSourceRetention
Expand Down
49 changes: 42 additions & 7 deletions src/main/kotlin/org/openrewrite/kotlin/KotlinTypeMapping.kt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaConstr
import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaMethod
import org.jetbrains.kotlin.load.kotlin.JvmPackagePartSource
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlin.types.Variance
import org.openrewrite.Incubating
import org.openrewrite.java.JavaTypeMapping
import org.openrewrite.java.internal.JavaTypeCache
Expand Down Expand Up @@ -212,6 +213,7 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
): JavaType.FullyQualified {
val firClass: FirClass
var resolvedTypeRef: FirResolvedTypeRef? = null
var typeArguments: Array<out ConeTypeProjection>? = null
if (classType is FirResolvedTypeRef) {
// The resolvedTypeRef is used to create parameterized types.
resolvedTypeRef = classType
Expand Down Expand Up @@ -244,6 +246,9 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
return JavaType.Unknown.getInstance()
}
}
} else if (classType is ConeClassLikeType) {
firClass = classType.toRegularClassSymbol(firSession)!!.fir
typeArguments = classType.typeArguments
} else {
firClass = classType as FirClass
}
Expand Down Expand Up @@ -373,8 +378,12 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
if (pt == null) {
pt = JavaType.Parameterized(null, null, null)
typeCache.put(signature, pt)
val typeParameters: MutableList<JavaType> = ArrayList(firClass.typeParameters.size)
if (resolvedTypeRef != null && resolvedTypeRef.type.typeArguments.isNotEmpty()) {
val typeParameters: MutableList<JavaType> = ArrayList(typeArguments?.size ?: firClass.typeParameters.size)
if (typeArguments != null) {
for (typeArgument: ConeTypeProjection in typeArguments) {
typeParameters.add(type(typeArgument))
}
} else if (resolvedTypeRef != null && resolvedTypeRef.type.typeArguments.isNotEmpty()) {
for (typeArgument: ConeTypeProjection in resolvedTypeRef.type.typeArguments) {
typeParameters.add(type(typeArgument))
}
Expand Down Expand Up @@ -974,7 +983,6 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
): JavaType? {
var resolvedType: JavaType? = JavaType.Unknown.getInstance()

// TODO: fix for multiple bounds.
val isGeneric = type is ConeKotlinTypeProjectionIn ||
type is ConeKotlinTypeProjectionOut ||
type is ConeStarProjection ||
Expand All @@ -999,14 +1007,35 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
typeCache.put(signature, gtv)
if (type is ConeKotlinTypeProjectionIn) {
variance = JavaType.GenericTypeVariable.Variance.CONTRAVARIANT
val classSymbol = type.type.toRegularClassSymbol(firSession)
bounds = ArrayList(1)
bounds.add(if (classSymbol != null) type(classSymbol.fir) else JavaType.Unknown.getInstance())
bounds.add(type(type.type))
} else if (type is ConeKotlinTypeProjectionOut) {
variance = JavaType.GenericTypeVariable.Variance.COVARIANT
val classSymbol = type.type.toRegularClassSymbol(firSession)
bounds = ArrayList(1)
bounds.add(if (classSymbol != null) type(classSymbol.fir) else JavaType.Unknown.getInstance())
bounds.add(type(type.type))
} else if (type is ConeTypeParameterType) {
val classifierSymbol: FirClassifierSymbol<*>? = type.lookupTag.toSymbol(firSession)
if (classifierSymbol is FirTypeParameterSymbol) {
variance = when (classifierSymbol.variance) {
Variance.INVARIANT -> {
if (classifierSymbol.resolvedBounds.none { it !is FirImplicitNullableAnyTypeRef }) JavaType.GenericTypeVariable.Variance.INVARIANT else JavaType.GenericTypeVariable.Variance.COVARIANT
}

Variance.IN_VARIANCE -> {
JavaType.GenericTypeVariable.Variance.CONTRAVARIANT
}

Variance.OUT_VARIANCE -> {
JavaType.GenericTypeVariable.Variance.COVARIANT
}
}
bounds = ArrayList(classifierSymbol.resolvedBounds.size)
for (bound: FirResolvedTypeRef in classifierSymbol.resolvedBounds) {
if (bound !is FirImplicitNullableAnyTypeRef) {
bounds.add(type(bound))
}
}
}
}
gtv.unsafeSet(name, variance, bounds)
resolvedType = gtv
Expand All @@ -1032,6 +1061,10 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
typeCache.put(signature, JavaType.Unknown.getInstance())
return JavaType.Unknown.getInstance()
}
if (signatureBuilder.signature(classSymbol.fir) != signature) {
// The signature contains generic bounded types and needs to be resolved.
return classType(coneClassLikeType, signature, ownerSymbol)
}
return type(classSymbol.fir, ownerSymbol)
}

Expand All @@ -1054,6 +1087,8 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
variance = JavaType.GenericTypeVariable.Variance.COVARIANT
} else if ("in" == typeParameter.variance.label) {
variance = JavaType.GenericTypeVariable.Variance.CONTRAVARIANT
} else {
variance = JavaType.GenericTypeVariable.Variance.COVARIANT
}
}
gtv.unsafeSet(gtv.name, variance, bounds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,15 @@ class KotlinTypeSignatureBuilder(private val firSession: FirSession) : JavaTypeS
return "Generic{$name}"
}
val s = StringBuilder("Generic{").append(name)
val boundSigs = StringJoiner(", ")
val boundSigs = StringJoiner(" & ")
for (bound in typeParameter.bounds) {
if (bound !is FirImplicitNullableAnyTypeRef) {
boundSigs.add(signature(bound))
}
}
val boundSigStr = boundSigs.toString()
if (!boundSigStr.isEmpty()) {
s.append(": ").append(boundSigStr)
s.append(" extends ").append(boundSigStr)
}
typeVariableNameStack!!.remove(name)
return s.append("}").toString()
Expand All @@ -330,7 +330,6 @@ class KotlinTypeSignatureBuilder(private val firSession: FirSession) : JavaTypeS
* Generate a ConeTypeProject signature.
*/
private fun coneTypeProjectionSignature(type: ConeTypeProjection): String {
val typeSignature: String
val s = StringBuilder()
if (type is ConeKotlinTypeProjectionIn) {
val (type1) = type
Expand Down Expand Up @@ -364,10 +363,14 @@ class KotlinTypeSignatureBuilder(private val firSession: FirSession) : JavaTypeS
s.append(">")
}
} else if (type is ConeTypeParameterType) {
s.append("Generic{")
typeSignature = convertKotlinFqToJavaFq(type.toString())
s.append(typeSignature)
s.append("}")
val symbol: FirClassifierSymbol<*>? = type.lookupTag.toSymbol(firSession)
if (symbol != null) {
s.append(signature(symbol))
} else {
s.append("Generic{")
s.append(convertKotlinFqToJavaFq(type.toString()))
s.append("}")
}
} else if (type is ConeFlexibleType) {
s.append(signature(type.lowerBound))
} else if (type is ConeDefinitelyNotNullType) {
Expand Down Expand Up @@ -494,7 +497,7 @@ class KotlinTypeSignatureBuilder(private val firSession: FirSession) : JavaTypeS
return "Generic{$name}"
}
val s = StringBuilder("Generic{").append(name)
val boundSigs = StringJoiner(", ")
val boundSigs = StringJoiner(" & ")
for (type in typeParameter.upperBounds) {
if (type.classifier != null && "java.lang.Object" != type.classifierQualifiedName) {
boundSigs.add(signature(type))
Expand Down
50 changes: 32 additions & 18 deletions src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package org.openrewrite.kotlin;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.ExpectedToFail;
Expand All @@ -41,7 +40,7 @@
public class KotlinTypeMappingTest {
private static final String goat = StringUtils.readFully(KotlinTypeMappingTest.class.getResourceAsStream("/KotlinTypeGoat.kt"));

private static final J.ClassDeclaration goatClassDeclaration;
private static final K.ClassDeclaration goatClassDeclaration;

static {
InMemoryExecutionContext ctx = new InMemoryExecutionContext();
Expand All @@ -53,8 +52,12 @@ public class KotlinTypeMappingTest {
.parse(ctx, goat)
.findFirst()
.get())
.getClasses()
.get(0)
.getStatements()
.stream()
.filter(K.ClassDeclaration.class::isInstance)
.findFirst()
.map(K.ClassDeclaration.class::cast)
.orElseThrow()
);
}

Expand All @@ -71,14 +74,24 @@ public JavaType.Method methodType(String methodName) {
}

public J.VariableDeclarations getField(String fieldName) {
return goatClassDeclaration.getBody().getStatements().stream()
.filter(org.openrewrite.java.tree.J.VariableDeclarations.class::isInstance)
return goatClassDeclaration.getClassDeclaration().getBody().getStatements().stream()
.filter(it -> it instanceof org.openrewrite.java.tree.J.VariableDeclarations || it instanceof K.Property)
.map(it -> it instanceof K.Property ? ((K.Property) it).getVariableDeclarations() : (J.VariableDeclarations) it)
.map(J.VariableDeclarations.class::cast)
.filter(mv -> mv.getVariables().stream().anyMatch(v -> v.getSimpleName().equals(fieldName)))
.findFirst()
.orElse(null);
}

public K.Property getProperty(String fieldName) {
return goatClassDeclaration.getClassDeclaration().getBody().getStatements().stream()
.filter(it -> it instanceof K.Property)
.map(K.Property.class::cast)
.filter(mv -> mv.getVariableDeclarations().getVariables().stream().anyMatch(v -> v.getSimpleName().equals(fieldName)))
.findFirst()
.orElse(null);
}

public JavaType firstMethodParameter(String methodName) {
return methodType(methodName).getParameterTypes().get(0);
}
Expand All @@ -90,13 +103,19 @@ void extendsKotlinAny() {

@Test
void fieldType() {
J.VariableDeclarations.NamedVariable variable = getField("field").getVariables().get(0);
K.Property property = getProperty("field");
J.VariableDeclarations.NamedVariable variable = property.getVariableDeclarations().getVariables().get(0);
J.Identifier id = variable.getName();
assertThat(variable.getType()).isEqualTo(id.getType());
assertThat(id.getFieldType()).isInstanceOf(JavaType.Variable.class);
assertThat(id.getFieldType().toString()).isEqualTo("org.openrewrite.kotlin.KotlinTypeGoat{name=field,type=kotlin.Int}");
assertThat(id.getType()).isInstanceOf(JavaType.Class.class);
assertThat(id.getType().toString()).isEqualTo("kotlin.Int");

assertThat(property.getGetter().getMethodType().toString().substring(property.getGetter().getMethodType().toString().indexOf("openRewriteFileKt"))).isEqualTo("openRewriteFileKt{name=accessor,return=kotlin.Int,parameters=[]}");
assertThat(property.getGetter().getMethodType()).isEqualTo(property.getGetter().getName().getType());
assertThat(property.getSetter().getMethodType().toString().substring(property.getGetter().getMethodType().toString().indexOf("openRewriteFileKt"))).isEqualTo("openRewriteFileKt{name=accessor,return=kotlin.Unit,parameters=[kotlin.Int]}");
assertThat(property.getSetter().getMethodType()).isEqualTo(property.getSetter().getName().getType());
}

@Test
Expand Down Expand Up @@ -154,7 +173,6 @@ void genericContravariant() {
isEqualTo("org.openrewrite.kotlin.C");
}

@Disabled("Requires parsing intersection types")
@Test
void genericMultipleBounds() {
List<JavaType> typeParameters = goatType.getTypeParameters();
Expand All @@ -174,7 +192,6 @@ void genericUnbounded() {
assertThat(generic.getBounds()).isEmpty();
}

@Disabled
@Test
void genericRecursive() {
JavaType.Parameterized param = (JavaType.Parameterized) firstMethodParameter("genericRecursive");
Expand All @@ -196,23 +213,21 @@ void innerClass() {
assertThat(clazz.getFullyQualifiedName()).isEqualTo("org.openrewrite.kotlin.C$Inner");
}

@Disabled("Requires parsing intersection types")
@Test
void inheritedJavaTypeGoat() {
JavaType.Parameterized clazz = (JavaType.Parameterized) firstMethodParameter("InheritedKotlinTypeGoat");
JavaType.Parameterized clazz = (JavaType.Parameterized) firstMethodParameter("inheritedKotlinTypeGoat");
assertThat(clazz.getTypeParameters().get(0).toString()).isEqualTo("Generic{T}");
assertThat(clazz.getTypeParameters().get(1).toString()).isEqualTo("Generic{U extends org.openrewrite.kotlin.PT<Generic{U}> & org.openrewrite.kotlin.C}");
assertThat(clazz.toString()).isEqualTo("org.openrewrite.kotlin.KotlinTypeGoat$InheritedKotlinTypeGoat<Generic{T}, Generic{U extends org.openrewrite.kotlin.PT<Generic{U}> & org.openrewrite.kotlin.C}>");
}

@Disabled("Requires parsing intersection types")
@Test
void genericIntersectionType() {
JavaType.GenericTypeVariable clazz = (JavaType.GenericTypeVariable) firstMethodParameter("genericIntersection");
assertThat(clazz.getBounds().get(0).toString()).isEqualTo("org.openrewrite.java.JavaTypeGoat$TypeA");
assertThat(clazz.getBounds().get(1).toString()).isEqualTo("org.openrewrite.java.PT<Generic{U extends org.openrewrite.java.JavaTypeGoat$TypeA & org.openrewrite.java.C}>");
assertThat(clazz.getBounds().get(2).toString()).isEqualTo("org.openrewrite.java.C");
assertThat(clazz.toString()).isEqualTo("Generic{U extends org.openrewrite.java.JavaTypeGoat$TypeA & org.openrewrite.java.PT<Generic{U}> & org.openrewrite.java.C}");
assertThat(clazz.getBounds().get(0).toString()).isEqualTo("org.openrewrite.kotlin.KotlinTypeGoat$TypeA");
assertThat(clazz.getBounds().get(1).toString()).isEqualTo("org.openrewrite.kotlin.PT<Generic{U extends org.openrewrite.kotlin.KotlinTypeGoat$TypeA & org.openrewrite.kotlin.C}>");
assertThat(clazz.getBounds().get(2).toString()).isEqualTo("org.openrewrite.kotlin.C");
assertThat(clazz.toString()).isEqualTo("Generic{U extends org.openrewrite.kotlin.KotlinTypeGoat$TypeA & org.openrewrite.kotlin.PT<Generic{U}> & org.openrewrite.kotlin.C}");
}

@Test
Expand Down Expand Up @@ -254,11 +269,10 @@ void ignoreSourceRetentionAnnotations() {
assertThat(clazzMethod.getAnnotations().get(0).getClassName()).isEqualTo("AnnotationWithRuntimeRetention");
}

@Disabled("Requires parsing intersection types")
@Test
void recursiveIntersection() {
JavaType.GenericTypeVariable clazz = TypeUtils.asGeneric(firstMethodParameter("recursiveIntersection"));
assertThat(clazz.toString()).isEqualTo("Generic{U extends org.openrewrite.java.JavaTypeGoat$Extension<Generic{U}> & org.openrewrite.java.Intersection<Generic{U}>}");
assertThat(clazz.toString()).isEqualTo("Generic{U extends org.openrewrite.kotlin.KotlinTypeGoat$Extension<Generic{U}> & org.openrewrite.kotlin.Intersection<Generic{U}>}");
}

@Test
Expand Down
Loading