From 04b1a717f70f0c0f78740a5f2056352b2b3e45df Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Tue, 26 Mar 2024 08:38:57 -0700 Subject: [PATCH] Fix constraints on OR --- docs/changelog.md | 1 + pyanalyze/name_check_visitor.py | 10 +++++++--- pyanalyze/test_never.py | 27 ++++++++++++++++++++++++++- pyanalyze/test_stacked_scopes.py | 16 ++++++++++++++++ 4 files changed, 50 insertions(+), 4 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 9875de02..87ee7182 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +- Fix type narrowing for certain conditionals using `or` (#755) - Fix incorrect `undefined_name` errors when a class is nested in a nested function and uses a name from the outer function (#750) - Fix incorrect `possibly_undefined_name` error on certain uses of the diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 8b914655..40370b67 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -3433,12 +3433,16 @@ def visit_BoolOp(self, node: ast.BoolOp) -> Value: values.append(constrain_value(new_value, TRUTHY_CONSTRAINT)) self.scopes.combine_subscopes(scopes) - constraint_cls = AndConstraint if is_and else OrConstraint - constraint = constraint_cls.make(reversed(out_constraints)) out = unite_values(*values) if definite_value is not None: out = annotate_value(out, [DefiniteValueExtension(definite_value)]) - return annotate_with_constraint(out, constraint) + if is_and: + constraint = AndConstraint.make(reversed(out_constraints)) + return annotate_with_constraint(out, constraint) + else: + # For OR conditions, no need to add a constraint here; we'll + # return a Union and extract_constraints() will combine them. + return out def visit_Compare(self, node: ast.Compare) -> Value: nodes = [node.left, *node.comparators] diff --git a/pyanalyze/test_never.py b/pyanalyze/test_never.py index 352e878d..8ab89e01 100644 --- a/pyanalyze/test_never.py +++ b/pyanalyze/test_never.py @@ -81,7 +81,7 @@ def capybara(x: Union[int, str]) -> None: assert_never(x) @assert_passes() - def test_enum(self): + def test_enum_in(self): import enum from typing_extensions import assert_never @@ -96,6 +96,31 @@ def capybara(x: Capybara) -> None: else: assert_never(x) + @assert_passes() + def test_enum_is_or(self): + import enum + + from typing_extensions import Literal, assert_never, assert_type + + class Capybara(enum.Enum): + hydrochaeris = 1 + isthmius = 2 + hesperotiganites = 3 + + def neochoerus(x: Capybara) -> None: + if x is Capybara.hydrochaeris or x is Capybara.isthmius: + assert_type(x, Literal[Capybara.hydrochaeris, Capybara.isthmius]) + else: + assert_type(x, Literal[Capybara.hesperotiganites]) + + def capybara(x: Capybara) -> None: + if x is Capybara.hydrochaeris or x is Capybara.isthmius: + assert_type(x, Literal[Capybara.hydrochaeris, Capybara.isthmius]) + elif x is Capybara.hesperotiganites: + assert_type(x, Literal[Capybara.hesperotiganites]) + else: + assert_never(x) + @assert_passes() def test_literal(self): from typing_extensions import Literal, assert_never, assert_type diff --git a/pyanalyze/test_stacked_scopes.py b/pyanalyze/test_stacked_scopes.py index 28c3d675..57da5099 100644 --- a/pyanalyze/test_stacked_scopes.py +++ b/pyanalyze/test_stacked_scopes.py @@ -977,6 +977,22 @@ def paca(cond1, cond2): else: assert_is_value(x, KnownValue(False)) + @assert_passes() + def test_two_booleans_is(self): + def capybara(cond1: bool, cond2: bool) -> None: + from typing import Union + + from typing_extensions import Literal, assert_type + + if (cond1 is True) or (cond2 is True): + # Ideally we'd elide the "Literal[False]" but this isn't + # wrong. + assert_type(cond1, Union[bool, Literal[False]]) + assert_type(cond2, bool) + else: + assert_type(cond1, Literal[False]) + assert_type(cond2, Literal[False]) + @assert_passes() def test_isinstance_mapping(self): from typing import Any, Mapping, Union