Skip to content

Commit

Permalink
Make JavaTemplate engine extensible (#3475)
Browse files Browse the repository at this point in the history
* Add new `JvmParser` as supertype to `JavaParser`

To allow the `JavaTemplate` mechanism to be used for other languages like Kotlin, the provided parser builder must provide a way to add a classpath entry. This is because the internal `__M__` and `__P__` types are required for the parameter substitution and rather than including them as source code in the compilation unit, the parser will load them from the classpath.

To this end this PR introduces new `JvmParser` and `JvmParser.Builder` types and the `JavaParser.Builder#classpath()` methods are moved to `JvmParser.Builder`. In order to not increase the API surface area, no corresponding getter was added to `JvmParser.Builder`. Instead, there is a new `Internals` class providing a static `getClasspath(JvmParser.Builder)` method.

Issue: openrewrite/rewrite-kotlin#218

* Correct bounds on `JvmParser.Builder`

* Make `JvmParser.Builder` constructor protected

* No need for initializer block when templating context-free expression

* Expose more of `JavaTemplate` engine API for extension

* Remove `JvmParser` and `Internals` again

* Make `Substitutions` extensible

* Lazily initialize `TEMPLATE_CLASSPATH_DIR` at runtime

* Add `BlockStatementTemplateGenerator#TEMPLATE_INTERNAL_IMPORTS`

* Correction to `Substitutions` after merging

* Rename new `addClasspath()` to `classpathEntry()`

* Remove unused `JavaTemplate#parameterCount`

* Update rewrite-java/src/main/java/org/openrewrite/java/internal/template/__M__.java

* Update rewrite-java/src/main/java/org/openrewrite/java/internal/template/__P__.java

* Revert accidental visibility change

* Fix problem with classpath in `JavaParser.Builder`
  • Loading branch information
knutwannheden authored Mar 4, 2024
1 parent 675f9d5 commit 3936ba1
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 55 deletions.
14 changes: 12 additions & 2 deletions rewrite-java/src/main/java/org/openrewrite/java/JavaParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ static List<Path> dependenciesFromResources(ExecutionContext ctx, String... arti
}

if (!missingArtifactNames.isEmpty()) {
//noinspection ConstantValue
throw new IllegalArgumentException("Unable to find classpath resource dependencies beginning with: " +
missingArtifactNames.stream().map(a -> "'" + a + "'").sorted().collect(joining(", ", "", ".\n")) +
"The caller is of type " + caller.getName() + ".\n" +
Expand Down Expand Up @@ -367,6 +366,16 @@ public B classpath(Collection<Path> classpath) {
return (B) this;
}

// internal method which doesn't overwrite the classpath but just amends it
B addClasspathEntry(Path entry) {
if (this.classpath.isEmpty()) {
this.classpath = Collections.singletonList(entry);
} else {
this.classpath.add(entry);
}
return (B) this;
}

public B classpath(String... artifactNames) {
this.artifactNames = Arrays.asList(artifactNames);
this.classpath = Collections.emptyList();
Expand Down Expand Up @@ -394,7 +403,8 @@ public B styles(Iterable<? extends NamedStyles> styles) {

protected Collection<Path> resolvedClasspath() {
if (!artifactNames.isEmpty()) {
classpath = JavaParser.dependenciesFromClasspath(artifactNames.toArray(new String[0]));
classpath = new ArrayList<>(classpath);
classpath.addAll(JavaParser.dependenciesFromClasspath(artifactNames.toArray(new String[0])));
artifactNames = Collections.emptyList();
}
return classpath;
Expand Down
57 changes: 49 additions & 8 deletions rewrite-java/src/main/java/org/openrewrite/java/JavaTemplate.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,67 @@
import org.openrewrite.Cursor;
import org.openrewrite.Incubating;
import org.openrewrite.internal.StringUtils;
import org.openrewrite.internal.lang.Nullable;
import org.openrewrite.java.internal.template.JavaTemplateJavaExtension;
import org.openrewrite.java.internal.template.JavaTemplateParser;
import org.openrewrite.java.internal.template.Substitutions;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaCoordinates;
import org.openrewrite.template.SourceTemplate;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Consumer;

@SuppressWarnings("unused")
public class JavaTemplate implements SourceTemplate<J, JavaCoordinates> {

@Nullable
private static Path TEMPLATE_CLASSPATH_DIR;

protected static Path getTemplateClasspathDir() {
if (TEMPLATE_CLASSPATH_DIR == null) {
try {
TEMPLATE_CLASSPATH_DIR = Files.createTempDirectory("java-template");
Path templateDir = Files.createDirectories(TEMPLATE_CLASSPATH_DIR.resolve("org/openrewrite/java/internal/template"));
try (InputStream in = JavaTemplateParser.class.getClassLoader().getResourceAsStream("org/openrewrite/java/internal/template/__M__.class")) {
assert in != null;
Files.copy(in, templateDir.resolve("__M__.class"));
}
try (InputStream in = JavaTemplateParser.class.getClassLoader().getResourceAsStream("org/openrewrite/java/internal/template/__P__.class")) {
assert in != null;
Files.copy(in, templateDir.resolve("__P__.class"));
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return TEMPLATE_CLASSPATH_DIR;
}

@Getter
private final String code;

private final Consumer<String> onAfterVariableSubstitution;
private final JavaTemplateParser templateParser;

private JavaTemplate(boolean contextSensitive, JavaParser.Builder<?, ?> javaParser, String code, Set<String> imports,
private JavaTemplate(boolean contextSensitive, JavaParser.Builder<?, ?> parser, String code, Set<String> imports,
Consumer<String> onAfterVariableSubstitution, Consumer<String> onBeforeParseTemplate) {
this(code, onAfterVariableSubstitution, new JavaTemplateParser(contextSensitive, augmentClasspath(parser), onAfterVariableSubstitution, onBeforeParseTemplate, imports));
}

private static JavaParser.Builder<?,?> augmentClasspath(JavaParser.Builder<?,?> parserBuilder) {
return parserBuilder.addClasspathEntry(getTemplateClasspathDir());
}

protected JavaTemplate(String code, Consumer<String> onAfterVariableSubstitution, JavaTemplateParser templateParser) {
this.code = code;
this.onAfterVariableSubstitution = onAfterVariableSubstitution;
this.templateParser = new JavaTemplateParser(contextSensitive, javaParser, onAfterVariableSubstitution, onBeforeParseTemplate, imports);
this.templateParser = templateParser;
}

@Override
Expand All @@ -54,7 +91,7 @@ public <J2 extends J> J2 apply(Cursor scope, JavaCoordinates coordinates, Object
throw new IllegalArgumentException("`scope` must point to a J instance.");
}

Substitutions substitutions = new Substitutions(code, parameters);
Substitutions substitutions = substitutions(parameters);
String substitutedTemplate = substitutions.substitute();
onAfterVariableSubstitution.accept(substitutedTemplate);

Expand All @@ -64,6 +101,10 @@ public <J2 extends J> J2 apply(Cursor scope, JavaCoordinates coordinates, Object
.visit(scope.getValue(), 0, scope.getParentOrThrow());
}

protected Substitutions substitutions(Object[] parameters) {
return new Substitutions(code, parameters);
}

@Incubating(since = "8.0.0")
public static boolean matches(String template, Cursor cursor) {
return JavaTemplate.builder(template).build().matches(cursor);
Expand Down Expand Up @@ -117,14 +158,14 @@ public static class Builder {

private boolean contextSensitive;

private JavaParser.Builder<?, ?> javaParser = JavaParser.fromJavaVersion();
private JavaParser.Builder<?, ?> parser = org.openrewrite.java.JavaParser.fromJavaVersion();

private Consumer<String> onAfterVariableSubstitution = s -> {
};
private Consumer<String> onBeforeParseTemplate = s -> {
};

Builder(String code) {
protected Builder(String code) {
this.code = code.trim();
}

Expand Down Expand Up @@ -173,8 +214,8 @@ private void validateImport(String typeName) {
}
}

public Builder javaParser(JavaParser.Builder<?, ?> javaParser) {
this.javaParser = javaParser;
public Builder javaParser(JavaParser.Builder<?, ?> parser) {
this.parser = parser;
return this;
}

Expand All @@ -189,7 +230,7 @@ public Builder doBeforeParseTemplate(Consumer<String> beforeParseTemplate) {
}

public JavaTemplate build() {
return new JavaTemplate(contextSensitive, javaParser, code, imports,
return new JavaTemplate(contextSensitive, parser.clone(), code, imports,
onAfterVariableSubstitution, onBeforeParseTemplate);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,9 @@
public class BlockStatementTemplateGenerator {
private static final String TEMPLATE_COMMENT = "__TEMPLATE__";
private static final String STOP_COMMENT = "__TEMPLATE_STOP__";
static final String EXPR_STATEMENT_PARAM = "class __P__ {" +
" static native <T> T p();" +
" static native <T> T[] arrp();" +
" static native boolean booleanp();" +
" static native byte bytep();" +
" static native char charp();" +
" static native double doublep();" +
" static native int intp();" +
" static native long longp();" +
" static native short shortp();" +
" static native float floatp();" +
"}";
private static final String METHOD_INVOCATION_STUBS = "class __M__ {" +
" static native Object any(Object o);" +
" static native Object any(java.util.function.Predicate<Boolean> o);" +
" static native <T> Object anyT();" +
"}";

private final Set<String> imports;
protected static final String TEMPLATE_INTERNAL_IMPORTS = "import org.openrewrite.java.internal.template.__M__;\nimport org.openrewrite.java.internal.template.__P__;\n";

protected final Set<String> imports;
private final boolean contextSensitive;

public String template(Cursor cursor, String template, Space.Location location, JavaCoordinates.Mode mode) {
Expand Down Expand Up @@ -219,7 +203,7 @@ private void template(Cursor cursor, J prior, StringBuilder before, StringBuilde
}

@SuppressWarnings("DataFlowIssue")
private void contextFreeTemplate(Cursor cursor, J j, StringBuilder before, StringBuilder after) {
protected void contextFreeTemplate(Cursor cursor, J j, StringBuilder before, StringBuilder after) {
if (j instanceof J.Lambda) {
throw new IllegalArgumentException(
"Templating a lambda requires a cursor so that it can be properly parsed and type-attributed. " +
Expand All @@ -229,10 +213,10 @@ private void contextFreeTemplate(Cursor cursor, J j, StringBuilder before, Strin
"Templating a method reference requires a cursor so that it can be properly parsed and type-attributed. " +
"Mark this template as context-sensitive by calling JavaTemplate.Builder#contextSensitive().");
} else if (j instanceof Expression && !(j instanceof J.Assignment)) {
before.insert(0, "class Template {{\n");
before.insert(0, "class Template {\n");
before.append("Object o = ");
after.append(";");
after.append("\n}}");
after.append("\n}");
} else if ((j instanceof J.MethodDeclaration || j instanceof J.VariableDeclarations || j instanceof J.Block || j instanceof J.ClassDeclaration)
&& cursor.getValue() instanceof J.Block
&& (cursor.getParent().getValue() instanceof J.ClassDeclaration || cursor.getParent().getValue() instanceof J.NewClass)) {
Expand All @@ -252,7 +236,7 @@ private void contextFreeTemplate(Cursor cursor, J j, StringBuilder before, Strin
after.append("\n}}");
}

before.insert(0, EXPR_STATEMENT_PARAM + METHOD_INVOCATION_STUBS);
before.insert(0, TEMPLATE_INTERNAL_IMPORTS);
for (String anImport : imports) {
before.insert(0, anImport);
}
Expand All @@ -262,7 +246,7 @@ private void contextFreeTemplate(Cursor cursor, J j, StringBuilder before, Strin
private void contextTemplate(Cursor cursor, J prior, StringBuilder before, StringBuilder after, J insertionPoint, JavaCoordinates.Mode mode) {
J j = cursor.getValue();
if (j instanceof JavaSourceFile) {
before.insert(0, EXPR_STATEMENT_PARAM + METHOD_INVOCATION_STUBS);
before.insert(0, TEMPLATE_INTERNAL_IMPORTS);

JavaSourceFile cu = (JavaSourceFile) j;
for (J.Import anImport : cu.getImports()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.InMemoryExecutionContext;
import org.openrewrite.Parser;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.internal.PropertyPlaceholderHelper;
import org.openrewrite.java.JavaParser;
Expand Down Expand Up @@ -52,23 +53,35 @@ public class JavaTemplateParser {
@Language("java")
private static final String SUBSTITUTED_ANNOTATION = "@java.lang.annotation.Documented public @interface SubAnnotation { int value(); }";

private final JavaParser.Builder<?, ?> parser;
private final Parser.Builder parser;
private final Consumer<String> onAfterVariableSubstitution;
private final Consumer<String> onBeforeParseTemplate;
private final Set<String> imports;
private final boolean contextSensitive;
private final BlockStatementTemplateGenerator statementTemplateGenerator;
private final AnnotationTemplateGenerator annotationTemplateGenerator;

public JavaTemplateParser(boolean contextSensitive, JavaParser.Builder<?, ?> parser, Consumer<String> onAfterVariableSubstitution,
public JavaTemplateParser(boolean contextSensitive, Parser.Builder parser, Consumer<String> onAfterVariableSubstitution,
Consumer<String> onBeforeParseTemplate, Set<String> imports) {
this(
parser,
onAfterVariableSubstitution,
onBeforeParseTemplate,
imports,
contextSensitive,
new BlockStatementTemplateGenerator(imports, contextSensitive),
new AnnotationTemplateGenerator(imports)
);
}

protected JavaTemplateParser(Parser.Builder parser, Consumer<String> onAfterVariableSubstitution, Consumer<String> onBeforeParseTemplate, Set<String> imports, boolean contextSensitive, BlockStatementTemplateGenerator statementTemplateGenerator, AnnotationTemplateGenerator annotationTemplateGenerator) {
this.parser = parser;
this.onAfterVariableSubstitution = onAfterVariableSubstitution;
this.onBeforeParseTemplate = onBeforeParseTemplate;
this.imports = imports;
this.contextSensitive = contextSensitive;
this.statementTemplateGenerator = new BlockStatementTemplateGenerator(imports, contextSensitive);
this.annotationTemplateGenerator = new AnnotationTemplateGenerator(imports);
this.statementTemplateGenerator = statementTemplateGenerator;
this.annotationTemplateGenerator = annotationTemplateGenerator;
}

public List<Statement> parseParameters(Cursor cursor, String template) {
Expand Down Expand Up @@ -241,7 +254,7 @@ private JavaSourceFile compileTemplate(@Language("java") String stub) {
ExecutionContext ctx = new InMemoryExecutionContext();
ctx.putMessage(JavaParser.SKIP_SOURCE_SET_TYPE_GENERATION, true);
ctx.putMessage(ExecutionContext.REQUIRE_PRINT_EQUALS_INPUT, false);
JavaParser jp = parser.clone().build();
Parser jp = parser.build();
return (stub.contains("@SubAnnotation") ?
jp.reset().parse(ctx, stub, SUBSTITUTED_ANNOTATION) :
jp.reset().parse(ctx, stub))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,12 @@ private String substituteTypedPattern(String key, int index, TemplateParameterPa
}
}

s = "(/*__p" + index + "__*/new ";

StringBuilder extraDim = new StringBuilder();
int dimensions = 1;
for (; arrayType.getElemType() instanceof JavaType.Array; arrayType = (JavaType.Array) arrayType.getElemType()) {
extraDim.append("[0]");
}

if (arrayType.getElemType() instanceof JavaType.Primitive) {
s += ((JavaType.Primitive) arrayType.getElemType()).getKeyword();
} else if (arrayType.getElemType() instanceof JavaType.FullyQualified) {
s += ((JavaType.FullyQualified) arrayType.getElemType()).getFullyQualifiedName().replace("$", ".");
dimensions++;
}

s += "[0]" + extraDim + ")";
s = "(" + newArrayParameter(arrayType.getElemType(), dimensions, index) + ")";
} else if ("any".equals(matcherName)) {
JavaType type;
if (param != null) {
Expand All @@ -151,11 +143,10 @@ private String substituteTypedPattern(String key, int index, TemplateParameterPa
}

String fqn = getTypeName(type);
JavaType.Primitive primitive = type instanceof JavaType.Primitive ? (JavaType.Primitive) type : null;
s = "__P__." + (primitive == null || primitive.equals(JavaType.Primitive.String) ?
"<" + fqn + ">/*__p" + index + "__*/p()" :
"/*__p" + index + "__*/" + fqn + "p()"
);
JavaType.Primitive primitive = JavaType.Primitive.fromKeyword(fqn);
s = primitive == null || primitive.equals(JavaType.Primitive.String) ?
newObjectParameter(fqn, index) :
newPrimitiveParameter(fqn, index);

parameters[index] = ((J) parameter).withPrefix(Space.EMPTY);
} else {
Expand All @@ -164,6 +155,27 @@ private String substituteTypedPattern(String key, int index, TemplateParameterPa
return s;
}

protected String newObjectParameter(String fqn, int index) {
return "__P__." + "<" + fqn + ">/*__p" + index + "__*/p()";
}

protected String newPrimitiveParameter(String fqn, int index) {
return "__P__./*__p" + index + "__*/" + fqn + "p()";
}

protected String newArrayParameter(JavaType elemType, int dimensions, int index) {
StringBuilder builder = new StringBuilder("/*__p" + index + "__*/" + "new ");
if (elemType instanceof JavaType.Primitive) {
builder.append(((JavaType.Primitive) elemType).getKeyword());
} else if (elemType instanceof JavaType.FullyQualified) {
builder.append(((JavaType.FullyQualified) elemType).getFullyQualifiedName().replace("$", "."));
}
for (int i = 0; i < dimensions; i++) {
builder.append("[0]");
}
return builder.toString();
}

private String getTypeName(@Nullable JavaType type) {
if (type == null) {
return "java.lang.Object";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2023 the original author or authors.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* https://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.openrewrite.java.internal.template;

public class __M__ {
public static native Object any(Object o);

public static native Object any(java.util.function.Predicate<Boolean> o);

public static native <T> Object anyT();
}
Loading

0 comments on commit 3936ba1

Please sign in to comment.