From 011da3c1676fe7875c998170a8bb38735a6e8a1b Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 18 Jun 2024 18:06:41 +0300 Subject: [PATCH 1/9] implement UserCode.__call__ --- .../syft/src/syft/service/code/user_code.py | 18 +++++++++++++++++- .../syft/service/code_history/code_history.py | 6 +++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 9ee7e09e7e9..57943fd8b81 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -33,7 +33,7 @@ # relative from ...abstract_node import NodeSideType from ...abstract_node import NodeType -from ...client.api import APIRegistry +from ...client.api import APIRegistry, generate_remote_function from ...client.api import NodeIdentity from ...client.enclave_client import EnclaveMetadata from ...node.credentials import SyftVerifyKey @@ -879,6 +879,22 @@ def show_code_cell(self) -> None: ip = get_ipython() ip.set_next_input(warning_message + self.raw_code) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + api = self._get_api() + if isinstance(api, SyftError): + return api + remote_user_function = generate_remote_function( + api=api, + node_uid=self.node_uid, + signature=self.signature, + path="code.call", + make_call=api.make_call, + pre_kwargs={"uid": self.id}, + warning=None, + communication_protocol=api.communication_protocol, + ) + return remote_user_function() + class UserCodeUpdate(PartialSyftObject): __canonical_name__ = "UserCodeUpdate" diff --git a/packages/syft/src/syft/service/code_history/code_history.py b/packages/syft/src/syft/service/code_history/code_history.py index b5e893c87bf..55f041572f6 100644 --- a/packages/syft/src/syft/service/code_history/code_history.py +++ b/packages/syft/src/syft/service/code_history/code_history.py @@ -75,7 +75,11 @@ def __getitem__(self, index: int | str) -> UserCode | SyftError: return SyftError( message=f"Can't access the api. You must login to {self.node_uid}" ) - if api.user_role.value >= ServiceRole.DATA_OWNER.value and index < 0: + if ( + api.user.get_current_user().role.value >= ServiceRole.DATA_OWNER.value + and index < 0 + ): + # negative index would dynamically resolve to a different version return SyftError( message="For security concerns we do not allow negative indexing. \ Try using absolute values when indexing" From dbefae518c087e155b78b1f20604b4554f4436b7 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 19 Jun 2024 12:30:09 +0300 Subject: [PATCH 2/9] expand user code test --- .../syft/src/syft/service/code/user_code.py | 3 +- .../syft/tests/syft/users/user_code_test.py | 38 +++++++++++++++---- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 57943fd8b81..e1ba6beabfe 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -33,8 +33,9 @@ # relative from ...abstract_node import NodeSideType from ...abstract_node import NodeType -from ...client.api import APIRegistry, generate_remote_function +from ...client.api import APIRegistry from ...client.api import NodeIdentity +from ...client.api import generate_remote_function from ...client.enclave_client import EnclaveMetadata from ...node.credentials import SyftVerifyKey from ...serde.deserialize import _deserialize diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index f006525097e..882f37e4496 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -74,23 +74,47 @@ def test_user_code(worker) -> None: assert multi_call_res.get() == result.get() -def test_duplicated_user_code(worker, guest_client: User) -> None: +def test_duplicated_user_code(worker) -> None: + worker.root_client.register( + name="Jane Doe", + email="jane@caltech.edu", + password="abc123", + password_verify="abc123", + institution="Caltech", + website="https://www.caltech.edu/", + ) + ds_client = worker.root_client.login( + email="jane@caltech.edu", + password="abc123", + ) + # mock_syft_func() - result = guest_client.api.services.code.request_code_execution(mock_syft_func) + result = ds_client.api.services.code.request_code_execution(mock_syft_func) assert isinstance(result, Request) - assert len(guest_client.code.get_all()) == 1 + assert len(ds_client.code.get_all()) == 1 # request the exact same code should return an error - result = guest_client.api.services.code.request_code_execution(mock_syft_func) + result = ds_client.api.services.code.request_code_execution(mock_syft_func) assert isinstance(result, SyftError) - assert len(guest_client.code.get_all()) == 1 + assert len(ds_client.code.get_all()) == 1 # request the a different function name but same content will also succeed # flaky if not blocking mock_syft_func_2(syft_no_node=True) - result = guest_client.api.services.code.request_code_execution(mock_syft_func_2) + result = ds_client.api.services.code.request_code_execution(mock_syft_func_2) assert isinstance(result, Request) - assert len(guest_client.code.get_all()) == 2 + assert len(ds_client.code.get_all()) == 2 + + code_history = ds_client.code_history + assert code_history.code_versions, "No code version found." + + code_histories = worker.root_client.code_histories + user_code_history = code_histories[ds_client.logged_in_user] + assert not isinstance(user_code_history, SyftError) + assert user_code_history.code_versions, "No code version found." + + result = user_code_history.mock_syft_func_2[0]() + assert result.get() == 1 def random_hash() -> str: From 6d735b40f4593bca06ff8a35ce60aa6906745e6b Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 19 Jun 2024 15:04:24 +0300 Subject: [PATCH 3/9] remove self and context from function signature --- packages/syft/src/syft/service/code/user_code.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index e1ba6beabfe..06319d4828c 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -41,6 +41,8 @@ from ...serde.deserialize import _deserialize from ...serde.serializable import serializable from ...serde.serialize import _serialize +from ...serde.signature import signature_remove_context +from ...serde.signature import signature_remove_self from ...store.document_store import PartitionKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime @@ -884,6 +886,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: api = self._get_api() if isinstance(api, SyftError): return api + + signature = self.signature + signature = signature_remove_self(signature) + signature = signature_remove_context(signature) remote_user_function = generate_remote_function( api=api, node_uid=self.node_uid, From ce9d9e7a2c25b095966b20956644cabefa0be2d7 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 20 Jun 2024 11:04:28 +0300 Subject: [PATCH 4/9] add _post_user_code_transform_ops --- .../syft/service/code/user_code_service.py | 57 +++++++++---------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 9287e49fec4..cd07e39d26b 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -75,6 +75,19 @@ def _submit( ) -> Result[UserCode, str]: if not isinstance(code, UserCode): code = code.to(UserCode, context=context) # type: ignore[unreachable] + result = self._post_user_code_transform_ops(context, code) + if isinstance(result, SyftError): + # if the validation fails, we should remove the user code status + # and code version to prevent dangling status + root_context = AuthedServiceContext( + credentials=context.node.verify_key, node=context.node + ) + + if code.status_link is not None: + _ = context.node.get_service("usercodestatusservice").remove( + root_context, code.status_link.object_uid + ) + return result result = self.stash.set(context.credentials, code) return result @@ -133,30 +146,7 @@ def get_by_service_name( return SyftError(message=str(result.err())) return result.ok() - def _request_code_execution( - self, - context: AuthedServiceContext, - code: SubmitUserCode, - reason: str | None = "", - ) -> Request | SyftError: - user_code: UserCode = code.to(UserCode, context=context) - result = self._validate_request_code_execution(context, user_code) - if isinstance(result, SyftError): - # if the validation fails, we should remove the user code status - # and code version to prevent dangling status - root_context = AuthedServiceContext( - credentials=context.node.verify_key, node=context.node - ) - - if user_code.status_link is not None: - _ = context.node.get_service("usercodestatusservice").remove( - root_context, user_code.status_link.object_uid - ) - return result - result = self._request_code_execution_inner(context, user_code, reason) - return result - - def _validate_request_code_execution( + def _post_user_code_transform_ops( self, context: AuthedServiceContext, user_code: UserCode, @@ -197,10 +187,6 @@ def _validate_request_code_execution( if isinstance(pool_result, SyftError): return pool_result - result = self.stash.set(context.credentials, user_code) - if result.is_err(): - return SyftError(message=str(result.err())) - # Create a code history code_history_service = context.node.get_service("codehistoryservice") result = code_history_service.submit_version(context=context, code=user_code) @@ -209,7 +195,7 @@ def _validate_request_code_execution( return SyftSuccess(message="") - def _request_code_execution_inner( + def _request_code_execution( self, context: AuthedServiceContext, user_code: UserCode, @@ -260,7 +246,18 @@ def request_code_execution( reason: str | None = "", ) -> SyftSuccess | SyftError: """Request Code execution on user code""" - return self._request_code_execution(context=context, code=code, reason=reason) + + # TODO: check for duplicate submissions + user_code_or_err = self._submit(context, code) + if user_code_or_err.is_err(): + return SyftError(user_code_or_err.err()) + + result = self._request_code_execution( + context, + user_code_or_err.ok(), + reason, + ) + return result @service_method(path="code.get_all", name="get_all", roles=GUEST_ROLE_LEVEL) def get_all(self, context: AuthedServiceContext) -> list[UserCode] | SyftError: From 40156267dcd6d359a5efb24e4abbe47dba7f4444 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 20 Jun 2024 11:29:19 +0300 Subject: [PATCH 5/9] use code directly in submit_version --- .../src/syft/service/code_history/code_history_service.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index adfd6dbee5d..54da3d491b3 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -49,11 +49,6 @@ def submit_version( if result.is_err(): return SyftError(message=str(result.err())) code = result.ok() - elif isinstance(code, UserCode): # type: ignore[unreachable] - result = user_code_service.get_by_uid(context=context, uid=code.id) - if isinstance(result, SyftError): - return result - code = result result = self.stash.get_by_service_func_name_and_verify_key( credentials=context.credentials, From d646af888b1ab60909ce0783d3c86161b9887e1b Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 20 Jun 2024 11:45:52 +0300 Subject: [PATCH 6/9] fix error message --- packages/syft/src/syft/service/code/user_code_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index cd07e39d26b..8afdcdc9bdc 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -250,7 +250,7 @@ def request_code_execution( # TODO: check for duplicate submissions user_code_or_err = self._submit(context, code) if user_code_or_err.is_err(): - return SyftError(user_code_or_err.err()) + return SyftError(message=user_code_or_err.err()) result = self._request_code_execution( context, From 29508c9e70441937441a84c5e5b686baa4832df2 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 20 Jun 2024 14:11:12 +0200 Subject: [PATCH 7/9] pass args to remote function --- packages/syft/src/syft/service/code/user_code.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index ddbd47487af..6b7dbfa35dc 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -912,7 +912,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: warning=None, communication_protocol=api.communication_protocol, ) - return remote_user_function() + return remote_user_function(*args, **kwargs) class UserCodeUpdate(PartialSyftObject): From 6aee14f432199f9f8309f55894e4c27ae34626d9 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 21 Jun 2024 11:17:50 +0200 Subject: [PATCH 8/9] fix code_histories for a particular user --- .../code_history/code_history_service.py | 25 ++++++++++++++----- .../syft/tests/syft/users/user_code_test.py | 3 +++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index 54da3d491b3..2c01b26fd1a 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -8,6 +8,7 @@ from ...util.telemetry import instrument from ..code.user_code import SubmitUserCode from ..code.user_code import UserCode +from ..code.user_code_service import UserCodeService from ..context import AuthedServiceContext from ..response import SyftError from ..response import SyftSuccess @@ -15,6 +16,7 @@ from ..service import service_method from ..user.user_roles import DATA_OWNER_ROLE_LEVEL from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL +from ..user.user_roles import ServiceRole from .code_history import CodeHistoriesDict from .code_history import CodeHistory from .code_history import CodeHistoryView @@ -115,14 +117,22 @@ def delete( def fetch_histories_for_user( self, context: AuthedServiceContext, user_verify_key: SyftVerifyKey ) -> CodeHistoriesDict | SyftError: - result = self.stash.get_by_verify_key( - credentials=context.credentials, user_verify_key=user_verify_key - ) + if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: + result = self.stash.get_by_verify_key( + credentials=context.node.verify_key, user_verify_key=user_verify_key + ) + else: + result = self.stash.get_by_verify_key( + credentials=context.credentials, user_verify_key=user_verify_key + ) - user_code_service = context.node.get_service("usercodeservice") + user_code_service: UserCodeService = context.node.get_service("usercodeservice") def get_code(uid: UID) -> UserCode | SyftError: - return user_code_service.get_by_uid(context=context, uid=uid) + return user_code_service.stash.get_by_uid( + credentials=context.node.verify_key, + uid=uid, + ).ok() if result.is_ok(): code_histories = result.ok() @@ -181,7 +191,10 @@ def get_history_for_user( def get_histories_group_by_user( self, context: AuthedServiceContext ) -> UsersCodeHistoriesDict | SyftError: - result = self.stash.get_all(credentials=context.credentials) + if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: + result = self.stash.get_all(context.credentials, has_permission=True) + else: + result = self.stash.get_all(context.credentials) if result.is_err(): return SyftError(message=result.err()) code_histories: list[CodeHistory] = result.ok() diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index fd9e6b79cb5..a1182b1630a 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -113,8 +113,11 @@ def test_duplicated_user_code(worker) -> None: code_histories = worker.root_client.code_histories user_code_history = code_histories[ds_client.logged_in_user] + assert not isinstance(code_histories, SyftError) assert not isinstance(user_code_history, SyftError) assert user_code_history.code_versions, "No code version found." + assert user_code_history.mock_syft_func.user_code_history[0].status is not None + assert user_code_history.mock_syft_func[0]._repr_markdown_(), "repr markdown failed" result = user_code_history.mock_syft_func_2[0]() assert result.get() == 1 From 6e556e1261765969df4a887cece14fc5661cebe7 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 26 Jun 2024 09:26:16 +0200 Subject: [PATCH 9/9] fix type --- .../syft/src/syft/service/code_history/code_history_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index 2c01b26fd1a..41839045747 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -126,7 +126,7 @@ def fetch_histories_for_user( credentials=context.credentials, user_verify_key=user_verify_key ) - user_code_service: UserCodeService = context.node.get_service("usercodeservice") + user_code_service: UserCodeService = context.node.get_service("usercodeservice") # type: ignore def get_code(uid: UID) -> UserCode | SyftError: return user_code_service.stash.get_by_uid(