diff --git a/amaranth/lib/wiring.py b/amaranth/lib/wiring.py index e34d8955c..75eac1c37 100644 --- a/amaranth/lib/wiring.py +++ b/amaranth/lib/wiring.py @@ -491,6 +491,19 @@ def __repr__(self): return super().__repr__() +def _gettypeattr(obj, attr): + # Resolve the attribute on the object's class, without triggering the descriptor protocol for + # attributes that are class methods, etc. + for cls in type(obj).__mro__: + try: + return cls.__dict__[attr] + except KeyError: + pass + # In case there is `__getattr__` on the metaclass, or just to generate an `AttributeError` with + # the standard message. + return type(obj).attr + + # To simplify implementation and reduce API surface area `FlippedSignature` is made final. This # restriction could be lifted if there is a compelling use case. @final @@ -524,15 +537,29 @@ def __eq__(self, other): is_compliant = Signature.is_compliant # FIXME: document this logic + + # Because we would like to forward attribute access (other than what is explicitly overridden) + # to the unflipped signature, including access via e.g. @property-decorated functions, we have + # to reimplement the Python decorator protocol here. Note that in all of these functions, there + # are two possible exits via `except AttributeError`: from `getattr` and from `.__get__()`. + def __getattr__(self, name): - value = getattr(self.__unflipped, name) - if inspect.ismethod(value): - return types.MethodType(value.__func__, self) - else: - return value + try: # descriptor first + return _gettypeattr(self.__unflipped, name).__get__(self, type(self.__unflipped)) + except AttributeError: + return getattr(self.__unflipped, name) def __setattr__(self, name, value): - return setattr(self.__unflipped, name, value) + try: # descriptor first + _gettypeattr(self.__unflipped, name).__set__(self, value) + except AttributeError: + setattr(self.__unflipped, name, value) + + def __delattr__(self, name): + try: # descriptor first + _gettypeattr(self.__unflipped, name).__delete__(self) + except AttributeError: + delattr(self.__unflipped, name) def create(self, *, path=()): return flipped(self.__unflipped.create(path=path)) @@ -566,18 +593,35 @@ def __eq__(self, other): return type(self) is type(other) and self.__unflipped == other.__unflipped # FIXME: document this logic + + # See the note in ``FlippedSignature``. In addition, these accessors also handle flipping of + # an interface member. + def __getattr__(self, name): - value = getattr(self.__unflipped, name) - if inspect.ismethod(value): - return types.MethodType(value.__func__, self) - elif name in self.__unflipped.signature.members and \ - self.__unflipped.signature.members[name].is_signature: - return flipped(value) + if (name in self.__unflipped.signature.members and + self.__unflipped.signature.members[name].is_signature): + return flipped(getattr(self.__unflipped, name)) else: - return value + try: # descriptor first + return _gettypeattr(self.__unflipped, name).__get__(self, type(self.__unflipped)) + except AttributeError: + return getattr(self.__unflipped, name) def __setattr__(self, name, value): - return setattr(self.__unflipped, name, value) + if (name in self.__unflipped.signature.members and + self.__unflipped.signature.members[name].is_signature): + setattr(self.__unflipped, name, flipped(value)) + else: + try: # descriptor first + _gettypeattr(self.__unflipped, name).__set__(self, value) + except AttributeError: + setattr(self.__unflipped, name, value) + + def __delattr__(self, name): + try: # descriptor first + _gettypeattr(self.__unflipped, name).__delete__(self) + except AttributeError: + delattr(self.__unflipped, name) def __repr__(self): return f"flipped({self.__unflipped!r})" diff --git a/tests/test_lib_wiring.py b/tests/test_lib_wiring.py index a1f52626d..75a058323 100644 --- a/tests/test_lib_wiring.py +++ b/tests/test_lib_wiring.py @@ -530,7 +530,7 @@ def test_repr(self): sig = Signature({"a": In(1)}).flip() self.assertEqual(repr(sig), "Signature({'a': In(1)}).flip()") - def test_getattr_setattr(self): + def test_getsetdelattr(self): class S(Signature): def __init__(self): super().__init__({}) @@ -539,12 +539,61 @@ def __init__(self): def f(self2): self.assertIsInstance(self2, FlippedSignature) return "f()" + sig = S() fsig = sig.flip() self.assertEqual(fsig.x, 1) self.assertEqual(fsig.f(), "f()") fsig.y = 2 self.assertEqual(sig.y, 2) + del fsig.y + self.assertFalse(hasattr(sig, "y")) + + def test_getsetdelattr_property(self): + class S(Signature): + def __init__(self): + super().__init__({}) + self.x_get_type = None + self.x_set_type = None + self.x_set_val = None + self.x_del_type = None + + @property + def x(self): + self.x_get_type = type(self) + + @x.setter + def x(self, val): + self.x_set_type = type(self) + self.x_set_val = val + + @x.deleter + def x(self): + self.x_del_type = type(self) + + sig = S() + fsig = sig.flip() + fsig.x + fsig.x = 1 + del fsig.x + # Tests both attribute access through the descriptor, and attribute setting without one! + self.assertEqual(sig.x_get_type, type(fsig)) + self.assertEqual(sig.x_set_type, type(fsig)) + self.assertEqual(sig.x_set_val, 1) + self.assertEqual(sig.x_del_type, type(fsig)) + + def test_classmethod(self): + x_type = None + class S(Signature): + @classmethod + def x(cls): + nonlocal x_type + x_type = cls + + sig = S({}) + fsig = sig.flip() + fsig.x() + self.assertEqual(x_type, S) class InterfaceTestCase(unittest.TestCase): @@ -564,8 +613,8 @@ def test_basic(self): r"^flipped\(<.+?\.Interface object at .+>\)$") self.assertIs(flipped(tintf), intf) - def test_getattr_setattr(self): - class I(Interface): + def test_getsetdelattr(self): + class I: signature = Signature({}) def __init__(self): @@ -574,27 +623,82 @@ def __init__(self): def f(self2): self.assertIsInstance(self2, FlippedInterface) return "f()" + intf = I() - tintf = flipped(intf) - self.assertEqual(tintf.x, 1) - self.assertEqual(tintf.f(), "f()") - tintf.y = 2 + fintf = flipped(intf) + self.assertEqual(fintf.x, 1) + self.assertEqual(fintf.f(), "f()") + fintf.y = 2 self.assertEqual(intf.y, 2) + del fintf.y + self.assertFalse(hasattr(intf, "y")) + + def test_getsetdelattr_property(self): + class I: + signature = Signature({}) + + def __init__(self): + self.x_get_type = None + self.x_set_type = None + self.x_set_val = None + self.x_del_type = None + + @property + def x(self): + self.x_get_type = type(self) + + @x.setter + def x(self, val): + self.x_set_type = type(self) + self.x_set_val = val + + @x.deleter + def x(self): + self.x_del_type = type(self) + + intf = I() + fintf = flipped(intf) + fintf.x + fintf.x = 1 + del fintf.x + # Tests both attribute access through the descriptor, and attribute setting without one! + self.assertEqual(intf.x_get_type, type(fintf)) + self.assertEqual(intf.x_set_type, type(fintf)) + self.assertEqual(intf.x_set_val, 1) + self.assertEqual(intf.x_del_type, type(fintf)) + + def test_classmethod(self): + x_type = None + class I: + signature = Signature({}) + + def __init__(self): + pass + + @classmethod + def x(cls): + nonlocal x_type + x_type = cls + + intf = I() + fintf = flipped(intf) + fintf.x() + self.assertEqual(x_type, I) def test_flipped_wrong(self): with self.assertRaisesRegex(TypeError, r"^flipped\(\) can only flip an interface object, not Signature\({}\)$"): flipped(Signature({})) - + def test_create_subclass_flipped(self): class CustomInterface(Interface): def custom_method(self): return 69 - + class CustomSignature(Signature): def create(self, *, path=()): return CustomInterface(self, path=path) - + flipped_interface = CustomSignature({}).flip().create() self.assertTrue(hasattr(flipped_interface, "custom_method")) @@ -630,6 +734,10 @@ def __init__(self): self.assertEqual(ifsub.g.members["h"].flow, In) self.assertEqual(flipped(ifsub).g.members["h"].flow, In) + # This should be a no-op! That requires hooking ``__setattr__``. + flipped(ifsub).a = flipped(ifsub).a + self.assertEqual(ifsub.a.signature.members["f"].flow, In) + class ConnectTestCase(unittest.TestCase): def test_arg_handles_and_signature_attr(self):