Skip to content

Commit

Permalink
Allow Any to match sequence patterns in match/case (#18448)
Browse files Browse the repository at this point in the history
Fixes #17095 (comment, the primary issue was already fixed somewhere
before). Fixes #16272. Fixes #12532. Fixes #12770.

Prior to this PR mypy did not consider that `Any` can match any
patterns, including sequence patterns (e.g. `case [_]`). This PR allows
matching `Any` against any such patterns.
  • Loading branch information
sterliakov authored Jan 13, 2025
1 parent 9685171 commit ee364ce
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 77 deletions.
5 changes: 4 additions & 1 deletion mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,8 @@ def should_self_match(self, typ: Type) -> bool:
return False

def can_match_sequence(self, typ: ProperType) -> bool:
if isinstance(typ, AnyType):
return True
if isinstance(typ, UnionType):
return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items)
for other in self.non_sequence_match_types:
Expand Down Expand Up @@ -763,6 +765,8 @@ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
or class T(Sequence[Tuple[T, T]]), there is no way any of those can map to Sequence[str].
"""
proper_type = get_proper_type(outer_type)
if isinstance(proper_type, AnyType):
return outer_type
if isinstance(proper_type, UnionType):
types = [
self.construct_sequence_child(item, inner_type)
Expand All @@ -772,7 +776,6 @@ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
return make_simplified_union(types)
sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
proper_type = get_proper_type(outer_type)
if isinstance(proper_type, TupleType):
proper_type = tuple_fallback(proper_type)
assert isinstance(proper_type, Instance)
Expand Down
164 changes: 88 additions & 76 deletions mypyc/test-data/irbuild-match.test
Original file line number Diff line number Diff line change
Expand Up @@ -1378,14 +1378,15 @@ def f(x):
r15 :: bit
r16 :: bool
r17 :: native_int
r18, rest :: object
r19 :: str
r20 :: object
r21 :: str
r22 :: object
r23 :: object[1]
r24 :: object_ptr
r25, r26 :: object
r18 :: object
r19, rest :: list
r20 :: str
r21 :: object
r22 :: str
r23 :: object
r24 :: object[1]
r25 :: object_ptr
r26, r27 :: object
L0:
r0 = CPySequence_Check(x)
r1 = r0 != 0
Expand Down Expand Up @@ -1414,21 +1415,23 @@ L3:
L4:
r17 = r2 - 0
r18 = PySequence_GetSlice(x, 2, r17)
rest = r18
r19 = cast(list, r18)
rest = r19
L5:
r19 = 'matched'
r20 = builtins :: module
r21 = 'print'
r22 = CPyObject_GetAttr(r20, r21)
r23 = [r19]
r24 = load_address r23
r25 = _PyObject_Vectorcall(r22, r24, 1, 0)
keep_alive r19
r20 = 'matched'
r21 = builtins :: module
r22 = 'print'
r23 = CPyObject_GetAttr(r21, r22)
r24 = [r20]
r25 = load_address r24
r26 = _PyObject_Vectorcall(r23, r25, 1, 0)
keep_alive r20
goto L7
L6:
L7:
r26 = box(None, 1)
return r26
r27 = box(None, 1)
return r27

[case testMatchSequenceWithStarPatternInTheMiddle_python3_10]
def f(x):
match x:
Expand All @@ -1455,14 +1458,15 @@ def f(x):
r16 :: bit
r17 :: bool
r18 :: native_int
r19, rest :: object
r20 :: str
r21 :: object
r22 :: str
r23 :: object
r24 :: object[1]
r25 :: object_ptr
r26, r27 :: object
r19 :: object
r20, rest :: list
r21 :: str
r22 :: object
r23 :: str
r24 :: object
r25 :: object[1]
r26 :: object_ptr
r27, r28 :: object
L0:
r0 = CPySequence_Check(x)
r1 = r0 != 0
Expand Down Expand Up @@ -1492,21 +1496,23 @@ L3:
L4:
r18 = r2 - 1
r19 = PySequence_GetSlice(x, 1, r18)
rest = r19
r20 = cast(list, r19)
rest = r20
L5:
r20 = 'matched'
r21 = builtins :: module
r22 = 'print'
r23 = CPyObject_GetAttr(r21, r22)
r24 = [r20]
r25 = load_address r24
r26 = _PyObject_Vectorcall(r23, r25, 1, 0)
keep_alive r20
r21 = 'matched'
r22 = builtins :: module
r23 = 'print'
r24 = CPyObject_GetAttr(r22, r23)
r25 = [r21]
r26 = load_address r25
r27 = _PyObject_Vectorcall(r24, r26, 1, 0)
keep_alive r21
goto L7
L6:
L7:
r27 = box(None, 1)
return r27
r28 = box(None, 1)
return r28

[case testMatchSequenceWithStarPatternAtTheStart_python3_10]
def f(x):
match x:
Expand All @@ -1530,14 +1536,15 @@ def f(x):
r17 :: bit
r18 :: bool
r19 :: native_int
r20, rest :: object
r21 :: str
r22 :: object
r23 :: str
r24 :: object
r25 :: object[1]
r26 :: object_ptr
r27, r28 :: object
r20 :: object
r21, rest :: list
r22 :: str
r23 :: object
r24 :: str
r25 :: object
r26 :: object[1]
r27 :: object_ptr
r28, r29 :: object
L0:
r0 = CPySequence_Check(x)
r1 = r0 != 0
Expand Down Expand Up @@ -1568,21 +1575,23 @@ L3:
L4:
r19 = r2 - 2
r20 = PySequence_GetSlice(x, 0, r19)
rest = r20
r21 = cast(list, r20)
rest = r21
L5:
r21 = 'matched'
r22 = builtins :: module
r23 = 'print'
r24 = CPyObject_GetAttr(r22, r23)
r25 = [r21]
r26 = load_address r25
r27 = _PyObject_Vectorcall(r24, r26, 1, 0)
keep_alive r21
r22 = 'matched'
r23 = builtins :: module
r24 = 'print'
r25 = CPyObject_GetAttr(r23, r24)
r26 = [r22]
r27 = load_address r26
r28 = _PyObject_Vectorcall(r25, r27, 1, 0)
keep_alive r22
goto L7
L6:
L7:
r28 = box(None, 1)
return r28
r29 = box(None, 1)
return r29

[case testMatchBuiltinClassPattern_python3_10]
def f(x):
match x:
Expand Down Expand Up @@ -1634,14 +1643,15 @@ def f(x):
r2 :: native_int
r3, r4 :: bit
r5 :: native_int
r6, rest :: object
r7 :: str
r8 :: object
r9 :: str
r10 :: object
r11 :: object[1]
r12 :: object_ptr
r13, r14 :: object
r6 :: object
r7, rest :: list
r8 :: str
r9 :: object
r10 :: str
r11 :: object
r12 :: object[1]
r13 :: object_ptr
r14, r15 :: object
L0:
r0 = CPySequence_Check(x)
r1 = r0 != 0
Expand All @@ -1654,21 +1664,23 @@ L1:
L2:
r5 = r2 - 0
r6 = PySequence_GetSlice(x, 0, r5)
rest = r6
r7 = cast(list, r6)
rest = r7
L3:
r7 = 'matched'
r8 = builtins :: module
r9 = 'print'
r10 = CPyObject_GetAttr(r8, r9)
r11 = [r7]
r12 = load_address r11
r13 = _PyObject_Vectorcall(r10, r12, 1, 0)
keep_alive r7
r8 = 'matched'
r9 = builtins :: module
r10 = 'print'
r11 = CPyObject_GetAttr(r9, r10)
r12 = [r8]
r13 = load_address r12
r14 = _PyObject_Vectorcall(r11, r13, 1, 0)
keep_alive r8
goto L5
L4:
L5:
r14 = box(None, 1)
return r14
r15 = box(None, 1)
return r15

[case testMatchTypeAnnotatedNativeClass_python3_10]
class A:
a: int
Expand Down
32 changes: 32 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -2439,3 +2439,35 @@ def foo(x: T) -> T:
return out

[builtins fixtures/isinstance.pyi]

[case testMatchSequenceReachableFromAny]
# flags: --warn-unreachable
from typing import Any

def maybe_list(d: Any) -> int:
match d:
case []:
return 0
case [[_]]:
return 1
case [_]:
return 1
case _:
return 2

def with_guard(d: Any) -> None:
match d:
case [s] if isinstance(s, str):
reveal_type(s) # N: Revealed type is "builtins.str"
match d:
case (s,) if isinstance(s, str):
reveal_type(s) # N: Revealed type is "builtins.str"

def nested_in_dict(d: dict[str, Any]) -> int:
match d:
case {"src": ["src"]}:
return 1
case _:
return 0

[builtins fixtures/dict.pyi]

0 comments on commit ee364ce

Please sign in to comment.