diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 4aa0d54bcb4..92538709083 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -237,6 +237,21 @@ class RemoteFunction(SyftObject): def __ipython_inspector_signature_override__(self) -> Optional[Signature]: return self.signature + def prepare_args_and_kwargs( + self, args: List[Any], kwargs: Dict[str, Any] + ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]: + # Validate and migrate args and kwargs + res = validate_callable_args_and_kwargs(args, kwargs, self.signature) + if isinstance(res, SyftError): + return res + args, kwargs = res + + args, kwargs = migrate_args_and_kwargs( + to_protocol=self.communication_protocol, args=args, kwargs=kwargs + ) + + return args, kwargs + def __call__(self, *args, **kwargs): if "blocking" in self.signature.parameters: raise Exception( @@ -248,17 +263,11 @@ def __call__(self, *args, **kwargs): blocking = bool(kwargs["blocking"]) del kwargs["blocking"] - # Migrate args and kwargs to communication protocol - args, kwargs = migrate_args_and_kwargs( - to_protocol=self.communication_protocol, args=args, kwargs=kwargs - ) - - res = validate_callable_args_and_kwargs(args, kwargs, self.signature) - + res = self.prepare_args_and_kwargs(args, kwargs) if isinstance(res, SyftError): return res - _valid_args, _valid_kwargs = res + _valid_args, _valid_kwargs = res if self.pre_kwargs: _valid_kwargs.update(self.pre_kwargs) @@ -289,6 +298,33 @@ class RemoteUserCodeFunction(RemoteFunction): __version__ = SYFT_OBJECT_VERSION_1 __repr_attrs__ = RemoteFunction.__repr_attrs__ + ["user_code_id"] + api: SyftAPI + + def prepare_args_and_kwargs( + self, args: List[Any], kwargs: Dict[str, Any] + ) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]: + # relative + from ..service.action.action_object import convert_to_pointers + + # Validate and migrate args and kwargs + res = validate_callable_args_and_kwargs(args, kwargs, self.signature) + if isinstance(res, SyftError): + return res + args, kwargs = res + + args, kwargs = convert_to_pointers( + api=self.api, + node_uid=self.node_uid, + args=args, + kwargs=kwargs, + ) + + args, kwargs = migrate_args_and_kwargs( + to_protocol=self.communication_protocol, args=args, kwargs=kwargs + ) + + return args, kwargs + @property def user_code_id(self) -> Optional[UID]: return self.pre_kwargs.get("uid", None) @@ -308,6 +344,7 @@ def jobs(self) -> Union[List[Job], SyftError]: def generate_remote_function( + api: SyftAPI, node_uid: UID, signature: Signature, path: str, @@ -324,6 +361,7 @@ def generate_remote_function( # UserCodes are always code.call with a user_code_id if path == "code.call" and pre_kwargs is not None and "uid" in pre_kwargs: remote_function = RemoteUserCodeFunction( + api=api, node_uid=node_uid, signature=signature, path=path, @@ -742,6 +780,7 @@ def build_endpoint_tree(endpoints, communication_protocol): signature = signature_remove_context(signature) if isinstance(v, APIEndpoint): endpoint_function = generate_remote_function( + self, self.node_uid, signature, v.service_path, @@ -1007,3 +1046,7 @@ def validate_callable_args_and_kwargs(args, kwargs, signature: Signature): _valid_args.append(arg) return _valid_args, _valid_kwargs + + +RemoteFunction.update_forward_refs() +RemoteUserCodeFunction.update_forward_refs() diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index f38adc30376..7f81b21af59 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -1247,21 +1247,33 @@ def add_queueitem_to_queue( return job def _get_existing_user_code_jobs( - self, user_code_id: UID, credentials: SyftVerifyKey + self, context: AuthedServiceContext, user_code_id: UID ) -> Union[List[Job], SyftError]: - role = self.get_role_for_credentials(credentials=credentials) - context = AuthedServiceContext(node=self, credentials=credentials, role=role) - job_service = self.get_service("jobservice") return job_service.get_by_user_code_id( context=context, user_code_id=user_code_id ) + def _is_usercode_call_on_owned_kwargs( + self, context: AuthedServiceContext, api_call: SyftAPICall + ) -> bool: + if api_call.path != "code.call": + return False + user_code_service = self.get_service("usercodeservice") + return user_code_service.is_execution_on_owned_args(api_call.kwargs, context) + def add_api_call_to_queue(self, api_call, parent_job_id=None): unsigned_call = api_call if isinstance(api_call, SignedSyftAPICall): unsigned_call = api_call.message + credentials = api_call.credentials + context = AuthedServiceContext( + node=self, + credentials=credentials, + role=self.get_role_for_credentials(credentials=credentials), + ) + is_user_code = unsigned_call.path == "code.call" service_str, method_str = unsigned_call.path.split(".") @@ -1270,9 +1282,16 @@ def add_api_call_to_queue(self, api_call, parent_job_id=None): if is_user_code: action = Action.from_api_call(unsigned_call) - if self.node_side_type == NodeSideType.LOW_SIDE: + is_usercode_call_on_owned_kwargs = self._is_usercode_call_on_owned_kwargs( + context, unsigned_call + ) + # Low side does not execute jobs, unless this is a mock execution + if ( + not is_usercode_call_on_owned_kwargs + and self.node_side_type == NodeSideType.LOW_SIDE + ): existing_jobs = self._get_existing_user_code_jobs( - action.user_code_id, api_call.credentials + context, action.user_code_id ) if isinstance(existing_jobs, SyftError): return existing_jobs diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index e03e704737b..67f204a0e92 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1103,6 +1103,54 @@ "hash": "cf26eeac3d9254dfa439917493b816341f8a379a77d182bbecba3b7ed2c1d00a", "action": "add" } + }, + "User": { + "1": { + "version": 1, + "hash": "078636e64f737e60245b39cf348d30fb006531e80c12b70aa7cf98254e1bb37a", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "ded970c92f202716ed33a2117cf541789f35fad66bd4b1db39da5026b1d7d0e7", + "action": "add" + } + }, + "UserUpdate": { + "1": { + "version": 1, + "hash": "839dd90aeb611e1dc471c8fd6daf230e913465c0625c6a297079cb7f0a271195", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "32cba8fbd786c575f92e26c31384d282e68e3ebfe5c4b0a0e793820b1228d246", + "action": "add" + } + }, + "UserCreate": { + "1": { + "version": 1, + "hash": "dab78b63544ae91c09f9843c323cb237c0a6fcfeb71c1acf5f738e2fcf5c277f", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "2540188c5aaea866914dccff459df6e0f4727108a503414bb1567ff6297d4646", + "action": "add" + } + }, + "UserView": { + "1": { + "version": 1, + "hash": "63289383fe7e7584652f242a4362ce6e2f0ade52f6416ab6149b326a506b0675", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "e410de583bb15bc5af57acef7be55ea5fc56b5b0fc169daa3869f4203c4d7473", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 93f19bf8428..336fb996f33 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -424,11 +424,14 @@ def convert_to_pointers( args: Optional[List] = None, kwargs: Optional[Dict] = None, ) -> Tuple[List, Dict]: + # relative + from ..dataset.dataset import Asset + arg_list = [] kwarg_dict = {} if args is not None: for arg in args: - if not isinstance(arg, ActionObject): + if not isinstance(arg, (ActionObject, Asset, UID)): arg = ActionObject.from_obj( syft_action_data=arg, syft_client_verify_key=api.signing_key.verify_key, @@ -443,7 +446,7 @@ def convert_to_pointers( if kwargs is not None: for k, arg in kwargs.items(): - if not isinstance(arg, ActionObject): + if not isinstance(arg, (ActionObject, Asset, UID)): arg = ActionObject.from_obj( syft_action_data=arg, syft_client_verify_key=api.signing_key.verify_key, diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 70be8bcba1a..0c46b9871e9 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -68,7 +68,7 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any: if isinstance(blob_store_result, SyftError): return blob_store_result - np_pointer = self.set(context, np_obj) + np_pointer = self._set(context, np_obj) return np_pointer @service_method( @@ -80,6 +80,14 @@ def set( self, context: AuthedServiceContext, action_object: Union[ActionObject, TwinObject], + ) -> Result[ActionObject, str]: + return self._set(context, action_object, has_result_read_permission=True) + + def _set( + self, + context: AuthedServiceContext, + action_object: Union[ActionObject, TwinObject], + has_result_read_permission: bool = False, ) -> Result[ActionObject, str]: """Save an object to the action store""" # 🟡 TODO 9: Create some kind of type checking / protocol for SyftSerializable @@ -90,8 +98,10 @@ def set( action_object.private_obj.syft_created_at = DateTime.now() action_object.mock_obj.syft_created_at = DateTime.now() - has_result_read_permission = context.extra_kwargs.get( - "has_result_read_permission", False + # If either context or argument is True, has_result_read_permission is True + has_result_read_permission = ( + context.extra_kwargs.get("has_result_read_permission", False) + or has_result_read_permission ) result = self.store.set( @@ -404,7 +414,7 @@ def set_result_to_store(self, result_action_object, context, output_policy=None) # pass permission information to the action store as extra kwargs context.extra_kwargs = {"has_result_read_permission": True} - set_result = self.set(context, result_action_object) + set_result = self._set(context, result_action_object) if set_result.is_err(): return set_result @@ -664,7 +674,7 @@ def execute( "has_result_read_permission": has_result_read_permission } - set_result = self.set(context, result_action_object) + set_result = self._set(context, result_action_object) if set_result.is_err(): return Err( f"Failed executing action {action}, set result is an error: {set_result.err()}" diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 38fb6dbd11f..020830d2508 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1218,7 +1218,7 @@ def launch_job(func: UserCode, **kwargs): kw2id = {} for k, v in kwargs.items(): value = ActionObject.from_obj(v) - ptr = action_service.set(context, value) + ptr = action_service._set(context, value) ptr = ptr.ok() kw2id[k] = ptr.id try: diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 78bd8b92c87..20c8630c231 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -318,6 +318,38 @@ def is_execution_allowed(self, code, context, output_policy): else: return True + def is_execution_on_owned_args_allowed(self, context: AuthedServiceContext) -> bool: + if context.role == ServiceRole.ADMIN: + return True + user_service = context.node.get_service("userservice") + current_user = user_service.get_current_user(context=context) + return current_user.mock_execution_permission + + def keep_owned_kwargs( + self, kwargs: Dict[str, Any], context: AuthedServiceContext + ) -> Dict[str, Any]: + """Return only the kwargs that are owned by the user""" + action_service = context.node.get_service("actionservice") + + mock_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, UID): + # Jobs have UID kwargs instead of ActionObject + v = action_service.get(context, uid=v) + if v.is_ok(): + v = v.ok() + if ( + isinstance(v, ActionObject) + and v.syft_client_verify_key == context.credentials + ): + mock_kwargs[k] = v + return mock_kwargs + + def is_execution_on_owned_args( + self, kwargs: Dict[str, Any], context: AuthedServiceContext + ) -> bool: + return len(self.keep_owned_kwargs(kwargs, context)) == len(kwargs) + @service_method(path="code.call", name="call", roles=GUEST_ROLE_LEVEL) def call( self, context: AuthedServiceContext, uid: UID, **kwargs: Any @@ -339,18 +371,28 @@ def _call( ) -> Result[ActionObject, Err]: """Call a User Code Function""" try: - # Unroll variables - kwarg2id = map_kwargs_to_id(kwargs) - - # get code item code_result = self.stash.get_by_uid(context.credentials, uid=uid) if code_result.is_err(): return code_result code: UserCode = code_result.ok() + + # Set Permissions + if self.is_execution_on_owned_args(kwargs, context): + if self.is_execution_on_owned_args_allowed(context): + context.has_execute_permissions = True + else: + return Err( + "You do not have the permissions for mock execution, please contact the admin" + ) override_execution_permission = ( context.has_execute_permissions or context.role == ServiceRole.ADMIN ) + # Override permissions bypasses the cache, since we do not check in/out policies + skip_fill_cache = override_execution_permission + # We do not read from output policy cache if there are mock arguments + skip_read_cache = len(self.keep_owned_kwargs(kwargs, context)) > 0 + # Check output policy output_policy = code.output_policy if not override_execution_permission: can_execute = self.is_execution_allowed( @@ -362,7 +404,10 @@ def _call( "UserCodeStatus.DENIED: Function has no output policy" ) if not (is_valid := output_policy.valid): - if len(output_policy.output_history) > 0: + if ( + len(output_policy.output_history) > 0 + and not skip_read_cache + ): result = resolve_outputs( context=context, output_ids=output_policy.last_output_ids, @@ -375,6 +420,7 @@ def _call( # Execute the code item action_service = context.node.get_service("actionservice") + kwarg2id = map_kwargs_to_id(kwargs) result_action_object: Result[ Union[ActionObject, TwinObject], str ] = action_service._user_code_execute( @@ -397,7 +443,7 @@ def _call( # this currently only works for nested syft_functions # and admins executing on high side (TODO, decide if we want to increment counter) - if not override_execution_permission: + if not skip_fill_cache: output_policy.apply_output(context=context, outputs=result) code.output_policy = output_policy if not ( diff --git a/packages/syft/src/syft/service/user/user.py b/packages/syft/src/syft/service/user/user.py index 1189c69eb16..c6211f9d0d3 100644 --- a/packages/syft/src/syft/service/user/user.py +++ b/packages/syft/src/syft/service/user/user.py @@ -22,8 +22,10 @@ from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...types.syft_metaclass import Empty +from ...types.syft_migration import migrate from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import drop @@ -38,11 +40,27 @@ from .user_roles import ServiceRole +class UserV1(SyftObject): + __canonical_name__ = "User" + __version__ = SYFT_OBJECT_VERSION_1 + + email: Optional[EmailStr] + name: Optional[str] + hashed_password: Optional[str] + salt: Optional[str] + signing_key: Optional[SyftSigningKey] + verify_key: Optional[SyftVerifyKey] + role: Optional[ServiceRole] + institution: Optional[str] + website: Optional[str] = None + created_at: Optional[str] = None + + @serializable() class User(SyftObject): # version __canonical_name__ = "User" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 id: Optional[UID] @@ -61,6 +79,8 @@ def make_email(cls, v: EmailStr) -> EmailStr: institution: Optional[str] website: Optional[str] = None created_at: Optional[str] = None + # TODO where do we put this flag? + mock_execution_permission: bool = False # serde / storage rules __attr_searchable__ = ["name", "email", "verify_key", "role"] @@ -106,10 +126,24 @@ def check_pwd(password: str, hashed_password: str) -> bool: ) +class UserUpdateV1(PartialSyftObject): + __canonical_name__ = "UserUpdate" + __version__ = SYFT_OBJECT_VERSION_1 + + email: EmailStr + name: str + role: ServiceRole + password: str + password_verify: str + verify_key: SyftVerifyKey + institution: str + website: str + + @serializable() class UserUpdate(PartialSyftObject): __canonical_name__ = "UserUpdate" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 @pydantic.validator("email", pre=True) def make_email(cls, v: Any) -> Any: @@ -129,12 +163,28 @@ def str_to_role(cls, v: Any) -> Any: verify_key: SyftVerifyKey institution: str website: str + mock_execution_permission: bool + + +class UserCreateV1(UserUpdateV1): + __canonical_name__ = "UserCreate" + __version__ = SYFT_OBJECT_VERSION_1 + + email: EmailStr + name: str + role: Optional[ServiceRole] = None + password: str + password_verify: Optional[str] = None + verify_key: Optional[SyftVerifyKey] + institution: Optional[str] + website: Optional[str] + created_by: Optional[SyftSigningKey] @serializable() class UserCreate(UserUpdate): __canonical_name__ = "UserCreate" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 email: EmailStr name: str @@ -145,6 +195,7 @@ class UserCreate(UserUpdate): institution: Optional[str] website: Optional[str] created_by: Optional[SyftSigningKey] + mock_execution_permission: bool = False __repr_attrs__ = ["name", "email"] @@ -160,16 +211,28 @@ class UserSearch(PartialSyftObject): name: str +class UserViewV1(SyftObject): + __canonical_name__ = "UserView" + __version__ = SYFT_OBJECT_VERSION_1 + + email: EmailStr + name: str + role: ServiceRole # make sure role cant be set without uid + institution: Optional[str] + website: Optional[str] + + @serializable() class UserView(SyftObject): __canonical_name__ = "UserView" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 email: EmailStr name: str role: ServiceRole # make sure role cant be set without uid institution: Optional[str] website: Optional[str] + mock_execution_permission: bool __repr_attrs__ = ["name", "email", "institution", "website", "role"] @@ -242,6 +305,7 @@ def update( institution: Union[Empty, str] = Empty, website: Union[str, Empty] = Empty, role: Union[str, Empty] = Empty, + mock_execution_permission: Union[bool, Empty] = Empty, ) -> Union[SyftSuccess, SyftError]: """Used to update name, institution, website of a user.""" api = APIRegistry.api_for( @@ -255,6 +319,7 @@ def update( institution=institution, website=website, role=role, + mock_execution_permission=mock_execution_permission, ) result = api.services.user.update(uid=self.id, user_update=user_update) @@ -266,6 +331,9 @@ def update( return SyftSuccess(message="User details successfully updated.") + def allow_mock_execution(self, allow: bool = True) -> Union[SyftSuccess, SyftError]: + return self.update(mock_execution_permission=allow) + @serializable() class UserViewPage(SyftObject): @@ -300,7 +368,19 @@ def user_create_to_user() -> List[Callable]: @transform(User, UserView) def user_to_view_user() -> List[Callable]: - return [keep(["id", "email", "name", "role", "institution", "website"])] + return [ + keep( + [ + "id", + "email", + "name", + "role", + "institution", + "website", + "mock_execution_permission", + ] + ) + ] @serializable() @@ -316,3 +396,43 @@ class UserPrivateKey(SyftObject): @transform(User, UserPrivateKey) def user_to_user_verify() -> List[Callable]: return [keep(["email", "signing_key", "id", "role"])] + + +@migrate(UserV1, User) +def upgrade_user_v1_to_v2(): + return [make_set_default(key="mock_execution_permission", value=False)] + + +@migrate(User, UserV1) +def downgrade_user_v2_to_v1(): + return [drop(["mock_execution_permission"])] + + +@migrate(UserUpdateV1, UserUpdate) +def upgrade_user_update_v1_to_v2(): + return [make_set_default(key="mock_execution_permission", value=False)] + + +@migrate(UserUpdate, UserUpdateV1) +def downgrade_user_update_v2_to_v1(): + return [drop(["mock_execution_permission"])] + + +@migrate(UserCreateV1, UserCreate) +def upgrade_user_create_v1_to_v2(): + return [make_set_default(key="mock_execution_permission", value=False)] + + +@migrate(UserCreate, UserCreateV1) +def downgrade_user_create_v2_to_v1(): + return [drop(["mock_execution_permission"])] + + +@migrate(UserViewV1, UserView) +def upgrade_user_view_v1_to_v2(): + return [make_set_default(key="mock_execution_permission", value=False)] + + +@migrate(UserView, UserViewV1) +def downgrade_user_view_v2_to_v1(): + return [drop(["mock_execution_permission"])] diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 0cca4a98529..8b7c7e2a91d 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -220,12 +220,14 @@ def update( self, context: AuthedServiceContext, uid: UID, user_update: UserUpdate ) -> Union[UserView, SyftError]: updates_role = user_update.role is not Empty + can_edit_roles = ServiceRoleCapability.CAN_EDIT_ROLES in context.capabilities() - if ( - updates_role - and ServiceRoleCapability.CAN_EDIT_ROLES not in context.capabilities() - ): + if updates_role and not can_edit_roles: return SyftError(message=f"{context.role} is not allowed to edit roles") + if (user_update.mock_execution_permission is not Empty) and not can_edit_roles: + return SyftError( + message=f"{context.role} is not allowed to update permissions" + ) # Get user to be updated by its UID result = self.stash.get_by_uid(credentials=context.credentials, uid=uid) diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 095f197f891..0a278370a6b 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -825,7 +825,7 @@ def __init__(self, *args, **kwargs) -> None: fields_with_default = set() for _field_name, _field in self.__fields__.items(): - if _field.default or _field.allow_none: + if _field.default is not None or _field.allow_none: fields_with_default.add(_field_name) # Fields whose values are set via a validator hook diff --git a/packages/syft/tests/syft/syft_functions/syft_function_test.py b/packages/syft/tests/syft/syft_functions/syft_function_test.py index cfa0d40e260..c81ae3d4561 100644 --- a/packages/syft/tests/syft/syft_functions/syft_function_test.py +++ b/packages/syft/tests/syft/syft_functions/syft_function_test.py @@ -17,11 +17,13 @@ @pytest.fixture def node(): + random.seed() + name = f"nested_job_test_domain-{random.randint(0,1000)}" _node = sy.orchestra.launch( - name="nested_job_test_domain", + name=name, dev_mode=True, reset=True, - n_consumers=3, + n_consumers=4, create_producer=True, queue_port=random.randint(13000, 13300), ) @@ -42,7 +44,7 @@ def test_nested_jobs(node): ## Dataset x = ActionObject.from_obj([1, 2]) - x_ptr = x.send(ds_client) + x_ptr = x.send(client) ## aggregate function @sy.syft_function() @@ -89,7 +91,6 @@ def process_all(domain, x): job = ds_client.code.process_all(x=x_ptr, blocking=False) job.wait() - # stdlib assert len(job.subjobs) == 3 # stdlib diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 2e66925649a..1b2de804655 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -29,8 +29,22 @@ def test_func_2(): return 1 -def test_user_code(worker, guest_client: User) -> None: - # test_func() +def test_user_code(worker) -> None: + root_domain_client = worker.root_client + root_domain_client.register( + name="data-scientist", + email="test_user@openmined.org", + password="0000", + password_verify="0000", + ) + guest_client = root_domain_client.login( + email="test_user@openmined.org", + password="0000", + ) + + users = root_domain_client.users.get_all() + users[-1].allow_mock_execution() + guest_client.api.services.code.request_code_execution(test_func) root_domain_client = worker.root_client @@ -145,3 +159,117 @@ def test_nested_requests(worker, guest_client: User): assert linked_obj.resolve.id == inner.id assert outer.status.approved assert not inner.status.approved + + +def test_user_code_mock_execution(worker) -> None: + # Setup + root_domain_client = worker.root_client + + # TODO guest_client fixture is not in root_domain_client.users + root_domain_client.register( + name="data-scientist", + email="test_user@openmined.org", + password="0000", + password_verify="0000", + ) + ds_client = root_domain_client.login( + email="test_user@openmined.org", + password="0000", + ) + + dataset = sy.Dataset( + name="my-dataset", + asset_list=[ + sy.Asset( + name="numpy-data", + data=np.array([0, 1, 2, 3, 4]), + mock=np.array([5, 6, 7, 8, 9]), + ) + ], + ) + root_domain_client.upload_dataset(dataset) + + # DS requests code execution + data = ds_client.datasets[0].assets[0] + + @sy.syft_function_single_use(data=data) + def compute_mean(data): + return data.mean() + + compute_mean.code = dedent(compute_mean.code) + ds_client.api.services.code.request_code_execution(compute_mean) + + # Guest attempts to set own permissions + guest_user = ds_client.users.get_current_user() + res = guest_user.allow_mock_execution() + assert isinstance(res, SyftError) + + # Mock execution fails, no permissions + result = ds_client.api.services.code.compute_mean(data=data.mock) + assert isinstance(result, SyftError) + + # DO grants permissions + users = root_domain_client.users.get_all() + guest_user = [u for u in users if u.id == guest_user.id][0] + guest_user.allow_mock_execution() + + # Mock execution succeeds + result = ds_client.api.services.code.compute_mean(data=data.mock).get() + assert isinstance(result, float) + + +def test_mock_multiple_arguments(worker) -> None: + # Setup + root_domain_client = worker.root_client + + root_domain_client.register( + name="data-scientist", + email="test_user@openmined.org", + password="0000", + password_verify="0000", + ) + ds_client = root_domain_client.login( + email="test_user@openmined.org", + password="0000", + ) + + dataset = sy.Dataset( + name="my-dataset", + asset_list=[ + sy.Asset( + name="numpy-data", + data=np.array([0, 1, 2, 3, 4]), + mock=np.array([5, 6, 7, 8, 9]), + ) + ], + ) + root_domain_client.upload_dataset(dataset) + users = root_domain_client.users.get_all() + users[-1].allow_mock_execution() + + # DS requests code execution + data = ds_client.datasets[0].assets[0] + + @sy.syft_function_single_use(data1=data, data2=data) + def compute_sum(data1, data2): + return data1 + data2 + + compute_sum.code = dedent(compute_sum.code) + ds_client.api.services.code.request_code_execution(compute_sum) + root_domain_client.requests[-1].approve() + + # Mock execution succeeds, result not cached + result = ds_client.api.services.code.compute_sum(data1=1, data2=1) + assert result.get() == 2 + + # Mixed execution fails on input policy + result = ds_client.api.services.code.compute_sum(data1=1, data2=data) + assert isinstance(result, SyftError) + + # Real execution succeeds + result = ds_client.api.services.code.compute_sum(data1=data, data2=data) + assert np.equal(result.get(), np.array([0, 2, 4, 6, 8])).all() + + # Mixed execution fails, no result from cache + result = ds_client.api.services.code.compute_sum(data1=1, data2=data) + assert isinstance(result, SyftError)