diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b7107dd1a7a..9dcd417cccc 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -150,7 +150,6 @@ repos:
name: "mypy: syft"
always_run: true
files: "^packages/syft/src/syft/"
- exclude: "packages/syft/src/syft/types/dicttuple.py"
args: [
"--follow-imports=skip",
"--ignore-missing-imports",
diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json
index fc112b0fcca..aec411969b2 100644
--- a/packages/syft/src/syft/protocol/protocol_version.json
+++ b/packages/syft/src/syft/protocol/protocol_version.json
@@ -23,6 +23,20 @@
"hash": "cf6c1cb55d569af9823d8541ca038806bd350450a919345244ed4f432a099f34",
"action": "add"
}
+ },
+ "DatasetPageView": {
+ "2": {
+ "version": 2,
+ "hash": "be1ca6dcd0b3aa0481ce5dce737e78432d06a78ad0c701aaf136be407c798352",
+ "action": "add"
+ }
+ },
+ "JobItem": {
+ "2": {
+ "version": 2,
+ "hash": "b087d0c62b7d304c6ca80e4fb0e8a7f2a444be8f8cba57490dc09aeb98033105",
+ "action": "add"
+ }
}
}
}
diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py
index ac0b118f466..66bb446fbf5 100644
--- a/packages/syft/src/syft/server/server.py
+++ b/packages/syft/src/syft/server/server.py
@@ -1370,6 +1370,7 @@ def add_queueitem_to_queue(
action=action,
requested_by=user_id,
job_type=job_type,
+ endpoint=queue_item.kwargs.get("path", None),
)
# 🟡 TODO 36: Needs distributed lock
diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py
index 8ce89758de2..eab1f63b103 100644
--- a/packages/syft/src/syft/service/dataset/dataset.py
+++ b/packages/syft/src/syft/service/dataset/dataset.py
@@ -26,6 +26,7 @@
from ...types.dicttuple import DictTuple
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_1
+from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.transforms import TransformContext
from ...types.transforms import generate_id
@@ -596,7 +597,15 @@ def _check_asset_must_contain_mock(asset_list: list[CreateAsset]) -> None:
@serializable()
class DatasetPageView(SyftObject):
- # version
+ __canonical_name__ = "DatasetPageView"
+ __version__ = SYFT_OBJECT_VERSION_2
+
+ datasets: DictTuple[str, Dataset]
+ total: int
+
+
+@serializable()
+class DatasetPageViewV1(SyftObject):
__canonical_name__ = "DatasetPageView"
__version__ = SYFT_OBJECT_VERSION_1
@@ -606,7 +615,6 @@ class DatasetPageView(SyftObject):
@serializable()
class CreateDataset(Dataset):
- # version
__canonical_name__ = "CreateDataset"
__version__ = SYFT_OBJECT_VERSION_1
asset_list: list[CreateAsset] = []
diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py
index c3bc68385ad..3ef76414593 100644
--- a/packages/syft/src/syft/service/dataset/dataset_service.py
+++ b/packages/syft/src/syft/service/dataset/dataset_service.py
@@ -143,7 +143,7 @@ def search(
name: str,
page_size: int | None = 0,
page_index: int | None = 0,
- ) -> DatasetPageView | SyftError:
+ ) -> DatasetPageView | DictTuple[str, Dataset] | SyftError:
"""Search a Dataset by name"""
results = self.get_all(context)
diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py
index 3ebf6a65b4b..7b32b6281cf 100644
--- a/packages/syft/src/syft/service/job/job_stash.py
+++ b/packages/syft/src/syft/service/job/job_stash.py
@@ -1,4 +1,5 @@
# stdlib
+from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -31,9 +32,12 @@
from ...store.document_store import UIDPartitionKey
from ...types.datetime import DateTime
from ...types.datetime import format_timedelta
+from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
+from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.syncable_object import SyncableSyftObject
+from ...types.transforms import make_set_default
from ...types.uid import UID
from ...util.markdown import as_markdown_code
from ...util.telemetry import instrument
@@ -86,7 +90,7 @@ def __str__(self) -> str:
@serializable()
class Job(SyncableSyftObject):
__canonical_name__ = "JobItem"
- __version__ = SYFT_OBJECT_VERSION_1
+ __version__ = SYFT_OBJECT_VERSION_2
id: UID
server_uid: UID
@@ -107,6 +111,8 @@ class Job(SyncableSyftObject):
user_code_id: UID | None = None
requested_by: UID | None = None
job_type: JobType = JobType.JOB
+ # used by JobType.TWINAPIJOB
+ endpoint: str | None = None
__attr_searchable__ = [
"parent_job_id",
@@ -452,9 +458,8 @@ def summary_html(self) -> str:
try:
# type_html = f'
{self.object_type_name.upper()}
'
- description_html = (
- f"{self.user_code_name}"
- )
+ job_name = self.user_code_name or self.endpoint or "Job"
+ description_html = f"{job_name}"
worker_summary = ""
if self.job_worker_id:
worker_copy_button = CopyIDButton(
@@ -931,3 +936,36 @@ def get_by_user_code_id(
)
return self.query_all(credentials=credentials, qks=qks)
+
+
+@serializable()
+class JobV1(SyncableSyftObject):
+ __canonical_name__ = "JobItem"
+ __version__ = SYFT_OBJECT_VERSION_1
+
+ id: UID
+ server_uid: UID
+ result: Any | None = None
+ resolved: bool = False
+ status: JobStatus = JobStatus.CREATED
+ log_id: UID | None = None
+ parent_job_id: UID | None = None
+ n_iters: int | None = 0
+ current_iter: int | None = None
+ creation_time: str | None = Field(
+ default_factory=lambda: str(datetime.now(tz=timezone.utc))
+ )
+ action: Action | None = None
+ job_pid: int | None = None
+ job_worker_id: UID | None = None
+ updated_at: DateTime | None = None
+ user_code_id: UID | None = None
+ requested_by: UID | None = None
+ job_type: JobType = JobType.JOB
+
+
+@migrate(JobV1, Job)
+def migrate_job_update_v1_current() -> list[Callable]:
+ return [
+ make_set_default("endpoint", None),
+ ]
diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py
index c54a21cd959..f10df279cc2 100644
--- a/packages/syft/src/syft/service/sync/diff_state.py
+++ b/packages/syft/src/syft/service/sync/diff_state.py
@@ -50,6 +50,7 @@
from ..job.job_stash import JobType
from ..log.log import SyftLog
from ..output.output_service import ExecutionOutput
+from ..policy.policy import Constant
from ..request.request import Request
from ..response import SyftError
from ..response import SyftSuccess
@@ -367,6 +368,14 @@ def repr_attr_dict(self, side: str) -> dict[str, Any]:
for attr in repr_attrs:
value = getattr(obj, attr)
res[attr] = value
+
+ # if there are constants in UserCode input policy, add to repr
+ # type ignores since mypy thinks the code is unreachable for some reason
+ if isinstance(obj, UserCode) and obj.input_policy_init_kwargs is not None: # type: ignore
+ for input_policy_kwarg in obj.input_policy_init_kwargs.values(): # type: ignore
+ for input_val in input_policy_kwarg.values():
+ if isinstance(input_val, Constant):
+ res[input_val.kw] = input_val.val
return res
def diff_attributes_str(self, side: str) -> str:
diff --git a/packages/syft/src/syft/types/dicttuple.py b/packages/syft/src/syft/types/dicttuple.py
index 4fe202454f2..4a019052c09 100644
--- a/packages/syft/src/syft/types/dicttuple.py
+++ b/packages/syft/src/syft/types/dicttuple.py
@@ -1,4 +1,5 @@
# stdlib
+from abc import ABCMeta
from collections import OrderedDict
from collections import deque
from collections.abc import Callable
@@ -7,15 +8,22 @@
from collections.abc import KeysView
from collections.abc import Mapping
from types import MappingProxyType
+from typing import Any
from typing import Generic
from typing import SupportsIndex
from typing import TypeVar
+from typing import get_args
+from typing import get_origin
from typing import overload
# third party
+from pydantic import GetCoreSchemaHandler
+from pydantic import ValidatorFunctionWrapHandler
+from pydantic_core import CoreSchema
+from pydantic_core import core_schema
from typing_extensions import Self
-_T = TypeVar("_T")
+_T = TypeVar("_T", bound="DictTuple")
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
@@ -40,27 +48,27 @@
# to customize the way __new__ and __init__ work together, by iterating over key_value_pairs
# once to extract both keys and values, then passing keys to __new__, values to __init__
# within the same function call.
-class _Meta(type):
+class _Meta(ABCMeta):
@overload
- def __call__(cls: type[_T]) -> _T: ...
+ def __call__(cls: type[_T], /) -> _T: ... # type: ignore[misc]
@overload
- def __call__(cls: type[_T], __value: Iterable[tuple[_KT, _VT]]) -> _T: ...
+ def __call__(cls: type[_T], __value: Iterable[tuple[_KT, _VT]], /) -> _T: ... # type: ignore[misc]
@overload
- def __call__(cls: type[_T], __value: Mapping[_KT, _VT]) -> _T: ...
+ def __call__(cls: type[_T], __value: Mapping[_KT, _VT], /) -> _T: ... # type: ignore[misc]
@overload
- def __call__(
- cls: type[_T], __value: Iterable[_VT], __key: Collection[_KT]
+ def __call__( # type: ignore[misc]
+ cls: type[_T], __value: Iterable[_VT], __key: Collection[_KT], /
) -> _T: ...
@overload
- def __call__(
- cls: type[_T], __value: Iterable[_VT], __key: Callable[[_VT], _KT]
+ def __call__( # type: ignore[misc]
+ cls: type[_T], __value: Iterable[_VT], __key: Callable[[_VT], _KT], /
) -> _T: ...
- def __call__(
+ def __call__( # type: ignore[misc]
cls: type[_T],
__value: Iterable | None = None,
__key: Callable | Collection | None = None,
@@ -69,7 +77,7 @@ def __call__(
# DictTuple()
if __value is None and __key is None:
obj = cls.__new__(cls)
- obj.__init__()
+ obj.__init__() # type: ignore[misc]
return obj
# DictTuple(DictTuple(...))
@@ -79,27 +87,27 @@ def __call__(
# DictTuple({"x": 123, "y": 456})
elif isinstance(__value, Mapping) and __key is None:
obj = cls.__new__(cls, __value.values())
- obj.__init__(__value.keys())
+ obj.__init__(__value.keys()) # type: ignore[misc]
return obj
# DictTuple(EnhancedDictTuple(...))
# EnhancedDictTuple(DictTuple(...))
# where EnhancedDictTuple subclasses DictTuple
- elif hasattr(__value, "items") and callable(__value.items):
- return cls.__call__(__value.items())
+ elif callable(__value_items := getattr(__value, "items", None)):
+ return cls.__call__(__value_items())
# DictTuple([("x", 123), ("y", 456)])
elif isinstance(__value, Iterable) and __key is None:
keys = OrderedDict()
- values = deque()
+ values: deque = deque()
for i, (k, v) in enumerate(__value):
keys[k] = i
values.append(v)
obj = cls.__new__(cls, values)
- obj.__init__(keys)
+ obj.__init__(keys) # type: ignore[misc]
return obj
@@ -108,15 +116,15 @@ def __call__(
keys = OrderedDict((k, i) for i, k in enumerate(__key))
obj = cls.__new__(cls, __value)
- obj.__init__(keys)
+ obj.__init__(keys) # type: ignore[misc]
return obj
# DictTuple(["abc", "xyz"], lambda x: x[0])
# equivalent to DictTuple({"a": "abc", "x": "xyz"})
elif isinstance(__value, Iterable) and isinstance(__key, Callable):
- obj = cls.__new__(cls, __value)
- obj.__init__(__key)
+ obj = cls.__new__(cls, __value) # type: ignore[misc]
+ obj.__init__(__key) # type: ignore[misc]
return obj
@@ -165,21 +173,23 @@ class DictTuple(tuple[_VT, ...], Generic[_KT, _VT], metaclass=_Meta):
# These overloads are copied from _Meta.__call__ just for IDE hints
@overload
- def __init__(self) -> None: ...
+ def __init__(self, /) -> None: ...
@overload
- def __init__(self, __value: Iterable[tuple[_KT, _VT]]) -> None: ...
+ def __init__(self, __value: Iterable[tuple[_KT, _VT]], /) -> None: ...
@overload
- def __init__(self, __value: Mapping[_KT, _VT]) -> None: ...
+ def __init__(self, __value: Mapping[_KT, _VT], /) -> None: ...
@overload
- def __init__(self, __value: Iterable[_VT], __key: Collection[_KT]) -> None: ...
+ def __init__(self, __value: Iterable[_VT], __key: Collection[_KT], /) -> None: ...
@overload
- def __init__(self, __value: Iterable[_VT], __key: Callable[[_VT], _KT]) -> None: ...
+ def __init__(
+ self, __value: Iterable[_VT], __key: Callable[[_VT], _KT], /
+ ) -> None: ...
- def __init__(self, __value=None, /):
+ def __init__(self, __value: Any = None, /) -> None: # type: ignore[misc]
if isinstance(__value, MappingProxyType):
self.__mapping = __value
elif isinstance(__value, Mapping):
@@ -204,16 +214,16 @@ def __init__(self, __value=None, /):
"or implement `__index__()`"
)
- @overload
- def __getitem__(self, __key: _KT) -> _VT: ...
+ @overload # type: ignore[override]
+ def __getitem__(self, __key: _KT, /) -> _VT: ...
- @overload
- def __getitem__(self, __key: slice) -> Self: ...
+ @overload # type: ignore[overload-overlap]
+ def __getitem__(self, __key: slice, /) -> Self: ...
@overload
- def __getitem__(self, __key: SupportsIndex) -> _VT: ...
+ def __getitem__(self, __key: SupportsIndex, /) -> _VT: ...
- def __getitem__(self, __key, /):
+ def __getitem__(self, __key: _KT | slice | SupportsIndex, /) -> _VT | Self:
if isinstance(__key, slice):
return self.__class__(
super().__getitem__(__key),
@@ -233,3 +243,46 @@ def keys(self) -> KeysView[_KT]:
def items(self) -> Iterable[tuple[_KT, _VT]]:
return zip(self.__mapping.keys(), self)
+
+ # https://docs.pydantic.dev/latest/concepts/types/#handling-custom-generic-classes
+ # pydantic validator
+ # this enables annotating a field with DictTuple[K, V] instead of just DictTuple
+ # inside a pydantic BaseModel, e.g.
+ #
+ # class DatasetPageView(BaseModel):
+ # datasets: DictTuple[str, Dataset]
+ @classmethod
+ def __get_pydantic_core_schema__(
+ cls, source_type: Any, handler: GetCoreSchemaHandler
+ ) -> CoreSchema:
+ origin = get_origin(source_type)
+ if origin is None:
+ origin = source_type
+ kt, vt = (Any, Any)
+ else:
+ kt, vt, *_ = get_args(source_type)
+
+ k_schema = handler.generate_schema(MappingProxyType[kt, int])
+ v_schema = handler.generate_schema(vt)
+
+ def val_k(v: cls, handler: ValidatorFunctionWrapHandler) -> cls:
+ handler(v.__mapping)
+ return v
+
+ def val_v(v: cls, handler: ValidatorFunctionWrapHandler) -> cls:
+ handler(v)
+ return v
+
+ # pydantic validator for DictTuple[K, V]
+ # - check that object has type DictTuple
+ # - check that object is a tuple[V]
+ # - check that the keys have type K
+ return core_schema.chain_schema(
+ [
+ core_schema.is_instance_schema(cls),
+ core_schema.no_info_wrap_validator_function(
+ val_v, core_schema.tuple_variable_schema(items_schema=v_schema)
+ ),
+ core_schema.no_info_wrap_validator_function(val_k, k_schema),
+ ]
+ )
diff --git a/packages/syft/src/syft/util/patch_ipython.py b/packages/syft/src/syft/util/patch_ipython.py
index 99572e79ff5..9de05bb8049 100644
--- a/packages/syft/src/syft/util/patch_ipython.py
+++ b/packages/syft/src/syft/util/patch_ipython.py
@@ -78,8 +78,8 @@ def _patch_ipython_sanitization() -> None:
escaped_itable_template = re.compile(itable_template, re.DOTALL)
def display_sanitized_html(obj: SyftObject | DictTuple) -> str | None:
- if callable(getattr(obj, "_repr_html_", None)):
- html_str = obj._repr_html_()
+ if callable(obj_repr_html_ := getattr(obj, "_repr_html_", None)):
+ html_str = obj_repr_html_()
if html_str is not None:
# find matching table and jobs
matching_table = escaped_template.findall(html_str)
diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py
index 43620c4cab5..e0d729ba123 100644
--- a/packages/syft/src/syft/util/util.py
+++ b/packages/syft/src/syft/util/util.py
@@ -1,6 +1,7 @@
# stdlib
import asyncio
from asyncio.selector_events import BaseSelectorEventLoop
+from collections import deque
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
@@ -11,6 +12,7 @@
from datetime import datetime
import functools
import hashlib
+from itertools import chain
from itertools import repeat
import json
import logging
@@ -29,6 +31,7 @@
from secrets import randbelow
import socket
import sys
+from sys import getsizeof
import threading
import time
import types
@@ -93,8 +96,66 @@ def get_name_for(klass: type) -> str:
return klass_name
-def get_mb_size(data: Any) -> float:
- return sys.getsizeof(data) / (1024 * 1024)
+def get_mb_size(data: Any, handlers: dict | None = None) -> float:
+ """Returns the approximate memory footprint an object and all of its contents.
+
+ Automatically finds the contents of the following builtin containers and
+ their subclasses: tuple, list, deque, dict, set and frozenset.
+ Otherwise, tries to read from the __slots__ or __dict__ of the object.
+ To search other containers, add handlers to iterate over their contents:
+
+ handlers = {SomeContainerClass: iter,
+ OtherContainerClass: OtherContainerClass.get_elements}
+
+ Lightly modified from
+ https://code.activestate.com/recipes/577504-compute-memory-footprint-of-an-object-and-its-cont/
+ which is referenced in official sys.getsizeof documentation
+ https://docs.python.org/3/library/sys.html#sys.getsizeof.
+
+ """
+
+ def dict_handler(d: dict[Any, Any]) -> Iterator[Any]:
+ return chain.from_iterable(d.items())
+
+ all_handlers = {
+ tuple: iter,
+ list: iter,
+ deque: iter,
+ dict: dict_handler,
+ set: iter,
+ frozenset: iter,
+ }
+ if handlers:
+ all_handlers.update(handlers) # user handlers take precedence
+ seen = set() # track which object id's have already been seen
+ default_size = getsizeof(0) # estimate sizeof object without __sizeof__
+
+ def sizeof(o: Any) -> int:
+ if id(o) in seen: # do not double count the same object
+ return 0
+ seen.add(id(o))
+ s = getsizeof(o, default_size)
+
+ for typ, handler in all_handlers.items():
+ if isinstance(o, typ):
+ s += sum(map(sizeof, handler(o))) # type: ignore
+ break
+ else:
+ # no __slots__ *usually* means a __dict__, but some special builtin classes
+ # (such as `type(None)`) have neither else, `o` has no attributes at all,
+ # so sys.getsizeof() actually returned the correct value
+ if not hasattr(o.__class__, "__slots__"):
+ if hasattr(o, "__dict__"):
+ s += sizeof(o.__dict__)
+ else:
+ s += sum(
+ sizeof(getattr(o, x))
+ for x in o.__class__.__slots__
+ if hasattr(o, x)
+ )
+ return s
+
+ return sizeof(data) / (1024.0 * 1024.0)
def get_mb_serialized_size(data: Any) -> float: