diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index ee750bfcf9f..29c60242ceb 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -40,11 +40,14 @@ from ...abstract_node import NodeType from ...client.api import APIRegistry from ...client.api import NodeIdentity +from ...client.api import generate_remote_function from ...client.enclave_client import EnclaveMetadata from ...node.credentials import SyftVerifyKey from ...serde.deserialize import _deserialize from ...serde.serializable import serializable from ...serde.serialize import _serialize +from ...serde.signature import signature_remove_context +from ...serde.signature import signature_remove_self from ...store.document_store import PartitionKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime @@ -928,6 +931,26 @@ def show_code_cell(self) -> None: ip = get_ipython() ip.set_next_input(warning_message + self.raw_code) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + api = self._get_api() + if isinstance(api, SyftError): + return api + + signature = self.signature + signature = signature_remove_self(signature) + signature = signature_remove_context(signature) + remote_user_function = generate_remote_function( + api=api, + node_uid=self.node_uid, + signature=self.signature, + path="code.call", + make_call=api.make_call, + pre_kwargs={"uid": self.id}, + warning=None, + communication_protocol=api.communication_protocol, + ) + return remote_user_function(*args, **kwargs) + class UserCodeUpdate(PartialSyftObject): __canonical_name__ = "UserCodeUpdate" diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 374fcf22475..511dad72731 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -72,6 +72,19 @@ def _submit( ) -> Result[UserCode, str]: if not isinstance(code, UserCode): code = code.to(UserCode, context=context) # type: ignore[unreachable] + result = self._post_user_code_transform_ops(context, code) + if isinstance(result, SyftError): + # if the validation fails, we should remove the user code status + # and code version to prevent dangling status + root_context = AuthedServiceContext( + credentials=context.node.verify_key, node=context.node + ) + + if code.status_link is not None: + _ = context.node.get_service("usercodestatusservice").remove( + root_context, code.status_link.object_uid + ) + return result result = self.stash.set(context.credentials, code) return result @@ -130,30 +143,7 @@ def get_by_service_name( return SyftError(message=str(result.err())) return result.ok() - def _request_code_execution( - self, - context: AuthedServiceContext, - code: SubmitUserCode, - reason: str | None = "", - ) -> Request | SyftError: - user_code: UserCode = code.to(UserCode, context=context) - result = self._validate_request_code_execution(context, user_code) - if isinstance(result, SyftError): - # if the validation fails, we should remove the user code status - # and code version to prevent dangling status - root_context = AuthedServiceContext( - credentials=context.node.verify_key, node=context.node - ) - - if user_code.status_link is not None: - _ = context.node.get_service("usercodestatusservice").remove( - root_context, user_code.status_link.object_uid - ) - return result - result = self._request_code_execution_inner(context, user_code, reason) - return result - - def _validate_request_code_execution( + def _post_user_code_transform_ops( self, context: AuthedServiceContext, user_code: UserCode, @@ -194,10 +184,6 @@ def _validate_request_code_execution( if isinstance(pool_result, SyftError): return pool_result - result = self.stash.set(context.credentials, user_code) - if result.is_err(): - return SyftError(message=str(result.err())) - # Create a code history code_history_service = context.node.get_service("codehistoryservice") result = code_history_service.submit_version(context=context, code=user_code) @@ -206,7 +192,7 @@ def _validate_request_code_execution( return SyftSuccess(message="") - def _request_code_execution_inner( + def _request_code_execution( self, context: AuthedServiceContext, user_code: UserCode, @@ -257,7 +243,18 @@ def request_code_execution( reason: str | None = "", ) -> Request | SyftError: """Request Code execution on user code""" - return self._request_code_execution(context=context, code=code, reason=reason) + + # TODO: check for duplicate submissions + user_code_or_err = self._submit(context, code) + if user_code_or_err.is_err(): + return SyftError(message=user_code_or_err.err()) + + result = self._request_code_execution( + context, + user_code_or_err.ok(), + reason, + ) + return result @service_method(path="code.get_all", name="get_all", roles=GUEST_ROLE_LEVEL) def get_all(self, context: AuthedServiceContext) -> list[UserCode] | SyftError: diff --git a/packages/syft/src/syft/service/code_history/code_history.py b/packages/syft/src/syft/service/code_history/code_history.py index 488083cf0c6..e013ef22c34 100644 --- a/packages/syft/src/syft/service/code_history/code_history.py +++ b/packages/syft/src/syft/service/code_history/code_history.py @@ -95,7 +95,11 @@ def __getitem__(self, index: int | str) -> UserCode | SyftError: return SyftError( message=f"Can't access the api. You must login to {self.node_uid}" ) - if api.user_role.value >= ServiceRole.DATA_OWNER.value and index < 0: + if ( + api.user.get_current_user().role.value >= ServiceRole.DATA_OWNER.value + and index < 0 + ): + # negative index would dynamically resolve to a different version return SyftError( message="For security concerns we do not allow negative indexing. \ Try using absolute values when indexing" diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index adfd6dbee5d..41839045747 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -8,6 +8,7 @@ from ...util.telemetry import instrument from ..code.user_code import SubmitUserCode from ..code.user_code import UserCode +from ..code.user_code_service import UserCodeService from ..context import AuthedServiceContext from ..response import SyftError from ..response import SyftSuccess @@ -15,6 +16,7 @@ from ..service import service_method from ..user.user_roles import DATA_OWNER_ROLE_LEVEL from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL +from ..user.user_roles import ServiceRole from .code_history import CodeHistoriesDict from .code_history import CodeHistory from .code_history import CodeHistoryView @@ -49,11 +51,6 @@ def submit_version( if result.is_err(): return SyftError(message=str(result.err())) code = result.ok() - elif isinstance(code, UserCode): # type: ignore[unreachable] - result = user_code_service.get_by_uid(context=context, uid=code.id) - if isinstance(result, SyftError): - return result - code = result result = self.stash.get_by_service_func_name_and_verify_key( credentials=context.credentials, @@ -120,14 +117,22 @@ def delete( def fetch_histories_for_user( self, context: AuthedServiceContext, user_verify_key: SyftVerifyKey ) -> CodeHistoriesDict | SyftError: - result = self.stash.get_by_verify_key( - credentials=context.credentials, user_verify_key=user_verify_key - ) + if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: + result = self.stash.get_by_verify_key( + credentials=context.node.verify_key, user_verify_key=user_verify_key + ) + else: + result = self.stash.get_by_verify_key( + credentials=context.credentials, user_verify_key=user_verify_key + ) - user_code_service = context.node.get_service("usercodeservice") + user_code_service: UserCodeService = context.node.get_service("usercodeservice") # type: ignore def get_code(uid: UID) -> UserCode | SyftError: - return user_code_service.get_by_uid(context=context, uid=uid) + return user_code_service.stash.get_by_uid( + credentials=context.node.verify_key, + uid=uid, + ).ok() if result.is_ok(): code_histories = result.ok() @@ -186,7 +191,10 @@ def get_history_for_user( def get_histories_group_by_user( self, context: AuthedServiceContext ) -> UsersCodeHistoriesDict | SyftError: - result = self.stash.get_all(credentials=context.credentials) + if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: + result = self.stash.get_all(context.credentials, has_permission=True) + else: + result = self.stash.get_all(context.credentials) if result.is_err(): return SyftError(message=result.err()) code_histories: list[CodeHistory] = result.ok() diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 333d246e37f..a1182b1630a 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -77,23 +77,50 @@ def test_user_code(worker) -> None: assert multi_call_res.get() == result.get() -def test_duplicated_user_code(worker, guest_client: User) -> None: +def test_duplicated_user_code(worker) -> None: + worker.root_client.register( + name="Jane Doe", + email="jane@caltech.edu", + password="abc123", + password_verify="abc123", + institution="Caltech", + website="https://www.caltech.edu/", + ) + ds_client = worker.root_client.login( + email="jane@caltech.edu", + password="abc123", + ) + # mock_syft_func() - result = guest_client.api.services.code.request_code_execution(mock_syft_func) + result = ds_client.api.services.code.request_code_execution(mock_syft_func) assert isinstance(result, Request) - assert len(guest_client.code.get_all()) == 1 + assert len(ds_client.code.get_all()) == 1 # request the exact same code should return an error - result = guest_client.api.services.code.request_code_execution(mock_syft_func) + result = ds_client.api.services.code.request_code_execution(mock_syft_func) assert isinstance(result, SyftError) - assert len(guest_client.code.get_all()) == 1 + assert len(ds_client.code.get_all()) == 1 # request the a different function name but same content will also succeed # flaky if not blocking mock_syft_func_2(syft_no_node=True) - result = guest_client.api.services.code.request_code_execution(mock_syft_func_2) + result = ds_client.api.services.code.request_code_execution(mock_syft_func_2) assert isinstance(result, Request) - assert len(guest_client.code.get_all()) == 2 + assert len(ds_client.code.get_all()) == 2 + + code_history = ds_client.code_history + assert code_history.code_versions, "No code version found." + + code_histories = worker.root_client.code_histories + user_code_history = code_histories[ds_client.logged_in_user] + assert not isinstance(code_histories, SyftError) + assert not isinstance(user_code_history, SyftError) + assert user_code_history.code_versions, "No code version found." + assert user_code_history.mock_syft_func.user_code_history[0].status is not None + assert user_code_history.mock_syft_func[0]._repr_markdown_(), "repr markdown failed" + + result = user_code_history.mock_syft_func_2[0]() + assert result.get() == 1 def random_hash() -> str: