From 741c3ce00b4f9a4f89021f67a67b71041ae91969 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 7 Aug 2024 13:52:08 +0200 Subject: [PATCH 01/11] show endpoint name for twin api jobs --- packages/syft/src/syft/server/server.py | 1 + packages/syft/src/syft/service/job/job_stash.py | 7 ++++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index b03e755f2b5..fd5eb5e82b6 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -1339,6 +1339,7 @@ def add_queueitem_to_queue( action=action, requested_by=user_id, job_type=job_type, + endpoint=queue_item.kwargs.get("path", None), ) # 🟡 TODO 36: Needs distributed lock diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 160f2703510..de966af8abe 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -107,6 +107,8 @@ class Job(SyncableSyftObject): user_code_id: UID | None = None requested_by: UID | None = None job_type: JobType = JobType.JOB + # used by JobType.TWINAPIJOB + endpoint: str | None = None __attr_searchable__ = [ "parent_job_id", @@ -452,9 +454,8 @@ def summary_html(self) -> str: try: # type_html = f'
{self.object_type_name.upper()}
' - description_html = ( - f"{self.user_code_name}" - ) + job_name = self.user_code_name or self.endpoint or "Job" + description_html = f"{job_name}" worker_summary = "" if self.job_worker_id: worker_copy_button = CopyIDButton( From 43d54afdea81c47349f648fd768ded0c485a76ed Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 7 Aug 2024 16:19:33 -0400 Subject: [PATCH 02/11] first pass at showing constants in sync widget --- packages/syft/src/syft/service/sync/diff_state.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index c54a21cd959..f10df279cc2 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -50,6 +50,7 @@ from ..job.job_stash import JobType from ..log.log import SyftLog from ..output.output_service import ExecutionOutput +from ..policy.policy import Constant from ..request.request import Request from ..response import SyftError from ..response import SyftSuccess @@ -367,6 +368,14 @@ def repr_attr_dict(self, side: str) -> dict[str, Any]: for attr in repr_attrs: value = getattr(obj, attr) res[attr] = value + + # if there are constants in UserCode input policy, add to repr + # type ignores since mypy thinks the code is unreachable for some reason + if isinstance(obj, UserCode) and obj.input_policy_init_kwargs is not None: # type: ignore + for input_policy_kwarg in obj.input_policy_init_kwargs.values(): # type: ignore + for input_val in input_policy_kwarg.values(): + if isinstance(input_val, Constant): + res[input_val.kw] = input_val.val return res def diff_attributes_str(self, side: str) -> str: From 458cd729b7064e3d120d68b5949261259a3eaffc Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Wed, 31 Jul 2024 18:47:27 +0800 Subject: [PATCH 03/11] Add pydantic validator for DictTuple so that DictTuple can be annotated with type parameters in pydantic BaseModel, e.g. class C(BaseModel): d: DictTuple[str, Dataset] previously `d: DictTuple[str, Dataset]` would raise an error and we had to settle for `d: DictTuple`. --- packages/syft/src/syft/types/dicttuple.py | 38 +++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/packages/syft/src/syft/types/dicttuple.py b/packages/syft/src/syft/types/dicttuple.py index 4fe202454f2..b7d16815e3d 100644 --- a/packages/syft/src/syft/types/dicttuple.py +++ b/packages/syft/src/syft/types/dicttuple.py @@ -7,12 +7,18 @@ from collections.abc import KeysView from collections.abc import Mapping from types import MappingProxyType +from typing import Any from typing import Generic from typing import SupportsIndex from typing import TypeVar +from typing import get_args +from typing import get_origin from typing import overload # third party +from pydantic import GetCoreSchemaHandler +from pydantic import ValidatorFunctionWrapHandler +from pydantic_core import core_schema from typing_extensions import Self _T = TypeVar("_T") @@ -233,3 +239,35 @@ def keys(self) -> KeysView[_KT]: def items(self) -> Iterable[tuple[_KT, _VT]]: return zip(self.__mapping.keys(), self) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ): + origin = get_origin(source_type) + if origin is None: # used as `x: Owner` without params + origin = source_type + kt, vt = (Any, Any) + else: + kt, vt, *_ = get_args(source_type) + + k_schema = handler.generate_schema(MappingProxyType[kt, int]) + v_schema = handler.generate_schema(vt) + + def val_k(v: cls, handler: ValidatorFunctionWrapHandler) -> cls: + handler(v.__mapping) + return v + + def val_v(v: cls, handler: ValidatorFunctionWrapHandler) -> cls: + handler(v) + return v + + return core_schema.chain_schema( + [ + core_schema.is_instance_schema(cls), + core_schema.no_info_wrap_validator_function( + val_v, core_schema.tuple_variable_schema(items_schema=v_schema) + ), + core_schema.no_info_wrap_validator_function(val_k, k_schema), + ] + ) From fb70e193ae426ed48bee67ce55b0143a93f8bd5a Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Wed, 31 Jul 2024 18:54:46 +0800 Subject: [PATCH 04/11] Use more precise typing --- packages/syft/src/syft/service/dataset/dataset.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 8ce89758de2..d9879db1e24 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -26,6 +26,7 @@ from ...types.dicttuple import DictTuple from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import generate_id @@ -596,6 +597,16 @@ def _check_asset_must_contain_mock(asset_list: list[CreateAsset]) -> None: @serializable() class DatasetPageView(SyftObject): + # version + __canonical_name__ = "DatasetPageView" + __version__ = SYFT_OBJECT_VERSION_2 + + datasets: DictTuple[str, Dataset] + total: int + + +@serializable() +class DatasetPageViewV1(SyftObject): # version __canonical_name__ = "DatasetPageView" __version__ = SYFT_OBJECT_VERSION_1 From ed8f2998ab678ee55027494ea6f8410158fe80c5 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 1 Aug 2024 10:19:24 +0800 Subject: [PATCH 05/11] Fix DictTuple typing --- .pre-commit-config.yaml | 1 - .../syft/service/dataset/dataset_service.py | 2 +- packages/syft/src/syft/types/dicttuple.py | 68 ++++++++++--------- packages/syft/src/syft/util/patch_ipython.py | 4 +- 4 files changed, 39 insertions(+), 36 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b7107dd1a7a..9dcd417cccc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -150,7 +150,6 @@ repos: name: "mypy: syft" always_run: true files: "^packages/syft/src/syft/" - exclude: "packages/syft/src/syft/types/dicttuple.py" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index c3bc68385ad..3ef76414593 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -143,7 +143,7 @@ def search( name: str, page_size: int | None = 0, page_index: int | None = 0, - ) -> DatasetPageView | SyftError: + ) -> DatasetPageView | DictTuple[str, Dataset] | SyftError: """Search a Dataset by name""" results = self.get_all(context) diff --git a/packages/syft/src/syft/types/dicttuple.py b/packages/syft/src/syft/types/dicttuple.py index b7d16815e3d..21aa5d8bd26 100644 --- a/packages/syft/src/syft/types/dicttuple.py +++ b/packages/syft/src/syft/types/dicttuple.py @@ -1,4 +1,5 @@ # stdlib +from abc import ABCMeta from collections import OrderedDict from collections import deque from collections.abc import Callable @@ -18,10 +19,11 @@ # third party from pydantic import GetCoreSchemaHandler from pydantic import ValidatorFunctionWrapHandler +from pydantic_core import CoreSchema from pydantic_core import core_schema from typing_extensions import Self -_T = TypeVar("_T") +_T = TypeVar("_T", bound="DictTuple") _KT = TypeVar("_KT") _VT = TypeVar("_VT") @@ -46,27 +48,27 @@ # to customize the way __new__ and __init__ work together, by iterating over key_value_pairs # once to extract both keys and values, then passing keys to __new__, values to __init__ # within the same function call. -class _Meta(type): +class _Meta(ABCMeta): @overload - def __call__(cls: type[_T]) -> _T: ... + def __call__(cls: type[_T], /) -> _T: ... # type: ignore[misc] @overload - def __call__(cls: type[_T], __value: Iterable[tuple[_KT, _VT]]) -> _T: ... + def __call__(cls: type[_T], __value: Iterable[tuple[_KT, _VT]], /) -> _T: ... # type: ignore[misc] @overload - def __call__(cls: type[_T], __value: Mapping[_KT, _VT]) -> _T: ... + def __call__(cls: type[_T], __value: Mapping[_KT, _VT], /) -> _T: ... # type: ignore[misc] @overload - def __call__( - cls: type[_T], __value: Iterable[_VT], __key: Collection[_KT] + def __call__( # type: ignore[misc] + cls: type[_T], __value: Iterable[_VT], __key: Collection[_KT], / ) -> _T: ... @overload - def __call__( - cls: type[_T], __value: Iterable[_VT], __key: Callable[[_VT], _KT] + def __call__( # type: ignore[misc] + cls: type[_T], __value: Iterable[_VT], __key: Callable[[_VT], _KT], / ) -> _T: ... - def __call__( + def __call__( # type: ignore[misc] cls: type[_T], __value: Iterable | None = None, __key: Callable | Collection | None = None, @@ -75,7 +77,7 @@ def __call__( # DictTuple() if __value is None and __key is None: obj = cls.__new__(cls) - obj.__init__() + obj.__init__() # type: ignore[misc] return obj # DictTuple(DictTuple(...)) @@ -85,27 +87,27 @@ def __call__( # DictTuple({"x": 123, "y": 456}) elif isinstance(__value, Mapping) and __key is None: obj = cls.__new__(cls, __value.values()) - obj.__init__(__value.keys()) + obj.__init__(__value.keys()) # type: ignore[misc] return obj # DictTuple(EnhancedDictTuple(...)) # EnhancedDictTuple(DictTuple(...)) # where EnhancedDictTuple subclasses DictTuple - elif hasattr(__value, "items") and callable(__value.items): - return cls.__call__(__value.items()) + elif callable(__value_items := getattr(__value, "items", None)): + return cls.__call__(__value_items()) # DictTuple([("x", 123), ("y", 456)]) elif isinstance(__value, Iterable) and __key is None: keys = OrderedDict() - values = deque() + values: deque = deque() for i, (k, v) in enumerate(__value): keys[k] = i values.append(v) obj = cls.__new__(cls, values) - obj.__init__(keys) + obj.__init__(keys) # type: ignore[misc] return obj @@ -114,15 +116,15 @@ def __call__( keys = OrderedDict((k, i) for i, k in enumerate(__key)) obj = cls.__new__(cls, __value) - obj.__init__(keys) + obj.__init__(keys) # type: ignore[misc] return obj # DictTuple(["abc", "xyz"], lambda x: x[0]) # equivalent to DictTuple({"a": "abc", "x": "xyz"}) elif isinstance(__value, Iterable) and isinstance(__key, Callable): - obj = cls.__new__(cls, __value) - obj.__init__(__key) + obj = cls.__new__(cls, __value) # type: ignore[misc] + obj.__init__(__key) # type: ignore[misc] return obj @@ -171,21 +173,23 @@ class DictTuple(tuple[_VT, ...], Generic[_KT, _VT], metaclass=_Meta): # These overloads are copied from _Meta.__call__ just for IDE hints @overload - def __init__(self) -> None: ... + def __init__(self, /) -> None: ... @overload - def __init__(self, __value: Iterable[tuple[_KT, _VT]]) -> None: ... + def __init__(self, __value: Iterable[tuple[_KT, _VT]], /) -> None: ... @overload - def __init__(self, __value: Mapping[_KT, _VT]) -> None: ... + def __init__(self, __value: Mapping[_KT, _VT], /) -> None: ... @overload - def __init__(self, __value: Iterable[_VT], __key: Collection[_KT]) -> None: ... + def __init__(self, __value: Iterable[_VT], __key: Collection[_KT], /) -> None: ... @overload - def __init__(self, __value: Iterable[_VT], __key: Callable[[_VT], _KT]) -> None: ... + def __init__( + self, __value: Iterable[_VT], __key: Callable[[_VT], _KT], / + ) -> None: ... - def __init__(self, __value=None, /): + def __init__(self, __value: Any = None, /) -> None: # type: ignore[misc] if isinstance(__value, MappingProxyType): self.__mapping = __value elif isinstance(__value, Mapping): @@ -210,16 +214,16 @@ def __init__(self, __value=None, /): "or implement `__index__()`" ) - @overload - def __getitem__(self, __key: _KT) -> _VT: ... + @overload # type: ignore[override] + def __getitem__(self, __key: _KT, /) -> _VT: ... - @overload - def __getitem__(self, __key: slice) -> Self: ... + @overload # type: ignore[overload-overlap] + def __getitem__(self, __key: slice, /) -> Self: ... @overload - def __getitem__(self, __key: SupportsIndex) -> _VT: ... + def __getitem__(self, __key: SupportsIndex, /) -> _VT: ... - def __getitem__(self, __key, /): + def __getitem__(self, __key: _KT | slice | SupportsIndex, /) -> _VT | Self: if isinstance(__key, slice): return self.__class__( super().__getitem__(__key), @@ -243,7 +247,7 @@ def items(self) -> Iterable[tuple[_KT, _VT]]: @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler - ): + ) -> CoreSchema: origin = get_origin(source_type) if origin is None: # used as `x: Owner` without params origin = source_type diff --git a/packages/syft/src/syft/util/patch_ipython.py b/packages/syft/src/syft/util/patch_ipython.py index 99572e79ff5..9de05bb8049 100644 --- a/packages/syft/src/syft/util/patch_ipython.py +++ b/packages/syft/src/syft/util/patch_ipython.py @@ -78,8 +78,8 @@ def _patch_ipython_sanitization() -> None: escaped_itable_template = re.compile(itable_template, re.DOTALL) def display_sanitized_html(obj: SyftObject | DictTuple) -> str | None: - if callable(getattr(obj, "_repr_html_", None)): - html_str = obj._repr_html_() + if callable(obj_repr_html_ := getattr(obj, "_repr_html_", None)): + html_str = obj_repr_html_() if html_str is not None: # find matching table and jobs matching_table = escaped_template.findall(html_str) From 1b023b3a820f8f7452f666c40338513bc14712f2 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 1 Aug 2024 10:29:17 +0800 Subject: [PATCH 06/11] Add comments --- packages/syft/src/syft/types/dicttuple.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/packages/syft/src/syft/types/dicttuple.py b/packages/syft/src/syft/types/dicttuple.py index 21aa5d8bd26..908feb016de 100644 --- a/packages/syft/src/syft/types/dicttuple.py +++ b/packages/syft/src/syft/types/dicttuple.py @@ -244,6 +244,13 @@ def keys(self) -> KeysView[_KT]: def items(self) -> Iterable[tuple[_KT, _VT]]: return zip(self.__mapping.keys(), self) + # https://docs.pydantic.dev/latest/concepts/types/#handling-custom-generic-classes + # pydantic validator + # this enables annotating a field with DictTuple[K, V] instead of just DictTuple + # inside a pydantic BaseModel, e.g. + # + # class DatasetPageView(BaseModel): + # datasets: DictTuple[str, Dataset] @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler @@ -266,6 +273,10 @@ def val_v(v: cls, handler: ValidatorFunctionWrapHandler) -> cls: handler(v) return v + # pydantic validator for DictTuple[K, V] + # - check that object has type DictTuple + # - check that object is a tuple[V] + # - check that the keys have type K return core_schema.chain_schema( [ core_schema.is_instance_schema(cls), From cb4c697d01955a917c8c4c1ed57022da52905710 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 1 Aug 2024 10:30:57 +0800 Subject: [PATCH 07/11] Remove wrong comment --- packages/syft/src/syft/service/dataset/dataset.py | 3 --- packages/syft/src/syft/types/dicttuple.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index d9879db1e24..eab1f63b103 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -597,7 +597,6 @@ def _check_asset_must_contain_mock(asset_list: list[CreateAsset]) -> None: @serializable() class DatasetPageView(SyftObject): - # version __canonical_name__ = "DatasetPageView" __version__ = SYFT_OBJECT_VERSION_2 @@ -607,7 +606,6 @@ class DatasetPageView(SyftObject): @serializable() class DatasetPageViewV1(SyftObject): - # version __canonical_name__ = "DatasetPageView" __version__ = SYFT_OBJECT_VERSION_1 @@ -617,7 +615,6 @@ class DatasetPageViewV1(SyftObject): @serializable() class CreateDataset(Dataset): - # version __canonical_name__ = "CreateDataset" __version__ = SYFT_OBJECT_VERSION_1 asset_list: list[CreateAsset] = [] diff --git a/packages/syft/src/syft/types/dicttuple.py b/packages/syft/src/syft/types/dicttuple.py index 908feb016de..4a019052c09 100644 --- a/packages/syft/src/syft/types/dicttuple.py +++ b/packages/syft/src/syft/types/dicttuple.py @@ -256,7 +256,7 @@ def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: origin = get_origin(source_type) - if origin is None: # used as `x: Owner` without params + if origin is None: origin = source_type kt, vt = (Any, Any) else: From d3e4799196edc9232f91c5dc57947d283bd1c9df Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Thu, 8 Aug 2024 14:31:27 +0800 Subject: [PATCH 08/11] Update protocol version --- packages/syft/src/syft/protocol/protocol_version.json | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index fc112b0fcca..4f42196228a 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -23,6 +23,13 @@ "hash": "cf6c1cb55d569af9823d8541ca038806bd350450a919345244ed4f432a099f34", "action": "add" } + }, + "DatasetPageView": { + "2": { + "version": 2, + "hash": "be1ca6dcd0b3aa0481ce5dce737e78432d06a78ad0c701aaf136be407c798352", + "action": "add" + } } } } From 3ae891eea3f021345921892f66031a0d0b171921 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 8 Aug 2024 09:37:34 +0200 Subject: [PATCH 09/11] add job migrations --- .../syft/src/syft/service/job/job_stash.py | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index de966af8abe..1009dc216e0 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable from datetime import datetime from datetime import timedelta from datetime import timezone @@ -31,9 +32,12 @@ from ...store.document_store import UIDPartitionKey from ...types.datetime import DateTime from ...types.datetime import format_timedelta +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.syncable_object import SyncableSyftObject +from ...types.transforms import make_set_default from ...types.uid import UID from ...util.markdown import as_markdown_code from ...util.telemetry import instrument @@ -86,7 +90,7 @@ def __str__(self) -> str: @serializable() class Job(SyncableSyftObject): __canonical_name__ = "JobItem" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 id: UID server_uid: UID @@ -931,3 +935,36 @@ def get_by_user_code_id( ) return self.query_all(credentials=credentials, qks=qks) + + +@serializable() +class JobV1(SyncableSyftObject): + __canonical_name__ = "JobItem" + __version__ = SYFT_OBJECT_VERSION_1 + + id: UID + server_uid: UID + result: Any | None = None + resolved: bool = False + status: JobStatus = JobStatus.CREATED + log_id: UID | None = None + parent_job_id: UID | None = None + n_iters: int | None = 0 + current_iter: int | None = None + creation_time: str | None = Field( + default_factory=lambda: str(datetime.now(tz=timezone.utc)) + ) + action: Action | None = None + job_pid: int | None = None + job_worker_id: UID | None = None + updated_at: DateTime | None = None + user_code_id: UID | None = None + requested_by: UID | None = None + job_type: JobType = JobType.JOB + + +@migrate(JobV1, Job) +def migrate_job_update_v1_current() -> list[Callable]: + return [ + make_set_default("endpoint", None), + ] From c9136c20ebd802716463147afe42e34d6cbe4ca0 Mon Sep 17 00:00:00 2001 From: alfred-openmined-bot <145415986+alfred-openmined-bot@users.noreply.github.com> Date: Thu, 8 Aug 2024 08:30:04 +0000 Subject: [PATCH 10/11] bump protocol and remove notebooks --- packages/syft/src/syft/protocol/protocol_version.json | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 4f42196228a..aec411969b2 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -30,6 +30,13 @@ "hash": "be1ca6dcd0b3aa0481ce5dce737e78432d06a78ad0c701aaf136be407c798352", "action": "add" } + }, + "JobItem": { + "2": { + "version": 2, + "hash": "b087d0c62b7d304c6ca80e4fb0e8a7f2a444be8f8cba57490dc09aeb98033105", + "action": "add" + } } } } From 4c3e44c395c05a1a6b04bdd36bef38b4bb3c3e26 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Thu, 8 Aug 2024 14:18:50 -0400 Subject: [PATCH 11/11] update get_mb_size util function to handle collections --- packages/syft/src/syft/util/util.py | 65 ++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 43620c4cab5..e0d729ba123 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -1,6 +1,7 @@ # stdlib import asyncio from asyncio.selector_events import BaseSelectorEventLoop +from collections import deque from collections.abc import Callable from collections.abc import Iterator from collections.abc import Sequence @@ -11,6 +12,7 @@ from datetime import datetime import functools import hashlib +from itertools import chain from itertools import repeat import json import logging @@ -29,6 +31,7 @@ from secrets import randbelow import socket import sys +from sys import getsizeof import threading import time import types @@ -93,8 +96,66 @@ def get_name_for(klass: type) -> str: return klass_name -def get_mb_size(data: Any) -> float: - return sys.getsizeof(data) / (1024 * 1024) +def get_mb_size(data: Any, handlers: dict | None = None) -> float: + """Returns the approximate memory footprint an object and all of its contents. + + Automatically finds the contents of the following builtin containers and + their subclasses: tuple, list, deque, dict, set and frozenset. + Otherwise, tries to read from the __slots__ or __dict__ of the object. + To search other containers, add handlers to iterate over their contents: + + handlers = {SomeContainerClass: iter, + OtherContainerClass: OtherContainerClass.get_elements} + + Lightly modified from + https://code.activestate.com/recipes/577504-compute-memory-footprint-of-an-object-and-its-cont/ + which is referenced in official sys.getsizeof documentation + https://docs.python.org/3/library/sys.html#sys.getsizeof. + + """ + + def dict_handler(d: dict[Any, Any]) -> Iterator[Any]: + return chain.from_iterable(d.items()) + + all_handlers = { + tuple: iter, + list: iter, + deque: iter, + dict: dict_handler, + set: iter, + frozenset: iter, + } + if handlers: + all_handlers.update(handlers) # user handlers take precedence + seen = set() # track which object id's have already been seen + default_size = getsizeof(0) # estimate sizeof object without __sizeof__ + + def sizeof(o: Any) -> int: + if id(o) in seen: # do not double count the same object + return 0 + seen.add(id(o)) + s = getsizeof(o, default_size) + + for typ, handler in all_handlers.items(): + if isinstance(o, typ): + s += sum(map(sizeof, handler(o))) # type: ignore + break + else: + # no __slots__ *usually* means a __dict__, but some special builtin classes + # (such as `type(None)`) have neither else, `o` has no attributes at all, + # so sys.getsizeof() actually returned the correct value + if not hasattr(o.__class__, "__slots__"): + if hasattr(o, "__dict__"): + s += sizeof(o.__dict__) + else: + s += sum( + sizeof(getattr(o, x)) + for x in o.__class__.__slots__ + if hasattr(o, x) + ) + return s + + return sizeof(data) / (1024.0 * 1024.0) def get_mb_serialized_size(data: Any) -> float: