Skip to content

Commit

Permalink
Always emit fully qualified code from TemplateProcessor (#55)
Browse files Browse the repository at this point in the history
* Always emit fully qualified code from `TemplateProcessor`

`TemplateProcessor` now always emits fully qualified type, field and method references regardless of what the input Java code looks like. The idea being that then a `ShortenFullyQualifiedTypeReferences` will remove unnecessary full qualifications.

* Remove unused code

* Don't print `java.lang` packages

* Let `TemplateCode` produce entire `JavaTemplate.Builder`

* Polish

* Allow full qualification to be configured

* Polish tests

---------

Co-authored-by: Tim te Beek <[email protected]>
  • Loading branch information
knutwannheden and timtebeek authored Dec 27, 2023
1 parent 10cd1b0 commit 7516697
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 114 deletions.
149 changes: 149 additions & 0 deletions src/main/java/org/openrewrite/java/template/internal/TemplateCode.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright 2023 the original author or authors.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* https://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.openrewrite.java.template.internal;

import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCIdent;
import com.sun.tools.javac.tree.Pretty;

import java.io.IOException;
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.io.Writer;
import java.util.*;
import java.util.stream.Collectors;

public class TemplateCode {

public static <T extends JCTree> String process(T tree, List<JCTree.JCVariableDecl> parameters, boolean fullyQualified) {
StringWriter writer = new StringWriter();
TemplateCodePrinter printer = new TemplateCodePrinter(writer, parameters, fullyQualified);
try {
printer.printExpr(tree);
StringBuilder builder = new StringBuilder("JavaTemplate\n");
builder
.append(" .builder(\"")
.append(writer.toString().replace("\\", "\\\\").replace("\"", "\\\""))
.append("\")");
if (!printer.imports.isEmpty()) {
builder.append("\n .imports(").append(printer.imports.stream().map(i -> '"' + i + '"').collect(Collectors.joining(", "))).append(")");
}
if (!printer.staticImports.isEmpty()) {
builder.append("\n .staticImports(").append(printer.staticImports.stream().map(i -> '"' + i + '"').collect(Collectors.joining(", "))).append(")");
}
List<Symbol> imports = ImportDetector.imports(tree);
String classpath = ClasspathJarNameDetector.classpathFor(tree, imports);
if (!classpath.isEmpty()) {
builder.append("\n .javaParser(JavaParser.fromJavaVersion().classpath(").append(classpath).append("))");
}
return builder.toString();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static class TemplateCodePrinter extends Pretty {

private static final String PRIMITIVE_ANNOTATION = "org.openrewrite.java.template.Primitive";
private final List<JCTree.JCVariableDecl> declaredParameters;
private final boolean fullyQualified;
private final Set<JCTree.JCVariableDecl> seenParameters = new HashSet<>();
private final TreeSet<String> imports = new TreeSet<>();
private final TreeSet<String> staticImports = new TreeSet<>();

public TemplateCodePrinter(Writer writer, List<JCTree.JCVariableDecl> declaredParameters, boolean fullyQualified) {
super(writer, true);
this.declaredParameters = declaredParameters;
this.fullyQualified = fullyQualified;
}

@Override
public void visitIdent(JCIdent jcIdent) {
try {
Symbol sym = jcIdent.sym;
Optional<JCTree.JCVariableDecl> param = declaredParameters.stream().filter(p -> p.sym == sym).findFirst();
if (param.isPresent()) {
print("#{" + sym.name);
if (seenParameters.add(param.get())) {
String type = param.get().type.toString();
if (param.get().getModifiers().getAnnotations().stream().anyMatch(a -> a.attribute.type.tsym.getQualifiedName().toString().equals(PRIMITIVE_ANNOTATION))) {
type = getUnboxedPrimitive(type);
}
print(":any(" + type + ")");
}
print("}");
} else if (sym != null) {
print(sym);
} else {
print(jcIdent.name);
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

void print(Symbol sym) throws IOException {
if (sym instanceof Symbol.ClassSymbol) {
if (fullyQualified) {
print(sym.packge().fullname.contentEquals("java.lang") ? sym.name : sym.getQualifiedName());
} else {
print(sym.name);
if (!sym.packge().fullname.contentEquals("java.lang")) {
imports.add(sym.getQualifiedName().toString());
}
}
} else if (sym instanceof Symbol.MethodSymbol || sym instanceof Symbol.VarSymbol) {
if (fullyQualified) {
print(sym.owner);
print('.');
print(sym.name);
} else {
print(sym.name);
if (!sym.packge().fullname.contentEquals("java.lang")) {
staticImports.add(sym.owner.getQualifiedName() + "." + sym.name);
}
}
} else if (sym instanceof Symbol.PackageSymbol) {
print(sym.getQualifiedName());
}
}

private String getUnboxedPrimitive(String paramType) {
switch (paramType) {
case "java.lang.Boolean":
return "boolean";
case "java.lang.Byte":
return "byte";
case "java.lang.Character":
return "char";
case "java.lang.Double":
return "double";
case "java.lang.Float":
return "float";
case "java.lang.Integer":
return "int";
case "java.lang.Long":
return "long";
case "java.lang.Short":
return "short";
case "java.lang.Void":
return "void";
}
return paramType;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,30 @@
import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
import com.sun.tools.javac.tree.TreeScanner;
import com.sun.tools.javac.util.Context;
import org.openrewrite.java.template.internal.ClasspathJarNameDetector;
import org.openrewrite.java.template.internal.ImportDetector;
import org.openrewrite.java.template.internal.JavacResolution;
import org.openrewrite.java.template.internal.TemplateCode;

import javax.annotation.processing.RoundEnvironment;
import javax.annotation.processing.SupportedAnnotationTypes;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import javax.tools.Diagnostic.Kind;
import javax.tools.JavaFileObject;
import java.io.*;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.Writer;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.Collections.*;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;

/**
* For steps to debug this annotation processor, see
* <a href="https://medium.com/@joachim.beckers/debugging-an-annotation-processor-using-intellij-idea-in-2018-cde72758b78a">this blog post</a>.
*/
@SupportedAnnotationTypes("*")
public class TemplateProcessor extends TypeAwareProcessor {
private static final String PRIMITIVE_ANNOTATION = "org.openrewrite.java.template.Primitive";

private final String javaFileContent;

public TemplateProcessor(String javaFileContent) {
this.javaFileContent = javaFileContent;
}

public TemplateProcessor() {
this(null);
}

@Override
public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment roundEnv) {
Expand Down Expand Up @@ -99,71 +90,18 @@ public void visitApply(JCTree.JCMethodInvocation tree) {

JCTree.JCLambda template = arg2 instanceof JCTree.JCLambda ? (JCTree.JCLambda) arg2 : (JCTree.JCLambda) ((JCTree.JCTypeCast) arg2).getExpression();

NavigableMap<Integer, JCTree.JCVariableDecl> parameterPositions;
List<JCTree.JCVariableDecl> parameters;
if (template.getParameters().isEmpty()) {
parameterPositions = emptyNavigableMap();
parameters = emptyList();
} else {
parameterPositions = new TreeMap<>();
Map<JCTree, JCTree> parameterResolution = res.resolveAll(context, cu, template.getParameters());
parameters = new ArrayList<>(template.getParameters().size());
for (VariableTree p : template.getParameters()) {
parameters.add((JCTree.JCVariableDecl) parameterResolution.get((JCTree) p));
}
JCTree.JCLambda resolvedTemplate = (JCTree.JCLambda) parameterResolution.get(template);

new TreeScanner() {
@Override
public void visitIdent(JCTree.JCIdent ident) {
for (JCTree.JCVariableDecl parameter : parameters) {
if (parameter.sym == ident.sym) {
parameterPositions.put(ident.getStartPosition(), parameter);
}
}
}
}.scan(resolvedTemplate.getBody());
}

try (InputStream inputStream = javaFileContent == null ?
cu.getSourceFile().openInputStream() : new ByteArrayInputStream(javaFileContent.getBytes())) {
//noinspection ResultOfMethodCallIgnored
inputStream.skip(template.getBody().getStartPosition());

byte[] templateSourceBytes = new byte[template.getBody().getEndPosition(cu.endPositions) - template.getBody().getStartPosition()];

//noinspection ResultOfMethodCallIgnored
inputStream.read(templateSourceBytes);

String templateSource = new String(templateSourceBytes);
templateSource = templateSource.replace("\\", "\\\\").replace("\"", "\\\"");

for (Map.Entry<Integer, JCTree.JCVariableDecl> paramPos : parameterPositions.descendingMap().entrySet()) {
JCTree.JCVariableDecl param = paramPos.getValue();

String typeDef = "";

// identify whether this is the leftmost occurrence of this parameter name
if (Objects.equals(parameterPositions.entrySet().stream().filter(p -> p.getValue() == param)
.map(Map.Entry::getKey)
.findFirst().orElse(null), paramPos.getKey())) {
String type = param.type.toString();
for (JCTree.JCAnnotation annotation : param.getModifiers().getAnnotations()) {
if (annotation.type.tsym.getQualifiedName().contentEquals(PRIMITIVE_ANNOTATION)) {
type = getUnboxedPrimitive(param.type.toString());
// don't generate the annotation into the source code
param.mods.annotations = com.sun.tools.javac.util.List.filter(param.mods.annotations, annotation);
}
}
typeDef = ":any(" + type + ")";
}

templateSource = templateSource.substring(0, paramPos.getKey() - template.getBody().getStartPosition()) +
"#{" + param.getName().toString() + typeDef + "}" +
templateSource.substring((paramPos.getKey() - template.getBody().getStartPosition()) +
param.name.length());
}

try {
JCTree.JCLiteral templateName = (JCTree.JCLiteral) tree.getArguments().get(1);
if (templateName.value == null) {
processingEnv.getMessager().printMessage(Kind.WARNING, "Can't compile a template with a null name.");
Expand Down Expand Up @@ -200,6 +138,8 @@ public void visitIdent(JCTree.JCIdent ident) {
}
}

String templateCode = TemplateCode.process(resolved.get(template.getBody()), parameters, false);

JavaFileObject builderFile = processingEnv.getFiler().createSourceFile(templateFqn);
try (Writer out = new BufferedWriter(builderFile.openWriter())) {
out.write("package " + classDecl.sym.packge().toString() + ";\n");
Expand Down Expand Up @@ -228,25 +168,7 @@ public void visitIdent(JCTree.JCIdent ident) {
out.write(" * @return the JavaTemplate builder.\n");
out.write(" */\n");
out.write(" public static JavaTemplate.Builder getTemplate() {\n");
out.write(" return JavaTemplate\n");
out.write(" .builder(\"" + templateSource + "\")");

List<Symbol> imports = ImportDetector.imports(resolved.get(template));
String classpath = ClasspathJarNameDetector.classpathFor(resolved.get(template), imports);
if (!classpath.isEmpty()) {
out.write("\n .javaParser(JavaParser.fromJavaVersion().classpath(" +
classpath + "))");
}

for (Symbol anImport : imports) {
if (anImport instanceof Symbol.ClassSymbol && !anImport.getQualifiedName().toString().startsWith("java.lang.")) {
out.write("\n .imports(\"" + ((Symbol.ClassSymbol) anImport).fullname.toString().replace('$', '.') + "\")");
} else if (anImport instanceof Symbol.VarSymbol || anImport instanceof Symbol.MethodSymbol) {
out.write("\n .staticImports(\"" + anImport.owner.getQualifiedName().toString().replace('$', '.') + '.' + anImport.flatName().toString() + "\")");
}
}

out.write(";\n");
out.write(" return " + indent(templateCode, 12) + ";\n");
out.write(" }\n");
out.write("}\n");
out.flush();
Expand All @@ -259,6 +181,13 @@ public void visitIdent(JCTree.JCIdent ident) {

super.visitApply(tree);
}

private String indent(String code, int width) {
char[] indent = new char[width];
Arrays.fill(indent, ' ');
String replacement = "$1" + new String(indent);
return code.replaceAll("(?m)(\\R)", replacement);
}
}.scan(cu);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,30 +224,6 @@ protected Object tryGetProxyDelegateToField(Object instance) {
}


protected String getUnboxedPrimitive(String paramType) {
switch (paramType) {
case "java.lang.Boolean":
return "boolean";
case "java.lang.Byte":
return "byte";
case "java.lang.Character":
return "char";
case "java.lang.Double":
return "double";
case "java.lang.Float":
return "float";
case "java.lang.Integer":
return "int";
case "java.lang.Long":
return "long";
case "java.lang.Short":
return "short";
case "java.lang.Void":
return "void";
}
return paramType;
}

protected String getBoxedPrimitive(String paramType) {
switch (paramType) {
case "boolean":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TemplateProcessorTest {
@ValueSource(strings = {
"Unqualified",
"FullyQualified",
"FullyQualifiedField",
"Primitive",
})
void qualification(String qualifier) {
Expand Down
19 changes: 18 additions & 1 deletion src/test/resources/template/ShouldAddClasspath.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
import org.openrewrite.java.template.Primitive;
import org.slf4j.LoggerFactory;

import java.util.regex.Pattern;

import static java.util.regex.Pattern.DOTALL;
import static org.slf4j.LoggerFactory.getLogger;

public class ShouldAddClasspath {

class Unqualified {
Expand All @@ -30,7 +35,7 @@ void before(String message) {

@AfterTemplate
void after(String message) {
LoggerFactory.getLogger(message);
getLogger(message);
}
}

Expand All @@ -46,6 +51,18 @@ void after(String message) {
}
}

class FullyQualifiedField {
@BeforeTemplate
void before(String message) {
Pattern.compile(message, DOTALL);
}

@AfterTemplate
void after(String message) {
System.out.println(message);
}
}

class Primitive {
@BeforeTemplate
void before(@org.openrewrite.java.template.Primitive int i) {
Expand Down
Loading

0 comments on commit 7516697

Please sign in to comment.