diff --git a/core/src/main/java/com/google/errorprone/bugpatterns/collectionincompatibletype/TruthIncompatibleType.java b/core/src/main/java/com/google/errorprone/bugpatterns/collectionincompatibletype/TruthIncompatibleType.java index 56bc11e9d7a..86a12e6b0fa 100644 --- a/core/src/main/java/com/google/errorprone/bugpatterns/collectionincompatibletype/TruthIncompatibleType.java +++ b/core/src/main/java/com/google/errorprone/bugpatterns/collectionincompatibletype/TruthIncompatibleType.java @@ -35,6 +35,7 @@ import com.google.common.collect.Streams; import com.google.errorprone.BugPattern; +import com.google.errorprone.ErrorProneFlags; import com.google.errorprone.VisitorState; import com.google.errorprone.bugpatterns.BugChecker; import com.google.errorprone.bugpatterns.BugChecker.MethodInvocationTreeMatcher; @@ -88,12 +89,7 @@ public class TruthIncompatibleType extends BugChecker implements MethodInvocatio .onDescendantOf("com.google.common.truth.extensions.proto.ProtoFluentAssertion"), instanceMethod().onDescendantOf("com.google.common.truth.extensions.proto.ProtoSubject")); - private static final Matcher SCALAR_CONTAINS = - instanceMethod() - .onDescendantOfAny( - "com.google.common.truth.IterableSubject", "com.google.common.truth.StreamSubject") - .namedAnyOf( - "contains", "containsExactly", "doesNotContain", "containsAnyOf", "containsNoneOf"); + private final Matcher scalarContains; private static final Matcher IS_ANY_OF = instanceMethod() @@ -156,8 +152,31 @@ public class TruthIncompatibleType extends BugChecker implements MethodInvocatio private final TypeCompatibility typeCompatibility; @Inject - TruthIncompatibleType(TypeCompatibility typeCompatibility) { + TruthIncompatibleType(TypeCompatibility typeCompatibility, ErrorProneFlags flags) { this.typeCompatibility = typeCompatibility; + this.scalarContains = + flags.getBoolean("TruthIncompatibleType:YetMore").orElse(true) + ? instanceMethod() + .onDescendantOfAny( + "com.google.common.truth.IterableSubject", + "com.google.common.truth.StreamSubject") + .namedAnyOf( + "contains", + "containsExactly", + "doesNotContain", + "containsAnyOf", + "containsNoneOf", + "containsAtLeast") + : instanceMethod() + .onDescendantOfAny( + "com.google.common.truth.IterableSubject", + "com.google.common.truth.StreamSubject") + .namedAnyOf( + "contains", + "containsExactly", + "doesNotContain", + "containsAnyOf", + "containsNoneOf"); } @Override @@ -278,7 +297,7 @@ private Stream matchArrayContains(MethodInvocationTree tree, Visito } private Stream matchScalarContains(MethodInvocationTree tree, VisitorState state) { - if (!SCALAR_CONTAINS.matches(tree, state)) { + if (!scalarContains.matches(tree, state)) { return Stream.empty(); } ExpressionTree receiver = getReceiver(tree); diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/collectionincompatibletype/TruthIncompatibleTypeTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/collectionincompatibletype/TruthIncompatibleTypeTest.java index e558e88bd01..ca8f23e9f4c 100644 --- a/core/src/test/java/com/google/errorprone/bugpatterns/collectionincompatibletype/TruthIncompatibleTypeTest.java +++ b/core/src/test/java/com/google/errorprone/bugpatterns/collectionincompatibletype/TruthIncompatibleTypeTest.java @@ -22,6 +22,7 @@ import static java.util.Arrays.stream; import com.google.common.collect.ImmutableList; +import com.google.common.truth.IterableSubject; import com.google.common.truth.Subject; import com.google.errorprone.CompilationTestHelper; import com.google.testing.junit.testparameterinjector.TestParameter; @@ -605,11 +606,30 @@ public void subjectExhaustiveness( .doTest(); } + @Test + public void iterableSubjectExhaustiveness( + @TestParameter(valuesProvider = IterableSubjectMethods.class) Method method) { + compilationHelper + .addSourceLines( + "Test.java", + "import static com.google.common.truth.Truth.assertThat;", + "import com.google.common.collect.ImmutableList;", + "class Test {", + " public void test(Iterable a, Long b) {", + " // BUG: Diagnostic contains:", + getOffensiveLine(method), + " }", + "}") + .doTest(); + } + private static String getOffensiveLine(Method method) { if (stream(method.getParameterTypes()).allMatch(p -> p.equals(Iterable.class))) { return format(" assertThat(a).%s(ImmutableList.of(b));", method.getName()); } else if (stream(method.getParameterTypes()).allMatch(p -> p.equals(Object.class))) { return format(" assertThat(a).%s(b);", method.getName()); + } else if (stream(method.getParameterTypes()).allMatch(p -> p.isArray())) { + return format(" assertThat(a).%s(new Long[]{b, b, b});", method.getName()); } else if (stream(method.getParameterTypes()) .allMatch(p -> p.equals(Object.class) || p.isArray())) { return format(" assertThat(a).%s(b, b, b);", method.getName()); @@ -623,17 +643,28 @@ private static String getOffensiveLine(Method method) { private static final class SubjectMethods implements TestParameterValuesProvider { @Override public ImmutableList provideValues() { - return stream(Subject.class.getDeclaredMethods()) - .filter( - m -> - Modifier.isPublic(m.getModifiers()) - && !m.getName().equals("equals") - && m.getParameterCount() > 0 - && (stream(m.getParameterTypes()).allMatch(p -> p.equals(Iterable.class)) - || stream(m.getParameterTypes()) - .allMatch(p -> p.equals(Object.class) || p.isArray()) - || stream(m.getParameterTypes()).allMatch(Class::isArray))) - .collect(toImmutableList()); + return getAssertionMethods(Subject.class); + } + } + + private static final class IterableSubjectMethods implements TestParameterValuesProvider { + @Override + public ImmutableList provideValues() { + return getAssertionMethods(IterableSubject.class); } } + + private static ImmutableList getAssertionMethods(Class clazz) { + return stream(clazz.getDeclaredMethods()) + .filter( + m -> + Modifier.isPublic(m.getModifiers()) + && !m.getName().equals("equals") + && m.getParameterCount() > 0 + && (stream(m.getParameterTypes()).allMatch(p -> p.equals(Iterable.class)) + || stream(m.getParameterTypes()) + .allMatch(p -> p.equals(Object.class) || p.isArray()) + || stream(m.getParameterTypes()).allMatch(Class::isArray))) + .collect(toImmutableList()); + } }