Skip to content

Commit

Permalink
Use a dict to keep track of TypedDict fields in semanal
Browse files Browse the repository at this point in the history
Useful for #7435
  • Loading branch information
hauntsaninja committed Dec 30, 2024
1 parent ac6151a commit c677b1c
Showing 1 changed file with 50 additions and 52 deletions.
102 changes: 50 additions & 52 deletions mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Final
from typing import Collection, Final

from mypy import errorcodes as codes, message_registry
from mypy.errorcodes import ErrorCode
Expand Down Expand Up @@ -97,21 +97,23 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
existing_info = None
if isinstance(defn.analyzed, TypedDictExpr):
existing_info = defn.analyzed.info

field_types: dict[str, Type] | None
if (
len(defn.base_type_exprs) == 1
and isinstance(defn.base_type_exprs[0], RefExpr)
and defn.base_type_exprs[0].fullname in TPDICT_NAMES
):
# Building a new TypedDict
fields, types, statements, required_keys, readonly_keys = (
field_types, statements, required_keys, readonly_keys = (
self.analyze_typeddict_classdef_fields(defn)
)
if fields is None:
if field_types is None:
return True, None # Defer
if self.api.is_func_scope() and "@" not in defn.name:
defn.name += "@" + str(defn.line)
info = self.build_typeddict_typeinfo(
defn.name, fields, types, required_keys, readonly_keys, defn.line, existing_info
defn.name, field_types, required_keys, readonly_keys, defn.line, existing_info
)
defn.analyzed = TypedDictExpr(info)
defn.analyzed.line = defn.line
Expand Down Expand Up @@ -154,26 +156,24 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
else:
self.fail("All bases of a new TypedDict must be TypedDict types", defn)

keys: list[str] = []
types = []
field_types = {}
required_keys = set()
readonly_keys = set()
# Iterate over bases in reverse order so that leftmost base class' keys take precedence
for base in reversed(typeddict_bases):
self.add_keys_and_types_from_base(
base, keys, types, required_keys, readonly_keys, defn
base, field_types, required_keys, readonly_keys, defn
)
(new_keys, new_types, new_statements, new_required_keys, new_readonly_keys) = (
self.analyze_typeddict_classdef_fields(defn, keys)
(new_field_types, new_statements, new_required_keys, new_readonly_keys) = (
self.analyze_typeddict_classdef_fields(defn, oldfields=field_types)
)
if new_keys is None:
if new_field_types is None:
return True, None # Defer
keys.extend(new_keys)
types.extend(new_types)
field_types.update(new_field_types)
required_keys.update(new_required_keys)
readonly_keys.update(new_readonly_keys)
info = self.build_typeddict_typeinfo(
defn.name, keys, types, required_keys, readonly_keys, defn.line, existing_info
defn.name, field_types, required_keys, readonly_keys, defn.line, existing_info
)
defn.analyzed = TypedDictExpr(info)
defn.analyzed.line = defn.line
Expand All @@ -184,8 +184,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> tuple[bool, TypeInfo | N
def add_keys_and_types_from_base(
self,
base: Expression,
keys: list[str],
types: list[Type],
field_types: dict[str, Type],
required_keys: set[str],
readonly_keys: set[str],
ctx: Context,
Expand Down Expand Up @@ -224,10 +223,10 @@ def add_keys_and_types_from_base(
with state.strict_optional_set(self.options.strict_optional):
valid_items = self.map_items_to_base(valid_items, tvars, base_args)
for key in base_items:
if key in keys:
if key in field_types:
self.fail(TYPEDDICT_OVERRIDE_MERGE.format(key), ctx)
keys.extend(valid_items.keys())
types.extend(valid_items.values())

field_types.update(valid_items)
required_keys.update(base_typed_dict.required_keys)
readonly_keys.update(base_typed_dict.readonly_keys)

Expand Down Expand Up @@ -280,8 +279,8 @@ def map_items_to_base(
return mapped_items

def analyze_typeddict_classdef_fields(
self, defn: ClassDef, oldfields: list[str] | None = None
) -> tuple[list[str] | None, list[Type], list[Statement], set[str], set[str]]:
self, defn: ClassDef, oldfields: Collection[str] | None = None
) -> tuple[dict[str, Type] | None, list[Statement], set[str], set[str]]:
"""Analyze fields defined in a TypedDict class definition.
This doesn't consider inherited fields (if any). Also consider totality,
Expand All @@ -294,9 +293,19 @@ def analyze_typeddict_classdef_fields(
part of a TypedDict definition
* Set of required keys
"""
fields: list[str] = []
types: list[Type] = []
fields: dict[str, Type] = {}
readonly_keys = set()
required_keys = set()
statements: list[Statement] = []

total: bool | None = True
for key in defn.keywords:
if key == "total":
total = require_bool_literal_argument(self.api, defn.keywords["total"], "total", True)
continue
for_function = ' for "__init_subclass__" of "TypedDict"'
self.msg.unexpected_keyword_argument_for_function(for_function, key, defn)

for stmt in defn.defs.body:
if not isinstance(stmt, AssignmentStmt):
# Still allow pass or ... (for empty TypedDict's) and docstrings
Expand All @@ -320,10 +329,11 @@ def analyze_typeddict_classdef_fields(
self.fail(f'Duplicate TypedDict key "{name}"', stmt)
continue
# Append stmt, name, and type in this case...
fields.append(name)
statements.append(stmt)

field_type: Type
if stmt.unanalyzed_type is None:
types.append(AnyType(TypeOfAny.unannotated))
field_type = AnyType(TypeOfAny.unannotated)
else:
analyzed = self.api.anal_type(
stmt.unanalyzed_type,
Expand All @@ -333,38 +343,27 @@ def analyze_typeddict_classdef_fields(
prohibit_special_class_field_types="TypedDict",
)
if analyzed is None:
return None, [], [], set(), set() # Need to defer
types.append(analyzed)
return None, [], set(), set() # Need to defer
field_type = analyzed
if not has_placeholder(analyzed):
stmt.type = self.extract_meta_info(analyzed, stmt)[0]

field_type, required, readonly = self.extract_meta_info(field_type)
fields[name] = field_type

if (total or required is True) and required is not False:
required_keys.add(name)
if readonly:
readonly_keys.add(name)

# ...despite possible minor failures that allow further analysis.
if stmt.type is None or hasattr(stmt, "new_syntax") and not stmt.new_syntax:
self.fail(TPDICT_CLASS_ERROR, stmt)
elif not isinstance(stmt.rvalue, TempNode):
# x: int assigns rvalue to TempNode(AnyType())
self.fail("Right hand side values are not supported in TypedDict", stmt)
total: bool | None = True
if "total" in defn.keywords:
total = require_bool_literal_argument(self.api, defn.keywords["total"], "total", True)
if defn.keywords and defn.keywords.keys() != {"total"}:
for_function = ' for "__init_subclass__" of "TypedDict"'
for key in defn.keywords:
if key == "total":
continue
self.msg.unexpected_keyword_argument_for_function(for_function, key, defn)

res_types = []
readonly_keys = set()
required_keys = set()
for field, t in zip(fields, types):
typ, required, readonly = self.extract_meta_info(t)
res_types.append(typ)
if (total or required is True) and required is not False:
required_keys.add(field)
if readonly:
readonly_keys.add(field)

return fields, res_types, statements, required_keys, readonly_keys
return fields, statements, required_keys, readonly_keys

def extract_meta_info(
self, typ: Type, context: Context | None = None
Expand Down Expand Up @@ -433,7 +432,7 @@ def check_typeddict(
name += "@" + str(call.line)
else:
name = var_name = "TypedDict@" + str(call.line)
info = self.build_typeddict_typeinfo(name, [], [], set(), set(), call.line, None)
info = self.build_typeddict_typeinfo(name, {}, set(), set(), call.line, None)
else:
if var_name is not None and name != var_name:
self.fail(
Expand Down Expand Up @@ -473,7 +472,7 @@ def check_typeddict(
if isinstance(node.analyzed, TypedDictExpr):
existing_info = node.analyzed.info
info = self.build_typeddict_typeinfo(
name, items, types, required_keys, readonly_keys, call.line, existing_info
name, dict(zip(items, types)), required_keys, readonly_keys, call.line, existing_info
)
info.line = node.line
# Store generated TypeInfo under both names, see semanal_namedtuple for more details.
Expand Down Expand Up @@ -578,8 +577,7 @@ def fail_typeddict_arg(
def build_typeddict_typeinfo(
self,
name: str,
items: list[str],
types: list[Type],
item_types: dict[str, Type],
required_keys: set[str],
readonly_keys: set[str],
line: int,
Expand All @@ -594,7 +592,7 @@ def build_typeddict_typeinfo(
assert fallback is not None
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
typeddict_type = TypedDictType(
dict(zip(items, types)), required_keys, readonly_keys, fallback
item_types, required_keys, readonly_keys, fallback
)
if info.special_alias and has_placeholder(info.special_alias.target):
self.api.process_placeholder(
Expand Down

0 comments on commit c677b1c

Please sign in to comment.