Skip to content

Commit

Permalink
Prevent crashing when match arms use name of existing callable (#18449
Browse files Browse the repository at this point in the history
)

Fixes #16793. Fixes crash in #13666.

Previously mypy considered that variables in match/case patterns must be
Var's, causing a hard crash when a name of captured pattern clashes with
a name of some existing function. This PR removes such assumption about
Var and allows other nodes.
  • Loading branch information
sterliakov authored Jan 13, 2025
1 parent 469b4e4 commit 9be49b3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 6 deletions.
19 changes: 13 additions & 6 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5402,17 +5402,21 @@ def _get_recursive_sub_patterns_map(

return sub_patterns_map

def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[Var, Type]:
all_captures: dict[Var, list[tuple[NameExpr, Type]]] = defaultdict(list)
def infer_variable_types_from_type_maps(
self, type_maps: list[TypeMap]
) -> dict[SymbolNode, Type]:
# Type maps may contain variables inherited from previous code which are not
# necessary `Var`s (e.g. a function defined earlier with the same name).
all_captures: dict[SymbolNode, list[tuple[NameExpr, Type]]] = defaultdict(list)
for tm in type_maps:
if tm is not None:
for expr, typ in tm.items():
if isinstance(expr, NameExpr):
node = expr.node
assert isinstance(node, Var)
assert node is not None
all_captures[node].append((expr, typ))

inferred_types: dict[Var, Type] = {}
inferred_types: dict[SymbolNode, Type] = {}
for var, captures in all_captures.items():
already_exists = False
types: list[Type] = []
Expand All @@ -5436,16 +5440,19 @@ def infer_variable_types_from_type_maps(self, type_maps: list[TypeMap]) -> dict[
new_type = UnionType.make_union(types)
# Infer the union type at the first occurrence
first_occurrence, _ = captures[0]
# If it didn't exist before ``match``, it's a Var.
assert isinstance(var, Var)
inferred_types[var] = new_type
self.infer_variable_type(var, first_occurrence, new_type, first_occurrence)
return inferred_types

def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: dict[Var, Type]) -> None:
def remove_capture_conflicts(
self, type_map: TypeMap, inferred_types: dict[SymbolNode, Type]
) -> None:
if type_map:
for expr, typ in list(type_map.items()):
if isinstance(expr, NameExpr):
node = expr.node
assert isinstance(node, Var)
if node not in inferred_types or not is_subtype(typ, inferred_types[node]):
del type_map[expr]

Expand Down
51 changes: 51 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -2471,3 +2471,54 @@ def nested_in_dict(d: dict[str, Any]) -> int:
return 0

[builtins fixtures/dict.pyi]

[case testMatchRebindsOuterFunctionName]
# flags: --warn-unreachable
from typing_extensions import Literal

def x() -> tuple[Literal["test"]]: ...

match x():
case (x,) if x == "test": # E: Incompatible types in capture pattern (pattern captures type "Literal['test']", variable has type "Callable[[], Tuple[Literal['test']]]")
reveal_type(x) # N: Revealed type is "def () -> Tuple[Literal['test']]"
case foo:
foo

[builtins fixtures/dict.pyi]

[case testMatchRebindsInnerFunctionName]
# flags: --warn-unreachable
class Some:
value: int | str
__match_args__ = ("value",)

def fn1(x: Some | int | str) -> None:
match x:
case int():
def value():
return 1
reveal_type(value) # N: Revealed type is "def () -> Any"
case str():
def value():
return 1
reveal_type(value) # N: Revealed type is "def () -> Any"
case Some(value): # E: Incompatible types in capture pattern (pattern captures type "Union[int, str]", variable has type "Callable[[], Any]")
pass

def fn2(x: Some | int | str) -> None:
match x:
case int():
def value() -> str:
return ""
reveal_type(value) # N: Revealed type is "def () -> builtins.str"
case str():
def value() -> int: # E: All conditional function variants must have identical signatures \
# N: Original: \
# N: def value() -> str \
# N: Redefinition: \
# N: def value() -> int
return 1
reveal_type(value) # N: Revealed type is "def () -> builtins.str"
case Some(value): # E: Incompatible types in capture pattern (pattern captures type "Union[int, str]", variable has type "Callable[[], str]")
pass
[builtins fixtures/dict.pyi]

0 comments on commit 9be49b3

Please sign in to comment.