diff --git a/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb b/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb index 608b2971b97..48632c0ffe9 100644 --- a/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb +++ b/notebooks/tutorials/data-owner/01-uploading-private-data.ipynb @@ -168,6 +168,37 @@ "client.datasets[\"my dataset\"]" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b9105cf", + "metadata": {}, + "outputs": [], + "source": [ + "search_result = client.datasets.search(\"my\", page_size=1 , page_index=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ed822bd", + "metadata": {}, + "outputs": [], + "source": [ + "from syft.service.dataset.dataset import DatasetPageView\n", + "assert isinstance(search_result, DatasetPageView)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8513b8f5", + "metadata": {}, + "outputs": [], + "source": [ + "search_result.datasets" + ] + }, { "attachments": {}, "cell_type": "markdown", diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 204b96691b0..2a9cf876103 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -369,7 +369,7 @@ "DatasetPageView": { "1": { "version": 1, - "hash": "6741bd16dc6089d9deea37b1bd4e895152d1a0c163b8bdfe45280b9bfc4a1354", + "hash": "b1de14bb9b6a259648dfc59b6a48fa526116afe50a689c24b8bb36fd0e6a97f8", "action": "add" } }, diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index f482e4ba67e..8cb524b1148 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -603,7 +603,7 @@ class DatasetPageView(SyftObject): __canonical_name__ = "DatasetPageView" __version__ = SYFT_OBJECT_VERSION_1 - datasets: DictTuple[str, Dataset] + datasets: DictTuple total: int diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index 2bce6a10ea4..2558e1e560b 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -36,7 +36,7 @@ def _paginate_collection( page_size: Optional[int] = 0, page_index: Optional[int] = 0, ) -> Optional[slice]: - if page_size is None or page_index <= 0: + if page_size is None or page_size <= 0: return None # If chunk size is defined, then split list into evenly sized chunks @@ -58,7 +58,7 @@ def _paginate_dataset_collection( ) -> Union[DictTuple[str, Dataset], DatasetPageView]: slice_ = _paginate_collection(datasets, page_size=page_size, page_index=page_index) chunk = datasets[slice_] if slice_ is not None else datasets - results = DictTuple((dataset.name, dataset) for dataset in chunk) + results = DictTuple(chunk, lambda dataset: dataset.name) return ( results diff --git a/packages/syft/src/syft/types/dicttuple.py b/packages/syft/src/syft/types/dicttuple.py index 77ed0b1f0a9..2af66bda704 100644 --- a/packages/syft/src/syft/types/dicttuple.py +++ b/packages/syft/src/syft/types/dicttuple.py @@ -71,17 +71,30 @@ def __call__( __key: Optional[Union[Callable, Collection]] = None, /, ) -> _T: + # DictTuple() if __value is None and __key is None: obj = cls.__new__(cls) obj.__init__() return obj + # DictTuple(DictTuple(...)) + elif type(__value) is cls: + return __value + + # DictTuple({"x": 123, "y": 456}) elif isinstance(__value, Mapping) and __key is None: obj = cls.__new__(cls, __value.values()) obj.__init__(__value.keys()) return obj + # DictTuple(EnhancedDictTuple(...)) + # EnhancedDictTuple(DictTuple(...)) + # where EnhancedDictTuple subclasses DictTuple + elif hasattr(__value, "items") and callable(__value.items): + return cls.__call__(__value.items()) + + # DictTuple([("x", 123), ("y", 456)]) elif isinstance(__value, Iterable) and __key is None: keys = OrderedDict() values = deque() @@ -95,6 +108,7 @@ def __call__( return obj + # DictTuple([123, 456], ["x", "y"]) elif isinstance(__value, Iterable) and isinstance(__key, Iterable): keys = OrderedDict((k, i) for i, k in enumerate(__key)) @@ -103,6 +117,8 @@ def __call__( return obj + # DictTuple(["abc", "xyz"], lambda x: x[0]) + # equivalent to DictTuple({"a": "abc", "x": "xyz"}) elif isinstance(__value, Iterable) and isinstance(__key, Callable): obj = cls.__new__(cls, __value) obj.__init__(__key) diff --git a/packages/syft/tests/syft/types/dicttuple_test.py b/packages/syft/tests/syft/types/dicttuple_test.py index eb5f0947881..220496b0032 100644 --- a/packages/syft/tests/syft/types/dicttuple_test.py +++ b/packages/syft/tests/syft/types/dicttuple_test.py @@ -12,12 +12,18 @@ from typing import Optional from typing import TypeVar from typing import Union +import uuid # third party +from faker import Faker import pytest from typing_extensions import Self # syft absolute +from syft.service.dataset.dataset import Contributor +from syft.service.dataset.dataset import Dataset +from syft.service.dataset.dataset import DatasetPageView +from syft.service.user.roles import Roles from syft.types.dicttuple import DictTuple @@ -219,6 +225,13 @@ def test_convert_to_dict(self, dict_tuple: DictTuple, case: Case) -> None: def test_convert_items_to_dict(self, dict_tuple: DictTuple, case: Case) -> None: assert dict(dict_tuple.items()) == case.mapping + def test_constructing_dicttuple_from_itself( + self, dict_tuple: DictTuple, case: Case + ) -> None: + dd = DictTuple(dict_tuple) + assert tuple(dd) == tuple(case.values) + assert tuple(dd.keys()) == tuple(case.keys) + @pytest.mark.parametrize( "args", Case(values=["z", "b"], keys=[1, 2]).constructor_args() @@ -243,3 +256,47 @@ def test_keys_should_not_be_int(args: Callable[[], tuple]) -> None: def test_keys_and_values_should_have_same_length(args: Callable[[], tuple]) -> None: with pytest.raises(ValueError, match="length"): DictTuple(*args()) + + +def test_datasetpageview(faker: Faker): + uploader = Contributor( + name=faker.name(), role=str(Roles.UPLOADER), email=faker.email() + ) + + length = 10 + datasets = ( + Dataset(name=uuid.uuid4().hex, contributor={uploader}, uploader=uploader) + for _ in range(length) + ) + dict_tuple = DictTuple(datasets, lambda d: d.name) + + assert DatasetPageView(datasets=dict_tuple, total=length) + + +class EnhancedDictTuple(DictTuple): + pass + + +@pytest.mark.parametrize( + "args,case", + chain.from_iterable( + ((args, c) for args in c.constructor_args()) for c in TEST_CASES + ), +) +def test_subclassing_dicttuple(args: Callable[[], tuple], case: Case): + dict_tuple = DictTuple(*args()) + enhanced_dict_tuple = EnhancedDictTuple(*args()) + dict_tuple_enhanced_dict_tuple = DictTuple(enhanced_dict_tuple) + enhanced_dict_tuple_dict_tuple = EnhancedDictTuple(dict_tuple) + + values = tuple(case.values) + keys = tuple(case.keys) + + for d in ( + dict_tuple, + enhanced_dict_tuple, + dict_tuple_enhanced_dict_tuple, + enhanced_dict_tuple_dict_tuple, + ): + assert tuple(d) == values + assert tuple(d.keys()) == keys