Skip to content

Commit

Permalink
moved methods to migration service
Browse files Browse the repository at this point in the history
  • Loading branch information
teo-milea committed Oct 14, 2024
1 parent 68117fd commit 1fef3ed
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 30 deletions.
31 changes: 31 additions & 0 deletions packages/syft/src/syft/service/migration/migration_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
from collections import defaultdict
import logging
from typing import Any

# syft absolute
import syft
Expand All @@ -16,6 +17,7 @@
from ...types.syft_object import SyftObject
from ...types.syft_object_registry import SyftObjectRegistry
from ...types.twin_object import TwinObject
from ...types.uid import UID
from ..action.action_object import Action
from ..action.action_object import ActionObject
from ..action.action_permissions import ActionObjectPermission
Expand All @@ -26,7 +28,10 @@
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..sync.sync_service import get_store
from ..sync.sync_service import get_store_by_type
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from ..worker.utils import DEFAULT_WORKER_POOL_NAME
from .object_migration_state import MigrationData
from .object_migration_state import StoreMetadata
Expand Down Expand Up @@ -493,3 +498,29 @@ def reset_and_restore(
)

return SyftSuccess(message="Database reset successfully.")

@service_method(
path="sync._get_object",
name="_get_object",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def _get_object(
self, context: AuthedServiceContext, uid: UID, object_type: type
) -> Any:
return (
get_store_by_type(context, object_type)
.get_by_uid(credentials=context.credentials, uid=uid)
.unwrap()
)

@service_method(
path="sync._update_object",
name="_update_object",
roles=ADMIN_ROLE_LEVEL,
)
def _update_object(self, context: AuthedServiceContext, object: Any) -> Any:
return (
get_store(context, object)
.update(credentials=context.credentials, obj=object)
.unwrap()
)
27 changes: 0 additions & 27 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from ..service import TYPE_TO_SERVICE
from ..service import service_method
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from .sync_stash import SyncStash
from .sync_state import SyncState

Expand Down Expand Up @@ -455,29 +454,3 @@ def build_current_state(
)
def _get_state(self, context: AuthedServiceContext) -> SyncState:
return self.build_current_state(context).unwrap()

@service_method(
path="sync._get_object",
name="_get_object",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def _get_object(
self, context: AuthedServiceContext, uid: UID, object_type: type
) -> Any:
return (
get_store_by_type(context, object_type)
.get_by_uid(credentials=context.credentials, uid=uid)
.unwrap()
)

@service_method(
path="sync._update_object",
name="_update_object",
roles=ADMIN_ROLE_LEVEL,
)
def _update_object(self, context: AuthedServiceContext, object: Any) -> Any:
return (
get_store(context, object)
.update(credentials=context.credentials, obj=object)
.unwrap()
)
8 changes: 5 additions & 3 deletions packages/syft/tests/syft/service/sync/get_set_object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,20 @@ def test_get_set_object(high_worker):
root_datasite_client.upload_dataset(dataset)
dataset = root_datasite_client.datasets[0]

other_dataset = high_client.api.services.sync._get_object(
other_dataset = high_client.api.services.migration._get_object(
uid=dataset.id, object_type=Dataset
)
other_dataset.server_uid = dataset.server_uid
assert dataset == other_dataset
other_dataset.name = "new_name"
updated_dataset = high_client.api.services.sync._update_object(object=other_dataset)
updated_dataset = high_client.api.services.migration._update_object(
object=other_dataset
)
assert updated_dataset.name == "new_name"

asset = root_datasite_client.datasets[0].assets[0]
source_ao = high_client.api.services.action.get(uid=asset.action_id)
ao = high_client.api.services.sync._get_object(
ao = high_client.api.services.migration._get_object(
uid=asset.action_id, object_type=ActionObject
)
ao._set_obj_location_(
Expand Down

0 comments on commit 1fef3ed

Please sign in to comment.