Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: cached properties implementation using descriptor rather than ge… #1357

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 38 additions & 66 deletions src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,34 @@
_DEFAULT_ON_SETATTR = setters.pipe(setters.convert, setters.validate)


class Desc:
def __init__(self, func, old_desc):
if isinstance(old_desc, Desc):
old_desc = old_desc.old_desc
self.old_desc = old_desc
self.func = func
functools.update_wrapper(self, func)
# self.attrname = func.__name__
self.__doc__ = func.__doc__

def __get__(self, obj, objtype=None):
print(f" getting {obj}, {objtype}")
try:
result = self.old_desc.__get__(obj)
print(f"resulting in {result}")
return result
except TypeError:
pass
except AttributeError as e:
print(f"exceptions = {e}")
if obj is None:
return self
print(f"{self.old_desc=}")
result = self.func(obj)
self.old_desc.__set__(obj, result)
return result


class _Nothing(enum.Enum):
"""
Sentinel to indicate the lack of a value when `None` is ambiguous.
Expand Down Expand Up @@ -477,64 +505,6 @@ def _transform_attrs(
return _Attributes((AttrsClass(attrs), base_attrs, base_attr_map))


def _make_cached_property_getattr(cached_properties, original_getattr, cls):
lines = [
# Wrapped to get `__class__` into closure cell for super()
# (It will be replaced with the newly constructed class after construction).
"def wrapper(_cls):",
" __class__ = _cls",
" def __getattr__(self, item, cached_properties=cached_properties, original_getattr=original_getattr, _cached_setattr_get=_cached_setattr_get):",
" func = cached_properties.get(item)",
" if func is not None:",
" result = func(self)",
" _setter = _cached_setattr_get(self)",
" _setter(item, result)",
" return result",
]
if original_getattr is not None:
lines.append(
" return original_getattr(self, item)",
)
else:
lines.extend(
[
" try:",
" return super().__getattribute__(item)",
" except AttributeError:",
" if not hasattr(super(), '__getattr__'):",
" raise",
" return super().__getattr__(item)",
" original_error = f\"'{self.__class__.__name__}' object has no attribute '{item}'\"",
" raise AttributeError(original_error)",
]
)

lines.extend(
[
" return __getattr__",
"__getattr__ = wrapper(_cls)",
]
)

unique_filename = _generate_unique_filename(cls, "getattr")

glob = {
"cached_properties": cached_properties,
"_cached_setattr_get": _OBJ_SETATTR.__get__,
"original_getattr": original_getattr,
}

return _make_method(
"__getattr__",
"\n".join(lines),
unique_filename,
glob,
locals={
"_cls": cls,
},
)


def _frozen_setattrs(self, name, value):
"""
Attached to frozen classes as __setattr__.
Expand Down Expand Up @@ -767,6 +737,7 @@ def _create_slots_class(self):
# Traverse the MRO to collect existing slots
# and check for an existing __weakref__.
existing_slots = {}
existing_cached_property = []
weakref_inherited = False
for base_cls in self._cls.__mro__[1:-1]:
if base_cls.__dict__.get("__weakref__", None) is not None:
Expand All @@ -777,6 +748,9 @@ def _create_slots_class(self):
for name in getattr(base_cls, "__slots__", [])
}
)
existing_cached_property.update(
getattr(base_cls, "__attrs_cached_properties__", [])
)

base_names = set(self._base_names)

Expand All @@ -795,29 +769,25 @@ def _create_slots_class(self):
if isinstance(cached_property, functools.cached_property)
}

cd["__attrs_cached_properties__"] = list(cached_properties.keys())

# Collect methods with a `__class__` reference that are shadowed in the new class.
# To know to update them.
property_calls = {}
additional_closure_functions_to_update = []
if cached_properties:
class_annotations = _get_annotations(self._cls)
for name, func in cached_properties.items():
# Add cached properties to names for slotting.
names += (name,)
# Clear out function from class to avoid clashing.

del cd[name]
additional_closure_functions_to_update.append(func)
annotation = inspect.signature(func).return_annotation
if annotation is not inspect.Parameter.empty:
class_annotations[name] = annotation

original_getattr = cd.get("__getattr__")
if original_getattr is not None:
additional_closure_functions_to_update.append(original_getattr)

cd["__getattr__"] = _make_cached_property_getattr(
cached_properties, original_getattr, self._cls
)

# We only add the names of attributes that aren't inherited.
# Setting __slots__ to inherited attributes wastes memory.
slot_names = [name for name in names if name not in base_names]
Expand All @@ -842,6 +812,8 @@ def _create_slots_class(self):

# Create new class based on old class and our methods.
cls = type(self._cls)(self._cls.__name__, self._cls.__bases__, cd)
for name, func in cached_properties.items():
setattr(cls, name, Desc(func, getattr(cls, name)))

# The following is a fix for
# <https://github.com/python-attrs/attrs/issues/102>.
Expand Down
19 changes: 19 additions & 0 deletions tests/test_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,25 @@ def f_2(self):
assert obj.f_2 == 2


def test_slots_cached_property_retains_doc():
"""
Cached property's docstring is retained.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
"""
This is a docstring.
"""
return self.x

assert "This is a docstring." in A.f.__doc__


@attr.s(slots=True)
class A:
x = attr.ib()
Expand Down
Loading