Skip to content

Commit

Permalink
Fix a crash in JUnitIncompatibleType
Browse files Browse the repository at this point in the history
Fixes #4291

PiperOrigin-RevId: 625359890
  • Loading branch information
cushon authored and Error Prone Team committed Apr 16, 2024
1 parent 5a7b8d9 commit a6ab21a
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 85 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2024 The Error Prone Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.errorprone.bugpatterns.collectionincompatibletype;

import static com.google.errorprone.util.ASTHelpers.getType;

import com.google.errorprone.VisitorState;
import com.sun.source.tree.ParenthesizedTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.TypeCastTree;
import com.sun.source.util.SimpleTreeVisitor;
import com.sun.tools.javac.code.Type;

/**
* A utility for handling types in cast expressions, shraed by {@link JUnitIncompatibleType} and
* {@link TruthIncompatibleType}.
*/
final class IgnoringCasts {

/**
* Returns the most specific type of a cast or its operand, for example returns {@code byte[]} for
* both {@code (byte[]) someObject} and {@code (Object) someByteArray}.
*/
static Type ignoringCasts(Tree tree, VisitorState state) {
return new SimpleTreeVisitor<Type, Void>() {
@Override
protected Type defaultAction(Tree node, Void unused) {
return getType(node);
}

@Override
public Type visitTypeCast(TypeCastTree node, Void unused) {
Type castType = getType(node);
Type expressionType = node.getExpression().accept(this, null);
return (castType.isPrimitive() || state.getTypes().isSubtype(castType, expressionType))
? castType
: expressionType;
}

@Override
public Type visitParenthesized(ParenthesizedTree node, Void unused) {
return node.getExpression().accept(this, null);
}
}.visit(tree, null);
}

private IgnoringCasts() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
package com.google.errorprone.bugpatterns.collectionincompatibletype;

import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
import static com.google.errorprone.bugpatterns.collectionincompatibletype.IgnoringCasts.ignoringCasts;
import static com.google.errorprone.matchers.Description.NO_MATCH;
import static com.google.errorprone.matchers.Matchers.allOf;
import static com.google.errorprone.matchers.Matchers.anyOf;
import static com.google.errorprone.matchers.method.MethodMatchers.staticMethod;
import static com.google.errorprone.util.ASTHelpers.getType;

import com.google.errorprone.BugPattern;
import com.google.errorprone.VisitorState;
Expand All @@ -34,10 +34,6 @@
import com.google.errorprone.util.Signatures;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.ParenthesizedTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.TypeCastTree;
import com.sun.source.util.SimpleTreeVisitor;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.code.Type.ArrayType;
import javax.inject.Inject;
Expand Down Expand Up @@ -76,14 +72,12 @@ public final class JUnitIncompatibleType extends BugChecker implements MethodInv
public Description matchMethodInvocation(MethodInvocationTree tree, VisitorState state) {
var arguments = tree.getArguments();
if (ASSERT_EQUALS.matches(tree, state)) {
var typeA = getType(ignoringCasts(arguments.get(arguments.size() - 2)));
var typeB = getType(ignoringCasts(arguments.get(arguments.size() - 1)));
var typeA = ignoringCasts(arguments.get(arguments.size() - 2), state);
var typeB = ignoringCasts(arguments.get(arguments.size() - 1), state);
return checkCompatibility(tree, typeA, typeB, state);
} else if (ASSERT_ARRAY_EQUALS.matches(tree, state)) {
var typeA =
((ArrayType) getType(ignoringCasts(arguments.get(arguments.size() - 2)))).elemtype;
var typeB =
((ArrayType) getType(ignoringCasts(arguments.get(arguments.size() - 1)))).elemtype;
var typeA = ((ArrayType) ignoringCasts(arguments.get(arguments.size() - 2), state)).elemtype;
var typeB = ((ArrayType) ignoringCasts(arguments.get(arguments.size() - 1), state)).elemtype;
return checkCompatibility(tree, typeA, typeB, state);
}
return NO_MATCH;
Expand Down Expand Up @@ -113,25 +107,4 @@ private Description checkCompatibility(
targetTypeName))
.build();
}

private Tree ignoringCasts(Tree tree) {
return tree.accept(
new SimpleTreeVisitor<Tree, Void>() {
@Override
protected Tree defaultAction(Tree node, Void unused) {
return node;
}

@Override
public Tree visitTypeCast(TypeCastTree node, Void unused) {
return getType(node).isPrimitive() ? node : node.getExpression().accept(this, null);
}

@Override
public Tree visitParenthesized(ParenthesizedTree node, Void unused) {
return node.getExpression().accept(this, null);
}
},
null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
import static com.google.errorprone.bugpatterns.collectionincompatibletype.AbstractCollectionIncompatibleTypeMatcher.extractTypeArgAsMemberOfSupertype;
import static com.google.errorprone.bugpatterns.collectionincompatibletype.IgnoringCasts.ignoringCasts;
import static com.google.errorprone.matchers.Description.NO_MATCH;
import static com.google.errorprone.matchers.Matchers.allOf;
import static com.google.errorprone.matchers.Matchers.anyOf;
Expand Down Expand Up @@ -46,10 +47,7 @@
import com.google.errorprone.util.Signatures;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.ParenthesizedTree;
import com.sun.source.tree.Tree;
import com.sun.source.tree.TypeCastTree;
import com.sun.source.util.SimpleTreeVisitor;
import com.sun.tools.javac.code.Symbol.MethodSymbol;
import com.sun.tools.javac.code.Symbol.TypeSymbol;
import com.sun.tools.javac.code.Type;
Expand Down Expand Up @@ -202,7 +200,7 @@ private Stream<Description> matchEquality(MethodInvocationTree tree, VisitorStat
}

Type targetType =
getType(ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments())));
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments()), state);
Type sourceType = getType(getOnlyElement(tree.getArguments()));
if (isNumericType(sourceType, state) && isNumericType(targetType, state)) {
return Stream.of();
Expand All @@ -219,7 +217,7 @@ private Stream<Description> matchIsAnyOf(MethodInvocationTree tree, VisitorState
return Stream.empty();
}
Type targetType =
getType(ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments())));
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments()), state);
return matchScalarContains(tree, targetType, state);
}

Expand All @@ -233,7 +231,7 @@ private Stream<Description> matchIsIn(MethodInvocationTree tree, VisitorState st
}

Type targetType =
getType(ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments())));
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments()), state);
Type sourceType =
getIterableTypeArg(
getType(getOnlyElement(tree.getArguments())),
Expand All @@ -254,7 +252,7 @@ private Stream<Description> matchVectorContains(MethodInvocationTree tree, Visit
Type targetType =
getIterableTypeArg(
getOnlyElement(getSymbol((MethodInvocationTree) receiver).getParameters()).type,
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments())),
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments()), state),
state);
Type sourceType =
getIterableTypeArg(
Expand All @@ -276,7 +274,7 @@ private Stream<Description> matchArrayContains(MethodInvocationTree tree, Visito
Type targetType =
getIterableTypeArg(
getOnlyElement(getSymbol((MethodInvocationTree) receiver).getParameters()).type,
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments())),
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments()), state),
state);
Type sourceType = ((ArrayType) getType(getOnlyElement(tree.getArguments()))).elemtype;
return checkCompatibility(getOnlyElement(tree.getArguments()), targetType, sourceType, state);
Expand All @@ -293,7 +291,7 @@ private Stream<Description> matchScalarContains(MethodInvocationTree tree, Visit
Type targetType =
getIterableTypeArg(
getOnlyElement(getSymbol((MethodInvocationTree) receiver).getParameters()).type,
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments())),
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments()), state),
state);
return matchScalarContains(tree, targetType, state);
}
Expand Down Expand Up @@ -324,7 +322,7 @@ private Stream<Description> matchCorrespondence(MethodInvocationTree tree, Visit
Type targetType =
getIterableTypeArg(
getOnlyElement(getSymbol((MethodInvocationTree) receiver).getParameters()).type,
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments())),
ignoringCasts(getOnlyElement(((MethodInvocationTree) receiver).getArguments()), state),
state);
if (targetType == null) {
// The target collection may be raw.
Expand Down Expand Up @@ -368,16 +366,10 @@ private Stream<Description> matchMapVectorContains(
getOnlyElement(getSymbol((MethodInvocationTree) receiver).getParameters()).type.tsym;
Type targetKeyType =
extractTypeArgAsMemberOfSupertype(
getType(ignoringCasts(assertee)),
assertionType,
/* typeArgIndex= */ 0,
state.getTypes());
ignoringCasts(assertee, state), assertionType, /* typeArgIndex= */ 0, state.getTypes());
Type targetValueType =
extractTypeArgAsMemberOfSupertype(
getType(ignoringCasts(assertee)),
assertionType,
/* typeArgIndex= */ 1,
state.getTypes());
ignoringCasts(assertee, state), assertionType, /* typeArgIndex= */ 1, state.getTypes());
Type sourceKeyType =
extractTypeArgAsMemberOfSupertype(
getType(getOnlyElement(tree.getArguments())),
Expand Down Expand Up @@ -411,10 +403,7 @@ private Stream<Description> matchMapContainsKey(MethodInvocationTree tree, Visit
getOnlyElement(getSymbol((MethodInvocationTree) receiver).getParameters()).type.tsym;
Type targetKeyType =
extractTypeArgAsMemberOfSupertype(
getType(ignoringCasts(assertee)),
assertionType,
/* typeArgIndex= */ 0,
state.getTypes());
ignoringCasts(assertee, state), assertionType, /* typeArgIndex= */ 0, state.getTypes());
return checkCompatibility(
getOnlyElement(tree.getArguments()),
targetKeyType,
Expand All @@ -437,16 +426,10 @@ private Stream<Description> matchMapScalarContains(
getOnlyElement(getSymbol((MethodInvocationTree) receiver).getParameters()).type.tsym;
Type targetKeyType =
extractTypeArgAsMemberOfSupertype(
getType(ignoringCasts(assertee)),
assertionType,
/* typeArgIndex= */ 0,
state.getTypes());
ignoringCasts(assertee, state), assertionType, /* typeArgIndex= */ 0, state.getTypes());
Type targetValueType =
extractTypeArgAsMemberOfSupertype(
getType(ignoringCasts(assertee)),
assertionType,
/* typeArgIndex= */ 1,
state.getTypes());
ignoringCasts(assertee, state), assertionType, /* typeArgIndex= */ 1, state.getTypes());
MethodSymbol methodSymbol = getSymbol(tree);
return Streams.mapWithIndex(
tree.getArguments().stream(),
Expand Down Expand Up @@ -494,32 +477,16 @@ private Stream<Description> checkCompatibility(
.build());
}

private Tree ignoringCasts(Tree tree) {
return tree.accept(
new SimpleTreeVisitor<Tree, Void>() {
@Override
protected Tree defaultAction(Tree node, Void unused) {
return node;
}

@Override
public Tree visitTypeCast(TypeCastTree node, Void unused) {
return getType(node).isPrimitive() ? node : node.getExpression().accept(this, null);
}

@Override
public Tree visitParenthesized(ParenthesizedTree node, Void unused) {
return node.getExpression().accept(this, null);
}
},
null);
}

private static Type getIterableTypeArg(Type type, Tree onlyElement, VisitorState state) {
return extractTypeArgAsMemberOfSupertype(
getType(onlyElement), type.tsym, /* typeArgIndex= */ 0, state.getTypes());
}

private static Type getIterableTypeArg(Type type, Type onlyElement, VisitorState state) {
return extractTypeArgAsMemberOfSupertype(
onlyElement, type.tsym, /* typeArgIndex= */ 0, state.getTypes());
}

private static Type getCorrespondenceTypeArg(Tree onlyElement, VisitorState state) {
return extractTypeArgAsMemberOfSupertype(
getType(onlyElement),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.google.errorprone.bugpatterns.collectionincompatibletype;

import com.google.errorprone.CompilationTestHelper;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -100,7 +99,6 @@ public void assertArrayEquals_primitiveOverloadsFine() {
.doTest();
}

@Ignore("https://github.com/google/error-prone/issues/4291")
@Test
public void assertArrayEquals_cast() {
compilationHelper
Expand All @@ -114,4 +112,20 @@ public void assertArrayEquals_cast() {
"}")
.doTest();
}

@Test
public void seesThroughCasts() {
compilationHelper
.addSourceLines(
"Test.java",
"import static org.junit.Assert.assertEquals;",
"import static org.junit.Assert.assertNotEquals;",
"class Test {",
" public void test() {",
" // BUG: Diagnostic contains:",
" assertEquals((Object) 1, (Object) 2L);",
" }",
"}")
.doTest();
}
}

0 comments on commit a6ab21a

Please sign in to comment.