Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added admin methods for get and set #9345

Merged
merged 8 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions packages/syft/src/syft/service/migration/migration_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# stdlib
from collections import defaultdict
import logging
from typing import Any

# syft absolute
import syft
Expand All @@ -16,6 +17,7 @@
from ...types.syft_object import SyftObject
from ...types.syft_object_registry import SyftObjectRegistry
from ...types.twin_object import TwinObject
from ...types.uid import UID
from ..action.action_object import Action
from ..action.action_object import ActionObject
from ..action.action_permissions import ActionObjectPermission
Expand All @@ -26,7 +28,10 @@
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..sync.sync_service import get_store
from ..sync.sync_service import get_store_by_type
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from ..worker.utils import DEFAULT_WORKER_POOL_NAME
from .object_migration_state import MigrationData
from .object_migration_state import StoreMetadata
Expand Down Expand Up @@ -493,3 +498,29 @@ def reset_and_restore(
)

return SyftSuccess(message="Database reset successfully.")

@service_method(
path="migration._get_object",
name="_get_object",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def _get_object(
self, context: AuthedServiceContext, uid: UID, object_type: type
) -> Any:
return (
get_store_by_type(context, object_type)
.get_by_uid(credentials=context.credentials, uid=uid)
.unwrap()
)

@service_method(
path="migration._update_object",
name="_update_object",
roles=ADMIN_ROLE_LEVEL,
)
def _update_object(self, context: AuthedServiceContext, object: Any) -> Any:
return (
get_store(context, object)
.update(credentials=context.credentials, obj=object)
.unwrap()
)
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def get_status(self, context: AuthedServiceContext | None = None) -> RequestStat
# which tries to send an email to the admin and ends up here
pass # lets keep going

self.refresh()
if len(self.history) == 0:
return RequestStatus.PENDING

Expand Down
8 changes: 6 additions & 2 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@


def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash:
if isinstance(item, ActionObject):
return get_store_by_type(context=context, obj_type=type(item))


def get_store_by_type(context: AuthedServiceContext, obj_type: type) -> ObjectStash:
if issubclass(obj_type, ActionObject):
service = context.server.services.action # type: ignore
return service.stash # type: ignore
service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore
service = context.server.get_service(TYPE_TO_SERVICE[obj_type]) # type: ignore
return service.stash


Expand Down
11 changes: 11 additions & 0 deletions packages/syft/src/syft/types/syft_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,17 @@ def make_id(cls, values: Any) -> Any:
__table_coll_widths__: ClassVar[list[str] | None] = None
__table_sort_attr__: ClassVar[str | None] = None

def refresh(self) -> None:
try:
api = self._get_api()
new_object = api.services.migration._get_object(
uid=self.id, object_type=type(self)
)
if type(new_object) == type(self):
self.__dict__.update(new_object.__dict__)
except Exception as _:
return

def __syft_get_funcs__(self) -> list[tuple[str, Signature]]:
funcs = print_type_cache[type(self)]
if len(funcs) > 0:
Expand Down
57 changes: 57 additions & 0 deletions packages/syft/tests/syft/service/sync/get_set_object_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# third party

# syft absolute
import syft as sy
from syft.client.datasite_client import DatasiteClient
from syft.service.action.action_object import ActionObject
from syft.service.dataset.dataset import Dataset


def get_ds_client(client: DatasiteClient) -> DatasiteClient:
client.register(
name="a",
email="[email protected]",
password="asdf",
password_verify="asdf",
)
return client.login(email="[email protected]", password="asdf")


def test_get_set_object(high_worker):
high_client: DatasiteClient = high_worker.root_client
_ = get_ds_client(high_client)
root_datasite_client = high_worker.root_client
dataset = sy.Dataset(
name="local_test",
asset_list=[
sy.Asset(
name="local_test",
data=[1, 2, 3],
mock=[1, 1, 1],
)
],
)
root_datasite_client.upload_dataset(dataset)
dataset = root_datasite_client.datasets[0]

other_dataset = high_client.api.services.migration._get_object(
uid=dataset.id, object_type=Dataset
)
other_dataset.server_uid = dataset.server_uid
assert dataset == other_dataset
other_dataset.name = "new_name"
updated_dataset = high_client.api.services.migration._update_object(
object=other_dataset
)
assert updated_dataset.name == "new_name"

asset = root_datasite_client.datasets[0].assets[0]
source_ao = high_client.api.services.action.get(uid=asset.action_id)
ao = high_client.api.services.migration._get_object(
uid=asset.action_id, object_type=ActionObject
)
ao._set_obj_location_(
high_worker.id,
root_datasite_client.credentials,
)
assert source_ao == ao
Loading