diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 00b0673f221..c40677221e2 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -38,6 +38,7 @@ from ..serde.signature import signature_remove_self from ..service.context import AuthedServiceContext from ..service.context import ChangeContext +from ..service.metadata.node_metadata import NodeMetadataJSON from ..service.response import SyftAttributeError from ..service.response import SyftError from ..service.response import SyftSuccess @@ -50,6 +51,7 @@ from ..types.identity import Identity from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.syft_object import SYFT_OBJECT_VERSION_2 +from ..types.syft_object import SYFT_OBJECT_VERSION_3 from ..types.syft_object import SyftBaseObject from ..types.syft_object import SyftMigrationRegistry from ..types.syft_object import SyftObject @@ -244,6 +246,7 @@ class RemoteFunction(SyftObject): node_uid: UID signature: Signature + refresh_api_callback: Callable | None = None path: str make_call: Callable pre_kwargs: dict[str, Any] | None = None @@ -306,6 +309,13 @@ def function_call( return result = self.make_call(api_call=api_call, cache_result=cache_result) + # TODO: annotate this on the service method decorator + API_CALLS_THAT_REQUIRE_REFRESH = ["settings.enable_eager_execution"] + + if path in API_CALLS_THAT_REQUIRE_REFRESH: + if self.refresh_api_callback is not None: + self.refresh_api_callback() + result, _ = migrate_args_and_kwargs( [result], kwargs={}, to_latest_protocol=True ) @@ -507,6 +517,7 @@ def generate_remote_function( custom_function = bool(path == "api.call_in_jobs") remote_function = RemoteFunction( node_uid=node_uid, + refresh_api_callback=api.refresh_api_callback, signature=signature, path=path, make_call=make_call, @@ -805,6 +816,38 @@ def result_needs_api_update(api_call_result: Any) -> bool: return False +@serializable( + attrs=[ + "endpoints", + "node_uid", + "node_name", + "lib_endpoints", + "communication_protocol", + ] +) +class SyftAPIV2(SyftObject): + # version + __canonical_name__ = "SyftAPI" + __version__ = SYFT_OBJECT_VERSION_2 + + # fields + connection: NodeConnection | None = None + node_uid: UID | None = None + node_name: str | None = None + endpoints: dict[str, APIEndpoint] + lib_endpoints: dict[str, LibEndpoint] | None = None + api_module: APIModule | None = None + libs: APIModule | None = None + signing_key: SyftSigningKey | None = None + # serde / storage rules + refresh_api_callback: Callable | None = None + __user_role: ServiceRole = ServiceRole.NONE + communication_protocol: PROTOCOL_TYPE + + # informs getattr does not have nasty side effects + __syft_allow_autocomplete__ = ["services"] + + @instrument @serializable( attrs=[ @@ -818,7 +861,7 @@ def result_needs_api_update(api_call_result: Any) -> bool: class SyftAPI(SyftObject): # version __canonical_name__ = "SyftAPI" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_3 # fields connection: NodeConnection | None = None @@ -833,6 +876,7 @@ class SyftAPI(SyftObject): refresh_api_callback: Callable | None = None __user_role: ServiceRole = ServiceRole.NONE communication_protocol: PROTOCOL_TYPE + metadata: NodeMetadataJSON | None = None # informs getattr does not have nasty side effects __syft_allow_autocomplete__ = ["services"] diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 3e27b445015..498c22e7536 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -37,8 +37,8 @@ from ..serde.serializable import serializable from ..serde.serialize import _serialize from ..service.context import NodeServiceContext +from ..service.metadata.node_metadata import NodeMetadata from ..service.metadata.node_metadata import NodeMetadataJSON -from ..service.metadata.node_metadata import NodeMetadataV3 from ..service.response import SyftError from ..service.response import SyftSuccess from ..service.user.user import UserCreate @@ -241,7 +241,10 @@ def get_node_metadata( return NodeMetadataJSON(**metadata_json) def get_api( - self, credentials: SyftSigningKey, communication_protocol: int + self, + credentials: SyftSigningKey, + communication_protocol: int, + metadata: NodeMetadataJSON | None = None, ) -> SyftAPI: params = { "verify_key": str(credentials.verify_key), @@ -264,6 +267,7 @@ def get_api( obj.connection = self obj.signing_key = credentials obj.communication_protocol = communication_protocol + obj.metadata = metadata if self.proxy_target_uid: obj.node_uid = self.proxy_target_uid return cast(SyftAPI, obj) @@ -378,7 +382,10 @@ def to_blob_route(self, path: str, host: str | None = None) -> GridURL: return GridURL(port=8333).with_path(path) def get_api( - self, credentials: SyftSigningKey, communication_protocol: int + self, + credentials: SyftSigningKey, + communication_protocol: int, + metadata: NodeMetadataJSON | None = None, ) -> SyftAPI: # todo: its a bit odd to identify a user by its verify key maybe? if self.proxy_target_uid: @@ -400,6 +407,7 @@ def get_api( obj.connection = self obj.signing_key = credentials obj.communication_protocol = communication_protocol + obj.metadata = metadata if self.proxy_target_uid: obj.node_uid = self.proxy_target_uid return obj @@ -697,7 +705,7 @@ def exchange_route( return self.api.services.network.exchange_credentials_with( self_node_route=self_node_route, remote_node_route=remote_node_route, - remote_node_verify_key=client.metadata.to(NodeMetadataV3).verify_key, + remote_node_verify_key=client.metadata.to(NodeMetadata).verify_key, ) else: raise ValueError( @@ -942,7 +950,9 @@ def _fetch_api(self, credentials: SyftSigningKey) -> SyftAPI: _api: SyftAPI = self.connection.get_api( credentials=credentials, communication_protocol=self.communication_protocol, + metadata=self.metadata, ) + self._fetch_node_metadata(self.credentials) def refresh_callback() -> SyftAPI: return self._fetch_api(self.credentials) @@ -958,6 +968,7 @@ def refresh_callback() -> SyftAPI: api=_api, ) self._api = _api + self._api.metadata = self.metadata self.services = _api.services return _api diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 2340ca33ba2..cdd3e2fec93 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -65,7 +65,7 @@ from ..service.job.job_stash import JobType from ..service.log.log_service import LogService from ..service.metadata.metadata_service import MetadataService -from ..service.metadata.node_metadata import NodeMetadataV3 +from ..service.metadata.node_metadata import NodeMetadata from ..service.network.network_service import NetworkService from ..service.network.utils import PeerHealthCheckTask from ..service.notification.notification_service import NotificationService @@ -1023,7 +1023,8 @@ def settings(self) -> NodeSettings: return settings @property - def metadata(self) -> NodeMetadataV3: + def metadata(self) -> NodeMetadata: + show_warnings = self.enable_warnings settings_data = self.settings name = settings_data.name organization = settings_data.organization @@ -1033,8 +1034,9 @@ def metadata(self) -> NodeMetadataV3: node_side_type = ( settings_data.node_side_type.value if settings_data.node_side_type else "" ) + eager_execution_enabled = settings_data.eager_execution_enabled - return NodeMetadataV3( + return NodeMetadata( name=name, id=self.id, verify_key=self.verify_key, @@ -1046,6 +1048,7 @@ def metadata(self) -> NodeMetadataV3: node_type=node_type, node_side_type=node_side_type, show_warnings=show_warnings, + eager_execution_enabled=eager_execution_enabled, ) @property diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 36f59d41b20..233e7c3c355 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -13,28 +13,66 @@ }, "dev": { "object_versions": { - "NodeSettingsUpdate": { + "NodeMetadata": { + "5": { + "version": 5, + "hash": "70197b4725dbdea0560ed8388e4d20b76808bee988f3630c5f916ee8f48761f8", + "action": "add" + } + }, + "SyftAPI": { "3": { "version": 3, - "hash": "0f812fdd5aecc3e3aa1a7c953bbf7f8d8b03a77c5cdbb37e981fa91c8134c9f4", + "hash": "b1b9d131a4f204ef2d56dc91bab3b945d5581080565232ede864f32015c0882a", "action": "add" + } + }, + "HTMLObject": { + "1": { + "version": 1, + "hash": "010d9aaca95f3fdfc8d1f97d01c1bd66483da774a59275b310c08d6912f7f863", + "action": "add" + } + }, + "NodeSettingsUpdate": { + "2": { + "version": 2, + "hash": "e1dc9d2f30c4aae1f7359eb3fd44de5537788cd3c69be5f30c36fb019f07c261", + "action": "remove" }, "4": { "version": 4, "hash": "ec783a7cd097e2bc4273a519d11023c796aebb9e3710c1d8332c0e46966d4ae0", "action": "add" + }, + "5": { + "version": 5, + "hash": "fd89638bb3d6dda9095905aab7ed2883f0b3dd5245900e8e141eec87921c2c9e", + "action": "add" } }, "NodeSettings": { - "4": { - "version": 4, - "hash": "318e578f8a9af213a6af0cc2c567b62196b0ff81769d808afff4dd1eb7c372b8", - "action": "add" + "3": { + "version": 3, + "hash": "2d5f6e79f074f75b5cfc2357eac7cf635b8f083421009a513240b4dbbd5a0fc1", + "action": "remove" }, "5": { "version": 5, "hash": "cde18eb23fdffcfba47bc0e85efdbba1d59f1f5d6baa9c9690e1af14b35eb74e", "action": "add" + }, + "6": { + "version": 6, + "hash": "986d201418035e59b12787dfaf60aa2af17817c1894ce42ab4b982ed73127403", + "action": "add" + } + }, + "BlobRetrievalByURL": { + "5": { + "version": 5, + "hash": "4934bf72bb10ac0a670c87ab735175088274e090819436563543473e64cf15e3", + "action": "add" } }, "EnclaveMetadata": { @@ -80,6 +118,13 @@ "action": "add" } }, + "CreateCustomImageChange": { + "3": { + "version": 3, + "hash": "e5f099940a7623f145f51f3e15b97a910a1d7fda1f67739420fed3035d1f2995", + "action": "add" + } + }, "TwinAPIContextView": { "1": { "version": 1, @@ -181,6 +226,13 @@ "action": "add" } }, + "NodeMetadataUpdate": { + "2": { + "version": 2, + "hash": "520ae8ffc0c057ffa827cb7b267a19fb6b92e3cf3c0a3666ac34e271b6dd0aed", + "action": "remove" + } + }, "SyncStateItem": { "1": { "version": 1, @@ -194,27 +246,6 @@ "hash": "c1796e7b01c9eae0dbf59cfd5c2c2f0e7eba593e0cea615717246572b27aae4b", "action": "remove" } - }, - "BlobRetrievalByURL": { - "5": { - "version": 5, - "hash": "4934bf72bb10ac0a670c87ab735175088274e090819436563543473e64cf15e3", - "action": "add" - } - }, - "CreateCustomImageChange": { - "3": { - "version": 3, - "hash": "e5f099940a7623f145f51f3e15b97a910a1d7fda1f67739420fed3035d1f2995", - "action": "add" - } - }, - "HTMLObject": { - "1": { - "version": 1, - "hash": "010d9aaca95f3fdfc8d1f97d01c1bd66483da774a59275b310c08d6912f7f863", - "action": "add" - } } } } diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 3cfe5fe2617..dffa3d3d9de 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -428,8 +428,6 @@ def make_action_side_effect( - Ok[[Tuple[PreHookContext, Tuple[Any, ...], Dict[str, Any]]] on success - Err[str] on failure """ - # relative - try: action = context.obj.syft_make_action_with_self( op=context.op_name, @@ -794,11 +792,12 @@ def _save_to_blob_storage_(self, data: Any) -> SyftError | None: from ...types.blob_storage import CreateBlobStorageEntry if not isinstance(data, ActionDataEmpty): - if isinstance(data, BlobFile) and not data.uploaded: - api = APIRegistry.api_for( - self.syft_node_location, self.syft_client_verify_key - ) - data.upload_to_blobstorage_from_api(api) + if isinstance(data, BlobFile): + if not data.uploaded: + api = APIRegistry.api_for( + self.syft_node_location, self.syft_client_verify_key + ) + data._upload_to_blobstorage_from_api(api) else: serialized = serialize(data, to_bytes=True) size = sys.getsizeof(serialized) @@ -867,7 +866,7 @@ def is_pointer(self) -> bool: @property def syft_lineage_id(self) -> LineageID: - """Compute the LineageID of the ActionObject, using the `id` and the `syft_history_hash` memebers""" + """Compute the LineageID of the ActionObject, using the `id` and the `syft_history_hash` members""" return LineageID(self.id, self.syft_history_hash) model_config = ConfigDict(validate_assignment=True) @@ -1477,36 +1476,66 @@ def __post_init__(self) -> None: if HOOK_ALWAYS not in self.syft_pre_hooks__: self.syft_pre_hooks__[HOOK_ALWAYS] = [] - if HOOK_ON_POINTERS not in self.syft_post_hooks__: + if HOOK_ON_POINTERS not in self.syft_pre_hooks__: self.syft_pre_hooks__[HOOK_ON_POINTERS] = [] - # this should be a list as orders matters - for side_effect in [make_action_side_effect]: - if side_effect not in self.syft_pre_hooks__[HOOK_ALWAYS]: - self.syft_pre_hooks__[HOOK_ALWAYS].append(side_effect) - - for side_effect in [send_action_side_effect]: - if side_effect not in self.syft_pre_hooks__[HOOK_ON_POINTERS]: - self.syft_pre_hooks__[HOOK_ON_POINTERS].append(side_effect) - - if trace_action_side_effect not in self.syft_pre_hooks__[HOOK_ALWAYS]: - self.syft_pre_hooks__[HOOK_ALWAYS].append(trace_action_side_effect) - if HOOK_ALWAYS not in self.syft_post_hooks__: self.syft_post_hooks__[HOOK_ALWAYS] = [] if HOOK_ON_POINTERS not in self.syft_post_hooks__: self.syft_post_hooks__[HOOK_ON_POINTERS] = [] - for side_effect in [propagate_node_uid]: - if side_effect not in self.syft_post_hooks__[HOOK_ALWAYS]: - self.syft_post_hooks__[HOOK_ALWAYS].append(side_effect) + api = APIRegistry.api_for(self.syft_node_location, self.syft_client_verify_key) + eager_execution_enabled = ( + api is not None + and api.metadata is not None + and api.metadata.eager_execution_enabled + ) + + self._syft_add_pre_hooks__(eager_execution_enabled) + self._syft_add_post_hooks__(eager_execution_enabled) if isinstance(self.syft_action_data_type, ActionObject): raise Exception("Nested ActionObjects", self.syft_action_data_repr_) self.syft_history_hash = hash(self.id) + def _syft_add_pre_hooks__(self, eager_execution: bool) -> None: + """ + Add pre-hooks + + Args: + eager_execution: bool: If eager execution is enabled, hooks for + tracing and executing the action on remote are added. + """ + + # this should be a list as orders matters + for side_effect in [make_action_side_effect]: + if side_effect not in self.syft_pre_hooks__[HOOK_ALWAYS]: + self.syft_pre_hooks__[HOOK_ALWAYS].append(side_effect) + + if eager_execution: + for side_effect in [send_action_side_effect]: + if side_effect not in self.syft_pre_hooks__[HOOK_ON_POINTERS]: + self.syft_pre_hooks__[HOOK_ON_POINTERS].append(side_effect) + + if trace_action_side_effect not in self.syft_pre_hooks__[HOOK_ALWAYS]: + self.syft_pre_hooks__[HOOK_ALWAYS].append(trace_action_side_effect) + + def _syft_add_post_hooks__(self, eager_execution: bool) -> None: + """ + Add post-hooks + + Args: + eager_execution: bool: If eager execution is enabled, hooks for + tracing and executing the action on remote are added. + """ + if eager_execution: + # this should be a list as orders matters + for side_effect in [propagate_node_uid]: + if side_effect not in self.syft_post_hooks__[HOOK_ALWAYS]: + self.syft_post_hooks__[HOOK_ALWAYS].append(side_effect) + def _syft_run_pre_hooks__( self, context: PreHookContext, name: str, args: Any, kwargs: Any ) -> tuple[PreHookContext, tuple[Any, ...], dict[str, Any]]: @@ -1635,7 +1664,7 @@ def _syft_attr_propagate_ids( result.syft_node_location = context.syft_node_location result.syft_client_verify_key = context.syft_client_verify_key - # Propogate Syft blob storage entry id + # Propagate Syft blob storage entry id object_attrs = [ "syft_blob_storage_entry_id", "syft_action_data_repr_", diff --git a/packages/syft/src/syft/service/action/plan.py b/packages/syft/src/syft/service/action/plan.py index 0bab10c0958..2b89a436b64 100644 --- a/packages/syft/src/syft/service/action/plan.py +++ b/packages/syft/src/syft/service/action/plan.py @@ -68,6 +68,8 @@ def planify(func: Callable) -> ActionObject: client = worker.root_client if client is None: raise ValueError("Not able to get client for plan building") + if client.settings is not None: + client.settings.enable_eager_execution(enable=True) TraceResultRegistry.set_trace_result_for_current_thread(client=client) try: # TraceResult._client = client diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index c292b139b33..a1e7eebb799 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -8,7 +8,7 @@ from ..service import AbstractService from ..service import service_method from ..user.user_roles import GUEST_ROLE_LEVEL -from .node_metadata import NodeMetadataV3 +from .node_metadata import NodeMetadata @instrument @@ -20,7 +20,7 @@ def __init__(self, store: DocumentStore) -> None: @service_method( path="metadata.get_metadata", name="get_metadata", roles=GUEST_ROLE_LEVEL ) - def get_metadata(self, context: AuthedServiceContext) -> NodeMetadataV3: + def get_metadata(self, context: AuthedServiceContext) -> NodeMetadata: return context.node.metadata # type: ignore # @service_method(path="metadata.get_admin", name="get_admin", roles=GUEST_ROLE_LEVEL) diff --git a/packages/syft/src/syft/service/metadata/node_metadata.py b/packages/syft/src/syft/service/metadata/node_metadata.py index 746e3336cd5..de60b90a412 100644 --- a/packages/syft/src/syft/service/metadata/node_metadata.py +++ b/packages/syft/src/syft/service/metadata/node_metadata.py @@ -14,8 +14,8 @@ from ...node.credentials import SyftVerifyKey from ...protocol.data_protocol import get_data_protocol from ...serde.serializable import serializable -from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.syft_object import SYFT_OBJECT_VERSION_5 from ...types.syft_object import StorableObjectType from ...types.syft_object import SyftObject from ...types.transforms import convert_types @@ -44,24 +44,33 @@ def check_version( @serializable() -class NodeMetadataUpdate(SyftObject): - __canonical_name__ = "NodeMetadataUpdate" - __version__ = SYFT_OBJECT_VERSION_2 - - name: str | None = None - organization: str | None = None - description: str | None = None - on_board: bool | None = None - id: UID | None = None # type: ignore[assignment] - verify_key: SyftVerifyKey | None = None - highest_object_version: int | None = None - lowest_object_version: int | None = None - syft_version: str | None = None - admin_email: str | None = None +class NodeMetadata(SyftObject): + __canonical_name__ = "NodeMetadata" + __version__ = SYFT_OBJECT_VERSION_5 + + name: str + id: UID + verify_key: SyftVerifyKey + highest_version: int + lowest_version: int + syft_version: str + node_type: NodeType = NodeType.DOMAIN + organization: str = "OpenMined" + description: str = "Text" + node_side_type: str + show_warnings: bool + eager_execution_enabled: bool + + def check_version(self, client_version: str) -> bool: + return check_version( + client_version=client_version, + server_version=self.syft_version, + server_name=self.name, + ) @serializable() -class NodeMetadataV3(SyftObject): +class NodeMetadataV4(SyftObject): __canonical_name__ = "NodeMetadata" __version__ = SYFT_OBJECT_VERSION_4 @@ -98,6 +107,7 @@ class NodeMetadataJSON(BaseModel, StorableObjectType): organization: str = "OpenMined" description: str = "My cool domain" signup_enabled: bool = False + eager_execution_enabled: bool = False admin_email: str = "" node_side_type: str show_warnings: bool @@ -119,7 +129,7 @@ def check_version(self, client_version: str) -> bool: ) -@transform(NodeMetadataV3, NodeMetadataJSON) +@transform(NodeMetadata, NodeMetadataJSON) def metadata_to_json() -> list[Callable]: return [ drop(["__canonical_name__"]), @@ -130,7 +140,7 @@ def metadata_to_json() -> list[Callable]: ] -@transform(NodeMetadataJSON, NodeMetadataV3) +@transform(NodeMetadataJSON, NodeMetadata) def json_to_metadata() -> list[Callable]: return [ drop(["metadata_version", "supported_protocols"]), diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 10074353146..410b996eeca 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -33,7 +33,7 @@ from ...util.util import prompt_warning_message from ..context import AuthedServiceContext from ..data_subject.data_subject import NamePartitionKey -from ..metadata.node_metadata import NodeMetadataV3 +from ..metadata.node_metadata import NodeMetadata from ..request.request import Request from ..request.request import RequestStatus from ..request.request import SubmitRequest @@ -953,7 +953,7 @@ def node_route_to_http_connection( return HTTPConnection(url=url, proxy_target_uid=obj.proxy_target_uid) -@transform(NodeMetadataV3, NodePeer) +@transform(NodeMetadata, NodePeer) def metadata_to_peer() -> list[Callable]: return [ keep( diff --git a/packages/syft/src/syft/service/network/node_peer.py b/packages/syft/src/syft/service/network/node_peer.py index 35292dd89dd..4a5447e293d 100644 --- a/packages/syft/src/syft/service/network/node_peer.py +++ b/packages/syft/src/syft/service/network/node_peer.py @@ -23,7 +23,7 @@ from ...types.transforms import TransformContext from ...types.uid import UID from ..context import NodeServiceContext -from ..metadata.node_metadata import NodeMetadataV3 +from ..metadata.node_metadata import NodeMetadata from .routes import HTTPNodeRoute from .routes import NodeRoute from .routes import NodeRouteType @@ -210,7 +210,7 @@ def from_client(client: SyftClient) -> "NodePeer": if not client.metadata: raise ValueError("Client has to have metadata first") - peer = client.metadata.to(NodeMetadataV3).to(NodePeer) + peer = client.metadata.to(NodeMetadata).to(NodePeer) route = connection_to_route(client.connection) peer.node_routes.append(route) return peer diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index d9b84ef9f15..981f7ff9192 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -24,7 +24,7 @@ from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...serde.serialize import _serialize -from ...service.metadata.node_metadata import NodeMetadataV3 +from ...service.metadata.node_metadata import NodeMetadata from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.identity import Identity @@ -60,7 +60,7 @@ class EventAlreadyAddedException(SyftException): pass -@transform(NodeMetadataV3, NodeIdentity) +@transform(NodeMetadata, NodeIdentity) def metadata_to_node_identity() -> list[Callable]: return [rename("id", "node_id"), rename("name", "node_name")] @@ -1245,7 +1245,7 @@ def to_node_identity(val: SyftClient | NodeIdentity) -> NodeIdentity: if isinstance(val, NodeIdentity): return val elif isinstance(val, SyftClient) and val.metadata is not None: - metadata = val.metadata.to(NodeMetadataV3) + metadata = val.metadata.to(NodeMetadata) return metadata.to(NodeIdentity) else: raise SyftException( diff --git a/packages/syft/src/syft/service/settings/settings.py b/packages/syft/src/syft/service/settings/settings.py index 0ac31b9ff7d..94adfbf307c 100644 --- a/packages/syft/src/syft/service/settings/settings.py +++ b/packages/syft/src/syft/service/settings/settings.py @@ -15,9 +15,9 @@ from ...types.syft_metaclass import Empty from ...types.syft_migration import migrate from ...types.syft_object import PartialSyftObject -from ...types.syft_object import SYFT_OBJECT_VERSION_2 -from ...types.syft_object import SYFT_OBJECT_VERSION_3 from ...types.syft_object import SYFT_OBJECT_VERSION_4 +from ...types.syft_object import SYFT_OBJECT_VERSION_5 +from ...types.syft_object import SYFT_OBJECT_VERSION_6 from ...types.syft_object import SyftObject from ...types.transforms import drop from ...types.transforms import make_set_default @@ -31,36 +31,7 @@ @serializable() -class NodeSettingsUpdateV2(PartialSyftObject): - __canonical_name__ = "NodeSettingsUpdate" - __version__ = SYFT_OBJECT_VERSION_2 - - id: UID - name: str - organization: str - description: str - on_board: bool - signup_enabled: bool - admin_email: str - - -@serializable() -class NodeSettingsUpdateV3(PartialSyftObject): - __canonical_name__ = "NodeSettingsUpdate" - __version__ = SYFT_OBJECT_VERSION_3 - - id: UID - name: str - organization: str - description: str - on_board: bool - signup_enabled: bool - admin_email: str - association_request_auto_approval: bool - - -@serializable() -class NodeSettingsUpdate(PartialSyftObject): +class NodeSettingsUpdateV4(PartialSyftObject): __canonical_name__ = "NodeSettingsUpdate" __version__ = SYFT_OBJECT_VERSION_4 id: UID @@ -89,57 +60,25 @@ def validate_node_side_type(cls, v: str) -> type[Empty]: @serializable() -class NodeSettings(SyftObject): - __canonical_name__ = "NodeSettings" - __version__ = 5 - __repr_attrs__ = [ - "name", - "organization", - "deployed_on", - "signup_enabled", - "admin_email", - ] - +class NodeSettingsUpdate(PartialSyftObject): + __canonical_name__ = "NodeSettingsUpdate" + __version__ = SYFT_OBJECT_VERSION_5 id: UID - name: str = "Node" - deployed_on: str - organization: str = "OpenMined" - verify_key: SyftVerifyKey - on_board: bool = True - description: str = "Text" - node_type: NodeType = NodeType.DOMAIN + name: str + organization: str + description: str + on_board: bool signup_enabled: bool admin_email: str - node_side_type: NodeSideType = NodeSideType.HIGH_SIDE - show_warnings: bool association_request_auto_approval: bool - default_worker_pool: str = DEFAULT_WORKER_POOL_NAME - welcome_markdown: HTMLObject | MarkdownDescription = HTMLObject( - text=DEFAULT_WELCOME_MSG - ) - - def _repr_html_(self) -> Any: - return f""" - -
-

Settings

-

Id: {self.id}

-

Name: {self.name}

-

Organization: {self.organization}

-

Deployed on: {self.deployed_on}

-

Signup enabled: {self.signup_enabled}

-

Admin email: {self.admin_email}

-
- - """ + welcome_markdown: HTMLObject | MarkdownDescription + eager_execution_enabled: bool = False @serializable() -class NodeSettingsV4(SyftObject): +class NodeSettings(SyftObject): __canonical_name__ = "NodeSettings" - __version__ = SYFT_OBJECT_VERSION_4 + __version__ = SYFT_OBJECT_VERSION_6 __repr_attrs__ = [ "name", "organization", @@ -161,7 +100,11 @@ class NodeSettingsV4(SyftObject): node_side_type: NodeSideType = NodeSideType.HIGH_SIDE show_warnings: bool association_request_auto_approval: bool + eager_execution_enabled: bool = False default_worker_pool: str = DEFAULT_WORKER_POOL_NAME + welcome_markdown: HTMLObject | MarkdownDescription = HTMLObject( + text=DEFAULT_WELCOME_MSG + ) def _repr_html_(self) -> Any: return f""" @@ -182,9 +125,9 @@ def _repr_html_(self) -> Any: @serializable() -class NodeSettingsV2(SyftObject): +class NodeSettingsV5(SyftObject): __canonical_name__ = "NodeSettings" - __version__ = SYFT_OBJECT_VERSION_3 + __version__ = SYFT_OBJECT_VERSION_5 __repr_attrs__ = [ "name", "organization", @@ -205,23 +148,28 @@ class NodeSettingsV2(SyftObject): admin_email: str node_side_type: NodeSideType = NodeSideType.HIGH_SIDE show_warnings: bool + association_request_auto_approval: bool + default_worker_pool: str = DEFAULT_WORKER_POOL_NAME + welcome_markdown: HTMLObject | MarkdownDescription = HTMLObject( + text=DEFAULT_WELCOME_MSG + ) -@migrate(NodeSettingsV2, NodeSettings) +@migrate(NodeSettingsV5, NodeSettings) def upgrade_node_settings() -> list[Callable]: - return [make_set_default("association_request_auto_approval", False)] + return [make_set_default("eager_execution_enabled", False)] -@migrate(NodeSettings, NodeSettingsV2) +@migrate(NodeSettings, NodeSettingsV5) def downgrade_node_settings() -> list[Callable]: - return [drop(["association_request_auto_approval"])] + return [drop(["eager_execution_enabled"])] -@migrate(NodeSettingsUpdateV2, NodeSettingsUpdate) +@migrate(NodeSettingsUpdateV4, NodeSettingsUpdate) def upgrade_node_settings_update() -> list[Callable]: return [] -@migrate(NodeSettings, NodeSettingsV2) +@migrate(NodeSettingsUpdate, NodeSettingsUpdateV4) def downgrade_node_settings_update() -> list[Callable]: - return [drop(["association_request_auto_approval"])] + return [drop(["eager_execution_enabled"])] diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index f90abb02137..35ef9262860 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -72,6 +72,20 @@ def set( def update( self, context: AuthedServiceContext, settings: NodeSettingsUpdate ) -> Result[SyftSuccess, SyftError]: + res = self._update(context, settings) + if res.is_ok(): + return SyftSuccess( + message=( + "Settings updated successfully. " + + "You must call .refresh() to sync your client with the changes." + ) + ) + else: + return SyftError(message=res.err()) + + def _update( + self, context: AuthedServiceContext, settings: NodeSettingsUpdate + ) -> Result[Ok, Err]: """ Update the Node Settings using the provided values. @@ -97,7 +111,6 @@ def update( >>> node_client.update(name='foo', organization='bar', description='baz', signup_enabled=True) SyftSuccess: Settings updated successfully. """ - result = self.stash.get_all(context.credentials) if result.is_ok(): current_settings = result.ok() @@ -106,19 +119,11 @@ def update( update=settings.to_dict(exclude_empty=True) ) update_result = self.stash.update(context.credentials, new_settings) - if update_result.is_ok(): - return SyftSuccess( - message=( - "Settings updated successfully. " - + "You must call .refresh() to sync your client with the changes." - ) - ) - else: - return SyftError(message=update_result.err()) + return update_result else: - return SyftError(message="No settings found") + return Err(value="No settings found") else: - return SyftError(message=result.err()) + return result @service_method( path="settings.set_node_side_type_dangerous", @@ -202,10 +207,8 @@ def allow_guest_signup( """Enable/Disable Registration for Data Scientist or Guest Users.""" flags.CAN_REGISTER = enable - method = context.node.get_service_method(SettingsService.update) settings = NodeSettingsUpdate(signup_enabled=enable) - - result = method(context=context, settings=settings) + result = self._update(context=context, settings=settings) if isinstance(result, SyftError): return SyftError(message=f"Failed to update settings: {result.err()}") @@ -213,6 +216,26 @@ def allow_guest_signup( message = "enabled" if enable else "disabled" return SyftSuccess(message=f"Registration feature successfully {message}") + @service_method( + path="settings.enable_eager_execution", + name="enable_eager_execution", + roles=ADMIN_ROLE_LEVEL, + warning=HighSideCRUDWarning(confirmation=True), + ) + def enable_eager_execution( + self, context: AuthedServiceContext, enable: bool + ) -> SyftSuccess | SyftError: + """Enable/Disable eager execution.""" + settings = NodeSettingsUpdate(eager_execution_enabled=enable) + + result = self._update(context=context, settings=settings) + + if result.is_err(): + return SyftError(message=f"Failed to update settings: {result.err()}") + + message = "enabled" if enable else "disabled" + return SyftSuccess(message=f"Eager execution {message}") + @service_method( path="settings.allow_association_request_auto_approval", name="allow_association_request_auto_approval", @@ -221,7 +244,7 @@ def allow_association_request_auto_approval( self, context: AuthedServiceContext, enable: bool ) -> SyftSuccess | SyftError: new_settings = NodeSettingsUpdate(association_request_auto_approval=enable) - result = self.update(context, settings=new_settings) + result = self._update(context, settings=new_settings) if isinstance(result, SyftError): return result @@ -275,7 +298,7 @@ def welcome_customize( welcome_msg = HTMLObject(text=html) new_settings = NodeSettingsUpdate(welcome_markdown=welcome_msg) - result = self.update(context=context, settings=new_settings) + result = self._update(context=context, settings=new_settings) if isinstance(result, SyftError): return result diff --git a/packages/syft/tests/syft/action_test.py b/packages/syft/tests/syft/action_test.py index a9b2adb1c97..c28f5c31615 100644 --- a/packages/syft/tests/syft/action_test.py +++ b/packages/syft/tests/syft/action_test.py @@ -14,6 +14,7 @@ def test_actionobject_method(worker): root_domain_client = worker.root_client + assert root_domain_client.settings.enable_eager_execution(enable=True) action_store = worker.get_service("actionservice").store obj = ActionObject.from_obj("abc") pointer = root_domain_client.api.services.action.set(obj) diff --git a/packages/syft/tests/syft/eager_test.py b/packages/syft/tests/syft/eager_test.py index fcfb10d3bdb..a597a79b2d6 100644 --- a/packages/syft/tests/syft/eager_test.py +++ b/packages/syft/tests/syft/eager_test.py @@ -12,6 +12,9 @@ def test_eager_permissions(worker, guest_client): root_domain_client = worker.root_client + assert root_domain_client.settings.enable_eager_execution(enable=True) + guest_client = worker.guest_client + input_obj = TwinObject( private_obj=np.array([[3, 3, 3], [3, 3, 3]]), mock_obj=np.array([[1, 1, 1], [1, 1, 1]]), @@ -35,6 +38,7 @@ def test_eager_permissions(worker, guest_client): def test_plan(worker): root_domain_client = worker.root_client + assert root_domain_client.settings.enable_eager_execution(enable=True) guest_client = worker.guest_client @planify @@ -76,6 +80,7 @@ def my_plan(x=np.array([[2, 2, 2], [2, 2, 2]])): # noqa: B008 @currently_fail_on_python_3_12(raises=AttributeError) def test_plan_with_function_call(worker, guest_client): root_domain_client = worker.root_client + assert root_domain_client.settings.enable_eager_execution(enable=True) guest_client = worker.guest_client @planify @@ -98,12 +103,14 @@ def my_plan(x=np.array([[2, 2, 2], [2, 2, 2]])): # noqa: B008 def test_plan_with_object_instantiation(worker, guest_client): + root_domain_client = worker.root_client + assert root_domain_client.settings.enable_eager_execution(enable=True) + guest_client = worker.guest_client + @planify def my_plan(x=np.array([1, 2, 3, 4, 5, 6])): # noqa: B008 return x + 1 - root_domain_client = worker.root_client - plan_ptr = my_plan.send(guest_client) input_obj = TwinObject( @@ -123,6 +130,8 @@ def my_plan(x=np.array([1, 2, 3, 4, 5, 6])): # noqa: B008 def test_setattribute(worker, guest_client): root_domain_client = worker.root_client + assert root_domain_client.settings.enable_eager_execution(enable=True) + guest_client = worker.guest_client private_data, mock_data = ( np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), @@ -160,6 +169,9 @@ def test_setattribute(worker, guest_client): def test_getattribute(worker, guest_client): root_domain_client = worker.root_client + assert root_domain_client.settings.enable_eager_execution(enable=True) + guest_client = worker.guest_client + obj = TwinObject( private_obj=np.array([[1, 2, 3], [4, 5, 6]]), mock_obj=np.array([[1, 1, 1], [1, 1, 1]]), @@ -176,6 +188,8 @@ def test_getattribute(worker, guest_client): def test_eager_method(worker, guest_client): root_domain_client = worker.root_client + assert root_domain_client.settings.enable_eager_execution(enable=True) + guest_client = worker.guest_client obj = TwinObject( private_obj=np.array([[1, 2, 3], [4, 5, 6]]), @@ -197,6 +211,8 @@ def test_eager_method(worker, guest_client): def test_eager_dunder_method(worker, guest_client): root_domain_client = worker.root_client + assert root_domain_client.settings.enable_eager_execution(enable=True) + guest_client = worker.guest_client obj = TwinObject( private_obj=np.array([[1, 2, 3], [4, 5, 6]]), diff --git a/packages/syft/tests/syft/service/action/action_object_test.py b/packages/syft/tests/syft/service/action/action_object_test.py index fa8efab4eaf..a595fdd0e8d 100644 --- a/packages/syft/tests/syft/service/action/action_object_test.py +++ b/packages/syft/tests/syft/service/action/action_object_test.py @@ -160,6 +160,23 @@ def test_actionobject_hooks_init(orig_obj: Any): assert HOOK_ALWAYS in obj.syft_pre_hooks__ assert HOOK_ALWAYS in obj.syft_post_hooks__ + assert HOOK_ON_POINTERS in obj.syft_pre_hooks__ + assert HOOK_ON_POINTERS in obj.syft_post_hooks__ + + assert make_action_side_effect in obj.syft_pre_hooks__[HOOK_ALWAYS] + + +def test_actionobject_add_pre_hooks(): + # Eager execution is disabled by default + obj = ActionObject.from_obj(1) + + assert make_action_side_effect in obj.syft_pre_hooks__[HOOK_ALWAYS] + assert send_action_side_effect not in obj.syft_pre_hooks__[HOOK_ON_POINTERS] + assert propagate_node_uid not in obj.syft_post_hooks__[HOOK_ALWAYS] + + # eager exec tests: + obj._syft_add_pre_hooks__(eager_execution=True) + obj._syft_add_post_hooks__(eager_execution=True) assert make_action_side_effect in obj.syft_pre_hooks__[HOOK_ALWAYS] assert send_action_side_effect in obj.syft_pre_hooks__[HOOK_ON_POINTERS] @@ -566,6 +583,8 @@ def test_actionobject_syft_get_attr_context(): ) def test_actionobject_syft_execute_hooks(worker, testcase): client = worker.root_client + assert client.settings.enable_eager_execution(enable=True) + orig_obj, op, args, kwargs, expected = testcase obj = helper_make_action_obj(orig_obj) @@ -918,7 +937,7 @@ def test_actionobject_syft_getattr_int(orig_obj: int, worker, scenario): assert (3 >> obj) == (3 >> orig_obj) -def test_actionobject_syft_getattr_int_history(worker): +def test_actionobject_syft_getattr_int_history(): orig_obj = 5 obj1 = ActionObject.from_obj(orig_obj) obj2 = ActionObject.from_obj(orig_obj)