Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for returning TypedDict for dataclasses.asdict #8583

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
bb9f051
Implement support for returning TypedDict for dataclasses.asdict
syastrov Mar 26, 2020
e2f9f06
Remove redundant test. Fix comment typo.
syastrov Mar 26, 2020
2f6ec2d
Test for cases where dataclasses.asdict is called on non-dataclass in…
syastrov Mar 26, 2020
a9779e2
Clean up tests, and test more edge-cases.
syastrov Mar 26, 2020
c5d0a15
Remove no-longer-needed type: ignore on CheckerPluginInterface.module…
syastrov Mar 26, 2020
454431d
Make typeddicts non-total.
syastrov Mar 26, 2020
d809e8b
Address some of review comments (formatting, docs, code nitpicks, rem…
syastrov Mar 30, 2020
4d195cc
Simplify: Remove _transform_type_args and remove unneeded None check …
syastrov Mar 30, 2020
b33798c
Fix unused import
syastrov Mar 30, 2020
fe19bc9
Add fine-grained test for dataclasses.asdict.
syastrov Mar 30, 2020
6328a7c
Oops, add forgotten fine-grained dataclasses test. And remove redunda…
syastrov Mar 30, 2020
4694299
Only import the module containing TypedDict fallback if dataclasses i…
syastrov Apr 1, 2020
9c29081
Only enable TypedDict for Python >= 3.8.
syastrov Apr 2, 2020
d7df77a
Refactor asdict implementation to use TypeTranslator instead of recur…
syastrov Apr 2, 2020
e9a56ba
Made TypedDicts returned by asdict total again.
syastrov Apr 2, 2020
2e5240e
Fixed test after total change.
syastrov Apr 2, 2020
52a1c27
Make code a bit more readable, and a bit more robust.
syastrov Apr 2, 2020
43f174c
Fix typo
syastrov Apr 2, 2020
227ba90
After refactoring to use TypeTranslator, ensure Callable and Type[..]…
syastrov Apr 2, 2020
d12c665
Address second review comments.
syastrov Apr 8, 2020
45e72d7
Fix return type
syastrov Apr 8, 2020
a03f033
Try to address more review comments and fix flake8
syastrov Jun 3, 2020
b4d7e15
Add fine grained deps test to help debug asdict dependencies.
syastrov Jun 3, 2020
d96d977
Fix some asdict tests missing tuple dependency
syastrov Jun 3, 2020
441b665
Revert "Fix some asdict tests missing tuple dependency"
syastrov Jun 4, 2020
344ca6a
Don't need dep on typing_extensions
syastrov Jun 4, 2020
c8858fa
Checker lookup_fully_qualified_or_none: Don't raise KeyError, return …
syastrov Jun 9, 2020
20d7716
Add dependencies for asdict on the referenced dataclasses and its att…
syastrov Jun 9, 2020
5fec41b
Fix fine-grained no-cache test by adding correct dep on dataclass attrs.
syastrov Aug 18, 2020
5862c16
remove unused imports
syastrov Aug 18, 2020
26b7393
Merge branch 'master' into dataclasses-asdict
syastrov Feb 17, 2021
d7e0310
Remove error when passing a "non-dataclass" to asdict to reduce false…
syastrov Feb 17, 2021
080c00c
Fix flake8
syastrov Feb 17, 2021
9e45f8f
Fix asdict tests (require using python version 3.7 minimum).
syastrov Feb 17, 2021
38b466a
Merge branch 'master' into dataclasses-asdict
syastrov Aug 19, 2021
74ebc6f
Fix tests for quoting changes
syastrov Aug 19, 2021
aef274a
Merge branch 'master' into dataclasses-asdict
97littleleaf11 Nov 17, 2021
79f25db
Merge
97littleleaf11 Jan 18, 2022
f54e503
Fix
97littleleaf11 Jan 18, 2022
9820cfc
Add fixture for tests
97littleleaf11 Jan 18, 2022
9f49cac
Add fixture for tests
97littleleaf11 Jan 18, 2022
fcd1ff5
Add fixture for tests
97littleleaf11 Jan 18, 2022
9fa6a6b
Merge from master
97littleleaf11 Jan 18, 2022
6e1585d
Fix
97littleleaf11 Jan 18, 2022
d6f9170
Merge branch 'master' of https://github.com/python/mypy into HEAD
97littleleaf11 Jan 19, 2022
e226562
Test for a workaround
97littleleaf11 Jan 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/additional_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,13 @@ and :pep:`557`.
Caveats/Known Issues
====================

Some functions in the :py:mod:`dataclasses` module, such as :py:func:`~dataclasses.replace` and :py:func:`~dataclasses.asdict`,
Some functions in the :py:mod:`dataclasses` module, such as :py:func:`~dataclasses.replace`,
have imprecise (too permissive) types. This will be fixed in future releases.

Calls to :py:func:`~dataclasses.asdict` will return a ``TypedDict`` based on the original dataclass
definition, transforming it recursively. There are, however, some limitations. In particular, a precise return type
cannot be inferred for recursive dataclasses, and for calls where ``dict_factory`` is set.

Mypy does not yet recognize aliases of :py:func:`dataclasses.dataclass <dataclasses.dataclass>`, and will
probably never recognize dynamically computed decorators. The following examples
do **not** work:
Expand Down
2 changes: 1 addition & 1 deletion misc/proper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def proper_types_hook(ctx: FunctionContext) -> Type:


def get_proper_type_instance(ctx: FunctionContext) -> Instance:
types = ctx.api.modules['mypy.types'] # type: ignore
types = ctx.api.modules['mypy.types']
proper_type_info = types.names['ProperType']
assert isinstance(proper_type_info.node, TypeInfo)
return Instance(proper_type_info.node, [])
Expand Down
42 changes: 42 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5118,6 +5118,21 @@ def named_type(self, name: str) -> Instance:
any_type = AnyType(TypeOfAny.from_omitted_generics)
return Instance(node, [any_type] * len(node.defn.type_vars))

def named_type_or_none(self, qualified_name: str,
args: Optional[List[Type]] = None) -> Optional[Instance]:
sym = self.lookup_fully_qualified_or_none(qualified_name)
if not sym:
return None
node = sym.node
if isinstance(node, TypeAlias):
assert isinstance(node.target, Instance) # type: ignore
node = node.target.type
assert isinstance(node, TypeInfo), node
if args is not None:
# TODO: assert len(args) == len(node.defn.type_vars)
return Instance(node, args)
return Instance(node, [AnyType(TypeOfAny.unannotated)] * len(node.defn.type_vars))

def named_generic_type(self, name: str, args: List[Type]) -> Instance:
"""Return an instance with the given name and type arguments.

Expand All @@ -5129,6 +5144,13 @@ def named_generic_type(self, name: str, args: List[Type]) -> Instance:
# TODO: assert len(args) == len(info.defn.type_vars)
return Instance(info, args)

def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> None:
if target is None:
target = self.tscope.current_target()

cur_module_node = self.modules[self.tscope.current_module_id()]
cur_module_node.plugin_deps.setdefault(trigger, set()).add(target)

def lookup_typeinfo(self, fullname: str) -> TypeInfo:
# Assume that the name refers to a class.
sym = self.lookup_qualified(fullname)
Expand Down Expand Up @@ -5200,6 +5222,26 @@ def lookup_qualified(self, name: str) -> SymbolTableNode:
msg = "Failed qualified lookup: '{}' (fullname = '{}')."
raise KeyError(msg.format(last, name))

def lookup_fully_qualified_or_none(self, fullname: str) -> Optional[SymbolTableNode]:
"""Lookup a fully qualified name that refers to a module-level definition.

Don't assume that the name is defined. This happens in the global namespace --
the local module namespace is ignored. This does not dereference indirect
refs.

Note that this can't be used for names nested in class namespaces.
"""
# TODO: unify/clean-up/simplify lookup methods, see #4157.
# TODO: support nested classes (but consider performance impact,
# we might keep the module level only lookup for thing like 'builtins.int').
assert '.' in fullname
module, name = fullname.rsplit('.', maxsplit=1)
if module not in self.modules:
return None
filenode = self.modules[module]
result = filenode.names.get(name)
return result

@contextmanager
def enter_partial_types(self, *, is_function: bool = False,
is_class: bool = False) -> Iterator[None]:
Expand Down
14 changes: 14 additions & 0 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class CheckerPluginInterface:
docstrings in checker.py for more details.
"""

modules: Dict[str, MypyFile]
msg: MessageBuilder
options: Options
path: str
Expand All @@ -234,6 +235,19 @@ def named_generic_type(self, name: str, args: List[Type]) -> Instance:
"""Construct an instance of a builtin type with given type arguments."""
raise NotImplementedError

@abstractmethod
def named_type_or_none(self, qualified_name: str,
args: Optional[List[Type]] = None) -> Optional[Instance]:
raise NotImplementedError

@abstractmethod
def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> None:
"""Specify semantic dependencies for generated methods/variables.

See the same function on SemanticAnalyzerPluginInterface for more details.
"""
raise NotImplementedError


@trait
class SemanticAnalyzerPluginInterface:
Expand Down
20 changes: 16 additions & 4 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import List, Optional, Union
from collections import OrderedDict
from typing import List, Optional, Union, Set

from mypy.nodes import (
ARG_POS, MDEF, Argument, Block, CallExpr, ClassDef, Expression, SYMBOL_FUNCBASE_TYPES,
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, JsonDict,
)
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface, CheckerPluginInterface
from mypy.semanal import set_callable_name
from mypy.semanal_typeddict import get_anonymous_typeddict_type
from mypy.types import (
CallableType, Overloaded, Type, TypeVarType, deserialize_type, get_proper_type,
CallableType, Overloaded, Type, deserialize_type, get_proper_type,
TypedDictType, TypeVarType
)
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
Expand Down Expand Up @@ -184,8 +187,17 @@ def add_attribute_to_class(


def deserialize_and_fixup_type(
data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface
data: Union[str, JsonDict],
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface]
) -> Type:
typ = deserialize_type(data)
typ.accept(TypeFixer(api.modules, allow_missing=False))
return typ


def make_anonymous_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, Type]',
required_keys: Set[str]) -> TypedDictType:
fallback = get_anonymous_typeddict_type(api)
assert fallback is not None
return TypedDictType(fields, required_keys=required_keys,
fallback=fallback)
118 changes: 110 additions & 8 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
"""Plugin that provides support for dataclasses."""

from typing import Dict, List, Set, Tuple, Optional
from collections import OrderedDict
from typing import Dict, List, Set, Tuple, Optional, Union

from typing_extensions import Final

from mypy.maptype import map_instance_to_supertype
from mypy.nodes import (
ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_POS, ARG_STAR, ARG_STAR2, MDEF,
Argument, AssignmentStmt, CallExpr, Context, Expression, JsonDict,
Argument, AssignmentStmt, CallExpr, Context, Expression, JsonDict,
NameExpr, RefExpr, SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr,
PlaceholderNode
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugin import ClassDefContext, FunctionContext, CheckerPluginInterface
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, add_attribute_to_class,
add_method, _get_decorator_bool_argument, make_anonymous_typeddict,
deserialize_and_fixup_type, add_attribute_to_class
)
from mypy.typeops import map_type_from_supertype
from mypy.type_visitor import TypeTranslator
from mypy.types import (
Type, Instance, NoneType, TypeVarType, CallableType, TupleType, LiteralType,
get_proper_type, AnyType, TypeOfAny,
get_proper_type, AnyType, TypeOfAny, TypeAliasType, TypeType
)
from mypy.server.trigger import make_wildcard_trigger
from mypy.server.trigger import make_wildcard_trigger, make_trigger

# The set of decorators that generate dataclasses.
dataclass_makers: Final = {
Expand All @@ -34,6 +40,10 @@
SELF_TVAR_NAME: Final = "_DT"


def is_type_dataclass(info: TypeInfo) -> bool:
return 'dataclass' in info.metadata


class DataclassAttribute:
def __init__(
self,
Expand Down Expand Up @@ -90,7 +100,8 @@ def serialize(self) -> JsonDict:

@classmethod
def deserialize(
cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface
cls, info: TypeInfo, data: JsonDict,
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface]
) -> 'DataclassAttribute':
data = data.copy()
if data.get('kw_only') is None:
Expand Down Expand Up @@ -390,7 +401,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
# we'll have unmodified attrs laying around.
all_attrs = attrs.copy()
for info in cls.info.mro[1:-1]:
if 'dataclass' not in info.metadata:
if not is_type_dataclass(info):
continue

super_attrs = []
Expand Down Expand Up @@ -546,3 +557,94 @@ def _collect_field_args(expr: Expression,
args[name] = arg
return True, args
return False, {}


def asdict_callback(ctx: FunctionContext) -> Type:
"""Check that calls to asdict pass in a dataclass. If possible, return TypedDicts."""
positional_arg_types = ctx.arg_types[0]

if positional_arg_types:
dataclass_instance = get_proper_type(positional_arg_types[0])
if isinstance(dataclass_instance, Instance):
if is_type_dataclass(dataclass_instance.type):
if len(ctx.arg_types) == 1:
# Can only infer a more precise type for calls where dict_factory is not set.
return _asdictify(ctx.api, dataclass_instance)

return ctx.default_return_type


class AsDictVisitor(TypeTranslator):
def __init__(self, api: CheckerPluginInterface) -> None:
self.api = api
self.seen_dataclasses: Set[str] = set()

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
return t.copy_modified(args=[a.accept(self) for a in t.args])

def visit_instance(self, t: Instance) -> Type:
info = t.type
any_type = AnyType(TypeOfAny.implementation_artifact)
if is_type_dataclass(info):
if info.fullname in self.seen_dataclasses:
# Recursive types not supported, so fall back to Dict[str, Any]
# Note: Would be nicer to fallback to default_return_type, but that is Any
# (due to overloads?)
return self.api.named_generic_type(
'builtins.dict', [self.api.named_generic_type('builtins.str', []), any_type])
attrs = info.metadata['dataclass']['attributes']
fields: OrderedDict[str, Type] = OrderedDict()
self.seen_dataclasses.add(info.fullname)
for data in attrs:
attr = DataclassAttribute.deserialize(info, data, self.api)
self.api.add_plugin_dependency(make_trigger(info.fullname + "." + attr.name))
# TODO: attr.name should be available
sym_node = info.names.get(attr.name, None)
if sym_node is None:
continue
attr_type = sym_node.type
assert attr_type is not None
fields[attr.name] = attr_type.accept(self)
self.seen_dataclasses.remove(info.fullname)
return make_anonymous_typeddict(self.api, fields=fields,
required_keys=set(fields.keys()))
elif info.has_base('builtins.list'):
supertype = map_instance_to_supertype(t, self.api.named_generic_type(
'builtins.list', [any_type]).type)
return self.api.named_generic_type('builtins.list',
self.translate_types(supertype.args))
elif info.has_base('builtins.dict'):
supertype = map_instance_to_supertype(t, self.api.named_generic_type(
'builtins.dict', [any_type, any_type]).type)
return self.api.named_generic_type('builtins.dict',
self.translate_types(supertype.args))
return t

def visit_tuple_type(self, t: TupleType) -> Type:
if t.partial_fallback.type.is_named_tuple:
# For namedtuples, return Any. To properly support transforming namedtuples,
# we would have to generate a partial_fallback type for the TupleType and add it
# to the symbol table. It's not currently possible to do this via the
# CheckerPluginInterface. Ideally it would use the same code as
# NamedTupleAnalyzer.build_namedtuple_typeinfo.
return AnyType(TypeOfAny.implementation_artifact)
# Note: Tuple subclasses not supported, hence overriding the fallback
return t.copy_modified(items=self.translate_types(t.items),
fallback=self.api.named_generic_type('builtins.tuple', []))

def visit_callable_type(self, t: CallableType) -> Type:
# Leave e.g. Callable[[SomeDataclass], SomeDataclass] alone
return t

def visit_type_type(self, t: TypeType) -> Type:
# Leave e.g. Type[SomeDataclass] alone
return t


def _asdictify(api: CheckerPluginInterface, typ: Type) -> Type:
"""Convert dataclasses into TypedDicts, recursively looking into built-in containers.

It will look for dataclasses inside of tuples, lists, and dicts and convert them to
TypedDicts.
"""
return typ.accept(AsDictVisitor(api))
8 changes: 6 additions & 2 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Callable, Optional, List

from mypy import message_registry
from mypy.nodes import Expression, StrExpr, IntExpr, DictExpr, UnaryExpr
from mypy.nodes import (
Expression, StrExpr, IntExpr, DictExpr, UnaryExpr
)
from mypy.plugin import (
Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext,
CheckerPluginInterface,
Expand All @@ -22,7 +24,7 @@ class DefaultPlugin(Plugin):

def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
from mypy.plugins import ctypes, singledispatch
from mypy.plugins import ctypes, singledispatch, dataclasses

if fullname in ('contextlib.contextmanager', 'contextlib.asynccontextmanager'):
return contextmanager_callback
Expand All @@ -32,6 +34,8 @@ def get_function_hook(self, fullname: str
return ctypes.array_constructor_callback
elif fullname == 'functools.singledispatch':
return singledispatch.create_singledispatch_function_callback
elif fullname == 'dataclasses.asdict':
return dataclasses.asdict_callback
return None

def get_method_signature_hook(self, fullname: str
Expand Down
18 changes: 12 additions & 6 deletions mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Semantic analysis of TypedDict definitions."""

from mypy.backports import OrderedDict
from typing import Optional, List, Set, Tuple
from typing import Optional, List, Set, Tuple, Union
from typing_extensions import Final

from mypy.plugin import CheckerPluginInterface
from mypy.types import (
Type, AnyType, TypeOfAny, TypedDictType, TPDICT_NAMES, RequiredType,
Type, AnyType, TypeOfAny, TypedDictType, TPDICT_NAMES, RequiredType, Instance
)
from mypy.nodes import (
CallExpr, TypedDictExpr, Expression, NameExpr, Context, StrExpr, BytesExpr, UnicodeExpr,
Expand Down Expand Up @@ -362,10 +363,7 @@ def build_typeddict_typeinfo(self, name: str, items: List[str],
types: List[Type],
required_keys: Set[str],
line: int) -> TypeInfo:
# Prefer typing then typing_extensions if available.
fallback = (self.api.named_type_or_none('typing._TypedDict', []) or
self.api.named_type_or_none('typing_extensions._TypedDict', []) or
self.api.named_type_or_none('mypy_extensions._TypedDict', []))
fallback = get_anonymous_typeddict_type(self.api)
assert fallback is not None
info = self.api.basic_new_typeinfo(name, fallback, line)
info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), required_keys,
Expand All @@ -383,3 +381,11 @@ def fail(self, msg: str, ctx: Context, *, code: Optional[ErrorCode] = None) -> N

def note(self, msg: str, ctx: Context) -> None:
self.api.note(msg, ctx)


def get_anonymous_typeddict_type(
api: Union[SemanticAnalyzerInterface, CheckerPluginInterface]) -> Optional[Instance]:
# Prefer typing then typing_extensions if available.
return (api.named_type_or_none('typing._TypedDict', []) or
api.named_type_or_none('typing_extensions._TypedDict', []) or
api.named_type_or_none('mypy_extensions._TypedDict', []))
Loading