Skip to content

Commit

Permalink
Merge pull request #8205 from kiendang/dicttuple-pydantic-validate
Browse files Browse the repository at this point in the history
Workaround pydantic tuple validator for DictTuple
  • Loading branch information
shubham3121 authored Nov 3, 2023
2 parents cd2c0bf + f73dbb2 commit 42d4d58
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 4 deletions.
31 changes: 31 additions & 0 deletions notebooks/tutorials/data-owner/01-uploading-private-data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,37 @@
"client.datasets[\"my dataset\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6b9105cf",
"metadata": {},
"outputs": [],
"source": [
"search_result = client.datasets.search(\"my\", page_size=1 , page_index=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ed822bd",
"metadata": {},
"outputs": [],
"source": [
"from syft.service.dataset.dataset import DatasetPageView\n",
"assert isinstance(search_result, DatasetPageView)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8513b8f5",
"metadata": {},
"outputs": [],
"source": [
"search_result.datasets"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@
"DatasetPageView": {
"1": {
"version": 1,
"hash": "6741bd16dc6089d9deea37b1bd4e895152d1a0c163b8bdfe45280b9bfc4a1354",
"hash": "b1de14bb9b6a259648dfc59b6a48fa526116afe50a689c24b8bb36fd0e6a97f8",
"action": "add"
}
},
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ class DatasetPageView(SyftObject):
__canonical_name__ = "DatasetPageView"
__version__ = SYFT_OBJECT_VERSION_1

datasets: DictTuple[str, Dataset]
datasets: DictTuple
total: int


Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/service/dataset/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _paginate_collection(
page_size: Optional[int] = 0,
page_index: Optional[int] = 0,
) -> Optional[slice]:
if page_size is None or page_index <= 0:
if page_size is None or page_size <= 0:
return None

# If chunk size is defined, then split list into evenly sized chunks
Expand All @@ -58,7 +58,7 @@ def _paginate_dataset_collection(
) -> Union[DictTuple[str, Dataset], DatasetPageView]:
slice_ = _paginate_collection(datasets, page_size=page_size, page_index=page_index)
chunk = datasets[slice_] if slice_ is not None else datasets
results = DictTuple((dataset.name, dataset) for dataset in chunk)
results = DictTuple(chunk, lambda dataset: dataset.name)

return (
results
Expand Down
16 changes: 16 additions & 0 deletions packages/syft/src/syft/types/dicttuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,30 @@ def __call__(
__key: Optional[Union[Callable, Collection]] = None,
/,
) -> _T:
# DictTuple()
if __value is None and __key is None:
obj = cls.__new__(cls)
obj.__init__()
return obj

# DictTuple(DictTuple(...))
elif type(__value) is cls:
return __value

# DictTuple({"x": 123, "y": 456})
elif isinstance(__value, Mapping) and __key is None:
obj = cls.__new__(cls, __value.values())
obj.__init__(__value.keys())

return obj

# DictTuple(EnhancedDictTuple(...))
# EnhancedDictTuple(DictTuple(...))
# where EnhancedDictTuple subclasses DictTuple
elif hasattr(__value, "items") and callable(__value.items):
return cls.__call__(__value.items())

# DictTuple([("x", 123), ("y", 456)])
elif isinstance(__value, Iterable) and __key is None:
keys = OrderedDict()
values = deque()
Expand All @@ -95,6 +108,7 @@ def __call__(

return obj

# DictTuple([123, 456], ["x", "y"])
elif isinstance(__value, Iterable) and isinstance(__key, Iterable):
keys = OrderedDict((k, i) for i, k in enumerate(__key))

Expand All @@ -103,6 +117,8 @@ def __call__(

return obj

# DictTuple(["abc", "xyz"], lambda x: x[0])
# equivalent to DictTuple({"a": "abc", "x": "xyz"})
elif isinstance(__value, Iterable) and isinstance(__key, Callable):
obj = cls.__new__(cls, __value)
obj.__init__(__key)
Expand Down
57 changes: 57 additions & 0 deletions packages/syft/tests/syft/types/dicttuple_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
from typing import Optional
from typing import TypeVar
from typing import Union
import uuid

# third party
from faker import Faker
import pytest
from typing_extensions import Self

# syft absolute
from syft.service.dataset.dataset import Contributor
from syft.service.dataset.dataset import Dataset
from syft.service.dataset.dataset import DatasetPageView
from syft.service.user.roles import Roles
from syft.types.dicttuple import DictTuple


Expand Down Expand Up @@ -219,6 +225,13 @@ def test_convert_to_dict(self, dict_tuple: DictTuple, case: Case) -> None:
def test_convert_items_to_dict(self, dict_tuple: DictTuple, case: Case) -> None:
assert dict(dict_tuple.items()) == case.mapping

def test_constructing_dicttuple_from_itself(
self, dict_tuple: DictTuple, case: Case
) -> None:
dd = DictTuple(dict_tuple)
assert tuple(dd) == tuple(case.values)
assert tuple(dd.keys()) == tuple(case.keys)


@pytest.mark.parametrize(
"args", Case(values=["z", "b"], keys=[1, 2]).constructor_args()
Expand All @@ -243,3 +256,47 @@ def test_keys_should_not_be_int(args: Callable[[], tuple]) -> None:
def test_keys_and_values_should_have_same_length(args: Callable[[], tuple]) -> None:
with pytest.raises(ValueError, match="length"):
DictTuple(*args())


def test_datasetpageview(faker: Faker):
uploader = Contributor(
name=faker.name(), role=str(Roles.UPLOADER), email=faker.email()
)

length = 10
datasets = (
Dataset(name=uuid.uuid4().hex, contributor={uploader}, uploader=uploader)
for _ in range(length)
)
dict_tuple = DictTuple(datasets, lambda d: d.name)

assert DatasetPageView(datasets=dict_tuple, total=length)


class EnhancedDictTuple(DictTuple):
pass


@pytest.mark.parametrize(
"args,case",
chain.from_iterable(
((args, c) for args in c.constructor_args()) for c in TEST_CASES
),
)
def test_subclassing_dicttuple(args: Callable[[], tuple], case: Case):
dict_tuple = DictTuple(*args())
enhanced_dict_tuple = EnhancedDictTuple(*args())
dict_tuple_enhanced_dict_tuple = DictTuple(enhanced_dict_tuple)
enhanced_dict_tuple_dict_tuple = EnhancedDictTuple(dict_tuple)

values = tuple(case.values)
keys = tuple(case.keys)

for d in (
dict_tuple,
enhanced_dict_tuple,
dict_tuple_enhanced_dict_tuple,
enhanced_dict_tuple_dict_tuple,
):
assert tuple(d) == values
assert tuple(d.keys()) == keys

0 comments on commit 42d4d58

Please sign in to comment.