org.openrewrite
rewrite-core
diff --git a/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/Node.java b/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/Node.java
new file mode 100644
index 0000000000..f5c9553139
--- /dev/null
+++ b/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/Node.java
@@ -0,0 +1,126 @@
+package tech.picnic.errorprone.refaster.runner;
+
+import static java.util.Comparator.comparingInt;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSortedSet;
+import com.google.common.collect.Maps;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+/**
+ * A node in an immutable tree.
+ *
+ * The tree's edges are string-labeled, while its leaves store values of type {@code T}.
+ */
+@AutoValue
+abstract class Node {
+ // XXX: Review: should this method accept a `SetMultimap>`, or should
+ // there be such an overload?
+ static Node create(
+ Set values, Function super T, ? extends Set extends Set>> pathExtractor) {
+ Builder tree = Builder.create();
+ tree.register(values, pathExtractor);
+ return tree.build();
+ }
+
+ abstract ImmutableMap> children();
+
+ abstract ImmutableList values();
+
+ // XXX: Consider having `RefasterRuleSelector` already collect the candidate edges into a
+ // `SortedSet`, as that would likely speed up `ImmutableSortedSet#copyOf`.
+ // XXX: If this ^ proves worthwhile, then the test code and benchmark should be updated
+ // accordingly.
+ void collectReachableValues(Set candidateEdges, Consumer sink) {
+ collectReachableValues(ImmutableSortedSet.copyOf(candidateEdges).asList(), sink);
+ }
+
+ private void collectReachableValues(ImmutableList candidateEdges, Consumer sink) {
+ values().forEach(sink);
+
+ if (candidateEdges.isEmpty() || children().isEmpty()) {
+ return;
+ }
+
+ /*
+ * For performance reasons we iterate over the smallest set of edges. In case there are fewer
+ * children than candidate edges we iterate over the former, at the cost of not pruning the set
+ * of candidate edges if a transition is made.
+ */
+ int candidateEdgeCount = candidateEdges.size();
+ if (children().size() < candidateEdgeCount) {
+ for (Map.Entry> e : children().entrySet()) {
+ if (candidateEdges.contains(e.getKey())) {
+ e.getValue().collectReachableValues(candidateEdges, sink);
+ }
+ }
+ } else {
+ for (int i = 0; i < candidateEdgeCount; i++) {
+ Node child = children().get(candidateEdges.get(i));
+ if (child != null) {
+ child.collectReachableValues(candidateEdges.subList(i + 1, candidateEdgeCount), sink);
+ }
+ }
+ }
+ }
+
+ @AutoValue
+ @SuppressWarnings("AutoValueImmutableFields" /* Type is used only during `Node` construction. */)
+ abstract static class Builder {
+ private static Builder create() {
+ return new AutoValue_Node_Builder<>(new HashMap<>(), new ArrayList<>());
+ }
+
+ abstract Map> children();
+
+ abstract List values();
+
+ /**
+ * Registers all paths to each of the given values.
+ *
+ * Shorter paths are registered first, so that longer paths can be skipped if a strict prefix
+ * leads to the same value.
+ */
+ private void register(
+ Set values, Function super T, ? extends Set extends Set>> pathsExtractor) {
+ for (T value : values) {
+ List extends Set> paths = new ArrayList<>(pathsExtractor.apply(value));
+ /*
+ * We sort paths by length ascending, so that in case of two paths where one is an initial
+ * prefix of the other, only the former is encoded (thus saving some space).
+ */
+ paths.sort(comparingInt(Set::size));
+ paths.forEach(path -> registerPath(value, ImmutableList.sortedCopyOf(path)));
+ }
+ }
+
+ private void registerPath(T value, ImmutableList path) {
+ if (values().contains(value)) {
+ /* Another (shorter) path already leads to this value. */
+ return;
+ }
+
+ if (path.isEmpty()) {
+ values().add(value);
+ } else {
+ children()
+ .computeIfAbsent(path.get(0), k -> create())
+ .registerPath(value, path.subList(1, path.size()));
+ }
+ }
+
+ private Node build() {
+ return new AutoValue_Node<>(
+ ImmutableMap.copyOf(Maps.transformValues(children(), Builder::build)),
+ ImmutableList.copyOf(values()));
+ }
+ }
+}
diff --git a/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/Refaster.java b/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/Refaster.java
index 94abfc23a8..8d5754cd7c 100644
--- a/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/Refaster.java
+++ b/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/Refaster.java
@@ -21,7 +21,6 @@
import com.google.errorprone.BugPattern;
import com.google.errorprone.BugPattern.SeverityLevel;
import com.google.errorprone.CodeTransformer;
-import com.google.errorprone.CompositeCodeTransformer;
import com.google.errorprone.ErrorProneFlags;
import com.google.errorprone.ErrorProneOptions.Severity;
import com.google.errorprone.SubContext;
@@ -39,6 +38,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import javax.inject.Inject;
@@ -64,8 +64,9 @@ public final class Refaster extends BugChecker implements CompilationUnitTreeMat
private static final long serialVersionUID = 1L;
+ // XXX: Review this suppression.
@SuppressWarnings({"java:S1948", "serial"} /* Concrete instance will be `Serializable`. */)
- private final CodeTransformer codeTransformer;
+ private final RefasterRuleSelector ruleSelector;
/** Instantiates a default {@link Refaster} instance. */
public Refaster() {
@@ -80,16 +81,29 @@ public Refaster() {
@Inject
@VisibleForTesting
public Refaster(ErrorProneFlags flags) {
- codeTransformer = createCompositeCodeTransformer(flags);
+ ruleSelector = createRefasterRuleSelector(flags);
}
@CanIgnoreReturnValue
@Override
public Description matchCompilationUnit(CompilationUnitTree tree, VisitorState state) {
+ Set candidateTransformers = ruleSelector.selectCandidateRules(tree);
+
/* First, collect all matches. */
+ SubContext context = new SubContext(state.context);
List matches = new ArrayList<>();
- codeTransformer.apply(state.getPath(), new SubContext(state.context), matches::add);
-
+ for (CodeTransformer transformer : candidateTransformers) {
+ try {
+ transformer.apply(state.getPath(), context, matches::add);
+ } catch (LinkageError e) {
+ // XXX: This `try/catch` block handles the issue described and resolved in
+ // https://github.com/google/error-prone/pull/2456. Drop this block once that change is
+ // released.
+ // XXX: Find a way to identify that we're running Picnic's Error Prone fork and disable this
+ // fallback if so, as it might hide other bugs.
+ return Description.NO_MATCH;
+ }
+ }
/* Then apply them. */
applyMatches(matches, ((JCCompilationUnit) tree).endPositions, state);
@@ -193,10 +207,12 @@ private static Stream getReplacements(
return description.fixes.stream().flatMap(fix -> fix.getReplacements(endPositions).stream());
}
- private static CodeTransformer createCompositeCodeTransformer(ErrorProneFlags flags) {
+ // XXX: Add a flag to disable the optimized `RefasterRuleSelector`. That would allow us to verify
+ // that we're not prematurely pruning rules.
+ private static RefasterRuleSelector createRefasterRuleSelector(ErrorProneFlags flags) {
ImmutableListMultimap allTransformers =
CodeTransformers.getAllCodeTransformers();
- return CompositeCodeTransformer.compose(
+ return RefasterRuleSelector.create(
flags
.get(INCLUDED_RULES_PATTERN_FLAG)
.map(Pattern::compile)
diff --git a/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/RefasterRuleSelector.java b/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/RefasterRuleSelector.java
new file mode 100644
index 0000000000..b0b22e2b97
--- /dev/null
+++ b/refaster-runner/src/main/java/tech/picnic/errorprone/refaster/runner/RefasterRuleSelector.java
@@ -0,0 +1,513 @@
+package tech.picnic.errorprone.refaster.runner;
+
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+import static java.util.Collections.newSetFromMap;
+import static java.util.stream.Collectors.toCollection;
+
+import com.google.common.collect.ImmutableCollection;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.errorprone.CodeTransformer;
+import com.google.errorprone.CompositeCodeTransformer;
+import com.google.errorprone.refaster.BlockTemplate;
+import com.google.errorprone.refaster.ExpressionTemplate;
+import com.google.errorprone.refaster.RefasterRule;
+import com.google.errorprone.refaster.UAnyOf;
+import com.google.errorprone.refaster.UExpression;
+import com.google.errorprone.refaster.UStatement;
+import com.google.errorprone.refaster.UStaticIdent;
+import com.google.errorprone.refaster.annotation.BeforeTemplate;
+import com.sun.source.tree.AssignmentTree;
+import com.sun.source.tree.BinaryTree;
+import com.sun.source.tree.ClassTree;
+import com.sun.source.tree.CompilationUnitTree;
+import com.sun.source.tree.CompoundAssignmentTree;
+import com.sun.source.tree.ExpressionTree;
+import com.sun.source.tree.IdentifierTree;
+import com.sun.source.tree.MemberReferenceTree;
+import com.sun.source.tree.MemberSelectTree;
+import com.sun.source.tree.MethodTree;
+import com.sun.source.tree.PackageTree;
+import com.sun.source.tree.Tree;
+import com.sun.source.tree.UnaryTree;
+import com.sun.source.tree.VariableTree;
+import com.sun.source.util.TreeScanner;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.IdentityHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.jspecify.annotations.Nullable;
+import tech.picnic.errorprone.refaster.AnnotatedCompositeCodeTransformer;
+
+// XXX: Add some examples of which source files would match what templates in the tree.
+// XXX: Consider this text in general.
+/**
+ * A {@link RefasterRuleSelector} algorithm that selects Refaster templates based on the content of
+ * a {@link CompilationUnitTree}.
+ *
+ * The algorithm consists of the following steps:
+ *
+ *
+ * - Create a {@link Node tree} structure based on the provided Refaster templates.
+ *
+ * - Extract all identifiers from the {@link BeforeTemplate}s.
+ *
- Sort identifiers lexicographically and collect into a set.
+ *
- Add a path to the tree based on the sorted identifiers.
+ *
+ * - Extract all identifiers from the {@link CompilationUnitTree} and sort them
+ * lexicographically.
+ *
- Traverse the tree based on the identifiers from the {@link CompilationUnitTree}. Every node
+ * can contain Refaster templates. Once a node is we found a candidate Refaster template that
+ * might match some code and will therefore be added to the list of candidates.
+ *
+ *
+ * This is an example to explain the algorithm. Consider the templates with identifiers; {@code
+ * T1 = [A, B, C]}, {@code T2 = [B]}, and {@code T3 = [B, D]}. This will result in the following
+ * tree structure:
+ *
+ *
{@code
+ *
+ * ├── A
+ * │ └── B
+ * │ └── C -- T1
+ * └── B -- T2
+ * └── D -- T3
+ * }
+ *
+ * The tree is traversed based on the identifiers in the {@link CompilationUnitTree}. When a node
+ * containing a template is reached, we can be certain that the identifiers from the {@link
+ * BeforeTemplate} are at least present in the {@link CompilationUnitTree}.
+ *
+ *
Since the identifiers are sorted, we can skip parts of the {@link Node tree} while we are
+ * traversing it. Instead of trying to match all Refaster templates against every expression in a
+ * {@link CompilationUnitTree} we now only matching a subset of the templates that at least have a
+ * chance of matching. As a result, the performance of Refaster increases significantly.
+ */
+final class RefasterRuleSelector {
+ private final Node codeTransformers;
+
+ private RefasterRuleSelector(Node codeTransformers) {
+ this.codeTransformers = codeTransformers;
+ }
+
+ /**
+ * Instantiates a new {@link RefasterRuleSelector} backed by the given {@link CodeTransformer}s.
+ */
+ static RefasterRuleSelector create(ImmutableCollection refasterRules) {
+ Map>> ruleIdentifiersByTransformer =
+ indexRuleIdentifiers(refasterRules);
+ return new RefasterRuleSelector(
+ Node.create(ruleIdentifiersByTransformer.keySet(), ruleIdentifiersByTransformer::get));
+ }
+
+ /**
+ * Retrieves a set of Refaster templates that can possibly match based on a {@link
+ * CompilationUnitTree}.
+ *
+ * @param tree The {@link CompilationUnitTree} for which candidate Refaster templates are
+ * selected.
+ * @return Set of Refaster templates that can possibly match in the provided {@link
+ * CompilationUnitTree}.
+ */
+ Set selectCandidateRules(CompilationUnitTree tree) {
+ Set candidateRules = newSetFromMap(new IdentityHashMap<>());
+ codeTransformers.collectReachableValues(extractSourceIdentifiers(tree), candidateRules::add);
+ return candidateRules;
+ }
+
+ private static Map>> indexRuleIdentifiers(
+ ImmutableCollection codeTransformers) {
+ IdentityHashMap>> identifiers =
+ new IdentityHashMap<>();
+ for (CodeTransformer transformer : codeTransformers) {
+ collectRuleIdentifiers(transformer, identifiers);
+ }
+ return identifiers;
+ }
+
+ private static void collectRuleIdentifiers(
+ CodeTransformer codeTransformer,
+ Map>> identifiers) {
+ if (codeTransformer instanceof CompositeCodeTransformer) {
+ for (CodeTransformer transformer :
+ ((CompositeCodeTransformer) codeTransformer).transformers()) {
+ collectRuleIdentifiers(transformer, identifiers);
+ }
+ } else if (codeTransformer instanceof AnnotatedCompositeCodeTransformer) {
+ AnnotatedCompositeCodeTransformer annotatedTransformer =
+ (AnnotatedCompositeCodeTransformer) codeTransformer;
+ for (Map.Entry>> e :
+ indexRuleIdentifiers(annotatedTransformer.transformers()).entrySet()) {
+ identifiers.put(annotatedTransformer.withTransformers(e.getKey()), e.getValue());
+ }
+ } else if (codeTransformer instanceof RefasterRule) {
+ identifiers.put(
+ codeTransformer, extractRuleIdentifiers((RefasterRule, ?>) codeTransformer));
+ } else {
+ /* Unrecognized `CodeTransformer` types are indexed such that they always apply. */
+ identifiers.put(codeTransformer, ImmutableSet.of(ImmutableSet.of()));
+ }
+ }
+
+ // XXX: Consider decomposing `RefasterRule`s such that each rule has exactly one
+ // `@BeforeTemplate`.
+ private static ImmutableSet> extractRuleIdentifiers(
+ RefasterRule, ?> refasterRule) {
+ ImmutableSet.Builder> results = ImmutableSet.builder();
+
+ for (Object template : RefasterIntrospection.getBeforeTemplates(refasterRule)) {
+ if (template instanceof ExpressionTemplate) {
+ UExpression expr = RefasterIntrospection.getExpression((ExpressionTemplate) template);
+ results.addAll(extractRuleIdentifiers(ImmutableList.of(expr)));
+ } else if (template instanceof BlockTemplate) {
+ ImmutableList statements =
+ RefasterIntrospection.getTemplateStatements((BlockTemplate) template);
+ results.addAll(extractRuleIdentifiers(statements));
+ } else {
+ throw new IllegalStateException(
+ String.format("Unexpected template type '%s'", template.getClass()));
+ }
+ }
+
+ return results.build();
+ }
+
+ // XXX: Consider interning the strings (once a benchmark is in place).
+ private static ImmutableSet> extractRuleIdentifiers(
+ ImmutableList extends Tree> trees) {
+ List> identifierCombinations = new ArrayList<>();
+ identifierCombinations.add(new HashSet<>());
+ TemplateIdentifierExtractor.INSTANCE.scan(trees, identifierCombinations);
+ return identifierCombinations.stream().map(ImmutableSet::copyOf).collect(toImmutableSet());
+ }
+
+ private static Set extractSourceIdentifiers(Tree tree) {
+ Set identifiers = new HashSet<>();
+ SourceIdentifierExtractor.INSTANCE.scan(tree, identifiers);
+ return identifiers;
+ }
+
+ /**
+ * Returns a unique string representation of the given {@link Tree.Kind}.
+ *
+ * @return A string representation of the operator, if known
+ * @throws IllegalArgumentException If the given input is not supported.
+ */
+ // XXX: Extend list to cover remaining cases; at least for any `Kind` that may appear in a
+ // Refaster template. (E.g. keywords such as `if`, `instanceof`, `new`, ...)
+ private static String treeKindToString(Tree.Kind kind) {
+ return switch (kind) {
+ case ASSIGNMENT -> "=";
+ case POSTFIX_INCREMENT -> "x++";
+ case PREFIX_INCREMENT -> "++x";
+ case POSTFIX_DECREMENT -> "x--";
+ case PREFIX_DECREMENT -> "--x";
+ case UNARY_PLUS -> "+x";
+ case UNARY_MINUS -> "-x";
+ case BITWISE_COMPLEMENT -> "~";
+ case LOGICAL_COMPLEMENT -> "!";
+ case MULTIPLY -> "*";
+ case DIVIDE -> "/";
+ case REMAINDER -> "%";
+ case PLUS -> "+";
+ case MINUS -> "-";
+ case LEFT_SHIFT -> "<<";
+ case RIGHT_SHIFT -> ">>";
+ case UNSIGNED_RIGHT_SHIFT -> ">>>";
+ case LESS_THAN -> "<";
+ case GREATER_THAN -> ">";
+ case LESS_THAN_EQUAL -> "<=";
+ case GREATER_THAN_EQUAL -> ">=";
+ case EQUAL_TO -> "==";
+ case NOT_EQUAL_TO -> "!=";
+ case AND -> "&";
+ case XOR -> "^";
+ case OR -> "|";
+ case CONDITIONAL_AND -> "&&";
+ case CONDITIONAL_OR -> "||";
+ case MULTIPLY_ASSIGNMENT -> "*=";
+ case DIVIDE_ASSIGNMENT -> "/=";
+ case REMAINDER_ASSIGNMENT -> "%=";
+ case PLUS_ASSIGNMENT -> "+=";
+ case MINUS_ASSIGNMENT -> "-=";
+ case LEFT_SHIFT_ASSIGNMENT -> "<<=";
+ case RIGHT_SHIFT_ASSIGNMENT -> ">>=";
+ case UNSIGNED_RIGHT_SHIFT_ASSIGNMENT -> ">>>=";
+ case AND_ASSIGNMENT -> "&=";
+ case XOR_ASSIGNMENT -> "^=";
+ case OR_ASSIGNMENT -> "|=";
+ default -> throw new IllegalStateException("Cannot convert Tree.Kind to a String: " + kind);
+ };
+ }
+
+ private static final class RefasterIntrospection {
+ // XXX: Update `ErrorProneRuntimeClasspath` to not suggest inaccessible types.
+ @SuppressWarnings("ErrorProneRuntimeClasspath")
+ private static final String UCLASS_IDENT_FQCN = "com.google.errorprone.refaster.UClassIdent";
+
+ private static final Class> UCLASS_IDENT = getClass(UCLASS_IDENT_FQCN);
+ private static final Method METHOD_REFASTER_RULE_BEFORE_TEMPLATES =
+ getMethod(RefasterRule.class, "beforeTemplates");
+ private static final Method METHOD_EXPRESSION_TEMPLATE_EXPRESSION =
+ getMethod(ExpressionTemplate.class, "expression");
+ private static final Method METHOD_BLOCK_TEMPLATE_TEMPLATE_STATEMENTS =
+ getMethod(BlockTemplate.class, "templateStatements");
+ private static final Method METHOD_USTATIC_IDENT_CLASS_IDENT =
+ getMethod(UStaticIdent.class, "classIdent");
+ private static final Method METHOD_UCLASS_IDENT_GET_TOP_LEVEL_CLASS =
+ getMethod(UCLASS_IDENT, "getTopLevelClass");
+ private static final Method METHOD_UANY_OF_EXPRESSIONS = getMethod(UAnyOf.class, "expressions");
+
+ static boolean isUClassIdent(IdentifierTree tree) {
+ return UCLASS_IDENT.isInstance(tree);
+ }
+
+ static ImmutableList> getBeforeTemplates(RefasterRule, ?> refasterRule) {
+ return invokeMethod(METHOD_REFASTER_RULE_BEFORE_TEMPLATES, refasterRule);
+ }
+
+ static UExpression getExpression(ExpressionTemplate template) {
+ return invokeMethod(METHOD_EXPRESSION_TEMPLATE_EXPRESSION, template);
+ }
+
+ static ImmutableList getTemplateStatements(BlockTemplate template) {
+ return invokeMethod(METHOD_BLOCK_TEMPLATE_TEMPLATE_STATEMENTS, template);
+ }
+
+ static IdentifierTree getClassIdent(UStaticIdent tree) {
+ return invokeMethod(METHOD_USTATIC_IDENT_CLASS_IDENT, tree);
+ }
+
+ // Arguments to this method must actually be of the package-private type `UClassIdent`.
+ static String getTopLevelClass(IdentifierTree uClassIdent) {
+ return invokeMethod(METHOD_UCLASS_IDENT_GET_TOP_LEVEL_CLASS, uClassIdent);
+ }
+
+ static ImmutableList getExpressions(UAnyOf tree) {
+ return invokeMethod(METHOD_UANY_OF_EXPRESSIONS, tree);
+ }
+
+ private static Class> getClass(String fqcn) {
+ try {
+ return RefasterIntrospection.class.getClassLoader().loadClass(fqcn);
+ } catch (ClassNotFoundException e) {
+ throw new IllegalStateException(String.format("Failed to load class `%s`", fqcn), e);
+ }
+ }
+
+ private static Method getMethod(Class> clazz, String methodName) {
+ try {
+ Method method = clazz.getDeclaredMethod(methodName);
+ method.setAccessible(true);
+ return method;
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(
+ String.format("No method `%s` on class `%s`", methodName, clazz.getName()), e);
+ }
+ }
+
+ @SuppressWarnings({"TypeParameterUnusedInFormals", "unchecked"})
+ private static T invokeMethod(Method method, Object instance) {
+ try {
+ return (T) method.invoke(instance);
+ } catch (IllegalAccessException | InvocationTargetException e) {
+ throw new IllegalStateException(String.format("Failed to invoke method `%s`", method), e);
+ }
+ }
+ }
+
+ private static final class TemplateIdentifierExtractor
+ extends TreeScanner<@Nullable Void, List>> {
+ private static final TemplateIdentifierExtractor INSTANCE = new TemplateIdentifierExtractor();
+
+ @Override
+ public @Nullable Void visitIdentifier(
+ IdentifierTree node, List> identifierCombinations) {
+ // XXX: Also include the package name if not `java.lang`; it must be present.
+ if (RefasterIntrospection.isUClassIdent(node)) {
+ for (Set ids : identifierCombinations) {
+ ids.add(getSimpleName(RefasterIntrospection.getTopLevelClass(node)));
+ ids.add(getIdentifier(node));
+ }
+ } else if (node instanceof UStaticIdent) {
+ IdentifierTree subNode = RefasterIntrospection.getClassIdent((UStaticIdent) node);
+ for (Set ids : identifierCombinations) {
+ ids.add(getSimpleName(RefasterIntrospection.getTopLevelClass(subNode)));
+ ids.add(getIdentifier(subNode));
+ ids.add(node.getName().toString());
+ }
+ }
+
+ return null;
+ }
+
+ private static String getIdentifier(IdentifierTree tree) {
+ return getSimpleName(tree.getName().toString());
+ }
+
+ private static String getSimpleName(String fqcn) {
+ int index = fqcn.lastIndexOf('.');
+ return index < 0 ? fqcn : fqcn.substring(index + 1);
+ }
+
+ @Override
+ public @Nullable Void visitMemberReference(
+ MemberReferenceTree node, List> identifierCombinations) {
+ super.visitMemberReference(node, identifierCombinations);
+ String id = node.getName().toString();
+ identifierCombinations.forEach(ids -> ids.add(id));
+ return null;
+ }
+
+ @Override
+ public @Nullable Void visitMemberSelect(
+ MemberSelectTree node, List> identifierCombinations) {
+ super.visitMemberSelect(node, identifierCombinations);
+ String id = node.getIdentifier().toString();
+ identifierCombinations.forEach(ids -> ids.add(id));
+ return null;
+ }
+
+ @Override
+ public @Nullable Void visitAssignment(
+ AssignmentTree node, List> identifierCombinations) {
+ registerOperator(node, identifierCombinations);
+ return super.visitAssignment(node, identifierCombinations);
+ }
+
+ @Override
+ public @Nullable Void visitCompoundAssignment(
+ CompoundAssignmentTree node, List> identifierCombinations) {
+ registerOperator(node, identifierCombinations);
+ return super.visitCompoundAssignment(node, identifierCombinations);
+ }
+
+ @Override
+ public @Nullable Void visitUnary(UnaryTree node, List> identifierCombinations) {
+ registerOperator(node, identifierCombinations);
+ return super.visitUnary(node, identifierCombinations);
+ }
+
+ @Override
+ public @Nullable Void visitBinary(BinaryTree node, List> identifierCombinations) {
+ registerOperator(node, identifierCombinations);
+ return super.visitBinary(node, identifierCombinations);
+ }
+
+ private static void registerOperator(
+ ExpressionTree node, List> identifierCombinations) {
+ String id = treeKindToString(node.getKind());
+ identifierCombinations.forEach(ids -> ids.add(id));
+ }
+
+ @Override
+ public @Nullable Void visitOther(Tree node, List> identifierCombinations) {
+ if (node instanceof UAnyOf) {
+ List> base = copy(identifierCombinations);
+ identifierCombinations.clear();
+
+ for (UExpression expr : RefasterIntrospection.getExpressions((UAnyOf) node)) {
+ List> branch = copy(base);
+ scan(expr, branch);
+ identifierCombinations.addAll(branch);
+ }
+ }
+
+ return null;
+ }
+
+ private static List> copy(List> identifierCombinations) {
+ return identifierCombinations.stream()
+ .map(HashSet::new)
+ .collect(toCollection(ArrayList::new));
+ }
+ }
+
+ private static final class SourceIdentifierExtractor
+ extends TreeScanner<@Nullable Void, Set> {
+ private static final SourceIdentifierExtractor INSTANCE = new SourceIdentifierExtractor();
+
+ @Override
+ public @Nullable Void visitPackage(PackageTree node, Set identifiers) {
+ /* Refaster rules never match package declarations. */
+ return null;
+ }
+
+ @Override
+ public @Nullable Void visitClass(ClassTree node, Set identifiers) {
+ /*
+ * Syntactic details of a class declaration other than the definition of its members do not
+ * need to be reflected in a Refaster rule for it to apply to the class's code.
+ */
+ return scan(node.getMembers(), identifiers);
+ }
+
+ @Override
+ public @Nullable Void visitMethod(MethodTree node, Set identifiers) {
+ /*
+ * Syntactic details of a method declaration other than its body do not need to be reflected
+ * in a Refaster rule for it to apply to the method's code.
+ */
+ return scan(node.getBody(), identifiers);
+ }
+
+ @Override
+ public @Nullable Void visitVariable(VariableTree node, Set identifiers) {
+ /* A variable's modifiers and name do not influence where a Refaster rule matches. */
+ return reduce(scan(node.getInitializer(), identifiers), scan(node.getType(), identifiers));
+ }
+
+ @Override
+ public @Nullable Void visitIdentifier(IdentifierTree node, Set identifiers) {
+ identifiers.add(node.getName().toString());
+ return null;
+ }
+
+ @Override
+ public @Nullable Void visitMemberReference(MemberReferenceTree node, Set identifiers) {
+ super.visitMemberReference(node, identifiers);
+ identifiers.add(node.getName().toString());
+ return null;
+ }
+
+ @Override
+ public @Nullable Void visitMemberSelect(MemberSelectTree node, Set identifiers) {
+ super.visitMemberSelect(node, identifiers);
+ identifiers.add(node.getIdentifier().toString());
+ return null;
+ }
+
+ @Override
+ public @Nullable Void visitAssignment(AssignmentTree node, Set identifiers) {
+ registerOperator(node, identifiers);
+ return super.visitAssignment(node, identifiers);
+ }
+
+ @Override
+ public @Nullable Void visitCompoundAssignment(
+ CompoundAssignmentTree node, Set identifiers) {
+ registerOperator(node, identifiers);
+ return super.visitCompoundAssignment(node, identifiers);
+ }
+
+ @Override
+ public @Nullable Void visitUnary(UnaryTree node, Set identifiers) {
+ registerOperator(node, identifiers);
+ return super.visitUnary(node, identifiers);
+ }
+
+ @Override
+ public @Nullable Void visitBinary(BinaryTree node, Set identifiers) {
+ registerOperator(node, identifiers);
+ return super.visitBinary(node, identifiers);
+ }
+
+ private static void registerOperator(ExpressionTree node, Set identifiers) {
+ identifiers.add(treeKindToString(node.getKind()));
+ }
+ }
+}
diff --git a/refaster-runner/src/test/java/tech/picnic/errorprone/refaster/runner/NodeBenchmark.java b/refaster-runner/src/test/java/tech/picnic/errorprone/refaster/runner/NodeBenchmark.java
new file mode 100644
index 0000000000..327592848d
--- /dev/null
+++ b/refaster-runner/src/test/java/tech/picnic/errorprone/refaster/runner/NodeBenchmark.java
@@ -0,0 +1,83 @@
+package tech.picnic.errorprone.refaster.runner;
+
+import static com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap;
+import static java.util.function.Function.identity;
+
+import com.google.common.collect.ImmutableListMultimap;
+import com.jakewharton.nopen.annotation.Open;
+import java.util.Collection;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.TimeUnit;
+import java.util.regex.Pattern;
+import java.util.stream.Stream;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+import org.openjdk.jmh.runner.Runner;
+import org.openjdk.jmh.runner.RunnerException;
+import org.openjdk.jmh.runner.options.OptionsBuilder;
+import tech.picnic.errorprone.refaster.runner.NodeTestCase.NodeTestCaseEntry;
+
+@Open
+@State(Scope.Benchmark)
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@Fork(jvmArgs = {"-Xms1G", "-Xmx1G"})
+@Warmup(iterations = 5)
+@Measurement(iterations = 10)
+public class NodeBenchmark {
+ @SuppressWarnings("NullAway" /* Initialized by `@Setup` method. */)
+ private ImmutableListMultimap, NodeTestCaseEntry> testCases;
+
+ public static void main(String[] args) throws RunnerException {
+ // XXX: Update `ErrorProneRuntimeClasspath` to allow same-package `Class` references.
+ @SuppressWarnings("ErrorProneRuntimeClasspath")
+ String testRegex = Pattern.quote(NodeBenchmark.class.getCanonicalName());
+ new Runner(new OptionsBuilder().include(testRegex).forks(1).build()).run();
+ }
+
+ @Setup
+ public final void setUp() {
+ Random random = new Random(0);
+
+ testCases =
+ Stream.of(
+ NodeTestCase.generate(100, 5, 10, 10, random),
+ NodeTestCase.generate(100, 5, 10, 100, random),
+ NodeTestCase.generate(100, 5, 10, 1000, random),
+ NodeTestCase.generate(1000, 10, 20, 10, random),
+ NodeTestCase.generate(1000, 10, 20, 100, random),
+ NodeTestCase.generate(1000, 10, 20, 1000, random),
+ NodeTestCase.generate(1000, 10, 20, 10000, random))
+ .collect(
+ flatteningToImmutableListMultimap(
+ identity(), testCase -> testCase.generateTestCaseEntries(random)));
+ }
+
+ @Benchmark
+ public final void create(Blackhole bh) {
+ for (NodeTestCase testCase : testCases.keySet()) {
+ bh.consume(testCase.buildTree());
+ }
+ }
+
+ @Benchmark
+ public final void collectReachableValues(Blackhole bh) {
+ for (Map.Entry, Collection>> e :
+ testCases.asMap().entrySet()) {
+ Node tree = e.getKey().buildTree();
+ for (NodeTestCaseEntry testCaseEntry : e.getValue()) {
+ tree.collectReachableValues(testCaseEntry.candidateEdges(), bh::consume);
+ }
+ }
+ }
+}
diff --git a/refaster-runner/src/test/java/tech/picnic/errorprone/refaster/runner/NodeTest.java b/refaster-runner/src/test/java/tech/picnic/errorprone/refaster/runner/NodeTest.java
new file mode 100644
index 0000000000..82a9f2be6c
--- /dev/null
+++ b/refaster-runner/src/test/java/tech/picnic/errorprone/refaster/runner/NodeTest.java
@@ -0,0 +1,47 @@
+package tech.picnic.errorprone.refaster.runner;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
+
+import com.google.common.collect.ImmutableSet;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Stream;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+final class NodeTest {
+ private static Stream collectReachableValuesTestCases() {
+ Random random = new Random(0);
+
+ return Stream.of(
+ NodeTestCase.generate(0, 0, 0, 0, random),
+ NodeTestCase.generate(1, 1, 1, 1, random),
+ NodeTestCase.generate(2, 2, 2, 10, random),
+ NodeTestCase.generate(10, 2, 5, 10, random),
+ NodeTestCase.generate(10, 2, 5, 100, random),
+ NodeTestCase.generate(100, 5, 10, 100, random),
+ NodeTestCase.generate(100, 5, 10, 1000, random))
+ .flatMap(
+ testCase -> {
+ Node tree = testCase.buildTree();
+ return testCase
+ .generateTestCaseEntries(random)
+ .map(e -> arguments(tree, e.candidateEdges(), e.reachableValues()));
+ });
+ }
+
+ @MethodSource("collectReachableValuesTestCases")
+ @ParameterizedTest
+ void collectReachableValues(
+ Node tree,
+ ImmutableSet candidateEdges,
+ Collection expectedReachable) {
+ List actualReachable = new ArrayList<>();
+ tree.collectReachableValues(candidateEdges, actualReachable::add);
+ assertThat(actualReachable).hasSameElementsAs(expectedReachable);
+ }
+}
diff --git a/refaster-runner/src/test/java/tech/picnic/errorprone/refaster/runner/NodeTestCase.java b/refaster-runner/src/test/java/tech/picnic/errorprone/refaster/runner/NodeTestCase.java
new file mode 100644
index 0000000000..298f8116e3
--- /dev/null
+++ b/refaster-runner/src/test/java/tech/picnic/errorprone/refaster/runner/NodeTestCase.java
@@ -0,0 +1,140 @@
+package tech.picnic.errorprone.refaster.runner;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+import static com.google.common.collect.ImmutableSetMultimap.flatteningToImmutableSetMultimap;
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.collectingAndThen;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.collect.ImmutableCollection;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.ImmutableSetMultimap;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Random;
+import java.util.stream.Stream;
+
+@AutoValue
+abstract class NodeTestCase {
+ static NodeTestCase generate(
+ int entryCount, int maxPathCount, int maxPathLength, int pathValueDomainSize, Random random) {
+ return random
+ .ints(entryCount)
+ .boxed()
+ .collect(
+ collectingAndThen(
+ flatteningToImmutableSetMultimap(
+ identity(),
+ i ->
+ random
+ .ints(random.nextInt(maxPathCount + 1))
+ .mapToObj(
+ p ->
+ random
+ .ints(
+ random.nextInt(maxPathLength + 1),
+ 0,
+ pathValueDomainSize)
+ .mapToObj(String::valueOf)
+ .collect(toImmutableSet()))),
+ AutoValue_NodeTestCase::new));
+ }
+
+ abstract ImmutableSetMultimap> input();
+
+ final Node buildTree() {
+ return Node.create(input().keySet(), input()::get);
+ }
+
+ final Stream> generateTestCaseEntries(Random random) {
+ return generatePathTestCases(input(), random);
+ }
+
+ private static Stream> generatePathTestCases(
+ ImmutableSetMultimap> treeInput, Random random) {
+ ImmutableSet allEdges =
+ treeInput.values().stream().flatMap(ImmutableSet::stream).collect(toImmutableSet());
+
+ return Stream.concat(
+ Stream.of(ImmutableSet.of()), shuffle(treeInput.values(), random).stream())
+ // XXX: Use `random.nextInt(20, 100)` once we no longer target JDK 11. (And consider
+ // introducing a Refaster template for this case.)
+ .limit(20 + random.nextInt(80))
+ .flatMap(edges -> generateVariations(edges, allEdges, "unused", random))
+ .distinct()
+ .map(edges -> createTestCaseEntry(treeInput, edges));
+ }
+
+ private static Stream> generateVariations(
+ ImmutableSet baseEdges, ImmutableSet allEdges, T unusedEdge, Random random) {
+ Optional knownEdge = selectRandomElement(allEdges, random);
+
+ return Stream.of(
+ random.nextBoolean() ? null : baseEdges,
+ random.nextBoolean() ? null : shuffle(baseEdges, random),
+ random.nextBoolean() ? null : insertValue(baseEdges, unusedEdge, random),
+ baseEdges.isEmpty() || random.nextBoolean()
+ ? null
+ : randomStrictSubset(baseEdges, random),
+ baseEdges.isEmpty() || random.nextBoolean()
+ ? null
+ : insertValue(randomStrictSubset(baseEdges, random), unusedEdge, random),
+ baseEdges.isEmpty() || random.nextBoolean()
+ ? null
+ : knownEdge
+ .map(edge -> insertValue(randomStrictSubset(baseEdges, random), edge, random))
+ .orElse(null))
+ .filter(Objects::nonNull);
+ }
+
+ private static Optional selectRandomElement(ImmutableSet collection, Random random) {
+ return collection.isEmpty()
+ ? Optional.empty()
+ : Optional.of(collection.asList().get(random.nextInt(collection.size())));
+ }
+
+ private static ImmutableSet shuffle(ImmutableCollection values, Random random) {
+ List allValues = new ArrayList<>(values);
+ Collections.shuffle(allValues, random);
+ return ImmutableSet.copyOf(allValues);
+ }
+
+ private static ImmutableSet insertValue(
+ ImmutableSet values, T extraValue, Random random) {
+ List allValues = new ArrayList<>(values);
+ allValues.add(random.nextInt(values.size() + 1), extraValue);
+ return ImmutableSet.copyOf(allValues);
+ }
+
+ private static ImmutableSet randomStrictSubset(ImmutableSet values, Random random) {
+ checkArgument(!values.isEmpty(), "Cannot select strict subset of random collection");
+
+ List allValues = new ArrayList<>(values);
+ Collections.shuffle(allValues, random);
+ return ImmutableSet.copyOf(allValues.subList(0, random.nextInt(allValues.size())));
+ }
+
+ private static NodeTestCaseEntry createTestCaseEntry(
+ ImmutableSetMultimap> treeInput, ImmutableSet edges) {
+ return new AutoValue_NodeTestCase_NodeTestCaseEntry<>(
+ edges,
+ treeInput.asMap().entrySet().stream()
+ .filter(e -> e.getValue().stream().anyMatch(edges::containsAll))
+ .map(Map.Entry::getKey)
+ .collect(toImmutableList()));
+ }
+
+ @AutoValue
+ abstract static class NodeTestCaseEntry {
+ abstract ImmutableSet candidateEdges();
+
+ abstract ImmutableList reachableValues();
+ }
+}
diff --git a/refaster-support/src/main/java/tech/picnic/errorprone/refaster/AnnotatedCompositeCodeTransformer.java b/refaster-support/src/main/java/tech/picnic/errorprone/refaster/AnnotatedCompositeCodeTransformer.java
index c7429291fb..e01cb35cb4 100644
--- a/refaster-support/src/main/java/tech/picnic/errorprone/refaster/AnnotatedCompositeCodeTransformer.java
+++ b/refaster-support/src/main/java/tech/picnic/errorprone/refaster/AnnotatedCompositeCodeTransformer.java
@@ -46,7 +46,12 @@ public abstract class AnnotatedCompositeCodeTransformer implements CodeTransform
abstract String packageName();
- abstract ImmutableList transformers();
+ /**
+ * Return The {@link CodeTransformer}s to which to delegate.
+ *
+ * @return The ordered {@link CodeTransformer}s to which to delegate.
+ */
+ public abstract ImmutableList transformers();
@Override
@SuppressWarnings("java:S3038" /* All AutoValue properties must be specified explicitly. */)
@@ -67,6 +72,18 @@ public static AnnotatedCompositeCodeTransformer create(
return new AutoValue_AnnotatedCompositeCodeTransformer(packageName, transformers, annotations);
}
+ /**
+ * Returns a new {@link AnnotatedCompositeCodeTransformer} similar to this one, but with the
+ * specified transformers.
+ *
+ * @param transformers The replacement transformers.
+ * @return A derivative {@link AnnotatedCompositeCodeTransformer}.
+ */
+ public AnnotatedCompositeCodeTransformer withTransformers(CodeTransformer... transformers) {
+ return new AutoValue_AnnotatedCompositeCodeTransformer(
+ packageName(), ImmutableList.copyOf(transformers), annotations());
+ }
+
@Override
public final void apply(TreePath path, Context context, DescriptionListener listener) {
for (CodeTransformer transformer : transformers()) {