Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DictTuple(tuple) for Dataset.assets and SyftClient.datasets instead of TupleDict(OrderedDict) #8164

Merged
merged 32 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6277dde
Add DictTuple
kiendang Oct 13, 2023
21b58ad
Use DictTuple for Dataset.assets
kiendang Oct 13, 2023
fe71be3
Use mappingproxy for immutability
kiendang Oct 13, 2023
00cad56
Make client.datasets return DictTuple
kiendang Oct 13, 2023
9cc9f60
Fix DictTuple typing
kiendang Oct 14, 2023
de2fec4
Improve DictTuple _repr_
kiendang Oct 14, 2023
82c44eb
Add typing for DictTuple.keys()
kiendang Oct 14, 2023
47791d4
Add serde for DictTuple
kiendang Oct 14, 2023
baa44a7
Add docstring for DictTuple
kiendang Oct 14, 2023
c96b3f0
Add implementation notes
kiendang Oct 14, 2023
5c16a8e
Use if else instead of pattern matching
kiendang Oct 17, 2023
05ddc2b
More precise error message
kiendang Oct 17, 2023
0399458
Add DictTuple.items()
kiendang Oct 18, 2023
cb5fc8e
Add tests for DictTuple
kiendang Oct 18, 2023
9a0240c
Add more typing to tests
kiendang Oct 18, 2023
7b4fb6c
Add support for creating keys from callable
kiendang Oct 18, 2023
6b4842c
Add test for indexing
kiendang Oct 18, 2023
46aef9f
Fix attaching attribute for DictTuple
kiendang Oct 21, 2023
f58a9d9
Add SyftClient.datasets indexing by name example
kiendang Oct 21, 2023
8ca63c9
Use DictTuple for DatasetPageView
kiendang Oct 21, 2023
e683477
Remove TupleDict class
kiendang Oct 21, 2023
5cb369c
Fix Worker serde for Python 3.10
kiendang Oct 21, 2023
7c427b6
zip() Python 3.9 compat
kiendang Oct 21, 2023
cfcb807
Rename tupledict.py to dicttuple.py
kiendang Oct 21, 2023
f55a7a5
Optimize init
kiendang Oct 22, 2023
28b0eb3
Add a more complete test suite
kiendang Oct 22, 2023
b0b7d1e
Raise an error on attempting to create a DictTuple with int keys
kiendang Oct 23, 2023
2c2678d
Typo
kiendang Oct 23, 2023
622c630
More tests
kiendang Oct 24, 2023
6f78424
Fix tests
kiendang Oct 24, 2023
16c9d80
Add tests for converting to dict
kiendang Oct 24, 2023
a88c9c7
Improve typing
kiendang Oct 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api_reference/syft.service.dataset.dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ syft.service.dataset.dataset
CreateDataset
Dataset
DatasetUpdate
TupleDict
DictTuple



Expand Down
26 changes: 24 additions & 2 deletions notebooks/tutorials/data-owner/01-uploading-private-data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"metadata": {},
"outputs": [],
"source": [
"node = sy.orchestra.launch(name=\"private-data-example-domain-1\",port=8040, reset=True)"
"node = sy.orchestra.launch(name=\"private-data-example-domain-1\", port=\"auto\", reset=True)"
]
},
{
Expand Down Expand Up @@ -158,6 +158,16 @@
"client.datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af495cad",
"metadata": {},
"outputs": [],
"source": [
"client.datasets[\"my dataset\"]"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -265,6 +275,18 @@
"source": [
"## High Side vs Low Side"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c13cdaa2",
"metadata": {},
"outputs": [],
"source": [
"# Cleanup local domain server\n",
"if node.node_type.value == \"python\":\n",
" node.land()"
]
}
],
"metadata": {
Expand All @@ -283,7 +305,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.5"
},
"toc": {
"base_numbering": 1,
Expand Down
31 changes: 19 additions & 12 deletions packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# stdlib
from collections import defaultdict
from collections.abc import MutableMapping
from collections.abc import MutableSequence
import hashlib
import json
import os
Expand All @@ -21,6 +23,7 @@
from ..service.response import SyftError
from ..service.response import SyftException
from ..service.response import SyftSuccess
from ..types.dicttuple import DictTuple
from ..types.syft_object import SyftBaseObject

PROTOCOL_STATE_FILENAME = "protocol_version.json"
Expand Down Expand Up @@ -316,20 +319,25 @@ def check_or_stage_protocol() -> Result[SyftSuccess, SyftError]:

def debox_arg_and_migrate(arg: Any, protocol_state: dict):
"""Debox the argument based on whether it is iterable or single entity."""
box_to_result_type = None

if type(arg) in OkErr:
box_to_result_type = type(arg)
arg = arg.value
constructor = None
extra_args = []

single_entity = False
is_tuple = isinstance(arg, tuple)

if isinstance(arg, (list, tuple)):
if isinstance(arg, OkErr):
constructor = type(arg)
arg = arg.value

if isinstance(arg, MutableMapping):
iterable_keys = arg.keys()
elif isinstance(arg, MutableSequence):
iterable_keys = range(len(arg))
elif isinstance(arg, tuple):
iterable_keys = range(len(arg))
constructor = type(arg)
if isinstance(arg, DictTuple):
extra_args.append(arg.keys())
arg = list(arg)
elif isinstance(arg, dict):
iterable_keys = arg.keys()
else:
iterable_keys = range(1)
arg = [arg]
Expand All @@ -349,9 +357,8 @@ def debox_arg_and_migrate(arg: Any, protocol_state: dict):
arg[key] = _object

wrapped_arg = arg[0] if single_entity else arg
wrapped_arg = tuple(wrapped_arg) if is_tuple else wrapped_arg
if box_to_result_type is not None:
wrapped_arg = box_to_result_type(wrapped_arg)
if constructor is not None:
wrapped_arg = constructor(wrapped_arg, *extra_args)

return wrapped_arg

Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@
"DatasetPageView": {
"1": {
"version": 1,
"hash": "68c7a0c3e7796fdabb8f732c6d150ec4a8071ce78d69b30da18393afdcea1e59",
"hash": "6741bd16dc6089d9deea37b1bd4e895152d1a0c163b8bdfe45280b9bfc4a1354",
"action": "add"
}
},
Expand Down
17 changes: 13 additions & 4 deletions packages/syft/src/syft/serde/recursive_primitives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
from collections import OrderedDict
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Mapping
from enum import Enum
from enum import EnumMeta
Expand Down Expand Up @@ -71,23 +72,31 @@ def deserialize_iterable(iterable_type: type, blob: bytes) -> Collection:
return iterable_type(values)


def serialize_kv(map: Mapping) -> bytes:
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")


def _serialize_kv_pairs(size: int, kv_pairs: Iterable[tuple[_KT, _VT]]) -> bytes:
# relative
from .serialize import _serialize

message = kv_iterable_schema.new_message()

message.init("keys", len(map))
message.init("values", len(map))
message.init("keys", size)
message.init("values", size)

for index, (k, v) in enumerate(map.items()):
for index, (k, v) in enumerate(kv_pairs):
message.keys[index] = _serialize(k, to_bytes=True)
serialized = _serialize(v, to_bytes=True)
chunk_bytes(serialized, index, message.values)

return message.to_bytes()


def serialize_kv(map: Mapping) -> bytes:
return _serialize_kv_pairs(len(map), map.items())


def get_deserialized_kv_pairs(blob: bytes) -> List[Any]:
# relative
from .deserialize import _deserialize
Expand Down
23 changes: 18 additions & 5 deletions packages/syft/src/syft/serde/third_party.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
import zmq.green as zmq

# relative
from ..types.tupledict import TupleDict
from ..types.dicttuple import DictTuple
from ..types.dicttuple import _Meta as _DictTupleMetaClass
from .deserialize import _deserialize as deserialize
from .recursive_primitives import _serialize_kv_pairs
from .recursive_primitives import deserialize_kv
from .recursive_primitives import deserialize_type
from .recursive_primitives import recursive_serde_register
from .recursive_primitives import recursive_serde_register_type
from .recursive_primitives import serialize_kv
from .recursive_primitives import serialize_type
from .serialize import _serialize as serialize

recursive_serde_register(
Expand Down Expand Up @@ -128,10 +131,20 @@ def deserialize_series(blob: bytes) -> Series:
deserialize=lambda x: Timestamp(deserialize(x, from_bytes=True)),
)


def _serialize_dicttuple(x: DictTuple) -> bytes:
return _serialize_kv_pairs(size=len(x), kv_pairs=zip(x.keys(), x))


recursive_serde_register(
_DictTupleMetaClass,
serialize=serialize_type,
deserialize=deserialize_type,
)
recursive_serde_register(
TupleDict,
serialize=serialize_kv,
deserialize=functools.partial(deserialize_kv, TupleDict),
DictTuple,
serialize=_serialize_dicttuple,
deserialize=functools.partial(deserialize_kv, DictTuple),
)


Expand Down
11 changes: 4 additions & 7 deletions packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
from ...serde.serializable import serializable
from ...store.document_store import PartitionKey
from ...types.datetime import DateTime
from ...types.dicttuple import DictTuple
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject
from ...types.transforms import TransformContext
from ...types.transforms import generate_id
from ...types.transforms import transform
from ...types.transforms import validate_url
from ...types.tupledict import TupleDict
from ...types.uid import UID
from ...util import options
from ...util.colors import ON_SURFACE_HIGHEST
Expand Down Expand Up @@ -525,11 +525,8 @@ def action_ids(self) -> List[UID]:
return data

@property
def assets(self) -> TupleDict[str, Asset]:
data = TupleDict()
for asset in self.asset_list:
data[asset.name] = asset
return data
def assets(self) -> DictTuple[str, Asset]:
return DictTuple((asset.name, asset) for asset in self.asset_list)

def _old_repr_markdown_(self) -> str:
_repr_str = f"Syft Dataset: {self.name}\n"
Expand Down Expand Up @@ -606,7 +603,7 @@ class DatasetPageView(SyftObject):
__canonical_name__ = "DatasetPageView"
__version__ = SYFT_OBJECT_VERSION_1

datasets: TupleDict[str, Dataset]
datasets: DictTuple[str, Dataset]
total: int


Expand Down
68 changes: 35 additions & 33 deletions packages/syft/src/syft/service/dataset/dataset_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# stdlib
from itertools import islice
from collections.abc import Collection
from typing import List
from typing import Optional
from typing import Union

# relative
from ...serde.serializable import serializable
from ...store.document_store import DocumentStore
from ...types.tupledict import TupleDict
from ...types.dicttuple import DictTuple
from ...types.uid import UID
from ...util.telemetry import instrument
from ..action.action_permissions import ActionObjectPermission
Expand All @@ -31,30 +31,40 @@
from .dataset_stash import DatasetStash


def _paginate_dataset_collection(
results: TupleDict[str, Dataset],
def _paginate_collection(
collection: Collection,
page_size: Optional[int] = 0,
page_index: Optional[int] = 0,
) -> DatasetPageView:
if page_size is None or page_size <= 0:
return results
) -> Optional[slice]:
if page_size is None or page_index <= 0:
return None

# If chunk size is defined, then split list into evenly sized chunks
total = len(results)
total = len(collection)
page_index = 0 if page_index is None else page_index

if page_size > total or page_index >= total // page_size or page_index < 0:
pass
else:
results = TupleDict(
islice(
results.items(),
page_size * page_index,
min(page_size * (page_index + 1), total),
)
)
return None

return DatasetPageView(datasets=results, total=total)
start = page_size * page_index
stop = min(page_size * (page_index + 1), total)
return slice(start, stop)


def _paginate_dataset_collection(
datasets: Collection[Dataset],
page_size: Optional[int] = 0,
page_index: Optional[int] = 0,
) -> Union[DictTuple[str, Dataset], DatasetPageView]:
slice_ = _paginate_collection(datasets, page_size=page_size, page_index=page_index)
chunk = datasets[slice_] if slice_ is not None else datasets
results = DictTuple((dataset.name, dataset) for dataset in chunk)

return (
results
if slice_ is None
else DatasetPageView(datasets=results, total=len(datasets))
)


@instrument
Expand Down Expand Up @@ -104,25 +114,19 @@ def get_all(
context: AuthedServiceContext,
page_size: Optional[int] = 0,
page_index: Optional[int] = 0,
) -> Union[DatasetPageView, TupleDict[str, Dataset], SyftError]:
) -> Union[DatasetPageView, DictTuple[str, Dataset], SyftError]:
"""Get a Dataset"""
result = self.stash.get_all(context.credentials)
if not result.is_ok():
return SyftError(message=result.err())

datasets = result.ok()

results = TupleDict()
for dataset in datasets:
dataset.node_uid = context.node.id
results[dataset.name] = dataset

return (
results
if page_size <= 0 or page_size is None
else _paginate_dataset_collection(
results, page_size=page_size, page_index=page_index
)

return _paginate_dataset_collection(
datasets=datasets, page_size=page_size, page_index=page_index
)

@service_method(
Expand All @@ -141,11 +145,9 @@ def search(
if isinstance(results, SyftError):
return results

filtered_results = TupleDict(
(dataset_name, dataset)
for dataset_name, dataset in results.items()
if name in dataset_name
)
filtered_results = [
dataset for dataset_name, dataset in results.items() if name in dataset_name
]

return _paginate_dataset_collection(
filtered_results, page_size=page_size, page_index=page_index
Expand Down
Loading