Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ForwardRef #215

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/cattr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .converters import Converter, GenConverter, UnstructureStrategy
from .gen import override
from ._compat import resolve_types

__all__ = (
"global_converter",
Expand All @@ -11,6 +12,7 @@
"Converter",
"GenConverter",
"override",
"resolve_types",
)


Expand Down
86 changes: 83 additions & 3 deletions src/cattr/_compat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import sys
from dataclasses import MISSING
from dataclasses import fields as dataclass_fields
from dataclasses import is_dataclass
from typing import Any, Dict, FrozenSet, List
from dataclasses import is_dataclass, make_dataclass
from dataclasses import Field as DataclassField
from typing import Any, Dict, FrozenSet, List, Optional
from typing import Mapping as TypingMapping
from typing import MutableMapping as TypingMutableMapping
from typing import MutableSequence as TypingMutableSequence
Expand All @@ -13,7 +14,8 @@

from attr import NOTHING, Attribute, Factory
from attr import fields as attrs_fields
from attr import resolve_types
from attr import resolve_types as attrs_resolve_types
from attr import has as attrs_has

version_info = sys.version_info[0:3]
is_py37 = version_info[:2] == (3, 7)
Expand Down Expand Up @@ -373,3 +375,81 @@ def copy_with(type, args):

def is_generic_attrs(type):
return is_generic(type) and has(type.__origin__)


def resolve_types(
cls: Any,
globalns: Optional[Dict[str, Any]] = None,
localns: Optional[Dict[str, Any]] = None,
):
"""
More generic version of `attrs.resolve_types`.

While `attrs.resolve_types` resolves ForwardRefs
only for for the fields of a `attrs` classes (and
fails otherwise), this `resolve_types` also
supports dataclasses and type aliases.

Even though often ForwardRefs outside of classes as e.g.
in type aliases can generally not be resolved automatically
(i.e. without explicit `globalns`, and `localns` context),
this is indeed sometimes possible and supported by Python.
This is for instance the case if the (internal) `module`
parameter of `ForwardRef` is set or we are dealing with
ForwardRefs in `TypedDict` or `NewType` types.
There may also be additions to typing.py module that there
will be more non-class types where ForwardRefs can automatically
be resolved.

See
https://bugs.python.org/issue41249
https://bugs.python.org/issue46369
https://bugs.python.org/issue46373
"""
allfields: List[Union[Attribute, DataclassField]] = []

if attrs_has(cls):
try:
attrs_resolve_types(cls, globalns, localns)
except NameError:
# ignore if ForwardRef cannot be resolved.
# We still want to allow manual registration of
# ForwardRefs (which will work with unevaluated ForwardRefs)
pass
allfields = fields(cls)
else:
if not is_dataclass(cls):
# we cannot call get_type_hints on type aliases
# directly, so put it in a field of a helper
# dataclass.
cls = make_dataclass("_resolve_helper", [("test", cls)])

# prevent resolving from cls.__module__ (which is what
# get_type_hints does if localns/globalns == None), as
# it would not be correct here.
# See: https://stackoverflow.com/questions/49457441
if globalns is None:
globalns = {}
if localns is None:
localns = {}
else:
allfields = dataclass_fields(cls)

try:
type_hints = get_type_hints(cls, globalns, localns)
for field in allfields:
field.type = type_hints.get(field.name, field.type)
except NameError:
pass
if not is_py39_plus:
# 3.8 and before did not recursively resolve ForwardRefs
# (likely a Python bug). Hence with PEP 563 (where all type
# annotations are initially treated as ForwardRefs) we
# need twice evaluation to properly resolve explicit ForwardRefs
fieldlist = [(field.name, field.type) for field in allfields]
cls2 = make_dataclass("_resolve_helper2", fieldlist)
cls2.__module__ = cls.__module__
try:
get_type_hints(cls2, globalns, localns)
except NameError:
pass
91 changes: 78 additions & 13 deletions src/cattr/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from dataclasses import Field
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, ForwardRef, Optional
from typing import Tuple, Type, TypeVar, Union

from attr import Attribute
from attr import has as attrs_has
from attr import resolve_types

from ._compat import (
FrozenSetSubscriptable,
Expand Down Expand Up @@ -35,6 +35,7 @@
is_sequence,
is_tuple,
is_union_type,
resolve_types,
)
from .disambiguators import create_uniq_field_dis_func
from .dispatch import MultiStrategyDispatch
Expand Down Expand Up @@ -141,6 +142,11 @@ def __init__(
(_subclass(Enum), self._unstructure_enum),
(has, self._unstructure_attrs),
(is_union_type, self._unstructure_union),
(
lambda o: o.__class__ is ForwardRef,
self._gen_unstructure_forwardref,
True,
),
]
)

Expand Down Expand Up @@ -173,6 +179,11 @@ def __init__(
),
(is_optional, self._structure_optional),
(has, self._structure_attrs),
(
lambda o: o.__class__ is ForwardRef,
self._gen_structure_forwardref,
True,
),
]
)
# Strings are sequences.
Expand Down Expand Up @@ -215,22 +226,31 @@ def register_unstructure_hook(
The converter function should take an instance of the class and return
its Python equivalent.
"""
if attrs_has(cls):
resolve_types(cls)
resolve_types(cls)
if is_union_type(cls):
self._unstructure_func.register_func_list(
[(lambda t: t == cls, func)]
)
else:
self._unstructure_func.register_cls_list([(cls, func)])
singledispatch_ok = isinstance(cls, type) and not is_generic(cls)
self._unstructure_func.register_cls_list(
[(cls, func)], direct=not singledispatch_ok
)

def register_unstructure_hook_func(
self, check_func: Callable[[Any], bool], func: Callable[[T], Any]
):
"""Register a class-to-primitive converter function for a class, using
a function to check if it's a match.
"""
self._unstructure_func.register_func_list([(check_func, func)])

def factory_func(cls: T) -> Callable[[T], Any]:
resolve_types(cls)
return func

self._unstructure_func.register_func_list(
[(check_func, factory_func, True)]
)

def register_unstructure_hook_factory(
self,
Expand All @@ -246,7 +266,14 @@ def register_unstructure_hook_factory(
A factory is a callable that, given a type, produces an unstructuring
hook for that type. This unstructuring hook will be cached.
"""
self._unstructure_func.register_func_list([(predicate, factory, True)])

def factory_func(cls: T) -> Callable[[Any], Any]:
resolve_types(cls)
return factory(cls)

self._unstructure_func.register_func_list(
[(predicate, factory_func, True)]
)

def register_structure_hook(
self, cl: Any, func: Callable[[Any, Type[T]], T]
Expand All @@ -260,13 +287,15 @@ def register_structure_hook(
and return the instance of the class. The type may seem redundant, but
is sometimes needed (for example, when dealing with generic classes).
"""
if attrs_has(cl):
resolve_types(cl)
resolve_types(cl)
if is_union_type(cl):
self._union_struct_registry[cl] = func
self._structure_func.clear_cache()
else:
self._structure_func.register_cls_list([(cl, func)])
singledispatch_ok = isinstance(cl, type) and not is_generic(cl)
self._structure_func.register_cls_list(
[(cl, func)], direct=not singledispatch_ok
)

def register_structure_hook_func(
self,
Expand All @@ -276,12 +305,19 @@ def register_structure_hook_func(
"""Register a class-to-primitive converter function for a class, using
a function to check if it's a match.
"""
self._structure_func.register_func_list([(check_func, func)])

def factory_func(cls: T) -> Callable[[Any, Type[T]], T]:
resolve_types(cls)
return func

self._structure_func.register_func_list(
[(check_func, factory_func, True)]
)

def register_structure_hook_factory(
self,
predicate: Callable[[Any], bool],
factory: Callable[[Any], Callable[[Any], Any]],
factory: Callable[[Any], Callable[[Any, Type[T]], T]],
) -> None:
"""
Register a hook factory for a given predicate.
Expand All @@ -292,7 +328,14 @@ def register_structure_hook_factory(
A factory is a callable that, given a type, produces a structuring
hook for that type. This structuring hook will be cached.
"""
self._structure_func.register_func_list([(predicate, factory, True)])

def factory_func(cls: T) -> Callable[[Any, Type[T]], T]:
resolve_types(cls)
return factory(cls)

self._structure_func.register_func_list(
[(predicate, factory_func, True)]
)

def structure(self, obj: Any, cl: Type[T]) -> T:
"""Convert unstructured Python data structures to structured data."""
Expand Down Expand Up @@ -355,6 +398,17 @@ def _unstructure_union(self, obj):
"""
return self._unstructure_func.dispatch(obj.__class__)(obj)

def _gen_unstructure_forwardref(self, cl):
if not cl.__forward_evaluated__:
raise ValueError(
f"ForwardRef({cl.__forward_arg__!r}) is not resolved."
" Consider resolving the parent type alias"
" manually with `cattr.resolve_types`"
" in the defining module or by registering a hook."
)
cl = cl.__forward_value__
return lambda o: self._unstructure_func.dispatch(cl)(o)

# Python primitives to classes.

def _structure_error(self, _, cl):
Expand Down Expand Up @@ -557,6 +611,17 @@ def _structure_tuple(self, obj, tup: Type[T]):
for t, e in zip(tup_params, obj)
)

def _gen_structure_forwardref(self, cl):
if not cl.__forward_evaluated__:
raise ValueError(
f"ForwardRef({cl.__forward_arg__!r}) is not resolved."
" Consider resolving the parent type alias"
" manually with `cattr.resolve_types`"
" in the defining module or by registering a hook."
)
cl = cl.__forward_value__
return lambda o, t: self._structure_func.dispatch(cl)(o, cl)

@staticmethod
def _get_dis_func(union):
# type: (Type) -> Callable[..., Type]
Expand Down
13 changes: 6 additions & 7 deletions src/cattr/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

import attr
from attr import NOTHING, resolve_types
from attr import NOTHING

from ._compat import (
adapted_fields,
Expand All @@ -24,6 +24,7 @@
is_annotated,
is_bare,
is_generic,
resolve_types,
)
from ._generics import deep_copy_with

Expand Down Expand Up @@ -63,9 +64,8 @@ def make_dict_unstructure_fn(
origin = get_origin(cl)
attrs = adapted_fields(origin or cl) # type: ignore

if any(isinstance(a.type, str) for a in attrs):
# PEP 563 annotations - need to be resolved.
resolve_types(cl)
# PEP 563 annotations and ForwardRefs - need to be resolved.
resolve_types(cl)

mapping = {}
if is_generic(cl):
Expand Down Expand Up @@ -245,9 +245,8 @@ def make_dict_structure_fn(
attrs = adapted_fields(cl)
is_dc = is_dataclass(cl)

if any(isinstance(a.type, str) for a in attrs):
# PEP 563 annotations - need to be resolved.
resolve_types(cl)
# PEP 563 annotations and ForwardRefs - need to be resolved.
resolve_types(cl)

lines.append(f"def {fn_name}(o, *_):")
lines.append(" res = {")
Expand Down
37 changes: 37 additions & 0 deletions tests/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from dataclasses import dataclass
from typing import List, Tuple

from attrs import define

from cattr import resolve_types


@dataclass
class DClass:
ival: "IntType_1"
ilist: List["IntType_2"]


@define
class AClass:
ival: "IntType_3"
ilist: List["IntType_4"]


@define
class ModuleClass:
a: int


IntType_1 = int
IntType_2 = int
IntType_3 = int
IntType_4 = int

RecursiveTypeAliasM = List[Tuple[ModuleClass, "RecursiveTypeAliasM"]]
RecursiveTypeAliasM_1 = List[Tuple[ModuleClass, "RecursiveTypeAliasM_1"]]
RecursiveTypeAliasM_2 = List[Tuple[ModuleClass, "RecursiveTypeAliasM_2"]]

resolve_types(RecursiveTypeAliasM, globals(), locals())
resolve_types(RecursiveTypeAliasM_1, globals(), locals())
resolve_types(RecursiveTypeAliasM_2, globals(), locals())
Loading