diff --git a/docs/source/api_reference/syft.service.dataset.dataset.rst b/docs/source/api_reference/syft.service.dataset.dataset.rst index 120d15361ba..55be8888665 100644 --- a/docs/source/api_reference/syft.service.dataset.dataset.rst +++ b/docs/source/api_reference/syft.service.dataset.dataset.rst @@ -36,7 +36,7 @@ syft.service.dataset.dataset CreateDataset Dataset DatasetUpdate - TupleDict + DictTuple diff --git a/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb b/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb index f63a3c81d31..608b2971b97 100644 --- a/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb +++ b/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb @@ -48,7 +48,7 @@ "metadata": {}, "outputs": [], "source": [ - "node = sy.orchestra.launch(name=\"private-data-example-domain-1\",port=8040, reset=True)" + "node = sy.orchestra.launch(name=\"private-data-example-domain-1\", port=\"auto\", reset=True)" ] }, { @@ -158,6 +158,16 @@ "client.datasets" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "af495cad", + "metadata": {}, + "outputs": [], + "source": [ + "client.datasets[\"my dataset\"]" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -265,6 +275,18 @@ "source": [ "## High Side vs Low Side" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c13cdaa2", + "metadata": {}, + "outputs": [], + "source": [ + "# Cleanup local domain server\n", + "if node.node_type.value == \"python\":\n", + " node.land()" + ] } ], "metadata": { @@ -283,7 +305,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.11.5" }, "toc": { "base_numbering": 1, diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 0f050949e0d..a3c753a25cb 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -1,5 +1,7 @@ # stdlib from collections import defaultdict +from collections.abc import MutableMapping +from collections.abc import MutableSequence import hashlib import json import os @@ -21,6 +23,7 @@ from ..service.response import SyftError from ..service.response import SyftException from ..service.response import SyftSuccess +from ..types.dicttuple import DictTuple from ..types.syft_object import SyftBaseObject PROTOCOL_STATE_FILENAME = "protocol_version.json" @@ -316,20 +319,25 @@ def check_or_stage_protocol() -> Result[SyftSuccess, SyftError]: def debox_arg_and_migrate(arg: Any, protocol_state: dict): """Debox the argument based on whether it is iterable or single entity.""" - box_to_result_type = None - - if type(arg) in OkErr: - box_to_result_type = type(arg) - arg = arg.value + constructor = None + extra_args = [] single_entity = False - is_tuple = isinstance(arg, tuple) - if isinstance(arg, (list, tuple)): + if isinstance(arg, OkErr): + constructor = type(arg) + arg = arg.value + + if isinstance(arg, MutableMapping): + iterable_keys = arg.keys() + elif isinstance(arg, MutableSequence): + iterable_keys = range(len(arg)) + elif isinstance(arg, tuple): iterable_keys = range(len(arg)) + constructor = type(arg) + if isinstance(arg, DictTuple): + extra_args.append(arg.keys()) arg = list(arg) - elif isinstance(arg, dict): - iterable_keys = arg.keys() else: iterable_keys = range(1) arg = [arg] @@ -349,9 +357,8 @@ def debox_arg_and_migrate(arg: Any, protocol_state: dict): arg[key] = _object wrapped_arg = arg[0] if single_entity else arg - wrapped_arg = tuple(wrapped_arg) if is_tuple else wrapped_arg - if box_to_result_type is not None: - wrapped_arg = box_to_result_type(wrapped_arg) + if constructor is not None: + wrapped_arg = constructor(wrapped_arg, *extra_args) return wrapped_arg diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index c0ae2159761..b5d9bbf2b32 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -345,7 +345,7 @@ "DatasetPageView": { "1": { "version": 1, - "hash": "68c7a0c3e7796fdabb8f732c6d150ec4a8071ce78d69b30da18393afdcea1e59", + "hash": "6741bd16dc6089d9deea37b1bd4e895152d1a0c163b8bdfe45280b9bfc4a1354", "action": "add" } }, diff --git a/packages/syft/src/syft/serde/recursive_primitives.py b/packages/syft/src/syft/serde/recursive_primitives.py index efcf6ca210b..8bc0dbdc640 100644 --- a/packages/syft/src/syft/serde/recursive_primitives.py +++ b/packages/syft/src/syft/serde/recursive_primitives.py @@ -1,6 +1,7 @@ # stdlib from collections import OrderedDict from collections import defaultdict +from collections.abc import Iterable from collections.abc import Mapping from enum import Enum from enum import EnumMeta @@ -71,16 +72,20 @@ def deserialize_iterable(iterable_type: type, blob: bytes) -> Collection: return iterable_type(values) -def serialize_kv(map: Mapping) -> bytes: +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +def _serialize_kv_pairs(size: int, kv_pairs: Iterable[tuple[_KT, _VT]]) -> bytes: # relative from .serialize import _serialize message = kv_iterable_schema.new_message() - message.init("keys", len(map)) - message.init("values", len(map)) + message.init("keys", size) + message.init("values", size) - for index, (k, v) in enumerate(map.items()): + for index, (k, v) in enumerate(kv_pairs): message.keys[index] = _serialize(k, to_bytes=True) serialized = _serialize(v, to_bytes=True) chunk_bytes(serialized, index, message.values) @@ -88,6 +93,10 @@ def serialize_kv(map: Mapping) -> bytes: return message.to_bytes() +def serialize_kv(map: Mapping) -> bytes: + return _serialize_kv_pairs(len(map), map.items()) + + def get_deserialized_kv_pairs(blob: bytes) -> List[Any]: # relative from .deserialize import _deserialize diff --git a/packages/syft/src/syft/serde/third_party.py b/packages/syft/src/syft/serde/third_party.py index edc08dee1b0..1abfe2d9cdc 100644 --- a/packages/syft/src/syft/serde/third_party.py +++ b/packages/syft/src/syft/serde/third_party.py @@ -27,12 +27,15 @@ import zmq.green as zmq # relative -from ..types.tupledict import TupleDict +from ..types.dicttuple import DictTuple +from ..types.dicttuple import _Meta as _DictTupleMetaClass from .deserialize import _deserialize as deserialize +from .recursive_primitives import _serialize_kv_pairs from .recursive_primitives import deserialize_kv +from .recursive_primitives import deserialize_type from .recursive_primitives import recursive_serde_register from .recursive_primitives import recursive_serde_register_type -from .recursive_primitives import serialize_kv +from .recursive_primitives import serialize_type from .serialize import _serialize as serialize recursive_serde_register( @@ -128,10 +131,20 @@ def deserialize_series(blob: bytes) -> Series: deserialize=lambda x: Timestamp(deserialize(x, from_bytes=True)), ) + +def _serialize_dicttuple(x: DictTuple) -> bytes: + return _serialize_kv_pairs(size=len(x), kv_pairs=zip(x.keys(), x)) + + +recursive_serde_register( + _DictTupleMetaClass, + serialize=serialize_type, + deserialize=deserialize_type, +) recursive_serde_register( - TupleDict, - serialize=serialize_kv, - deserialize=functools.partial(deserialize_kv, TupleDict), + DictTuple, + serialize=_serialize_dicttuple, + deserialize=functools.partial(deserialize_kv, DictTuple), ) diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 42f1d6734d7..f482e4ba67e 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -26,13 +26,13 @@ from ...serde.serializable import serializable from ...store.document_store import PartitionKey from ...types.datetime import DateTime +from ...types.dicttuple import DictTuple from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import generate_id from ...types.transforms import transform from ...types.transforms import validate_url -from ...types.tupledict import TupleDict from ...types.uid import UID from ...util import options from ...util.colors import ON_SURFACE_HIGHEST @@ -525,11 +525,8 @@ def action_ids(self) -> List[UID]: return data @property - def assets(self) -> TupleDict[str, Asset]: - data = TupleDict() - for asset in self.asset_list: - data[asset.name] = asset - return data + def assets(self) -> DictTuple[str, Asset]: + return DictTuple((asset.name, asset) for asset in self.asset_list) def _old_repr_markdown_(self) -> str: _repr_str = f"Syft Dataset: {self.name}\n" @@ -606,7 +603,7 @@ class DatasetPageView(SyftObject): __canonical_name__ = "DatasetPageView" __version__ = SYFT_OBJECT_VERSION_1 - datasets: TupleDict[str, Dataset] + datasets: DictTuple[str, Dataset] total: int diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index c01d4ca4840..2bce6a10ea4 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -1,5 +1,5 @@ # stdlib -from itertools import islice +from collections.abc import Collection from typing import List from typing import Optional from typing import Union @@ -7,7 +7,7 @@ # relative from ...serde.serializable import serializable from ...store.document_store import DocumentStore -from ...types.tupledict import TupleDict +from ...types.dicttuple import DictTuple from ...types.uid import UID from ...util.telemetry import instrument from ..action.action_permissions import ActionObjectPermission @@ -31,30 +31,40 @@ from .dataset_stash import DatasetStash -def _paginate_dataset_collection( - results: TupleDict[str, Dataset], +def _paginate_collection( + collection: Collection, page_size: Optional[int] = 0, page_index: Optional[int] = 0, -) -> DatasetPageView: - if page_size is None or page_size <= 0: - return results +) -> Optional[slice]: + if page_size is None or page_index <= 0: + return None # If chunk size is defined, then split list into evenly sized chunks - total = len(results) + total = len(collection) page_index = 0 if page_index is None else page_index if page_size > total or page_index >= total // page_size or page_index < 0: - pass - else: - results = TupleDict( - islice( - results.items(), - page_size * page_index, - min(page_size * (page_index + 1), total), - ) - ) + return None - return DatasetPageView(datasets=results, total=total) + start = page_size * page_index + stop = min(page_size * (page_index + 1), total) + return slice(start, stop) + + +def _paginate_dataset_collection( + datasets: Collection[Dataset], + page_size: Optional[int] = 0, + page_index: Optional[int] = 0, +) -> Union[DictTuple[str, Dataset], DatasetPageView]: + slice_ = _paginate_collection(datasets, page_size=page_size, page_index=page_index) + chunk = datasets[slice_] if slice_ is not None else datasets + results = DictTuple((dataset.name, dataset) for dataset in chunk) + + return ( + results + if slice_ is None + else DatasetPageView(datasets=results, total=len(datasets)) + ) @instrument @@ -104,7 +114,7 @@ def get_all( context: AuthedServiceContext, page_size: Optional[int] = 0, page_index: Optional[int] = 0, - ) -> Union[DatasetPageView, TupleDict[str, Dataset], SyftError]: + ) -> Union[DatasetPageView, DictTuple[str, Dataset], SyftError]: """Get a Dataset""" result = self.stash.get_all(context.credentials) if not result.is_ok(): @@ -112,17 +122,11 @@ def get_all( datasets = result.ok() - results = TupleDict() for dataset in datasets: dataset.node_uid = context.node.id - results[dataset.name] = dataset - - return ( - results - if page_size <= 0 or page_size is None - else _paginate_dataset_collection( - results, page_size=page_size, page_index=page_index - ) + + return _paginate_dataset_collection( + datasets=datasets, page_size=page_size, page_index=page_index ) @service_method( @@ -141,11 +145,9 @@ def search( if isinstance(results, SyftError): return results - filtered_results = TupleDict( - (dataset_name, dataset) - for dataset_name, dataset in results.items() - if name in dataset_name - ) + filtered_results = [ + dataset for dataset_name, dataset in results.items() if name in dataset_name + ] return _paginate_dataset_collection( filtered_results, page_size=page_size, page_index=page_index diff --git a/packages/syft/src/syft/types/dicttuple.py b/packages/syft/src/syft/types/dicttuple.py new file mode 100644 index 00000000000..77ed0b1f0a9 --- /dev/null +++ b/packages/syft/src/syft/types/dicttuple.py @@ -0,0 +1,232 @@ +# stdlib +from collections import OrderedDict +from collections import deque +from collections.abc import Collection +from collections.abc import Iterable +from collections.abc import KeysView +from collections.abc import Mapping +from types import MappingProxyType +from typing import Callable +from typing import Generic +from typing import Optional +from typing import SupportsIndex +from typing import TypeVar +from typing import Union +from typing import overload + +# third party +from typing_extensions import Self + +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +# To correctly implement the creation of a DictTuple instance with +# DictTuple(key_value_pairs: Iterable[tuple[_KT, _VT]]), +# implementing just __init__ and __new__ is not enough. +# +# We need to extract the keys and values from keys_value_pairs +# and pass the values to __new__ (needs to call tuple.__new__(values) to create the tuple) +# and the keys to __init__ to create the mapping _KT -> int, +# but we can only iterate over key_value_pairs once since keys_value_pairs, +# being just an Iterable, might be ephemeral. +# +# Implementing just __new__ and __init__ for DictTuple is not enough since when +# calling DictTuple(key_value_pairs), __new__(keys_and_values) and __init__(key_value_pairs) +# are called in 2 separate function calls. If keys_and_values are ephemeral, like a generator, +# by the time it's passed to __init__() it's already been exhausted and there is no way to +# extract the keys out to create the mapping. +# +# Thus it is necessary to override __call__ of the metaclass +# to customize the way __new__ and __init__ work together, by iterating over key_value_pairs +# once to extract both keys and values, then passing keys to __new__, values to __init__ +# within the same function call. +class _Meta(type): + @overload + def __call__(cls: type[_T]) -> _T: + ... + + @overload + def __call__(cls: type[_T], __value: Iterable[tuple[_KT, _VT]]) -> _T: + ... + + @overload + def __call__(cls: type[_T], __value: Mapping[_KT, _VT]) -> _T: + ... + + @overload + def __call__(cls: type[_T], __value: Iterable[_VT], __key: Collection[_KT]) -> _T: + ... + + @overload + def __call__( + cls: type[_T], __value: Iterable[_VT], __key: Callable[[_VT], _KT] + ) -> _T: + ... + + def __call__( + cls: type[_T], + __value: Optional[Iterable] = None, + __key: Optional[Union[Callable, Collection]] = None, + /, + ) -> _T: + if __value is None and __key is None: + obj = cls.__new__(cls) + obj.__init__() + return obj + + elif isinstance(__value, Mapping) and __key is None: + obj = cls.__new__(cls, __value.values()) + obj.__init__(__value.keys()) + + return obj + + elif isinstance(__value, Iterable) and __key is None: + keys = OrderedDict() + values = deque() + + for i, (k, v) in enumerate(__value): + keys[k] = i + values.append(v) + + obj = cls.__new__(cls, values) + obj.__init__(keys) + + return obj + + elif isinstance(__value, Iterable) and isinstance(__key, Iterable): + keys = OrderedDict((k, i) for i, k in enumerate(__key)) + + obj = cls.__new__(cls, __value) + obj.__init__(keys) + + return obj + + elif isinstance(__value, Iterable) and isinstance(__key, Callable): + obj = cls.__new__(cls, __value) + obj.__init__(__key) + + return obj + + raise NotImplementedError + + +class DictTuple(tuple[_VT, ...], Generic[_KT, _VT], metaclass=_Meta): + """ + OVERVIEW + + tuple with support for dict-like __getitem__(key) + + dict_tuple = DictTuple({"x": 1, "y": 2}) + + dict_tuple["x"] == 1 + + dict_tuple["y"] == 2 + + dict_tuple[0] == 1 + + dict_tuple[1] == 2 + + everything else, e.g. __contains__, __iter__, behaves similarly to a tuple + + + CREATION + + DictTuple(iterable) -> DictTuple([("x", 1), ("y", 2)]) + + DictTuple(mapping) -> DictTuple({"x": 1, "y": 2}) + + DictTuple(values, keys) -> DictTuple([1, 2], ["x", "y"]) + + + IMPLEMENTATION DETAILS + + DictTuple[_KT, _VT] is essentially a tuple[_VT, ...] that maintains an immutable Mapping[_KT, int] + from the key to the tuple index internally. + + For example DictTuple({"x": 12, "y": 34}) is just a tuple (12, 34) with a {"x": 0, "y": 1} mapping. + + types.MappingProxyType is used for the mapping for immutability. + """ + + __mapping: MappingProxyType[_KT, int] + + # These overloads are copied from _Meta.__call__ just for IDE hints + @overload + def __init__(self) -> None: + ... + + @overload + def __init__(self, __value: Iterable[tuple[_KT, _VT]]) -> None: + ... + + @overload + def __init__(self, __value: Mapping[_KT, _VT]) -> None: + ... + + @overload + def __init__(self, __value: Iterable[_VT], __key: Collection[_KT]) -> None: + ... + + @overload + def __init__(self, __value: Iterable[_VT], __key: Callable[[_VT], _KT]) -> None: + ... + + def __init__(self, __value=None, /): + if isinstance(__value, MappingProxyType): + self.__mapping = __value + elif isinstance(__value, Mapping): + self.__mapping = MappingProxyType(__value) + elif isinstance(__value, Iterable): + self.__mapping = MappingProxyType( + OrderedDict((k, i) for i, k in enumerate(__value)) + ) + elif isinstance(__value, Callable): + self.__mapping = MappingProxyType( + OrderedDict((__value(v), i) for i, v in enumerate(self)) + ) + + super().__init__() + + if len(self.__mapping) != len(self): + raise ValueError("`__keys` and `__values` do not have the same length") + + if any(isinstance(k, SupportsIndex) for k in self.__mapping.keys()): + raise ValueError( + "values of `__keys` should not have type `int`, " + "or implement `__index__()`" + ) + + @overload + def __getitem__(self, __key: _KT) -> _VT: + ... + + @overload + def __getitem__(self, __key: slice) -> Self: + ... + + @overload + def __getitem__(self, __key: SupportsIndex) -> _VT: + ... + + def __getitem__(self, __key, /): + if isinstance(__key, slice): + return self.__class__( + super().__getitem__(__key), + list(self.__mapping.keys()).__getitem__(__key), + ) + + if isinstance(__key, SupportsIndex): + return super().__getitem__(__key) + + return super().__getitem__(self.__mapping[__key]) + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}{super().__repr__()}" + + def keys(self) -> KeysView[_KT]: + return self.__mapping.keys() + + def items(self) -> Iterable[tuple[_KT, _VT]]: + return zip(self.__mapping.keys(), self) diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index b3226adeaf0..ee474273814 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -1,6 +1,8 @@ # stdlib from collections import defaultdict from collections.abc import Mapping +from collections.abc import MutableMapping +from collections.abc import MutableSequence from collections.abc import Set from hashlib import sha256 import inspect @@ -38,6 +40,7 @@ from ..util.util import aggressive_set_attr from ..util.util import full_name_with_qualname from ..util.util import get_qualname_for +from .dicttuple import DictTuple from .syft_metaclass import Empty from .syft_metaclass import PartialModelMetaclass from .uid import UID @@ -778,6 +781,7 @@ def list_dict_repr_html(self) -> str: aggressive_set_attr(type([]), "_repr_html_", list_dict_repr_html) aggressive_set_attr(type({}), "_repr_html_", list_dict_repr_html) aggressive_set_attr(type(set()), "_repr_html_", list_dict_repr_html) +aggressive_set_attr(tuple, "_repr_html_", list_dict_repr_html) class StorableObjectType: @@ -838,20 +842,25 @@ def __init__(self, *args, **kwargs) -> None: def attach_attribute_to_syft_object(result: Any, attr_dict: Dict[str, Any]) -> Any: - box_to_result_type = None - - if type(result) in OkErr: - box_to_result_type = type(result) - result = result.value + constructor = None + extra_args = [] single_entity = False - is_tuple = isinstance(result, tuple) - if isinstance(result, (list, tuple)): + if isinstance(result, OkErr): + constructor = type(result) + result = result.value + + if isinstance(result, MutableMapping): + iterable_keys = result.keys() + elif isinstance(result, MutableSequence): + iterable_keys = range(len(result)) + elif isinstance(result, tuple): iterable_keys = range(len(result)) + constructor = type(result) + if isinstance(result, DictTuple): + extra_args.append(result.keys()) result = list(result) - elif isinstance(result, Mapping): - iterable_keys = result.keys() else: iterable_keys = range(1) result = [result] @@ -872,8 +881,7 @@ def attach_attribute_to_syft_object(result: Any, attr_dict: Dict[str, Any]) -> A result[key] = _object wrapped_result = result[0] if single_entity else result - wrapped_result = tuple(wrapped_result) if is_tuple else wrapped_result - if box_to_result_type is not None: - wrapped_result = box_to_result_type(wrapped_result) + if constructor is not None: + wrapped_result = constructor(wrapped_result, *extra_args) return wrapped_result diff --git a/packages/syft/src/syft/types/tupledict.py b/packages/syft/src/syft/types/tupledict.py deleted file mode 100644 index f55eda9ed6c..00000000000 --- a/packages/syft/src/syft/types/tupledict.py +++ /dev/null @@ -1,21 +0,0 @@ -# stdlib -from collections import OrderedDict -from typing import Iterator -from typing import TypeVar -from typing import Union - -_KT = TypeVar("_KT") -_VT = TypeVar("_VT") - - -class TupleDict(OrderedDict[_KT, _VT]): - def __getitem__(self, key: Union[int, _KT]) -> _VT: - if isinstance(key, int): - return list(self.values())[key] - return super().__getitem__(key) - - def __len__(self) -> int: - return len(self.keys()) - - def __iter__(self) -> Iterator[_VT]: - yield from self.values() diff --git a/packages/syft/tests/syft/types/dicttuple_test.py b/packages/syft/tests/syft/types/dicttuple_test.py new file mode 100644 index 00000000000..eb5f0947881 --- /dev/null +++ b/packages/syft/tests/syft/types/dicttuple_test.py @@ -0,0 +1,245 @@ +# stdlib +from collections.abc import Collection +from collections.abc import Iterable +from collections.abc import Mapping +from functools import cached_property +from itertools import chain +from itertools import combinations +from typing import Any +from typing import Callable +from typing import Generator +from typing import Generic +from typing import Optional +from typing import TypeVar +from typing import Union + +# third party +import pytest +from typing_extensions import Self + +# syft absolute +from syft.types.dicttuple import DictTuple + + +def test_dict_tuple_not_subclassing_mapping(): + assert not issubclass(DictTuple, Mapping) + + +# different ways to create a DictTuple +SIMPLE_TEST_CASES = [ + DictTuple({"x": 1, "y": 2}), + DictTuple([("x", 1), ("y", 2)]), + DictTuple([1, 2], ["x", "y"]), +] + + +@pytest.mark.parametrize("dict_tuple", SIMPLE_TEST_CASES) +def test_should_iter_over_value(dict_tuple: DictTuple) -> None: + values = [] + for v in dict_tuple: + values.append(v) + + assert values == [1, 2] + + +@pytest.mark.parametrize("dict_tuple", SIMPLE_TEST_CASES) +def test_indexing(dict_tuple: DictTuple) -> None: + assert dict_tuple[0] == 1 + assert dict_tuple[1] == 2 + assert dict_tuple["x"] == 1 + assert dict_tuple["y"] == 2 + + +@pytest.mark.parametrize("dict_tuple", SIMPLE_TEST_CASES) +def test_convert_to_other_iterable_types(dict_tuple: DictTuple) -> None: + assert list(dict_tuple) == [1, 2] + assert tuple(dict_tuple) == (1, 2) + + +@pytest.mark.parametrize("dict_tuple", SIMPLE_TEST_CASES) +def test_keys(dict_tuple) -> None: + assert list(dict_tuple.keys()) == ["x", "y"] + + +@pytest.mark.parametrize("dict_tuple", SIMPLE_TEST_CASES) +def test_convert_to_dict(dict_tuple: DictTuple) -> None: + assert dict(dict_tuple) == {"x": 1, "y": 2} + + +@pytest.mark.parametrize("dict_tuple", SIMPLE_TEST_CASES) +def test_convert_items_to_dicttest_get_mapping(dict_tuple: DictTuple) -> None: + assert dict(dict_tuple.items()) == {"x": 1, "y": 2} + + +@pytest.mark.parametrize("dict_tuple", SIMPLE_TEST_CASES) +def test_iter_over_items(dict_tuple: DictTuple) -> None: + items = [] + for k, v in dict_tuple.items(): + items.append((k, v)) + + assert items == [("x", 1), ("y", 2)] + + +@pytest.mark.parametrize("dict_tuple", SIMPLE_TEST_CASES) +def test_dicttuple_is_not_a_mapping(dict_tuple: DictTuple) -> None: + assert not isinstance(dict_tuple, Mapping) + + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +class Case(Generic[_KT, _VT]): + values: Collection[_VT] + keys: Collection[_KT] + key_fn: Optional[Callable[[_VT], _KT]] + value_generator: Callable[[], Generator[_VT, Any, None]] + key_generator: Callable[[], Generator[_KT, Any, None]] + + def __init__( + self, + values: Collection[_VT], + keys: Union[Callable[[_VT], _KT], Collection[_KT]], + ) -> None: + self.values = values + + if isinstance(keys, Callable): + self.key_fn = keys + self.keys = [keys(v) for v in values] + else: + self.key_fn = None + self.keys = keys + + def value_generator() -> Generator[_VT, Any, None]: + yield from values + + def key_generator() -> Generator[_KT, Any, None]: + yield from self.keys + + self.value_generator = value_generator + self.key_generator = key_generator + + def kv(self) -> Iterable[tuple[_KT, _VT]]: + return zip(self.keys, self.values) + + @cached_property + def mapping(self) -> dict[_KT, _VT]: + return dict(self.kv()) + + def constructor_args(self, mapping: bool = True) -> list[Callable[[], tuple]]: + return [ + lambda: (self.values, self.keys), + lambda: (self.value_generator(), self.key_generator()), + lambda: (self.values, self.key_generator()), + lambda: (self.value_generator(), self.keys), + *( + [ + lambda: (self.mapping,), + lambda: (self.kv(),), + ] + if mapping + else [] + ), + *( + [ + lambda: (self.values, self.key_fn), + lambda: (self.value_generator(), self.key_fn), + ] + if self.key_fn is not None + else [] + ), + ] + + def generate(self) -> Generator[DictTuple[_KT, _VT], Any, None]: + return (DictTuple(*args()) for args in self.constructor_args()) + + def generate_one(self) -> DictTuple[_KT, _VT]: + return next(self.generate()) + + @classmethod + def from_kv(cls, kv: Mapping[_KT, _VT]) -> Self: + return cls(kv.values(), kv.keys()) + + def __repr__(self): + return f"{self.__class__.__qualname__}{self.mapping}" + + +TEST_CASES: list[Case] = [ + Case(values=[1, 2, 3], keys=["x", "y", "z"]), + Case(values=[1, 2, 3], keys=str), +] + + +@pytest.mark.parametrize( + "args1,args2", + chain.from_iterable(combinations(c.constructor_args(), 2) for c in TEST_CASES), +) +def test_all_equal(args1: Callable[[], tuple], args2: Callable[[], tuple]) -> None: + d1 = DictTuple(*args1()) + d2 = DictTuple(*args2()) + + assert d1 == d2 + assert d1.keys() == d2.keys() + + +@pytest.mark.parametrize( + "dict_tuple,case", + [(c.generate_one(), c) for c in TEST_CASES], +) +class TestDictTupleProperties: + def test_should_iter_over_value(self, dict_tuple: DictTuple, case: Case) -> None: + itered = (v for v in dict_tuple) + assert all(a == b for a, b in zip(itered, case.values)) + + def test_int_indexing(self, dict_tuple: DictTuple, case: Case) -> None: + for i in range(len(dict_tuple)): + assert dict_tuple[i] == case.values[i] + + def test_key_indexing(self, dict_tuple: DictTuple, case: Case) -> None: + for k in case.keys: + assert dict_tuple[k] == case.mapping[k] + + def test_convert_to_other_iterable_types( + self, dict_tuple: DictTuple, case: Case + ) -> None: + assert list(dict_tuple) == list(case.values) + assert tuple(dict_tuple) == tuple(case.values) + + def test_keys(self, dict_tuple: DictTuple, case: Case) -> None: + assert list(dict_tuple.keys()) == list(case.keys) + + def test_dicttuple_is_not_a_mapping( + self, dict_tuple: DictTuple, case: Case + ) -> None: + assert not isinstance(dict_tuple, Mapping) + + def test_convert_to_dict(self, dict_tuple: DictTuple, case: Case) -> None: + assert dict(dict_tuple) == case.mapping + + def test_convert_items_to_dict(self, dict_tuple: DictTuple, case: Case) -> None: + assert dict(dict_tuple.items()) == case.mapping + + +@pytest.mark.parametrize( + "args", Case(values=["z", "b"], keys=[1, 2]).constructor_args() +) +def test_keys_should_not_be_int(args: Callable[[], tuple]) -> None: + with pytest.raises(ValueError, match="int"): + DictTuple(*args()) + + +LENGTH_MISMACTH_TEST_CASES = [ + Case(values=[1, 2, 3], keys=["x", "y"]), + Case(values=[1, 2], keys=["x", "y", "z"]), +] + + +@pytest.mark.parametrize( + "args", + chain.from_iterable( + c.constructor_args(mapping=False) for c in LENGTH_MISMACTH_TEST_CASES + ), +) +def test_keys_and_values_should_have_same_length(args: Callable[[], tuple]) -> None: + with pytest.raises(ValueError, match="length"): + DictTuple(*args())