diff --git a/README.rst b/README.rst index ac9a74e..543933b 100644 --- a/README.rst +++ b/README.rst @@ -427,6 +427,23 @@ This will create a dataclass that implements an interface. For example:: a = Animal2(height=4.5, species='Giraffe') +This is done by populating the ``__annotations__`` attribute of all interfaces and all direct interface sub-classes +with the interface attribute names of the class. Annotation entries are not created for attributes that already exist +on the class. For example:: + + @dataclasses.dataclass + class FixedHeightAnimal(IAnimal2): + @property + def height(self): + return 12.3 + + def speak(self): + print('Hello, I am a 12.3 metre tall {}', self.height, self.species) + + a = FixedHeightAnimal(species='Dinosaur') + +Because ``height`` exists in the class definition, the ``height`` attribute is not added to the ``__annotations__`` +attribute of ``FixedHeightAnimal`` and it is ignored by the dataclass decorator. Interface Type Information ========================== diff --git a/pure_interface/interface.py b/pure_interface/interface.py index 82bf6a7..f7e5ad1 100644 --- a/pure_interface/interface.py +++ b/pure_interface/interface.py @@ -402,13 +402,16 @@ def _ensure_everything_is_abstract(attributes): return namespace, functions, interface_method_signatures, interface_attribute_names -def _ensure_annotations(names, namespace): - # annotations need to be kept in order, add base-class names first +def _ensure_annotations(names, namespace, base_interfaces): + # annotations need to be kept in order for dataclass decorator # we only want dataclass annotations for attributes that don't already exist annotations = {} + base_annos = {} + for base in reversed(base_interfaces): + base_annos.update(base.__annotations__) for name in names: if name not in annotations and name not in namespace: - annotations[name] = Any + annotations[name] = base_annos.get(name, Any) annotations.update(namespace.get('__annotations__', {})) namespace['__annotations__'] = annotations @@ -540,9 +543,9 @@ def __new__(mcs, clsname, bases, attributes, **kwargs): this_type_is_an_interface = True else: assert 'Interface' in globals() - if Interface in bases and not all(is_interface for cls, is_interface in base_types): - raise InterfaceError('All bases must be interface types when declaring an interface') this_type_is_an_interface = Interface in bases + if this_type_is_an_interface and not all(is_interface for cls, is_interface in base_types): + raise InterfaceError('All bases must be interface types when declaring an interface') interface_method_signatures = dict() interface_attribute_names = list() abstract_properties = set() @@ -566,6 +569,12 @@ def __new__(mcs, clsname, bases, attributes, **kwargs): if is_development: _check_method_signatures(attributes, clsname, interface_method_signatures) + base_interfaces = [bt for bt, is_interface in base_types if is_interface] + if interface_attribute_names and base_interfaces: + # provide interface attributes as annotations so that dataclass decorator creates all attributes + # defined on base interfaces. + _ensure_annotations(interface_attribute_names, attributes, base_interfaces) + if this_type_is_an_interface: if clsname == 'Interface' and attributes.get('__module__', '') == 'pure_interface.interface': namespace = attributes @@ -591,9 +600,6 @@ def __new__(mcs, clsname, bases, attributes, **kwargs): for bt, is_interface in base_types: if not is_interface: class_properties |= set(k for k, v in bt.__dict__.items() if _is_descriptor(v)) - if any(is_interface for bt, is_interface in base_types): - # provide interface attributes as annotations so that dataclass decorator works. - _ensure_annotations(interface_attribute_names, namespace) class_properties |= set(k for k, v in namespace.items() if _is_descriptor(v)) abstract_properties.difference_update(class_properties) partial_implementation = 'pi_partial_implementation' in namespace diff --git a/tests/test_dataclass_support.py b/tests/test_dataclass_support.py index f18c60e..306a437 100644 --- a/tests/test_dataclass_support.py +++ b/tests/test_dataclass_support.py @@ -1,9 +1,10 @@ -from dataclasses import dataclass -import unittest -from pure_interface import * - - -class IFoo(Interface): +from dataclasses import dataclass +import unittest + +from pure_interface import * + + +class IFoo(Interface): a: int b: str @@ -19,6 +20,10 @@ def foo(self): return 'a={}, b={}, c={}'.format(self.a, self.b, self.c) +class IBar(IFoo, Interface): + a: Foo + + class TestDataClasses(unittest.TestCase): def test_data_class(self): try: @@ -67,7 +72,7 @@ def foo(self): return 'a={}, b={}, c={}'.format(self.a, self.b, self.c) f = RoFoo(a=1, c=3) - self.assertEqual({'a', 'c'}, set(RoFoo.__annotations__.keys())) + self.assertEqual({'a': int, 'c': int}, RoFoo.__annotations__) self.assertEqual(1, f.a) self.assertEqual('str', f.b) @@ -80,6 +85,27 @@ def foo(self): return 'a={}, b={}, c={}'.format(self.a, self.b, self.c) f = AFoo(b='str') - self.assertEqual({'b'}, set(AFoo.__annotations__.keys())) + self.assertEqual({'b': str}, AFoo.__annotations__) self.assertEqual(10, f.a) self.assertEqual('str', f.b) + + def test_annotations_override(self): + """ ensure overridden annotations are used correctly """ + @dataclass + class Bar(IBar): + + def foo(self): + return 'a={}, b={}'.format(self.a, self.b) + + self.assertEqual({'a': int, 'b': str}, IFoo.__annotations__) + self.assertEqual({'a': Foo, 'b': str}, IBar.__annotations__) + self.assertEqual({'a': Foo, 'b': str}, Bar.__annotations__) + b = Bar(a=Foo(a=1, b='two'), b='three') + self.assertIsInstance(b.a, Foo) + + def test_non_direct_subclass(self): + """ ensure no extra annotations are added to the class""" + class Baz(Foo): + e: str + + self.assertEqual({'e': str}, Baz.__annotations__)