diff --git a/mypy/checker.py b/mypy/checker.py index 62acfc9e3abe..81f125e89ead 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5309,10 +5309,16 @@ def get_types_from_except_handler(self, typ: Type, n: Expression) -> list[Type]: def visit_for_stmt(self, s: ForStmt) -> None: """Type check a for statement.""" + lvalue_type, b, c = self.check_lvalue(s.index) + if lvalue_type is not None: + context: Type | None = self.named_generic_type("typing.Iterable", [lvalue_type]) + else: + context = None + if s.is_async: - iterator_type, item_type = self.analyze_async_iterable_item_type(s.expr) + iterator_type, item_type = self.analyze_async_iterable_item_type(s.expr, context) else: - iterator_type, item_type = self.analyze_iterable_item_type(s.expr) + iterator_type, item_type = self.analyze_iterable_item_type(s.expr, context) s.inferred_item_type = item_type s.inferred_iterator_type = iterator_type @@ -5324,10 +5330,12 @@ def visit_for_stmt(self, s: ForStmt) -> None: ), ) - def analyze_async_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]: + def analyze_async_iterable_item_type( + self, expr: Expression, context: Type | None = None + ) -> tuple[Type, Type]: """Analyse async iterable expression and return iterator and iterator item types.""" echk = self.expr_checker - iterable = echk.accept(expr) + iterable = echk.accept(expr, context) iterator = echk.check_method_call_by_name("__aiter__", iterable, [], [], expr)[0] awaitable = echk.check_method_call_by_name("__anext__", iterator, [], [], expr)[0] item_type = echk.check_awaitable_expr( @@ -5335,10 +5343,12 @@ def analyze_async_iterable_item_type(self, expr: Expression) -> tuple[Type, Type ) return iterator, item_type - def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]: + def analyze_iterable_item_type( + self, expr: Expression, context: Type | None = None + ) -> tuple[Type, Type]: """Analyse iterable expression and return iterator and iterator item types.""" iterator, iterable = self.analyze_iterable_item_type_without_expression( - self.expr_checker.accept(expr), context=expr + self.expr_checker.accept(expr, context), context=expr ) int_type = self.analyze_range_native_int_type(expr) if int_type: diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index c71d83324694..936c844abc9b 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -307,6 +307,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: narrowed_inner_types = [] inner_rest_types = [] for inner_type, new_inner_type in zip(inner_types, new_inner_types): + # TODO: for loop type context should narrow on "assignment"? + assert inner_type is not None + (narrowed_inner_type, inner_rest_type) = ( self.chk.conditional_types_with_intersection( inner_type, [get_type_range(new_inner_type)], o, default=inner_type diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 7219d5d5e708..a3ea1c5c7165 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -2317,15 +2317,15 @@ def decompose_union_helper( rest_items.append(item) exit_block = BasicBlock() result = Register(result_type) - for i, item in enumerate(fast_items): + for i, inst in enumerate(fast_items): more_types = i < len(fast_items) - 1 or rest_items if more_types: # We are not at the final item so we need one more branch - op = self.isinstance_native(obj, item.class_ir, line) + op = self.isinstance_native(obj, inst.class_ir, line) true_block, false_block = BasicBlock(), BasicBlock() self.add_bool_branch(op, true_block, false_block) self.activate_block(true_block) - coerced = self.coerce(obj, item, line) + coerced = self.coerce(obj, inst, line) temp = process_item(coerced) temp2 = self.coerce(temp, result_type, line) self.add(Assign(result, temp2)) diff --git a/mypyc/test-data/irbuild-set.test b/mypyc/test-data/irbuild-set.test index c42a1fa74a75..43134f7bad38 100644 --- a/mypyc/test-data/irbuild-set.test +++ b/mypyc/test-data/irbuild-set.test @@ -736,10 +736,8 @@ def not_precomputed() -> None: [out] def precomputed(): r0 :: set - r1, r2 :: object - r3 :: str - _ :: object - r4 :: bit + r1, r2, _ :: object + r3 :: bit L0: r0 = frozenset({'False', 'None', 'True'}) r1 = PyObject_GetIter(r0) @@ -747,12 +745,11 @@ L1: r2 = PyIter_Next(r1) if is_error(r2) goto L4 else goto L2 L2: - r3 = cast(str, r2) - _ = r3 + _ = r2 L3: goto L1 L4: - r4 = CPy_NoErrOccurred() + r3 = CPy_NoErrOccurred() L5: return 1 def precomputed2(): diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index 17ae6d9934b7..41c4136afc55 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -1495,3 +1495,12 @@ def g(b: Optional[str]) -> None: z: Callable[[], str] = lambda: reveal_type(b) # N: Revealed type is "builtins.str" f2(lambda: reveal_type(b)) # N: Revealed type is "builtins.str" lambda: reveal_type(b) # N: Revealed type is "builtins.str" + +[case testInferenceForForLoops] +from typing import Literal + +def func2() -> None: + b: Literal["foo", "bar", "baz"] + for b in ["foo", "bar"]: + # TODO: this should narrow to "foo" | "bar" ideally? + reveal_type(b) # N: Revealed type is "Union[Literal['foo'], Literal['bar'], Literal['baz']]" diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 42b5a05ab39a..670a079abb73 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -1149,11 +1149,11 @@ for x, (y, z) in [(A(), (B(), C()))]: a = x b = y c = z -for xx, yy, zz in [(A(), B())]: # E: Need more than 2 values to unpack (3 expected) +for x2, y2, z2 in [(A(), B())]: # E: Need more than 2 values to unpack (3 expected) pass -for xx, (yy, zz) in [(A(), B())]: # E: "B" object is not iterable +for x3, (y3, z3) in [(A(), B())]: # E: "B" object is not iterable pass -for xxx, yyy in [(None, None)]: +for x4, y4 in [(None, None)]: pass [builtins fixtures/for.pyi]