Skip to content

Commit

Permalink
Merge pull request #8164 from kiendang/dataset-index-2
Browse files Browse the repository at this point in the history
Use `DictTuple(tuple)` for `Dataset.assets` and `SyftClient.datasets` instead of `TupleDict(OrderedDict)`
  • Loading branch information
madhavajay authored Oct 30, 2023
2 parents 243663f + a88c9c7 commit 348385b
Show file tree
Hide file tree
Showing 12 changed files with 612 additions and 98 deletions.
2 changes: 1 addition & 1 deletion docs/source/api_reference/syft.service.dataset.dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ syft.service.dataset.dataset
CreateDataset
Dataset
DatasetUpdate
TupleDict
DictTuple



Expand Down
26 changes: 24 additions & 2 deletions notebooks/tutorials/data-owner/01-uploading-private-data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"metadata": {},
"outputs": [],
"source": [
"node = sy.orchestra.launch(name=\"private-data-example-domain-1\",port=8040, reset=True)"
"node = sy.orchestra.launch(name=\"private-data-example-domain-1\", port=\"auto\", reset=True)"
]
},
{
Expand Down Expand Up @@ -158,6 +158,16 @@
"client.datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af495cad",
"metadata": {},
"outputs": [],
"source": [
"client.datasets[\"my dataset\"]"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -265,6 +275,18 @@
"source": [
"## High Side vs Low Side"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c13cdaa2",
"metadata": {},
"outputs": [],
"source": [
"# Cleanup local domain server\n",
"if node.node_type.value == \"python\":\n",
" node.land()"
]
}
],
"metadata": {
Expand All @@ -283,7 +305,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.5"
},
"toc": {
"base_numbering": 1,
Expand Down
31 changes: 19 additions & 12 deletions packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# stdlib
from collections import defaultdict
from collections.abc import MutableMapping
from collections.abc import MutableSequence
import hashlib
import json
import os
Expand All @@ -21,6 +23,7 @@
from ..service.response import SyftError
from ..service.response import SyftException
from ..service.response import SyftSuccess
from ..types.dicttuple import DictTuple
from ..types.syft_object import SyftBaseObject

PROTOCOL_STATE_FILENAME = "protocol_version.json"
Expand Down Expand Up @@ -352,20 +355,25 @@ def check_or_stage_protocol() -> Result[SyftSuccess, SyftError]:

def debox_arg_and_migrate(arg: Any, protocol_state: dict):
"""Debox the argument based on whether it is iterable or single entity."""
box_to_result_type = None

if type(arg) in OkErr:
box_to_result_type = type(arg)
arg = arg.value
constructor = None
extra_args = []

single_entity = False
is_tuple = isinstance(arg, tuple)

if isinstance(arg, (list, tuple)):
if isinstance(arg, OkErr):
constructor = type(arg)
arg = arg.value

if isinstance(arg, MutableMapping):
iterable_keys = arg.keys()
elif isinstance(arg, MutableSequence):
iterable_keys = range(len(arg))
elif isinstance(arg, tuple):
iterable_keys = range(len(arg))
constructor = type(arg)
if isinstance(arg, DictTuple):
extra_args.append(arg.keys())
arg = list(arg)
elif isinstance(arg, dict):
iterable_keys = arg.keys()
else:
iterable_keys = range(1)
arg = [arg]
Expand All @@ -385,9 +393,8 @@ def debox_arg_and_migrate(arg: Any, protocol_state: dict):
arg[key] = _object

wrapped_arg = arg[0] if single_entity else arg
wrapped_arg = tuple(wrapped_arg) if is_tuple else wrapped_arg
if box_to_result_type is not None:
wrapped_arg = box_to_result_type(wrapped_arg)
if constructor is not None:
wrapped_arg = constructor(wrapped_arg, *extra_args)

return wrapped_arg

Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@
"DatasetPageView": {
"1": {
"version": 1,
"hash": "68c7a0c3e7796fdabb8f732c6d150ec4a8071ce78d69b30da18393afdcea1e59",
"hash": "6741bd16dc6089d9deea37b1bd4e895152d1a0c163b8bdfe45280b9bfc4a1354",
"action": "add"
}
},
Expand Down
17 changes: 13 additions & 4 deletions packages/syft/src/syft/serde/recursive_primitives.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
from collections import OrderedDict
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Mapping
from enum import Enum
from enum import EnumMeta
Expand Down Expand Up @@ -71,23 +72,31 @@ def deserialize_iterable(iterable_type: type, blob: bytes) -> Collection:
return iterable_type(values)


def serialize_kv(map: Mapping) -> bytes:
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")


def _serialize_kv_pairs(size: int, kv_pairs: Iterable[tuple[_KT, _VT]]) -> bytes:
# relative
from .serialize import _serialize

message = kv_iterable_schema.new_message()

message.init("keys", len(map))
message.init("values", len(map))
message.init("keys", size)
message.init("values", size)

for index, (k, v) in enumerate(map.items()):
for index, (k, v) in enumerate(kv_pairs):
message.keys[index] = _serialize(k, to_bytes=True)
serialized = _serialize(v, to_bytes=True)
chunk_bytes(serialized, index, message.values)

return message.to_bytes()


def serialize_kv(map: Mapping) -> bytes:
return _serialize_kv_pairs(len(map), map.items())


def get_deserialized_kv_pairs(blob: bytes) -> List[Any]:
# relative
from .deserialize import _deserialize
Expand Down
23 changes: 18 additions & 5 deletions packages/syft/src/syft/serde/third_party.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@
import zmq.green as zmq

# relative
from ..types.tupledict import TupleDict
from ..types.dicttuple import DictTuple
from ..types.dicttuple import _Meta as _DictTupleMetaClass
from .deserialize import _deserialize as deserialize
from .recursive_primitives import _serialize_kv_pairs
from .recursive_primitives import deserialize_kv
from .recursive_primitives import deserialize_type
from .recursive_primitives import recursive_serde_register
from .recursive_primitives import recursive_serde_register_type
from .recursive_primitives import serialize_kv
from .recursive_primitives import serialize_type
from .serialize import _serialize as serialize

recursive_serde_register(
Expand Down Expand Up @@ -128,10 +131,20 @@ def deserialize_series(blob: bytes) -> Series:
deserialize=lambda x: Timestamp(deserialize(x, from_bytes=True)),
)


def _serialize_dicttuple(x: DictTuple) -> bytes:
return _serialize_kv_pairs(size=len(x), kv_pairs=zip(x.keys(), x))


recursive_serde_register(
_DictTupleMetaClass,
serialize=serialize_type,
deserialize=deserialize_type,
)
recursive_serde_register(
TupleDict,
serialize=serialize_kv,
deserialize=functools.partial(deserialize_kv, TupleDict),
DictTuple,
serialize=_serialize_dicttuple,
deserialize=functools.partial(deserialize_kv, DictTuple),
)


Expand Down
11 changes: 4 additions & 7 deletions packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
from ...serde.serializable import serializable
from ...store.document_store import PartitionKey
from ...types.datetime import DateTime
from ...types.dicttuple import DictTuple
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject
from ...types.transforms import TransformContext
from ...types.transforms import generate_id
from ...types.transforms import transform
from ...types.transforms import validate_url
from ...types.tupledict import TupleDict
from ...types.uid import UID
from ...util import options
from ...util.colors import ON_SURFACE_HIGHEST
Expand Down Expand Up @@ -525,11 +525,8 @@ def action_ids(self) -> List[UID]:
return data

@property
def assets(self) -> TupleDict[str, Asset]:
data = TupleDict()
for asset in self.asset_list:
data[asset.name] = asset
return data
def assets(self) -> DictTuple[str, Asset]:
return DictTuple((asset.name, asset) for asset in self.asset_list)

def _old_repr_markdown_(self) -> str:
_repr_str = f"Syft Dataset: {self.name}\n"
Expand Down Expand Up @@ -606,7 +603,7 @@ class DatasetPageView(SyftObject):
__canonical_name__ = "DatasetPageView"
__version__ = SYFT_OBJECT_VERSION_1

datasets: TupleDict[str, Dataset]
datasets: DictTuple[str, Dataset]
total: int


Expand Down
68 changes: 35 additions & 33 deletions packages/syft/src/syft/service/dataset/dataset_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# stdlib
from itertools import islice
from collections.abc import Collection
from typing import List
from typing import Optional
from typing import Union

# relative
from ...serde.serializable import serializable
from ...store.document_store import DocumentStore
from ...types.tupledict import TupleDict
from ...types.dicttuple import DictTuple
from ...types.uid import UID
from ...util.telemetry import instrument
from ..action.action_permissions import ActionObjectPermission
Expand All @@ -31,30 +31,40 @@
from .dataset_stash import DatasetStash


def _paginate_dataset_collection(
results: TupleDict[str, Dataset],
def _paginate_collection(
collection: Collection,
page_size: Optional[int] = 0,
page_index: Optional[int] = 0,
) -> DatasetPageView:
if page_size is None or page_size <= 0:
return results
) -> Optional[slice]:
if page_size is None or page_index <= 0:
return None

# If chunk size is defined, then split list into evenly sized chunks
total = len(results)
total = len(collection)
page_index = 0 if page_index is None else page_index

if page_size > total or page_index >= total // page_size or page_index < 0:
pass
else:
results = TupleDict(
islice(
results.items(),
page_size * page_index,
min(page_size * (page_index + 1), total),
)
)
return None

return DatasetPageView(datasets=results, total=total)
start = page_size * page_index
stop = min(page_size * (page_index + 1), total)
return slice(start, stop)


def _paginate_dataset_collection(
datasets: Collection[Dataset],
page_size: Optional[int] = 0,
page_index: Optional[int] = 0,
) -> Union[DictTuple[str, Dataset], DatasetPageView]:
slice_ = _paginate_collection(datasets, page_size=page_size, page_index=page_index)
chunk = datasets[slice_] if slice_ is not None else datasets
results = DictTuple((dataset.name, dataset) for dataset in chunk)

return (
results
if slice_ is None
else DatasetPageView(datasets=results, total=len(datasets))
)


@instrument
Expand Down Expand Up @@ -104,25 +114,19 @@ def get_all(
context: AuthedServiceContext,
page_size: Optional[int] = 0,
page_index: Optional[int] = 0,
) -> Union[DatasetPageView, TupleDict[str, Dataset], SyftError]:
) -> Union[DatasetPageView, DictTuple[str, Dataset], SyftError]:
"""Get a Dataset"""
result = self.stash.get_all(context.credentials)
if not result.is_ok():
return SyftError(message=result.err())

datasets = result.ok()

results = TupleDict()
for dataset in datasets:
dataset.node_uid = context.node.id
results[dataset.name] = dataset

return (
results
if page_size <= 0 or page_size is None
else _paginate_dataset_collection(
results, page_size=page_size, page_index=page_index
)

return _paginate_dataset_collection(
datasets=datasets, page_size=page_size, page_index=page_index
)

@service_method(
Expand All @@ -141,11 +145,9 @@ def search(
if isinstance(results, SyftError):
return results

filtered_results = TupleDict(
(dataset_name, dataset)
for dataset_name, dataset in results.items()
if name in dataset_name
)
filtered_results = [
dataset for dataset_name, dataset in results.items() if name in dataset_name
]

return _paginate_dataset_collection(
filtered_results, page_size=page_size, page_index=page_index
Expand Down
Loading

0 comments on commit 348385b

Please sign in to comment.