diff --git a/elasticai/creator/ir/__init__.py b/elasticai/creator/ir/__init__.py new file mode 100644 index 00000000..6473ebae --- /dev/null +++ b/elasticai/creator/ir/__init__.py @@ -0,0 +1,6 @@ +from .abstract_ir_data import AbstractIRData, MandatoryField + +__all__ = [ + "MandatoryField", + "AbstractIRData", +] diff --git a/elasticai/creator/ir/abstract_ir_data/__init__.py b/elasticai/creator/ir/abstract_ir_data/__init__.py new file mode 100644 index 00000000..d5c375de --- /dev/null +++ b/elasticai/creator/ir/abstract_ir_data/__init__.py @@ -0,0 +1,7 @@ +from .abstract_ir_data import AbstractIRData +from .mandatory_field import MandatoryField + +__all__ = [ + "AbstractIRData", + "MandatoryField", +] diff --git a/elasticai/creator/ir/abstract_ir_data/_attributes_descriptor.py b/elasticai/creator/ir/abstract_ir_data/_attributes_descriptor.py new file mode 100644 index 00000000..0383b79a --- /dev/null +++ b/elasticai/creator/ir/abstract_ir_data/_attributes_descriptor.py @@ -0,0 +1,16 @@ +from collections.abc import MutableMapping + +from elasticai.creator.ir.attribute import AttributeT + +from ._has_data import HasData +from ._hiding_dict import _HidingDict + + +class _AttributesDescriptor: + def __init__(self, hidden_names: set[str]): + self._hidden_names = hidden_names + + def __get__( + self, instance: HasData, owner: type[HasData] + ) -> MutableMapping[str, AttributeT]: + return _HidingDict(self._hidden_names, instance.data) diff --git a/elasticai/creator/ir/abstract_ir_data/_has_data.py b/elasticai/creator/ir/abstract_ir_data/_has_data.py new file mode 100644 index 00000000..b5ff2652 --- /dev/null +++ b/elasticai/creator/ir/abstract_ir_data/_has_data.py @@ -0,0 +1,9 @@ +from typing import Protocol, runtime_checkable + +from elasticai.creator.ir.attribute import AttributeT + + +@runtime_checkable +class HasData(Protocol): + @property + def data(self) -> dict[str, AttributeT]: ... diff --git a/elasticai/creator/ir/abstract_ir_data/_hiding_dict.py b/elasticai/creator/ir/abstract_ir_data/_hiding_dict.py new file mode 100644 index 00000000..59684fac --- /dev/null +++ b/elasticai/creator/ir/abstract_ir_data/_hiding_dict.py @@ -0,0 +1,83 @@ +import collections +from collections.abc import Iterable +from itertools import filterfalse +from typing import MutableMapping, TypeVar + +T = TypeVar("T") + + +class _HidingDict(MutableMapping[str, T]): + """Allows to hide keys with `hidden_names` for all read operations. + We use this to implement an attributes field for Nodes that looks like a dictionary, but hides + all mandatory fields. + You can still write to `HidingDict`, e.g., + + >>> d = dict(a="a", b="b") + >>> h = _HidingDict({"a"}, d) + >>> "b", == tuple(h.keys()) + True + >>> "a" in h + False + >>> h["a"] = "c" + >>> h.data["a"] + 'c' + >>> d["a"] + 'c' + >>> "a" in d and "a" in h.data + True + """ + + def __init__(self, hidden_names: Iterable[str], data: dict) -> None: + self.data = data + self._hidden_names = set(hidden_names) + + def __setitem__(self, key: str, value: T): + self.data[key] = value + + def _is_hidden(self, name: str) -> bool: + return name in self._hidden_names + + def __delitem__(self, key: str): + del self.data[key] + + def __iter__(self): + return filterfalse(self._is_hidden, iter(self.data)) + + def __contains__(self, item): + # overriding this should also make class behave correctly for getting items + return item not in self._hidden_names and item in self.data + + def __getitem__(self, item: str) -> T: + return self.data[item] + + def __len__(self): + return len(self.data) + + def get(self, key: str, default=None) -> T: + if key in self: + return self[key] + return default + + def __copy__(self) -> MutableMapping[str, T]: + inst = self.__class__.__new__(self.__class__) + inst.__dict__.update(self.__dict__) + # Create a copy and avoid triggering descriptors + inst.__dict__["data"] = self.__dict__["data"].copy() + return inst + + def copy(self) -> MutableMapping[str, T]: + if self.__class__ is collections.UserDict: + return _HidingDict(self._hidden_names.copy(), self.data.copy()) + import copy + + data = self.data + try: + self.data = {} + c = copy.copy(self) + finally: + self.data = data + c.update(self) + return c + + def __repr__(self) -> str: + return f"HidingDict({', '.join(self._hidden_names)}, data={self.data})" diff --git a/elasticai/creator/ir/abstract_ir_data/abstract_ir_data.py b/elasticai/creator/ir/abstract_ir_data/abstract_ir_data.py new file mode 100644 index 00000000..6a9c7859 --- /dev/null +++ b/elasticai/creator/ir/abstract_ir_data/abstract_ir_data.py @@ -0,0 +1,171 @@ +import inspect +import sys +from abc import abstractmethod +from collections.abc import Iterable +from typing import Any, Callable, TypeVar + +from elasticai.creator.ir.attribute import AttributeT + +from ._attributes_descriptor import _AttributesDescriptor +from .mandatory_field import MandatoryField, TransformableMandatoryField + +if sys.version_info.minor > 10: + from typing import Self +else: + Self = TypeVar("Self", bound="AbstractIrData") + + +class AbstractIRData: + """ + This class should provide a way to easily create new wrappers around dictionaries. + + It is supposed to be used together with the `MandatoryField` class. + Every child of `AbstractIRData` is expected to have a constructor that takes a dictionary. + That dictionary is not copied, but instead can be shared with other Node classes. + + Most of the private functions in this class deal with handling arguments of the classmethod + `new`, that is used to create new nodes from scratch. + + The `attributes` attribute of the class provides a dict-like object, that hides all keys that are associated + with mandatory fields. + + The purpose of this class is to provide a way to easily write new wrappers around dictionaries, that let us customize + access, while still allowing static type annotations. + """ + + __slots__ = ("data",) + attributes = _AttributesDescriptor(set()) + + def __init__(self: Self, data: dict[str, AttributeT]): + """IMPORTANT: Do not override this. If you want to create a function that creates new nodes of your subtype, + override the `new` method instead. + """ + for k in self._mandatory_fields(): + if k not in data: + raise ValueError(f"Missing mandatory field {k}") + self.data = data + + @classmethod + def _do_new(cls, *args, **kwargs) -> Self: + """This is here for your convenience to be called in `new`.""" + cls.__validate_arguments(args, kwargs) + data = cls.__turn_arguments_into_data_dict(args, kwargs) + return cls(data) + + @classmethod + @abstractmethod + def new(cls, *args, **kwargs) -> Self: + """Create a new node by creating a new dictionary from args and kwargs. + + Use this for creation of new nodes from inline code. This is typically also where you want to provide + type hints for users via `@overload`. You can delegate to the `_do_new()` method of `BaseNode` + """ + ... + + def as_dict(self: Self) -> dict[str, AttributeT]: + return self.data + + @classmethod + def from_dict(cls: type[Self], data: dict[str, AttributeT]) -> Self: + return cls(data) + + def __eq__(self: Self, other: object) -> bool: + if hasattr(other, "data") and isinstance(other.data, dict): + return self.data == other.data + else: + return False + + @classmethod + def __turn_arguments_into_data_dict( + cls, args: tuple[Any], kwargs: dict[str, Any] + ) -> dict[str, Any]: + data = cls.__extract_attributes_from_args(args, kwargs) + data.update(cls.__get_kwargs_without_attributes(kwargs)) + data.update(cls.__get_args_as_kwargs(args)) + cls.__transform_args_with_mandatory_fields(data) + return data + + @classmethod + def __get_mandatory_field_descriptors( + cls, + ) -> Iterable[tuple[str, TransformableMandatoryField]]: + for c in reversed(inspect.getmro(cls)): + for a in c.__dict__: + if ( + not a.startswith("__") + and not a.endswith("__") + and isinstance(c.__dict__[a], TransformableMandatoryField) + ): + yield a, c.__dict__[a] + + @classmethod + def _mandatory_fields(cls) -> tuple[str, ...]: + return tuple(name for name, _ in cls.__get_mandatory_field_descriptors()) + + def __attribute_keys(self: Self): + return tuple(k for k in self.data.keys() if k not in self._mandatory_fields()) + + def __repr__(self: Self): + mandatory_fields_repr = ", ".join( + f"{k}={self.data[k]}" for k in self._mandatory_fields() + ) + attributes = ", ".join( + f"'{k}': '{self.data[k]}'" for k in self.__attribute_keys() + ) + return ( + f"{self.__class__.__name__}({mandatory_fields_repr}," + f" attributes={attributes})" + ) + + @classmethod + def __validate_arguments(cls, args: tuple[Any], kwargs: dict[str, Any]): + num_total_args = len(args) + len(kwargs) + if num_total_args not in ( + len(cls._mandatory_fields()), + len(cls._mandatory_fields()) + 1, + ): + raise ValueError( + f"allowed arguments are {cls._mandatory_fields()} and attributes, but" + f" passed args: {args} and kwargs: {kwargs}" + ) + argument_names_in_args = set(k for k, _ in zip(cls._mandatory_fields(), args)) + arguments_specified_twice = argument_names_in_args.intersection(kwargs.keys()) + if len(arguments_specified_twice) > 0: + raise ValueError(f"arguments specified twice {arguments_specified_twice}") + + @classmethod + def __extract_attributes_from_args( + cls, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> dict[str, Any]: + data = dict() + if "attributes" in kwargs: + if callable(kwargs["attributes"]): + data.update(kwargs["attributes"]()) + else: + data.update(kwargs["attributes"]) + elif len(args) + len(kwargs) == len(cls._mandatory_fields()) + 1: + attributes = args[-1] + data.update(attributes) + return data + + @classmethod + def __get_kwargs_without_attributes( + cls, kwargs: dict[str, Any] + ) -> dict[str, AttributeT]: + kwargs = {k: v for k, v in kwargs.items() if k != "attributes"} + return kwargs + + @classmethod + def __transform_args_with_mandatory_fields(cls, args: dict[str, Any]) -> None: + set_transforms = cls.__get_field_transforms() + for k in args: + if k in set_transforms: + args[k] = set_transforms[k](args[k]) + + @classmethod + def __get_field_transforms(cls) -> dict[str, Callable[[Any], AttributeT]]: + return {k: v.set_transform for k, v in cls.__get_mandatory_field_descriptors()} + + @classmethod + def __get_args_as_kwargs(cls, args: tuple[Any]) -> Iterable[tuple[str, Any]]: + return zip(cls._mandatory_fields(), args) diff --git a/elasticai/creator/ir/abstract_ir_data/mandatory_field.py b/elasticai/creator/ir/abstract_ir_data/mandatory_field.py new file mode 100644 index 00000000..93bab39d --- /dev/null +++ b/elasticai/creator/ir/abstract_ir_data/mandatory_field.py @@ -0,0 +1,72 @@ +from collections.abc import Callable +from typing import cast + +from typing_extensions import Generic, Protocol, TypeVar + +from elasticai.creator.ir.attribute import AttributeT + +from ._has_data import HasData + +T = TypeVar("T", bound=AttributeT) # stored data type + +F = TypeVar("F", default=T) # visible data type + + +class AbstractIR(Protocol): + data: dict[str, AttributeT] + + +class TransformableMandatoryField(Generic[T, F]): + """ + A __descriptor__ that designates a mandatory field of an abstract ir data class. + The descriptor accesses the `data` dictionary of the owning abstract ir data object + to read and write values. You can use the `set_transform` and `get_transform` functions + to transform values during read/write accesses. `T` designates the type stored in the + `data` dictionary, while `F` is the type that the mandatory field receives. + That allows to keep dictionary of primitive (serializable) data types in memory, + while still providing abstract ways to manipulate that data in complex ways. + This is typically required when working with Nodes and Graphs to create new + intermediate representations and transform one graph into another. + + E.g. + + ```python + class A(AbstractIrData): + number: TransformableMandatoryField[str, int] = TransformableMandatoryField(set_transform=str, get_transform=int) + + + a = A({'number': "12"}) + a.number = a.number + 3 + print(a.data) # {'number': "15"} + ``` + """ + + def __init__( + self, + set_transform: Callable[[F], T], + get_transform: Callable[[T], F], + ): + self.set_transform = set_transform + self.get_transform = get_transform + + def __set_name__(self, owner, name: str) -> None: + """ + IMPORTANT: do not remove owner even though it's not used + see https://docs.python.org/3/reference/datamodel.html#descriptors for more information + """ + self.name = name + + def __get__(self, instance: HasData, owner) -> F: + """ + IMPORTANT: do not remove owner even though it's not used + see https://docs.python.org/3/reference/datamodel.html#descriptors for more information + """ + return self.get_transform(cast(T, instance.data[self.name])) + + def __set__(self, instance: HasData, value: F) -> None: + instance.data[self.name] = self.set_transform(value) + + +class MandatoryField(TransformableMandatoryField[T, T]): + def __init__(self): + super().__init__(lambda x: x, lambda x: x) diff --git a/elasticai/creator/ir/attribute.py b/elasticai/creator/ir/attribute.py new file mode 100644 index 00000000..2441d7e3 --- /dev/null +++ b/elasticai/creator/ir/attribute.py @@ -0,0 +1,6 @@ +from typing import TypeAlias + +SizeT: TypeAlias = tuple[int] | tuple[int, int] | tuple[int, int, int] +AttributeT: TypeAlias = ( + int | float | str | tuple["AttributeT", ...] | dict[str, "AttributeT"] +) diff --git a/elasticai/creator/ir/node.py b/elasticai/creator/ir/node.py new file mode 100644 index 00000000..c550f2d5 --- /dev/null +++ b/elasticai/creator/ir/node.py @@ -0,0 +1,25 @@ +from typing import overload + +from .abstract_ir_data import AbstractIRData +from .abstract_ir_data.mandatory_field import MandatoryField +from .attribute import AttributeT + + +class Node(AbstractIRData): + name: MandatoryField[str] = MandatoryField() + type: MandatoryField[str] = MandatoryField() + + @classmethod + @overload + def new(cls, name: str, type: str, attributes: dict[str, AttributeT]) -> "Node": ... + + @classmethod + @overload + def new(cls, name: str, type: str) -> "Node": ... + + @classmethod + def new(cls, *args, **kwargs) -> "Node": + return cls._do_new(*args, **kwargs) + + def __hash__(self): + return hash((self.name, self.type)) diff --git a/elasticai/creator/ir/node_test.py b/elasticai/creator/ir/node_test.py new file mode 100644 index 00000000..f9d472af --- /dev/null +++ b/elasticai/creator/ir/node_test.py @@ -0,0 +1,114 @@ +from typing import overload + +import pytest + +from elasticai.creator.ir.abstract_ir_data.mandatory_field import MandatoryField + +from .abstract_ir_data import AbstractIRData +from .abstract_ir_data.mandatory_field import TransformableMandatoryField +from .node import Node +from .node import Node as OtherNode + + +def test_can_create_new_node(): + n = Node.new(name="my_name", type="my_type") + assert n.name == "my_name" + + +def test_creating_new_node_without_filling_all_args_yields_error(): + with pytest.raises(ValueError): + Node.new(name="my_name") + + +def test_filling_name_arg_twice_leads_to_error(): + with pytest.raises(ValueError): + Node.new("my_name", "my_type", name="other_name") + + +def test_attributes_are_merged_into_data(): + n = Node.new(name="my_name", type="my_type", attributes=dict(input_shape=(1, 1))) + assert n.data == dict(name="my_name", type="my_type", input_shape=(1, 1)) + + +def test_attributes_are_merged_into_data_for_call_with_positional_args_only(): + n = Node.new("my_name", "my_type", dict(input_shape=(1, 1))) + assert n.data == dict(name="my_name", type="my_type", input_shape=(1, 1)) + + +def test_two_different_node_types_can_share_data() -> None: + class NewNode(AbstractIRData): + input_shape: MandatoryField[tuple[int, int]] = MandatoryField() + + @classmethod + @overload + def new(cls, input_shape: tuple[int, int]) -> "NewNode": ... + + @classmethod + def new(cls, *args, **kwargs) -> "NewNode": + return cls._do_new(*args, **kwargs) + + a = Node.new(name="my_name", type="my_type", attributes=dict(input_shape=(1, 3))) + b = NewNode(a.data) + b.input_shape = (4, 3) + assert a.attributes["input_shape"] == b.input_shape + + +def test_nodes_inherit_mandatory_fields() -> None: + class NewNode(Node): + input_shape: MandatoryField[tuple[int, int]] = MandatoryField() + + @classmethod + @overload + def new( + cls, name: str, type: str, input_shape: tuple[int, int] + ) -> "NewNode": ... + + @classmethod + def new(cls, *args, **kwargs) -> "NewNode": + return cls._do_new(*args, **kwargs) + + n = NewNode.new(name="my_name", type="my_type", input_shape=(1, 3)) + assert n.input_shape == (1, 3) + + +def test_can_serialize_node_with_attributes(): + n = Node.new(name="my_node", type="my_type", attributes={"a": "b", "c": (1, 2)}) + assert dict(name="my_node", type="my_type", a="b", c=(1, 2)) == n.as_dict() + + +def test_can_deserialize_node_with_attributes(): + n = Node.new(name="my_node", type="my_type", attributes={"a": "b", "c": (1, 2)}) + assert n == Node.from_dict(dict(name="my_node", type="my_type", a="b", c=(1, 2))) + + +def test_import_path_does_not_matter_for_equality(): + n = Node.new(name="a", type="a_type", attributes=dict(a="b")) + other = OtherNode.new(name="a", type="a_type", attributes=dict(a="b")) + assert n == other + + +def test_set_transform_is_applied_when_calling_new() -> None: + def set_transform(x: str) -> int: + return int(x) + + def get_transform(x: int) -> str: + return str(x) + + class NewNode(Node): + input_shape: TransformableMandatoryField[int, str] = ( + TransformableMandatoryField( + set_transform=set_transform, + get_transform=get_transform, + ) + ) + + @classmethod + @overload + def new(cls, name: str, type: str, input_shape: str) -> "NewNode": ... + + @classmethod + def new(cls, *args, **kwargs) -> "NewNode": + return cls._do_new(*args, **kwargs) + + n = NewNode.new(name="my_name", type="my_type", input_shape="3") + assert n.input_shape == "3" diff --git a/elasticai/creator/ir/typing.py b/elasticai/creator/ir/typing.py new file mode 100644 index 00000000..6062af87 --- /dev/null +++ b/elasticai/creator/ir/typing.py @@ -0,0 +1,22 @@ +from typing import Protocol, Self, TypeVar, runtime_checkable + +from .attribute import AttributeT + + +@runtime_checkable +class Node(Protocol): + name: str + type: str + data: dict[str, AttributeT] + attributes: dict[str, AttributeT] + + @classmethod + def new(cls, *args, **kwargs) -> Self: ... + + @classmethod + def from_dict(cls, d: dict[str, AttributeT]) -> Self: ... + + def as_dict(self) -> dict[str, AttributeT]: ... + + +NodeT = TypeVar("NodeT", bound=Node)