Skip to content

Commit

Permalink
Fix isinstance on non-types (#747)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Mar 10, 2024
1 parent a0a9b70 commit 0977bb7
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 28 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

- Fix narrowing on `isinstance` calls with arguments that are not
instance of `typing`, such as unions and certain typing special forms (#747)
- Detect invalid calls to `isinstance` (#747)
- Support calls to `TypeVar` and several other typing constructs in
code that is not executed (e.g., under `if TYPE_CHECKING`) (#746)
- Fix spurious errors for the class-based syntax for creating
Expand Down
11 changes: 2 additions & 9 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from .find_unused import used
from .functions import FunctionDefNode
from .node_visitor import ErrorContext
from .safe import is_instance_of_typing_name, is_typing_name
from .safe import is_instance_of_typing_name, is_typing_name, is_union
from .signature import (
ANY_SIGNATURE,
ELLIPSIS_PARAM,
Expand Down Expand Up @@ -119,11 +119,6 @@
if TYPE_CHECKING:
from .name_check_visitor import NameCheckVisitor

try:
from types import UnionType
except ImportError:
UnionType = None


CONTEXT_MANAGER_TYPES = (typing.ContextManager, contextlib.AbstractContextManager)
ASYNC_CONTEXT_MANAGER_TYPES = (
Expand Down Expand Up @@ -1182,9 +1177,7 @@ def _value_of_origin_args(
_type_from_runtime(arg, ctx, allow_unpack=True) for arg in args
]
return _make_sequence_value(tuple, args_vals, ctx)
elif is_typing_name(origin, "Union") or (
UnionType is not None and origin is UnionType
):
elif is_union(origin):
return unite_values(*[_type_from_runtime(arg, ctx) for arg in args])
elif origin is Callable or is_typing_name(origin, "Callable"):
if len(args) == 0:
Expand Down
62 changes: 45 additions & 17 deletions pyanalyze/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import (
Callable,
Dict,
Iterable,
NewType,
Optional,
Sequence,
Expand All @@ -26,7 +27,7 @@
from .format_strings import parse_format_string
from .predicates import IsAssignablePredicate
from .runtime import is_compatible
from .safe import hasattr_static, safe_isinstance, safe_issubclass
from .safe import hasattr_static, is_union, safe_isinstance, safe_issubclass
from .signature import (
ANY_SIGNATURE,
CallContext,
Expand Down Expand Up @@ -120,15 +121,19 @@ def _issubclass_impl(ctx: CallContext) -> Value:
varname = ctx.varname_for_arg("cls")
if varname is None or not isinstance(class_or_tuple, KnownValue):
return TypedValue(bool)
if isinstance(class_or_tuple.val, type):
narrowed_type = SubclassValue(TypedValue(class_or_tuple.val))
elif isinstance(class_or_tuple.val, tuple) and all(
isinstance(elt, type) for elt in class_or_tuple.val
):
vals = [SubclassValue(TypedValue(elt)) for elt in class_or_tuple.val]
narrowed_type = unite_values(*vals)
else:
try:
narrowed_types = list(_resolve_isinstance_arg(class_or_tuple.val))
except _CannotResolve as e:
ctx.show_error(
f'Second argument to "issubclass" must be a type, union,'
f' or tuple of types, not "{e.args[0]!r}"',
ErrorCode.incompatible_argument,
arg="class_or_tuple",
)
return TypedValue(bool)
narrowed_type = unite_values(
*[SubclassValue(TypedValue(typ)) for typ in narrowed_types]
)
predicate = IsAssignablePredicate(narrowed_type, ctx.visitor, positive_only=False)
constraint = Constraint(varname, ConstraintType.predicate, True, predicate)
return annotate_with_constraint(TypedValue(bool), constraint)
Expand All @@ -139,20 +144,43 @@ def _isinstance_impl(ctx: CallContext) -> Value:
varname = ctx.varname_for_arg("obj")
if varname is None or not isinstance(class_or_tuple, KnownValue):
return TypedValue(bool)
if isinstance(class_or_tuple.val, type):
narrowed_type = TypedValue(class_or_tuple.val)
elif isinstance(class_or_tuple.val, tuple) and all(
isinstance(elt, type) for elt in class_or_tuple.val
):
vals = [TypedValue(elt) for elt in class_or_tuple.val]
narrowed_type = unite_values(*vals)
else:
try:
narrowed_types = list(_resolve_isinstance_arg(class_or_tuple.val))
except _CannotResolve as e:
ctx.show_error(
f'Second argument to "isinstance" must be a type, union,'
f' or tuple of types, not "{e.args[0]!r}"',
ErrorCode.incompatible_argument,
arg="class_or_tuple",
)
return TypedValue(bool)
narrowed_type = unite_values(*[TypedValue(typ) for typ in narrowed_types])
predicate = IsAssignablePredicate(narrowed_type, ctx.visitor, positive_only=False)
constraint = Constraint(varname, ConstraintType.predicate, True, predicate)
return annotate_with_constraint(TypedValue(bool), constraint)


class _CannotResolve(Exception):
pass


def _resolve_isinstance_arg(val: object) -> Iterable[type]:
if safe_isinstance(val, type):
yield val
elif safe_isinstance(val, tuple):
for elt in val:
yield from _resolve_isinstance_arg(elt)
else:
origin = typing_extensions.get_origin(val)
if is_union(origin):
for arg in typing_extensions.get_args(val):
yield from _resolve_isinstance_arg(arg)
elif safe_isinstance(origin, type):
yield origin
else:
raise _CannotResolve(val)


def _constraint_from_isinstance(
varname: Optional[VarnameWithOrigin], class_or_tuple: Value
) -> AbstractConstraint:
Expand Down
6 changes: 5 additions & 1 deletion pyanalyze/patma.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@
except ImportError:
# 3.9 and lower
MatchAs = MatchClass = MatchMapping = Any
MatchOr = MatchSequence = MatchSingleton = MatchStar = MatchValue = Any
MatchOr = MatchSequence = MatchSingleton = MatchValue = Any

# Avoid false positive errors on isinstance() in 3.8/3.9 self check
class MatchStar(ast.AST):
pass


# For these types, a single class subpattern matches the whole thing
Expand Down
10 changes: 10 additions & 0 deletions pyanalyze/safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ def is_typing_name(obj: object, name: str) -> bool:
return safe_in(obj, names)


try:
from types import UnionType
except ImportError:
UnionType = None


def is_union(obj: object) -> bool:
return is_typing_name(obj, "Union") or (UnionType is not None and obj is UnionType)


def is_instance_of_typing_name(obj: object, name: str) -> bool:
objs, _ = _fill_typing_name_cache(name)
return isinstance(obj, objs)
Expand Down
58 changes: 57 additions & 1 deletion pyanalyze/test_stacked_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .name_check_visitor import build_stacked_scopes
from .stacked_scopes import ScopeType, uniq_chain
from .test_name_check_visitor import TestNameCheckVisitorBase
from .test_node_visitor import assert_passes
from .test_node_visitor import assert_passes, skip_before
from .value import (
NO_RETURN_VALUE,
UNINITIALIZED_VALUE,
Expand Down Expand Up @@ -932,6 +932,8 @@ def capybara(x):
# Don't widen the type to A.
assert_is_value(x, TypedValue(B))

@assert_passes()
def test_isinstance_multiple_types(self):
def kerodon(cond1, cond2, val, lst: list): # E: missing_generic_parameters
if cond1:
x = int(val)
Expand All @@ -958,6 +960,8 @@ def kerodon(cond1, cond2, val, lst: list): # E: missing_generic_parameters
else:
assert_is_value(x, TypedValue(list))

@assert_passes()
def test_complex_boolean(self):
def paca(cond1, cond2):
if cond1:
x = True
Expand All @@ -973,6 +977,58 @@ def paca(cond1, cond2):
else:
assert_is_value(x, KnownValue(False))

@assert_passes()
def test_isinstance_mapping(self):
from typing import Any, Mapping, Union

from typing_extensions import assert_type

class A: ...

def takes_mapping(x: Mapping[str, Any]) -> None: ...

def foo(x: Union[A, Mapping[str, Any]]) -> None:
# This is tricky because Mapping is not an instance of type.
if isinstance(x, Mapping):
assert_type(x, Mapping[str, Any])
else:
assert_type(x, A)

@skip_before((3, 10))
@assert_passes()
def test_isinstance_union(self):
from typing import Union

from typing_extensions import assert_type

def foo(x: Union[int, str, range]) -> None:
if isinstance(x, int | str):
assert_type(x, int | str)
else:
assert_type(x, range)
if isinstance(x, Union[int, range]):
assert_type(x, int | range)
else:
assert_type(x, str)

@assert_passes()
def test_isinstance_nested_tuple(self):
from typing import Union

from typing_extensions import assert_type

def foo(x: Union[int, str, range]) -> None:
if isinstance(x, (((int,), (str,)),)):
assert_type(x, Union[int, str])
else:
assert_type(x, range)

@assert_passes()
def test_isinstance_bad_arg(self):
def capybara(x):
if isinstance(x, 1): # E: incompatible_argument
pass

@assert_passes()
def test_double_index(self):
from typing import Optional, Union
Expand Down

0 comments on commit 0977bb7

Please sign in to comment.