Skip to content

Commit

Permalink
Uncommented previously un-parsable code in KotlinTypeGoat. (#333)
Browse files Browse the repository at this point in the history
Fixed type mapping for bounded generic types.
Added support for generic types with multiple bounds in KotlinTypeMapping.
Added type mapping test for K.Property.
Updated signature builder to match java type signatures.
  • Loading branch information
traceyyoshima authored Oct 7, 2023
1 parent 9c61e46 commit cd91d57
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 86 deletions.
27 changes: 11 additions & 16 deletions src/main/java/org/openrewrite/kotlin/KotlinTypeGoat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,16 @@ import java.lang.Object
// TODO: FIX ME. Files needs to declare fields and methods to assert type mapping.
@AnnotationWithRuntimeRetention
@AnnotationWithSourceRetention
abstract class KotlinTypeGoat<T, S> {
//abstract class KotlinTypeGoat<T, S> where S: PT<S>, S: C {
abstract class KotlinTypeGoat<T, S> where S: PT<S>, S: C {
val parameterizedField: PT<TypeA> = object : PT<TypeA> {}

val field: Int = 10

val gettableField: Int
get() = 10

var settableField : String = ""
set ( value ) {
var field: Int = 10
get() = field
set(value) {
field = value
}

// abstract class InheritedKotlinTypeGoat<T, U> : KotlinTypeGoat<T, U>() where U : PT<U>, U : C
abstract class InheritedKotlinTypeGoat<T, U> : KotlinTypeGoat<T, U>() where U : PT<U>, U : C

enum class EnumTypeA {
FOO, BAR(),
Expand Down Expand Up @@ -70,11 +65,11 @@ abstract class KotlinTypeGoat<T, S> {
abstract fun inner(n: C.Inner)
abstract fun enumTypeA(n: EnumTypeA)
abstract fun enumTypeB(n: EnumTypeB)
// abstract fun <U> inheritedJavaTypeGoat(n: InheritedKotlinTypeGoat<T, U>): InheritedKotlinTypeGoat<T, U> where U : PT<U>, U : C
// abstract fun <U> genericIntersection(n: U): U where U : TypeA, U : PT<U>, U : C
abstract fun <U> inheritedKotlinTypeGoat(n: InheritedKotlinTypeGoat<T, U>): InheritedKotlinTypeGoat<T, U> where U : PT<U>, U : C
abstract fun <U> genericIntersection(n: U): U where U : TypeA, U : PT<U>, U : C
abstract fun genericT(n: T): T // remove after signatures are common.

// abstract fun <U> recursiveIntersection(n: U) where U : KotlinTypeGoat.Extension<U>, U : Intersection<U>
abstract fun <U> recursiveIntersection(n: U) where U : Extension<U>, U : Intersection<U>

abstract fun javaType(n: Object)
}
Expand All @@ -85,9 +80,9 @@ interface C {

interface PT<T>

//internal interface Intersection<T> where T : KotlinTypeGoat.Extension<T>, T : Intersection<T> {
// val intersectionType: T
//}
internal interface Intersection<T> where T : KotlinTypeGoat.Extension<T>, T : Intersection<T> {
val intersectionType: T
}

@Retention(AnnotationRetention.SOURCE)
internal annotation class AnnotationWithSourceRetention
Expand Down
49 changes: 42 additions & 7 deletions src/main/kotlin/org/openrewrite/kotlin/KotlinTypeMapping.kt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaConstr
import org.jetbrains.kotlin.load.java.structure.impl.classFiles.BinaryJavaMethod
import org.jetbrains.kotlin.load.kotlin.JvmPackagePartSource
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlin.types.Variance
import org.openrewrite.Incubating
import org.openrewrite.java.JavaTypeMapping
import org.openrewrite.java.internal.JavaTypeCache
Expand Down Expand Up @@ -212,6 +213,7 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
): JavaType.FullyQualified {
val firClass: FirClass
var resolvedTypeRef: FirResolvedTypeRef? = null
var typeArguments: Array<out ConeTypeProjection>? = null
if (classType is FirResolvedTypeRef) {
// The resolvedTypeRef is used to create parameterized types.
resolvedTypeRef = classType
Expand Down Expand Up @@ -244,6 +246,9 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
return JavaType.Unknown.getInstance()
}
}
} else if (classType is ConeClassLikeType) {
firClass = classType.toRegularClassSymbol(firSession)!!.fir
typeArguments = classType.typeArguments
} else {
firClass = classType as FirClass
}
Expand Down Expand Up @@ -373,8 +378,12 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
if (pt == null) {
pt = JavaType.Parameterized(null, null, null)
typeCache.put(signature, pt)
val typeParameters: MutableList<JavaType> = ArrayList(firClass.typeParameters.size)
if (resolvedTypeRef != null && resolvedTypeRef.type.typeArguments.isNotEmpty()) {
val typeParameters: MutableList<JavaType> = ArrayList(typeArguments?.size ?: firClass.typeParameters.size)
if (typeArguments != null) {
for (typeArgument: ConeTypeProjection in typeArguments) {
typeParameters.add(type(typeArgument))
}
} else if (resolvedTypeRef != null && resolvedTypeRef.type.typeArguments.isNotEmpty()) {
for (typeArgument: ConeTypeProjection in resolvedTypeRef.type.typeArguments) {
typeParameters.add(type(typeArgument))
}
Expand Down Expand Up @@ -974,7 +983,6 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
): JavaType? {
var resolvedType: JavaType? = JavaType.Unknown.getInstance()

// TODO: fix for multiple bounds.
val isGeneric = type is ConeKotlinTypeProjectionIn ||
type is ConeKotlinTypeProjectionOut ||
type is ConeStarProjection ||
Expand All @@ -999,14 +1007,35 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
typeCache.put(signature, gtv)
if (type is ConeKotlinTypeProjectionIn) {
variance = JavaType.GenericTypeVariable.Variance.CONTRAVARIANT
val classSymbol = type.type.toRegularClassSymbol(firSession)
bounds = ArrayList(1)
bounds.add(if (classSymbol != null) type(classSymbol.fir) else JavaType.Unknown.getInstance())
bounds.add(type(type.type))
} else if (type is ConeKotlinTypeProjectionOut) {
variance = JavaType.GenericTypeVariable.Variance.COVARIANT
val classSymbol = type.type.toRegularClassSymbol(firSession)
bounds = ArrayList(1)
bounds.add(if (classSymbol != null) type(classSymbol.fir) else JavaType.Unknown.getInstance())
bounds.add(type(type.type))
} else if (type is ConeTypeParameterType) {
val classifierSymbol: FirClassifierSymbol<*>? = type.lookupTag.toSymbol(firSession)
if (classifierSymbol is FirTypeParameterSymbol) {
variance = when (classifierSymbol.variance) {
Variance.INVARIANT -> {
if (classifierSymbol.resolvedBounds.none { it !is FirImplicitNullableAnyTypeRef }) JavaType.GenericTypeVariable.Variance.INVARIANT else JavaType.GenericTypeVariable.Variance.COVARIANT
}

Variance.IN_VARIANCE -> {
JavaType.GenericTypeVariable.Variance.CONTRAVARIANT
}

Variance.OUT_VARIANCE -> {
JavaType.GenericTypeVariable.Variance.COVARIANT
}
}
bounds = ArrayList(classifierSymbol.resolvedBounds.size)
for (bound: FirResolvedTypeRef in classifierSymbol.resolvedBounds) {
if (bound !is FirImplicitNullableAnyTypeRef) {
bounds.add(type(bound))
}
}
}
}
gtv.unsafeSet(name, variance, bounds)
resolvedType = gtv
Expand All @@ -1032,6 +1061,10 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
typeCache.put(signature, JavaType.Unknown.getInstance())
return JavaType.Unknown.getInstance()
}
if (signatureBuilder.signature(classSymbol.fir) != signature) {
// The signature contains generic bounded types and needs to be resolved.
return classType(coneClassLikeType, signature, ownerSymbol)
}
return type(classSymbol.fir, ownerSymbol)
}

Expand All @@ -1054,6 +1087,8 @@ class KotlinTypeMapping(typeCache: JavaTypeCache, firSession: FirSession) : Java
variance = JavaType.GenericTypeVariable.Variance.COVARIANT
} else if ("in" == typeParameter.variance.label) {
variance = JavaType.GenericTypeVariable.Variance.CONTRAVARIANT
} else {
variance = JavaType.GenericTypeVariable.Variance.COVARIANT
}
}
gtv.unsafeSet(gtv.name, variance, bounds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,15 @@ class KotlinTypeSignatureBuilder(private val firSession: FirSession) : JavaTypeS
return "Generic{$name}"
}
val s = StringBuilder("Generic{").append(name)
val boundSigs = StringJoiner(", ")
val boundSigs = StringJoiner(" & ")
for (bound in typeParameter.bounds) {
if (bound !is FirImplicitNullableAnyTypeRef) {
boundSigs.add(signature(bound))
}
}
val boundSigStr = boundSigs.toString()
if (!boundSigStr.isEmpty()) {
s.append(": ").append(boundSigStr)
s.append(" extends ").append(boundSigStr)
}
typeVariableNameStack!!.remove(name)
return s.append("}").toString()
Expand All @@ -330,7 +330,6 @@ class KotlinTypeSignatureBuilder(private val firSession: FirSession) : JavaTypeS
* Generate a ConeTypeProject signature.
*/
private fun coneTypeProjectionSignature(type: ConeTypeProjection): String {
val typeSignature: String
val s = StringBuilder()
if (type is ConeKotlinTypeProjectionIn) {
val (type1) = type
Expand Down Expand Up @@ -364,10 +363,14 @@ class KotlinTypeSignatureBuilder(private val firSession: FirSession) : JavaTypeS
s.append(">")
}
} else if (type is ConeTypeParameterType) {
s.append("Generic{")
typeSignature = convertKotlinFqToJavaFq(type.toString())
s.append(typeSignature)
s.append("}")
val symbol: FirClassifierSymbol<*>? = type.lookupTag.toSymbol(firSession)
if (symbol != null) {
s.append(signature(symbol))
} else {
s.append("Generic{")
s.append(convertKotlinFqToJavaFq(type.toString()))
s.append("}")
}
} else if (type is ConeFlexibleType) {
s.append(signature(type.lowerBound))
} else if (type is ConeDefinitelyNotNullType) {
Expand Down Expand Up @@ -494,7 +497,7 @@ class KotlinTypeSignatureBuilder(private val firSession: FirSession) : JavaTypeS
return "Generic{$name}"
}
val s = StringBuilder("Generic{").append(name)
val boundSigs = StringJoiner(", ")
val boundSigs = StringJoiner(" & ")
for (type in typeParameter.upperBounds) {
if (type.classifier != null && "java.lang.Object" != type.classifierQualifiedName) {
boundSigs.add(signature(type))
Expand Down
50 changes: 32 additions & 18 deletions src/test/java/org/openrewrite/kotlin/KotlinTypeMappingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package org.openrewrite.kotlin;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.ExpectedToFail;
Expand All @@ -41,7 +40,7 @@
public class KotlinTypeMappingTest {
private static final String goat = StringUtils.readFully(KotlinTypeMappingTest.class.getResourceAsStream("/KotlinTypeGoat.kt"));

private static final J.ClassDeclaration goatClassDeclaration;
private static final K.ClassDeclaration goatClassDeclaration;

static {
InMemoryExecutionContext ctx = new InMemoryExecutionContext();
Expand All @@ -53,8 +52,12 @@ public class KotlinTypeMappingTest {
.parse(ctx, goat)
.findFirst()
.get())
.getClasses()
.get(0)
.getStatements()
.stream()
.filter(K.ClassDeclaration.class::isInstance)
.findFirst()
.map(K.ClassDeclaration.class::cast)
.orElseThrow()
);
}

Expand All @@ -71,14 +74,24 @@ public JavaType.Method methodType(String methodName) {
}

public J.VariableDeclarations getField(String fieldName) {
return goatClassDeclaration.getBody().getStatements().stream()
.filter(org.openrewrite.java.tree.J.VariableDeclarations.class::isInstance)
return goatClassDeclaration.getClassDeclaration().getBody().getStatements().stream()
.filter(it -> it instanceof org.openrewrite.java.tree.J.VariableDeclarations || it instanceof K.Property)
.map(it -> it instanceof K.Property ? ((K.Property) it).getVariableDeclarations() : (J.VariableDeclarations) it)
.map(J.VariableDeclarations.class::cast)
.filter(mv -> mv.getVariables().stream().anyMatch(v -> v.getSimpleName().equals(fieldName)))
.findFirst()
.orElse(null);
}

public K.Property getProperty(String fieldName) {
return goatClassDeclaration.getClassDeclaration().getBody().getStatements().stream()
.filter(it -> it instanceof K.Property)
.map(K.Property.class::cast)
.filter(mv -> mv.getVariableDeclarations().getVariables().stream().anyMatch(v -> v.getSimpleName().equals(fieldName)))
.findFirst()
.orElse(null);
}

public JavaType firstMethodParameter(String methodName) {
return methodType(methodName).getParameterTypes().get(0);
}
Expand All @@ -90,13 +103,19 @@ void extendsKotlinAny() {

@Test
void fieldType() {
J.VariableDeclarations.NamedVariable variable = getField("field").getVariables().get(0);
K.Property property = getProperty("field");
J.VariableDeclarations.NamedVariable variable = property.getVariableDeclarations().getVariables().get(0);
J.Identifier id = variable.getName();
assertThat(variable.getType()).isEqualTo(id.getType());
assertThat(id.getFieldType()).isInstanceOf(JavaType.Variable.class);
assertThat(id.getFieldType().toString()).isEqualTo("org.openrewrite.kotlin.KotlinTypeGoat{name=field,type=kotlin.Int}");
assertThat(id.getType()).isInstanceOf(JavaType.Class.class);
assertThat(id.getType().toString()).isEqualTo("kotlin.Int");

assertThat(property.getGetter().getMethodType().toString().substring(property.getGetter().getMethodType().toString().indexOf("openRewriteFileKt"))).isEqualTo("openRewriteFileKt{name=accessor,return=kotlin.Int,parameters=[]}");
assertThat(property.getGetter().getMethodType()).isEqualTo(property.getGetter().getName().getType());
assertThat(property.getSetter().getMethodType().toString().substring(property.getGetter().getMethodType().toString().indexOf("openRewriteFileKt"))).isEqualTo("openRewriteFileKt{name=accessor,return=kotlin.Unit,parameters=[kotlin.Int]}");
assertThat(property.getSetter().getMethodType()).isEqualTo(property.getSetter().getName().getType());
}

@Test
Expand Down Expand Up @@ -154,7 +173,6 @@ void genericContravariant() {
isEqualTo("org.openrewrite.kotlin.C");
}

@Disabled("Requires parsing intersection types")
@Test
void genericMultipleBounds() {
List<JavaType> typeParameters = goatType.getTypeParameters();
Expand All @@ -174,7 +192,6 @@ void genericUnbounded() {
assertThat(generic.getBounds()).isEmpty();
}

@Disabled
@Test
void genericRecursive() {
JavaType.Parameterized param = (JavaType.Parameterized) firstMethodParameter("genericRecursive");
Expand All @@ -196,23 +213,21 @@ void innerClass() {
assertThat(clazz.getFullyQualifiedName()).isEqualTo("org.openrewrite.kotlin.C$Inner");
}

@Disabled("Requires parsing intersection types")
@Test
void inheritedJavaTypeGoat() {
JavaType.Parameterized clazz = (JavaType.Parameterized) firstMethodParameter("InheritedKotlinTypeGoat");
JavaType.Parameterized clazz = (JavaType.Parameterized) firstMethodParameter("inheritedKotlinTypeGoat");
assertThat(clazz.getTypeParameters().get(0).toString()).isEqualTo("Generic{T}");
assertThat(clazz.getTypeParameters().get(1).toString()).isEqualTo("Generic{U extends org.openrewrite.kotlin.PT<Generic{U}> & org.openrewrite.kotlin.C}");
assertThat(clazz.toString()).isEqualTo("org.openrewrite.kotlin.KotlinTypeGoat$InheritedKotlinTypeGoat<Generic{T}, Generic{U extends org.openrewrite.kotlin.PT<Generic{U}> & org.openrewrite.kotlin.C}>");
}

@Disabled("Requires parsing intersection types")
@Test
void genericIntersectionType() {
JavaType.GenericTypeVariable clazz = (JavaType.GenericTypeVariable) firstMethodParameter("genericIntersection");
assertThat(clazz.getBounds().get(0).toString()).isEqualTo("org.openrewrite.java.JavaTypeGoat$TypeA");
assertThat(clazz.getBounds().get(1).toString()).isEqualTo("org.openrewrite.java.PT<Generic{U extends org.openrewrite.java.JavaTypeGoat$TypeA & org.openrewrite.java.C}>");
assertThat(clazz.getBounds().get(2).toString()).isEqualTo("org.openrewrite.java.C");
assertThat(clazz.toString()).isEqualTo("Generic{U extends org.openrewrite.java.JavaTypeGoat$TypeA & org.openrewrite.java.PT<Generic{U}> & org.openrewrite.java.C}");
assertThat(clazz.getBounds().get(0).toString()).isEqualTo("org.openrewrite.kotlin.KotlinTypeGoat$TypeA");
assertThat(clazz.getBounds().get(1).toString()).isEqualTo("org.openrewrite.kotlin.PT<Generic{U extends org.openrewrite.kotlin.KotlinTypeGoat$TypeA & org.openrewrite.kotlin.C}>");
assertThat(clazz.getBounds().get(2).toString()).isEqualTo("org.openrewrite.kotlin.C");
assertThat(clazz.toString()).isEqualTo("Generic{U extends org.openrewrite.kotlin.KotlinTypeGoat$TypeA & org.openrewrite.kotlin.PT<Generic{U}> & org.openrewrite.kotlin.C}");
}

@Test
Expand Down Expand Up @@ -254,11 +269,10 @@ void ignoreSourceRetentionAnnotations() {
assertThat(clazzMethod.getAnnotations().get(0).getClassName()).isEqualTo("AnnotationWithRuntimeRetention");
}

@Disabled("Requires parsing intersection types")
@Test
void recursiveIntersection() {
JavaType.GenericTypeVariable clazz = TypeUtils.asGeneric(firstMethodParameter("recursiveIntersection"));
assertThat(clazz.toString()).isEqualTo("Generic{U extends org.openrewrite.java.JavaTypeGoat$Extension<Generic{U}> & org.openrewrite.java.Intersection<Generic{U}>}");
assertThat(clazz.toString()).isEqualTo("Generic{U extends org.openrewrite.kotlin.KotlinTypeGoat$Extension<Generic{U}> & org.openrewrite.kotlin.Intersection<Generic{U}>}");
}

@Test
Expand Down
Loading

0 comments on commit cd91d57

Please sign in to comment.