diff --git a/build.gradle.kts b/build.gradle.kts index 3169aa698..82800b835 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -39,6 +39,8 @@ dependencies { testImplementation("org.openrewrite:rewrite-test") testRuntimeOnly("org.junit.jupiter:junit-jupiter-engine:latest.release") testRuntimeOnly("org.openrewrite:rewrite-java-17") + testRuntimeOnly("com.squareup.misk:misk-prometheus:latest.release") + testRuntimeOnly("com.squareup.misk:misk-metrics:latest.release") testImplementation("com.github.ajalt.clikt:clikt:3.5.0") } @@ -50,4 +52,4 @@ compileKotlin.kotlinOptions { val compileTestKotlin: KotlinCompile by tasks compileTestKotlin.kotlinOptions { jvmTarget = "1.8" -} \ No newline at end of file +} diff --git a/src/main/java/org/openrewrite/kotlin/AddImport.java b/src/main/java/org/openrewrite/kotlin/AddImport.java index cea8004c4..a66840256 100644 --- a/src/main/java/org/openrewrite/kotlin/AddImport.java +++ b/src/main/java/org/openrewrite/kotlin/AddImport.java @@ -237,7 +237,39 @@ private boolean hasReference(JavaSourceFile compilationUnit) { if (member == null) { //Non-static imports, we just look for field accesses. for (NameTree t : FindTypes.find(compilationUnit, fullyQualifiedName)) { - if ((!(t instanceof J.FieldAccess) || !((J.FieldAccess) t).isFullyQualifiedClassReference(fullyQualifiedName)) && + JavaType.Class classType = JavaType.ShallowClass.build(fullyQualifiedName); + boolean foundReference = false; + boolean usingAlias = false; + if (t instanceof J.ParameterizedType) { + J.ParameterizedType pt = (J.ParameterizedType) t; + if (pt.getClazz() instanceof J.Identifier) { + String nameInSource = ((J.Identifier) pt.getClazz()).getSimpleName(); + if (alias != null) { + if ( nameInSource.equals(alias)) { + usingAlias = true; + } + } else if (nameInSource.equals(classType.getClassName())) { + foundReference = true; + } + } + } else if (t instanceof J.Identifier) { + String nameInSource = ((J.Identifier) t).getSimpleName(); + if (alias != null) { + if ( nameInSource.equals(alias)) { + usingAlias = true; + } + } else if (nameInSource.equals(classType.getClassName())) { + foundReference = true; + } + } else { + foundReference = true; + } + + if (usingAlias) { + return true; + } + + if (foundReference && (!(t instanceof J.FieldAccess) || !((J.FieldAccess) t).isFullyQualifiedClassReference(fullyQualifiedName)) && isTypeReference(t)) { return true; } diff --git a/src/main/java/org/openrewrite/kotlin/KotlinVisitor.java b/src/main/java/org/openrewrite/kotlin/KotlinVisitor.java index bc655163e..13e1851bb 100644 --- a/src/main/java/org/openrewrite/kotlin/KotlinVisitor.java +++ b/src/main/java/org/openrewrite/kotlin/KotlinVisitor.java @@ -22,7 +22,6 @@ import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.tree.*; import org.openrewrite.kotlin.marker.*; -import org.openrewrite.kotlin.service.KotlinAutoFormatService; import org.openrewrite.kotlin.tree.K; import org.openrewrite.kotlin.tree.KContainer; import org.openrewrite.kotlin.tree.KRightPadded; @@ -66,41 +65,6 @@ public J visitCompilationUnit(J.CompilationUnit cu, P p) { throw new UnsupportedOperationException("Kotlin has a different structure for its compilation unit. See K.CompilationUnit."); } - @Override - public J2 autoFormat(J2 j, P p) { - return autoFormat(j, p, getCursor().getParentTreeCursor()); - } - - @SuppressWarnings({"ConstantConditions", "unchecked"}) - @Override - public J2 autoFormat(J2 j, @Nullable J stopAfter, P p, Cursor cursor) { - KotlinAutoFormatService service = getCursor().firstEnclosingOrThrow(JavaSourceFile.class).service(KotlinAutoFormatService.class); - return (J2) service.autoFormatVisitor(stopAfter).visit(j, p, cursor); - } - - @Override - public J2 autoFormat(J2 j, P p, Cursor cursor) { - return autoFormat(j, null, p, cursor); - } - - @Override - public J2 maybeAutoFormat(J2 before, J2 after, P p) { - return maybeAutoFormat(before, after, p, getCursor().getParentTreeCursor()); - } - - @Override - public J2 maybeAutoFormat(J2 before, J2 after, P p, Cursor cursor) { - return maybeAutoFormat(before, after, null, p, cursor); - } - - @Override - public J2 maybeAutoFormat(J2 before, J2 after, @Nullable J stopAfter, P p, Cursor cursor) { - if (before != after) { - return autoFormat(after, stopAfter, p, cursor); - } - return after; - } - public J visitAnnotatedExpression(K.AnnotatedExpression annotatedExpression, P p) { K.AnnotatedExpression ae = annotatedExpression; ae = ae.withMarkers(visitMarkers(ae.getMarkers(), p)); diff --git a/src/main/java/org/openrewrite/kotlin/format/AutoFormatVisitor.java b/src/main/java/org/openrewrite/kotlin/format/AutoFormatVisitor.java index b926b9caf..fe39cf764 100644 --- a/src/main/java/org/openrewrite/kotlin/format/AutoFormatVisitor.java +++ b/src/main/java/org/openrewrite/kotlin/format/AutoFormatVisitor.java @@ -24,7 +24,6 @@ import org.openrewrite.java.tree.JavaSourceFile; import org.openrewrite.kotlin.KotlinIsoVisitor; import org.openrewrite.kotlin.style.*; -import org.openrewrite.kotlin.tree.K; import org.openrewrite.style.GeneralFormatStyle; import java.util.Optional; @@ -44,11 +43,6 @@ public AutoFormatVisitor(@Nullable Tree stopAfter) { this.stopAfter = stopAfter; } - @Override - public boolean isAcceptable(SourceFile sourceFile, P p) { - return sourceFile instanceof K.CompilationUnit; - } - @Override public J visit(@Nullable Tree tree, P p, Cursor cursor) { JavaSourceFile cu = (tree instanceof JavaSourceFile) ? diff --git a/src/main/java/org/openrewrite/kotlin/format/SpacesVisitor.java b/src/main/java/org/openrewrite/kotlin/format/SpacesVisitor.java index a7b4b1232..3daa5be79 100644 --- a/src/main/java/org/openrewrite/kotlin/format/SpacesVisitor.java +++ b/src/main/java/org/openrewrite/kotlin/format/SpacesVisitor.java @@ -945,7 +945,9 @@ public K.FunctionType visitFunctionType(K.FunctionType functionType, P p) { K.FunctionType kf = super.visitFunctionType(functionType, p); // handle space around arrow in function type - kf = kf.withArrow(updateSpace(kf.getArrow(), style.getOther().getAroundArrowInFunctionTypes())); + if (kf.getArrow() != null) { + kf = kf.withArrow(updateSpace(kf.getArrow(), style.getOther().getAroundArrowInFunctionTypes())); + } kf = kf.withReturnType(spaceBefore(kf.getReturnType(), style.getOther().getAroundArrowInFunctionTypes())); return kf; } diff --git a/src/main/java/org/openrewrite/kotlin/internal/KotlinPrinter.java b/src/main/java/org/openrewrite/kotlin/internal/KotlinPrinter.java index 6ca1f070a..6c20f3897 100755 --- a/src/main/java/org/openrewrite/kotlin/internal/KotlinPrinter.java +++ b/src/main/java/org/openrewrite/kotlin/internal/KotlinPrinter.java @@ -238,7 +238,7 @@ public J visitFunctionType(K.FunctionType functionType, PrintOutputCapture

p) p.append("."); } delegate.visitContainer("(", functionType.getPadding().getParameters(), JContainer.Location.TYPE_PARAMETERS, ",", ")", p); - visitSpace(functionType.getArrow(), KSpace.Location.FUNCTION_TYPE_ARROW_PREFIX, p); + visitSpace(functionType.getArrow() != null ? functionType.getArrow() : Space.SINGLE_SPACE, KSpace.Location.FUNCTION_TYPE_ARROW_PREFIX, p); p.append("->"); visit(functionType.getReturnType(), p); if (nullable) { diff --git a/src/main/java/org/openrewrite/kotlin/style/Autodetect.java b/src/main/java/org/openrewrite/kotlin/style/Autodetect.java index 0f9599e3a..0efbe46e5 100644 --- a/src/main/java/org/openrewrite/kotlin/style/Autodetect.java +++ b/src/main/java/org/openrewrite/kotlin/style/Autodetect.java @@ -418,7 +418,14 @@ public Expression visitExpression(Expression expression, IndentStatistics stats) if (statementExpressions.contains(expression)) { return expression; } - countIndents(expression.getPrefix().getWhitespace(), true, stats); + // (newline-separated) annotations on some common target are not continuations + // (newline-separated) annotations on some common target are not continuations + boolean isContinuation = !(expression instanceof J.Annotation && !( + // ...but annotations which are *arguments* to other annotations can be continuations + getCursor().getParentTreeCursor().getValue() instanceof J.Annotation + || getCursor().getParentTreeCursor().getValue() instanceof J.NewArray + )); + countIndents(expression.getPrefix().getWhitespace(), isContinuation, stats); return expression; } diff --git a/src/main/java/org/openrewrite/kotlin/tree/K.java b/src/main/java/org/openrewrite/kotlin/tree/K.java index db1637576..a886fb470 100644 --- a/src/main/java/org/openrewrite/kotlin/tree/K.java +++ b/src/main/java/org/openrewrite/kotlin/tree/K.java @@ -1011,7 +1011,7 @@ class FunctionType implements K, TypeTree, Expression { public FunctionType(UUID id, Space prefix, Markers markers, List leadingAnnotations, List modifiers, @Nullable JRightPadded receiver, - JContainer parameters, Space arrow, TypedTree returnType) { + JContainer parameters, @Nullable Space arrow, TypedTree returnType) { this.id = id; this.prefix = prefix; this.markers = markers; @@ -1071,6 +1071,7 @@ public FunctionType withParameters(List parameters) { return getPadding().withParameters(JContainer.withElementsNullable(this.parameters, parameters)); } + @Nullable // nullable for LST backwards compatibility reasons only @With @Getter Space arrow; diff --git a/src/main/kotlin/org/openrewrite/kotlin/internal/KotlinParserVisitor.kt b/src/main/kotlin/org/openrewrite/kotlin/internal/KotlinParserVisitor.kt index c74b528ea..7505ecbdb 100644 --- a/src/main/kotlin/org/openrewrite/kotlin/internal/KotlinParserVisitor.kt +++ b/src/main/kotlin/org/openrewrite/kotlin/internal/KotlinParserVisitor.kt @@ -1175,10 +1175,22 @@ class KotlinParserVisitor( markers = markers.addIfAbsent(Extension(randomId())) } - val implicitExtensionFunction = functionCall is FirImplicitInvokeCall + var hasExplicitReceiver = false + if (functionCall is FirImplicitInvokeCall) { + val explicitReceiver = functionCall.explicitReceiver + if (explicitReceiver is FirPropertyAccessExpression) { + if (explicitReceiver.explicitReceiver != null) { + hasExplicitReceiver = true + } + } + } + + var implicitExtensionFunction = functionCall is FirImplicitInvokeCall && functionCall.arguments.isNotEmpty() && functionCall.source != null && functionCall.source!!.startOffset < functionCall.calleeReference.source!!.startOffset + && !hasExplicitReceiver + if (functionCall !is FirImplicitInvokeCall || implicitExtensionFunction) { val receiver = if (implicitExtensionFunction) functionCall.arguments[0] else getReceiver(functionCall.explicitReceiver) if (receiver != null) { @@ -1194,6 +1206,13 @@ class KotlinParserVisitor( } } + if (functionCall is FirImplicitInvokeCall) { + val receiver = functionCall.explicitReceiver + if (receiver is FirPropertyAccessExpression && receiver.explicitReceiver != null) { + select = padRight(convertToExpression(receiver.explicitReceiver as FirElement, data)!!, whitespace()) + } + } + val name = visitElement(namedReference, data) as J.Identifier var typeParams: JContainer? = null if (functionCall.typeArguments.isNotEmpty()) { @@ -1413,6 +1432,18 @@ class KotlinParserVisitor( var callPsi = getPsiElement(firCall)!! callPsi = if (callPsi is KtDotQualifiedExpression || callPsi is KtSafeQualifiedExpression) callPsi.lastChild else callPsi val firArguments = if (skipFirstArgument) firCall.argumentList.arguments.subList(1, firCall.argumentList.arguments.size) else firCall.argumentList.arguments + + var hasParentheses = false + var lPAROffset = 0 + if (firCall.argumentList.source is KtRealPsiSourceElement) { + val firArgumentsSource = firCall.argumentList.source.psi + val firstChild = firArgumentsSource?.firstChild + if (firstChild != null && firstChild.node.elementType == KtTokens.LPAR) { + hasParentheses = true + lPAROffset = firstChild.node.startOffset + } + } + val flattenedExpressions = firArguments.stream() .map { e -> if (e is FirVarargArgumentsExpression) e.arguments else listOf(e) } .flatMap { it.stream() } @@ -1441,7 +1472,7 @@ class KotlinParserVisitor( cursor++ saveCursor = cursor parenOrBrace = source[cursor] - } else if ((parenOrBrace != '(' && parenOrBrace != '[') || isInfix) { + } else if (!hasParentheses && ((parenOrBrace != '(' && parenOrBrace != '[') || isInfix)) { cursor(saveCursor) containerPrefix = Space.EMPTY markers = markers.addIfAbsent(OmitParentheses(randomId())) @@ -1469,6 +1500,12 @@ class KotlinParserVisitor( for (i in flattenedExpressions.indices) { isTrailingLambda = hasTrailingLambda && i == argumentCount - 1 val expression = flattenedExpressions[i] + + // Didn't find a way to proper reset the cursor, so have to do a hard reset here + if (firCall is FirImplicitInvokeCall && hasParentheses && i == 0) { + cursor = lPAROffset + 1 + } + var expr = convertToExpression(expression, data)!! if (isTrailingLambda && expr !is J.Empty) { expr = expr.withMarkers(expr.markers.addIfAbsent(TrailingLambdaArgument(randomId()))) diff --git a/src/test/java/org/openrewrite/kotlin/ChangeTypeTest.java b/src/test/java/org/openrewrite/kotlin/ChangeTypeTest.java index 65efd003e..4048cf0df 100644 --- a/src/test/java/org/openrewrite/kotlin/ChangeTypeTest.java +++ b/src/test/java/org/openrewrite/kotlin/ChangeTypeTest.java @@ -15,6 +15,7 @@ */ package org.openrewrite.kotlin; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.Issue; @@ -22,11 +23,12 @@ import org.openrewrite.java.tree.TypeUtils; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; +import org.openrewrite.test.SourceSpec; import static org.assertj.core.api.Assertions.assertThat; import static org.openrewrite.kotlin.Assertions.kotlin; -public class ChangeTypeTest implements RewriteTest { +class ChangeTypeTest implements RewriteTest { @Override public void defaults(RecipeSpec spec) { spec.recipe(new ChangeType("a.b.Original", "x.y.Target", true)); @@ -59,6 +61,34 @@ class A { ); } + + @Test + void changeImportAlias() { + rewriteRun( + kotlin( + """ + package a.b + class Original + """), + kotlin( + """ + import a.b.Original as MyAlias + + class A { + val type : MyAlias = MyAlias() + } + """, + """ + import x.y.Target as MyAlias + + class A { + val type : MyAlias = MyAlias() + } + """ + ) + ); + } + @Issue("https://github.com/openrewrite/rewrite-kotlin/issues/42") @Test void changeTypeWithGenericArgument() { @@ -92,6 +122,40 @@ fun test(original: Target) { } ); } + + @Issue("https://github.com/openrewrite/rewrite-kotlin/issues/42") + @Test + void changeTypeWithGenericArgumentAlias() { + rewriteRun( + kotlin( + """ + package a.b + class Original + """), + kotlin( + """ + package x.y + class Target + """), + kotlin( + """ + package example + + import a.b.Original as MyAlias + + fun test(original: MyAlias) { } + """, + """ + package example + + import x.y.Target as MyAlias + + fun test(original: MyAlias) { } + """ + ) + ); + } + @Test void changeTypeWithGenericArgumentFullyQualified() { rewriteRun( @@ -243,6 +307,27 @@ fun main() { ); } + @Test + void implicitImport() { + rewriteRun( + spec -> spec.recipe(new ChangeType("java.util.ArrayList", "java.util.LinkedList", true)), + kotlin( + """ + fun main() { + val list = ArrayList() + } + """, + """ + import java.util.LinkedList + + fun main() { + val list = LinkedList() + } + """ + ) + ); + } + @Test void qualifiedReference() { rewriteRun( @@ -256,12 +341,38 @@ fun main() { } """, """ - import java.util.LinkedList as MyList + fun main() { val list2 = java.util.LinkedList() } """ + , SourceSpec::noTrim + ) + ); + } + + @Disabled + @Test + void fromLibrary() { + rewriteRun( + spec -> spec.recipe(new ChangeType( + "misk.metrics.backends.prometheus.v2.PrometheusMetrics", + "misk.metrics.v2.Metrics", true)) + .parser(KotlinParser.builder() + .classpath("misk-prometheus", "misk-metrics")) + , + kotlin( + """ + import misk.metrics.backends.prometheus.v2.PrometheusMetrics + + class A(val a: PrometheusMetrics) + """, + """ + import java.util.LinkedList as MyList + + class A(val a: Metrics) + """ ) ); } diff --git a/src/test/java/org/openrewrite/kotlin/style/AutodetectTest.java b/src/test/java/org/openrewrite/kotlin/style/AutodetectTest.java index cfcec2a02..c0b930e3f 100644 --- a/src/test/java/org/openrewrite/kotlin/style/AutodetectTest.java +++ b/src/test/java/org/openrewrite/kotlin/style/AutodetectTest.java @@ -17,6 +17,7 @@ import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.junitpioneer.jupiter.ExpectedToFail; import org.openrewrite.Issue; @@ -1016,4 +1017,61 @@ class Test { assertThat(tabsAndIndents.getContinuationIndent()).isEqualTo(5); } + + @Nested + class ContinuationIndentForAnnotations { + + @Test + @Issue("https://github.com/openrewrite/rewrite/issues/3568") + void ignoreSpaceBetweenAnnotations() { + var cus = kp().parse( + """ + class Test { + @SafeVarargs + @Deprecated("") + @Suppress("more", "mistakes") + fun count(vararg strings: String) { + return strings.length + } + } + """ + ); + + var detector = Autodetect.detector(); + cus.forEach(detector::sample); + var styles = detector.build(); + var tabsAndIndents = NamedStyles.merge(TabsAndIndentsStyle.class, singletonList(styles)); + + assertThat(tabsAndIndents.getIndentSize()).isEqualTo(4); + assertThat(tabsAndIndents.getContinuationIndent()) + .as("With no actual continuation indents to go off of, assume IntelliJ IDEA default of 2x the normal indent") + .isEqualTo(8); + } + + @Test + void includeAnnotationAsAnnotationArg() { + var cus = kp().parse( + """ + annotation class Foo + annotation class Foos(val value: Array) + + class Test { + @Foos( + value = [Foo()]) + fun count(vararg strings: String) { + return strings.length + } + } + """ + ); + + var detector = Autodetect.detector(); + cus.forEach(detector::sample); + var styles = detector.build(); + var tabsAndIndents = NamedStyles.merge(TabsAndIndentsStyle.class, singletonList(styles)); + + assertThat(tabsAndIndents.getIndentSize()).isEqualTo(4); + assertThat(tabsAndIndents.getContinuationIndent()).isEqualTo(3); + } + } } diff --git a/src/test/java/org/openrewrite/kotlin/tree/FunctionTypeTest.java b/src/test/java/org/openrewrite/kotlin/tree/FunctionTypeTest.java index 8f531f6ac..c1a1e57ba 100644 --- a/src/test/java/org/openrewrite/kotlin/tree/FunctionTypeTest.java +++ b/src/test/java/org/openrewrite/kotlin/tree/FunctionTypeTest.java @@ -24,6 +24,30 @@ class FunctionTypeTest implements RewriteTest { + @Issue("https://github.com/openrewrite/rewrite-kotlin/issues/326") + @Test + void functionWithFunctionTypeParameter() { + rewriteRun( + kotlin( + """ + class GradleSpigotDependencyLoaderTestBuilder( + var init: TestInitializer.() -> Unit = {} + ) { + } + + class TestInitializer( + val resourcesDir: String + ) + + fun runTest() { + val builder = GradleSpigotDependencyLoaderTestBuilder() + builder.init(TestInitializer("null")) + } + """ + ) + ); + } + @Test void nested() { rewriteRun(