diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index f5bde266ae9..eb5b1d1cc44 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -14,6 +14,8 @@ # third party from argon2 import PasswordHasher +from cachetools import TTLCache +from cachetools import cached from pydantic import field_validator import requests from requests import Response @@ -200,6 +202,8 @@ def session(self) -> Session: return self.session_cache def _make_get(self, path: str, params: dict | None = None) -> bytes: + if params is None: + return self._make_get_no_params(path) url = self.url.with_path(path) response = self.session.get( str(url), @@ -218,6 +222,26 @@ def _make_get(self, path: str, params: dict | None = None) -> bytes: return response.content + @cached(cache=TTLCache(maxsize=128, ttl=300)) + def _make_get_no_params(self, path: str) -> bytes: + print(path) + url = self.url.with_path(path) + response = self.session.get( + str(url), + headers=self.headers, + verify=verify_tls(), + proxies={}, + ) + if response.status_code != 200: + raise requests.ConnectionError( + f"Failed to fetch {url}. Response returned with code {response.status_code}" + ) + + # upgrade to tls if available + self.url = upgrade_tls(self.url, response) + + return response.content + def _make_post( self, path: str, diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index f357f867467..170c103d7a8 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -3,6 +3,7 @@ from collections.abc import Iterable from collections.abc import MutableMapping from collections.abc import MutableSequence +from functools import cache import hashlib import json from operator import itemgetter @@ -529,12 +530,20 @@ def reset_dev_protocol(self) -> None: def get_data_protocol(raise_exception: bool = False) -> DataProtocol: - return DataProtocol( + return _get_data_protocol( filename=data_protocol_file_name(), raise_exception=raise_exception, ) +@cache +def _get_data_protocol(filename: str, raise_exception: bool = False) -> DataProtocol: + return DataProtocol( + filename=filename, + raise_exception=raise_exception, + ) + + def stage_protocol_changes() -> Result[SyftSuccess, SyftError]: data_protocol = get_data_protocol(raise_exception=True) return data_protocol.stage_protocol_changes() diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b71f5aa4cc6..34c0dc70acd 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -332,6 +332,8 @@ class UserCode(SyncableSyftObject): origin_node_side_type: NodeSideType l0_deny_reason: str | None = None + _has_output_read_permissions_cache: bool | None = None + __table_coll_widths__ = [ "min-content", "auto", @@ -439,9 +441,14 @@ def _compute_status_l0( if isinstance(api, SyftError): return api node_identity = NodeIdentity.from_api(api) - is_approved = api.output.has_output_read_permissions( - self.id, self.user_verify_key - ) + + if self._has_output_read_permissions_cache is None: + is_approved = api.output.has_output_read_permissions( + self.id, self.user_verify_key + ) + self._has_output_read_permissions_cache = is_approved + else: + is_approved = self._has_output_read_permissions_cache else: # Serverside node_identity = NodeIdentity.from_node(context.node) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 011e96c25d7..ac7fe6b607f 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -1212,6 +1212,8 @@ class UserCodeStatusChange(Change): @property def code(self) -> UserCode: + if self.linked_user_code._resolve_cache: + return self.linked_user_code._resolve_cache return self.linked_user_code.resolve def get_user_code(self, context: AuthedServiceContext) -> UserCode: diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 9778e98f200..d5f8eb60caf 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -566,11 +566,11 @@ class ObjectDiffBatch(SyftObject): root_diff: ObjectDiff sync_direction: SyncDirection | None - def resolve(self) -> "ResolveWidget": + def resolve(self, build_state: bool = True) -> "ResolveWidget": # relative from .resolve_widget import ResolveWidget - return ResolveWidget(self) + return ResolveWidget(self, build_state=build_state) def walk_graph( self, @@ -1142,14 +1142,16 @@ class NodeDiff(SyftObject): include_ignored: bool = False - def resolve(self) -> "PaginatedResolveWidget | SyftSuccess": + def resolve( + self, build_state: bool = True + ) -> "PaginatedResolveWidget | SyftSuccess": if len(self.batches) == 0: return SyftSuccess(message="No batches to resolve") # relative from .resolve_widget import PaginatedResolveWidget - return PaginatedResolveWidget(batches=self.batches) + return PaginatedResolveWidget(batches=self.batches, build_state=build_state) def __getitem__(self, idx: Any) -> ObjectDiffBatch: return self.batches[idx] diff --git a/packages/syft/src/syft/service/sync/resolve_widget.py b/packages/syft/src/syft/service/sync/resolve_widget.py index 496fb7a65eb..4a868634df3 100644 --- a/packages/syft/src/syft/service/sync/resolve_widget.py +++ b/packages/syft/src/syft/service/sync/resolve_widget.py @@ -105,10 +105,19 @@ def __init__( direction: SyncDirection, with_box: bool = True, show_share_warning: bool = False, + build_state: bool = True, ): - self.low_properties = diff.repr_attr_dict("low") - self.high_properties = diff.repr_attr_dict("high") - self.statuses = diff.repr_attr_diffstatus_dict() + build_state = build_state + + if build_state: + self.low_properties = diff.repr_attr_dict("low") + self.high_properties = diff.repr_attr_dict("high") + self.statuses = diff.repr_attr_diffstatus_dict() + else: + self.low_properties = {} + self.high_properties = {} + self.statuses = {} + self.direction = direction self.diff: ObjectDiff = diff self.with_box = with_box @@ -203,9 +212,10 @@ def __init__( self, diff: ObjectDiff, direction: SyncDirection, + build_state: bool = True, ): self.direction = direction - + self.build_state = build_state self.share_private_data = False self.diff: ObjectDiff = diff self.sync: bool = False @@ -275,6 +285,7 @@ def build(self) -> widgets.VBox: self.direction, with_box=False, show_share_warning=self.show_share_button, + build_state=self.build_state, ).widget accordion, share_private_checkbox, sync_checkbox = self.build_accordion( @@ -411,8 +422,12 @@ def _on_share_private_data_change(self, change: Any) -> None: class ResolveWidget: def __init__( - self, obj_diff_batch: ObjectDiffBatch, on_sync_callback: Callable | None = None + self, + obj_diff_batch: ObjectDiffBatch, + on_sync_callback: Callable | None = None, + build_state: bool = True, ): + self.build_state = build_state self.obj_diff_batch: ObjectDiffBatch = obj_diff_batch self.id2widget: dict[ UID, CollapsableObjectDiffWidget | MainObjectDiffWidget @@ -483,6 +498,7 @@ def batch_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: CollapsableObjectDiffWidget( diff, direction=self.obj_diff_batch.sync_direction, + build_state=self.build_state, ) for diff in dependents ] @@ -498,7 +514,9 @@ def dependent_root_diff_widgets(self) -> list[CollapsableObjectDiffWidget]: ] widgets = [ CollapsableObjectDiffWidget( - diff, direction=self.obj_diff_batch.sync_direction + diff, + direction=self.obj_diff_batch.sync_direction, + build_state=self.build_state, ) for diff in other_roots ] @@ -509,6 +527,7 @@ def main_object_diff_widget(self) -> MainObjectDiffWidget: obj_diff_widget = MainObjectDiffWidget( self.obj_diff_batch.root_diff, direction=self.obj_diff_batch.sync_direction, + build_state=self.build_state, ) return obj_diff_widget @@ -712,12 +731,14 @@ class PaginatedResolveWidget: paginated by a PaginationControl widget. """ - def __init__(self, batches: list[ObjectDiffBatch]): + def __init__(self, batches: list[ObjectDiffBatch], build_state: bool = True): + self.build_state = build_state self.batches = batches self.resolve_widgets: list[ResolveWidget] = [ ResolveWidget( batch, on_sync_callback=partial(self.on_click_sync, i), + build_state=build_state, ) for i, batch in enumerate(self.batches) ] diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py index 93f63d1f8b4..6e76a799930 100644 --- a/packages/syft/src/syft/store/linked_obj.py +++ b/packages/syft/src/syft/store/linked_obj.py @@ -26,6 +26,8 @@ class LinkedObject(SyftObject): object_type: type[SyftObject] object_uid: UID + _resolve_cache: SyftObject | None = None + __exclude_sync_diff_attrs__ = ["node_uid"] def __str__(self) -> str: @@ -46,7 +48,9 @@ def resolve(self) -> SyftObject: if api is None: raise ValueError(f"api is None. You must login to {self.node_uid}") - return api.services.notifications.resolve_object(self) + resolve: SyftObject = api.services.notifications.resolve_object(self) + self._resolve_cache = resolve + return resolve def resolve_with_context(self, context: NodeServiceContext) -> Any: if context.node is None: diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 9df3f22300c..863b65581ab 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -7,6 +7,7 @@ from collections.abc import Mapping from collections.abc import Sequence from collections.abc import Set +from functools import cache from hashlib import sha256 import inspect from inspect import Signature @@ -229,6 +230,11 @@ def get_transform( ) +@cache +def cached_get_type_hints(cls: type) -> dict[str, Any]: + return typing.get_type_hints(cls) + + class SyftMigrationRegistry: __migration_version_registry__: dict[str, dict[int, str]] = {} __migration_transform_registry__: dict[str, dict[str, Callable]] = {} @@ -578,7 +584,7 @@ def _syft_set_validate_private_attrs_(self, **kwargs: Any) -> None: return # Validate and set private attributes # https://github.com/pydantic/pydantic/issues/2105 - annotations = typing.get_type_hints(self.__class__) + annotations = cached_get_type_hints(self.__class__) for attr, decl in self.__private_attributes__.items(): value = kwargs.get(attr, decl.get_default()) var_annotation = annotations.get(attr)