From 4c9ba8bdc5ca31728d8fc67f4599fadaf58ba727 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 27 Sep 2023 17:43:47 +0530 Subject: [PATCH] add logic to validate and save migration state of objects in the partition --- packages/syft/src/syft/node/node.py | 17 +++++++++++ .../syft/src/syft/protocol/data_protocol.py | 2 ++ .../object_search/migration_state_service.py | 30 +++++++++++++++++++ .../object_search/object_migration_state.py | 6 ++-- 4 files changed, 53 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 2b98ffe60b1..f7af15f36d3 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -63,6 +63,7 @@ from ..service.metadata.node_metadata import NodeMetadataV2 from ..service.network.network_service import NetworkService from ..service.notification.notification_service import NotificationService +from ..service.object_search.migration_state_service import MigrateStateService from ..service.policy.policy_service import PolicyService from ..service.project.project_service import ProjectService from ..service.queue.queue import APICallMessageHandler @@ -452,6 +453,22 @@ def root_client(self): root_client.api.refresh_api_callback() return root_client + def __validate_data_migration_state(self): + partition_to_be_migrated = [] + migration_state_service = self.get_service(MigrateStateService) + for partition_settings in self.document_store.partitions.values(): + object_type = partition_settings.object_type + canonical_name = object_type.__canonical_name__ + migration_state = migration_state_service.get_state(canonical_name) + if migration_state is not None: + if migration_state.current_version != migration_state.latest_version: + partition_to_be_migrated.append(canonical_name) + else: + migration_state.register_migration_state( + current_version=object_type.__version__, + canonical_name=canonical_name, + ) + @property def guest_client(self): return self.get_guest_client() diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 1b781bc2469..3f9e3096553 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -12,6 +12,7 @@ # relative from ..service.response import SyftException +from ..service.response import SyftSuccess from ..types.syft_object import SyftBaseObject from ..types.syft_object import SyftMigrationRegistry from ..util.util import get_env @@ -185,6 +186,7 @@ def upgrade(self): "supported": True, } self.save_state() + return SyftSuccess(message="Protocol successfully updated !!") def validate_current_state(self) -> bool: current_object_version_map = self.state[self.latest_version]["object_versions"] diff --git a/packages/syft/src/syft/service/object_search/migration_state_service.py b/packages/syft/src/syft/service/object_search/migration_state_service.py index cad1338587c..2adbd1478bd 100644 --- a/packages/syft/src/syft/service/object_search/migration_state_service.py +++ b/packages/syft/src/syft/service/object_search/migration_state_service.py @@ -42,3 +42,33 @@ def get_version( ) return migration_state.current_version + + @service_method(path="migration", name="get_state") + def get_state( + self, context: AuthedServiceContext, canonical_name: str + ) -> Union[bool, SyftError]: + result = self.stash.get_by_name( + canonical_name=canonical_name, credentials=context.credentials + ) + + if result.is_err(): + return SyftError(message=f"{result.err()}") + + return result.ok() + + @service_method(path="migration", name="register_migration_state") + def register_migration_state( + self, + context: AuthedServiceContext, + current_version: int, + canonical_name: str, + ) -> Union[SyftObjectMigrationState, SyftError]: + obj = SyftObjectMigrationState( + current_version=current_version, canonical_name=canonical_name + ) + result = self.stash.set(migration_state=obj, credentials=context.credentials) + + if result.is_err(): + return SyftError(message=f"{result.err()}") + + return result.ok() diff --git a/packages/syft/src/syft/service/object_search/object_migration_state.py b/packages/syft/src/syft/service/object_search/object_migration_state.py index 17bc2b35232..aab7749bea5 100644 --- a/packages/syft/src/syft/service/object_search/object_migration_state.py +++ b/packages/syft/src/syft/service/object_search/object_migration_state.py @@ -24,6 +24,8 @@ class SyftObjectMigrationState(SyftObject): __canonical_name__ = "SyftObjectMigrationState" __version__ = SYFT_OBJECT_VERSION_1 + __attr_unique__ = ["canonical_name"] + canonical_name: str current_version: int @@ -58,10 +60,10 @@ def __init__(self, store: DocumentStore) -> None: def set( self, credentials: SyftVerifyKey, - syft_object_metadata: SyftObjectMigrationState, + migration_state: SyftObjectMigrationState, add_permissions: Optional[List[ActionObjectPermission]] = None, ) -> Result[SyftObjectMigrationState, str]: - res = self.check_type(syft_object_metadata, self.object_type) + res = self.check_type(migration_state, self.object_type) # we dont use and_then logic here as it is hard because of the order of the arguments if res.is_err(): return res