Skip to content

Commit

Permalink
add logic to validate and save migration state of objects in the part…
Browse files Browse the repository at this point in the history
…ition
  • Loading branch information
shubham3121 committed Sep 27, 2023
1 parent cda159a commit 4c9ba8b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 2 deletions.
17 changes: 17 additions & 0 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from ..service.metadata.node_metadata import NodeMetadataV2
from ..service.network.network_service import NetworkService
from ..service.notification.notification_service import NotificationService
from ..service.object_search.migration_state_service import MigrateStateService
from ..service.policy.policy_service import PolicyService
from ..service.project.project_service import ProjectService
from ..service.queue.queue import APICallMessageHandler
Expand Down Expand Up @@ -452,6 +453,22 @@ def root_client(self):
root_client.api.refresh_api_callback()
return root_client

def __validate_data_migration_state(self):
partition_to_be_migrated = []
migration_state_service = self.get_service(MigrateStateService)
for partition_settings in self.document_store.partitions.values():
object_type = partition_settings.object_type
canonical_name = object_type.__canonical_name__
migration_state = migration_state_service.get_state(canonical_name)
if migration_state is not None:
if migration_state.current_version != migration_state.latest_version:
partition_to_be_migrated.append(canonical_name)
else:
migration_state.register_migration_state(
current_version=object_type.__version__,
canonical_name=canonical_name,
)

@property
def guest_client(self):
return self.get_guest_client()
Expand Down
2 changes: 2 additions & 0 deletions packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# relative
from ..service.response import SyftException
from ..service.response import SyftSuccess
from ..types.syft_object import SyftBaseObject
from ..types.syft_object import SyftMigrationRegistry
from ..util.util import get_env
Expand Down Expand Up @@ -185,6 +186,7 @@ def upgrade(self):
"supported": True,
}
self.save_state()
return SyftSuccess(message="Protocol successfully updated !!")

def validate_current_state(self) -> bool:
current_object_version_map = self.state[self.latest_version]["object_versions"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,33 @@ def get_version(
)

return migration_state.current_version

@service_method(path="migration", name="get_state")
def get_state(
self, context: AuthedServiceContext, canonical_name: str
) -> Union[bool, SyftError]:
result = self.stash.get_by_name(
canonical_name=canonical_name, credentials=context.credentials
)

if result.is_err():
return SyftError(message=f"{result.err()}")

return result.ok()

@service_method(path="migration", name="register_migration_state")
def register_migration_state(
self,
context: AuthedServiceContext,
current_version: int,
canonical_name: str,
) -> Union[SyftObjectMigrationState, SyftError]:
obj = SyftObjectMigrationState(
current_version=current_version, canonical_name=canonical_name
)
result = self.stash.set(migration_state=obj, credentials=context.credentials)

if result.is_err():
return SyftError(message=f"{result.err()}")

return result.ok()
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class SyftObjectMigrationState(SyftObject):
__canonical_name__ = "SyftObjectMigrationState"
__version__ = SYFT_OBJECT_VERSION_1

__attr_unique__ = ["canonical_name"]

canonical_name: str
current_version: int

Expand Down Expand Up @@ -58,10 +60,10 @@ def __init__(self, store: DocumentStore) -> None:
def set(
self,
credentials: SyftVerifyKey,
syft_object_metadata: SyftObjectMigrationState,
migration_state: SyftObjectMigrationState,
add_permissions: Optional[List[ActionObjectPermission]] = None,
) -> Result[SyftObjectMigrationState, str]:
res = self.check_type(syft_object_metadata, self.object_type)
res = self.check_type(migration_state, self.object_type)
# we dont use and_then logic here as it is hard because of the order of the arguments
if res.is_err():
return res
Expand Down

0 comments on commit 4c9ba8b

Please sign in to comment.