Skip to content

Commit

Permalink
Merge pull request #8860 from OpenMined/eager-settings
Browse files Browse the repository at this point in the history
feat: make eager execution optional, disabled by default
  • Loading branch information
tcp authored May 31, 2024
2 parents 5837a36 + 3f2c1b3 commit 7953126
Show file tree
Hide file tree
Showing 16 changed files with 328 additions and 191 deletions.
46 changes: 45 additions & 1 deletion packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..serde.signature import signature_remove_self
from ..service.context import AuthedServiceContext
from ..service.context import ChangeContext
from ..service.metadata.node_metadata import NodeMetadataJSON
from ..service.response import SyftAttributeError
from ..service.response import SyftError
from ..service.response import SyftSuccess
Expand All @@ -50,6 +51,7 @@
from ..types.identity import Identity
from ..types.syft_object import SYFT_OBJECT_VERSION_1
from ..types.syft_object import SYFT_OBJECT_VERSION_2
from ..types.syft_object import SYFT_OBJECT_VERSION_3
from ..types.syft_object import SyftBaseObject
from ..types.syft_object import SyftMigrationRegistry
from ..types.syft_object import SyftObject
Expand Down Expand Up @@ -244,6 +246,7 @@ class RemoteFunction(SyftObject):

node_uid: UID
signature: Signature
refresh_api_callback: Callable | None = None
path: str
make_call: Callable
pre_kwargs: dict[str, Any] | None = None
Expand Down Expand Up @@ -306,6 +309,13 @@ def function_call(
return
result = self.make_call(api_call=api_call, cache_result=cache_result)

# TODO: annotate this on the service method decorator
API_CALLS_THAT_REQUIRE_REFRESH = ["settings.enable_eager_execution"]

if path in API_CALLS_THAT_REQUIRE_REFRESH:
if self.refresh_api_callback is not None:
self.refresh_api_callback()

result, _ = migrate_args_and_kwargs(
[result], kwargs={}, to_latest_protocol=True
)
Expand Down Expand Up @@ -507,6 +517,7 @@ def generate_remote_function(
custom_function = bool(path == "api.call_in_jobs")
remote_function = RemoteFunction(
node_uid=node_uid,
refresh_api_callback=api.refresh_api_callback,
signature=signature,
path=path,
make_call=make_call,
Expand Down Expand Up @@ -805,6 +816,38 @@ def result_needs_api_update(api_call_result: Any) -> bool:
return False


@serializable(
attrs=[
"endpoints",
"node_uid",
"node_name",
"lib_endpoints",
"communication_protocol",
]
)
class SyftAPIV2(SyftObject):
# version
__canonical_name__ = "SyftAPI"
__version__ = SYFT_OBJECT_VERSION_2

# fields
connection: NodeConnection | None = None
node_uid: UID | None = None
node_name: str | None = None
endpoints: dict[str, APIEndpoint]
lib_endpoints: dict[str, LibEndpoint] | None = None
api_module: APIModule | None = None
libs: APIModule | None = None
signing_key: SyftSigningKey | None = None
# serde / storage rules
refresh_api_callback: Callable | None = None
__user_role: ServiceRole = ServiceRole.NONE
communication_protocol: PROTOCOL_TYPE

# informs getattr does not have nasty side effects
__syft_allow_autocomplete__ = ["services"]


@instrument
@serializable(
attrs=[
Expand All @@ -818,7 +861,7 @@ def result_needs_api_update(api_call_result: Any) -> bool:
class SyftAPI(SyftObject):
# version
__canonical_name__ = "SyftAPI"
__version__ = SYFT_OBJECT_VERSION_2
__version__ = SYFT_OBJECT_VERSION_3

# fields
connection: NodeConnection | None = None
Expand All @@ -833,6 +876,7 @@ class SyftAPI(SyftObject):
refresh_api_callback: Callable | None = None
__user_role: ServiceRole = ServiceRole.NONE
communication_protocol: PROTOCOL_TYPE
metadata: NodeMetadataJSON | None = None

# informs getattr does not have nasty side effects
__syft_allow_autocomplete__ = ["services"]
Expand Down
19 changes: 15 additions & 4 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from ..serde.serializable import serializable
from ..serde.serialize import _serialize
from ..service.context import NodeServiceContext
from ..service.metadata.node_metadata import NodeMetadata
from ..service.metadata.node_metadata import NodeMetadataJSON
from ..service.metadata.node_metadata import NodeMetadataV3
from ..service.response import SyftError
from ..service.response import SyftSuccess
from ..service.user.user import UserCreate
Expand Down Expand Up @@ -241,7 +241,10 @@ def get_node_metadata(
return NodeMetadataJSON(**metadata_json)

def get_api(
self, credentials: SyftSigningKey, communication_protocol: int
self,
credentials: SyftSigningKey,
communication_protocol: int,
metadata: NodeMetadataJSON | None = None,
) -> SyftAPI:
params = {
"verify_key": str(credentials.verify_key),
Expand All @@ -264,6 +267,7 @@ def get_api(
obj.connection = self
obj.signing_key = credentials
obj.communication_protocol = communication_protocol
obj.metadata = metadata
if self.proxy_target_uid:
obj.node_uid = self.proxy_target_uid
return cast(SyftAPI, obj)
Expand Down Expand Up @@ -378,7 +382,10 @@ def to_blob_route(self, path: str, host: str | None = None) -> GridURL:
return GridURL(port=8333).with_path(path)

def get_api(
self, credentials: SyftSigningKey, communication_protocol: int
self,
credentials: SyftSigningKey,
communication_protocol: int,
metadata: NodeMetadataJSON | None = None,
) -> SyftAPI:
# todo: its a bit odd to identify a user by its verify key maybe?
if self.proxy_target_uid:
Expand All @@ -400,6 +407,7 @@ def get_api(
obj.connection = self
obj.signing_key = credentials
obj.communication_protocol = communication_protocol
obj.metadata = metadata
if self.proxy_target_uid:
obj.node_uid = self.proxy_target_uid
return obj
Expand Down Expand Up @@ -697,7 +705,7 @@ def exchange_route(
return self.api.services.network.exchange_credentials_with(
self_node_route=self_node_route,
remote_node_route=remote_node_route,
remote_node_verify_key=client.metadata.to(NodeMetadataV3).verify_key,
remote_node_verify_key=client.metadata.to(NodeMetadata).verify_key,
)
else:
raise ValueError(
Expand Down Expand Up @@ -942,7 +950,9 @@ def _fetch_api(self, credentials: SyftSigningKey) -> SyftAPI:
_api: SyftAPI = self.connection.get_api(
credentials=credentials,
communication_protocol=self.communication_protocol,
metadata=self.metadata,
)
self._fetch_node_metadata(self.credentials)

def refresh_callback() -> SyftAPI:
return self._fetch_api(self.credentials)
Expand All @@ -958,6 +968,7 @@ def refresh_callback() -> SyftAPI:
api=_api,
)
self._api = _api
self._api.metadata = self.metadata
self.services = _api.services
return _api

Expand Down
9 changes: 6 additions & 3 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from ..service.job.job_stash import JobType
from ..service.log.log_service import LogService
from ..service.metadata.metadata_service import MetadataService
from ..service.metadata.node_metadata import NodeMetadataV3
from ..service.metadata.node_metadata import NodeMetadata
from ..service.network.network_service import NetworkService
from ..service.network.utils import PeerHealthCheckTask
from ..service.notification.notification_service import NotificationService
Expand Down Expand Up @@ -1023,7 +1023,8 @@ def settings(self) -> NodeSettings:
return settings

@property
def metadata(self) -> NodeMetadataV3:
def metadata(self) -> NodeMetadata:
show_warnings = self.enable_warnings
settings_data = self.settings
name = settings_data.name
organization = settings_data.organization
Expand All @@ -1033,8 +1034,9 @@ def metadata(self) -> NodeMetadataV3:
node_side_type = (
settings_data.node_side_type.value if settings_data.node_side_type else ""
)
eager_execution_enabled = settings_data.eager_execution_enabled

return NodeMetadataV3(
return NodeMetadata(
name=name,
id=self.id,
verify_key=self.verify_key,
Expand All @@ -1046,6 +1048,7 @@ def metadata(self) -> NodeMetadataV3:
node_type=node_type,
node_side_type=node_side_type,
show_warnings=show_warnings,
eager_execution_enabled=eager_execution_enabled,
)

@property
Expand Down
85 changes: 58 additions & 27 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,66 @@
},
"dev": {
"object_versions": {
"NodeSettingsUpdate": {
"NodeMetadata": {
"5": {
"version": 5,
"hash": "70197b4725dbdea0560ed8388e4d20b76808bee988f3630c5f916ee8f48761f8",
"action": "add"
}
},
"SyftAPI": {
"3": {
"version": 3,
"hash": "0f812fdd5aecc3e3aa1a7c953bbf7f8d8b03a77c5cdbb37e981fa91c8134c9f4",
"hash": "b1b9d131a4f204ef2d56dc91bab3b945d5581080565232ede864f32015c0882a",
"action": "add"
}
},
"HTMLObject": {
"1": {
"version": 1,
"hash": "010d9aaca95f3fdfc8d1f97d01c1bd66483da774a59275b310c08d6912f7f863",
"action": "add"
}
},
"NodeSettingsUpdate": {
"2": {
"version": 2,
"hash": "e1dc9d2f30c4aae1f7359eb3fd44de5537788cd3c69be5f30c36fb019f07c261",
"action": "remove"
},
"4": {
"version": 4,
"hash": "ec783a7cd097e2bc4273a519d11023c796aebb9e3710c1d8332c0e46966d4ae0",
"action": "add"
},
"5": {
"version": 5,
"hash": "fd89638bb3d6dda9095905aab7ed2883f0b3dd5245900e8e141eec87921c2c9e",
"action": "add"
}
},
"NodeSettings": {
"4": {
"version": 4,
"hash": "318e578f8a9af213a6af0cc2c567b62196b0ff81769d808afff4dd1eb7c372b8",
"action": "add"
"3": {
"version": 3,
"hash": "2d5f6e79f074f75b5cfc2357eac7cf635b8f083421009a513240b4dbbd5a0fc1",
"action": "remove"
},
"5": {
"version": 5,
"hash": "cde18eb23fdffcfba47bc0e85efdbba1d59f1f5d6baa9c9690e1af14b35eb74e",
"action": "add"
},
"6": {
"version": 6,
"hash": "986d201418035e59b12787dfaf60aa2af17817c1894ce42ab4b982ed73127403",
"action": "add"
}
},
"BlobRetrievalByURL": {
"5": {
"version": 5,
"hash": "4934bf72bb10ac0a670c87ab735175088274e090819436563543473e64cf15e3",
"action": "add"
}
},
"EnclaveMetadata": {
Expand Down Expand Up @@ -80,6 +118,13 @@
"action": "add"
}
},
"CreateCustomImageChange": {
"3": {
"version": 3,
"hash": "e5f099940a7623f145f51f3e15b97a910a1d7fda1f67739420fed3035d1f2995",
"action": "add"
}
},
"TwinAPIContextView": {
"1": {
"version": 1,
Expand Down Expand Up @@ -181,6 +226,13 @@
"action": "add"
}
},
"NodeMetadataUpdate": {
"2": {
"version": 2,
"hash": "520ae8ffc0c057ffa827cb7b267a19fb6b92e3cf3c0a3666ac34e271b6dd0aed",
"action": "remove"
}
},
"SyncStateItem": {
"1": {
"version": 1,
Expand All @@ -194,27 +246,6 @@
"hash": "c1796e7b01c9eae0dbf59cfd5c2c2f0e7eba593e0cea615717246572b27aae4b",
"action": "remove"
}
},
"BlobRetrievalByURL": {
"5": {
"version": 5,
"hash": "4934bf72bb10ac0a670c87ab735175088274e090819436563543473e64cf15e3",
"action": "add"
}
},
"CreateCustomImageChange": {
"3": {
"version": 3,
"hash": "e5f099940a7623f145f51f3e15b97a910a1d7fda1f67739420fed3035d1f2995",
"action": "add"
}
},
"HTMLObject": {
"1": {
"version": 1,
"hash": "010d9aaca95f3fdfc8d1f97d01c1bd66483da774a59275b310c08d6912f7f863",
"action": "add"
}
}
}
}
Expand Down
Loading

0 comments on commit 7953126

Please sign in to comment.