Skip to content

Commit

Permalink
Merge pull request #8413 from OpenMined/eelco/execute-on-mock
Browse files Browse the repository at this point in the history
Execute usercode on mock data
  • Loading branch information
eelcovdw authored Jan 24, 2024
2 parents a708983 + 4879553 commit 9affd9d
Show file tree
Hide file tree
Showing 12 changed files with 464 additions and 44 deletions.
59 changes: 51 additions & 8 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,21 @@ class RemoteFunction(SyftObject):
def __ipython_inspector_signature_override__(self) -> Optional[Signature]:
return self.signature

def prepare_args_and_kwargs(
self, args: List[Any], kwargs: Dict[str, Any]
) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]:
# Validate and migrate args and kwargs
res = validate_callable_args_and_kwargs(args, kwargs, self.signature)
if isinstance(res, SyftError):
return res
args, kwargs = res

args, kwargs = migrate_args_and_kwargs(
to_protocol=self.communication_protocol, args=args, kwargs=kwargs
)

return args, kwargs

def __call__(self, *args, **kwargs):
if "blocking" in self.signature.parameters:
raise Exception(
Expand All @@ -248,17 +263,11 @@ def __call__(self, *args, **kwargs):
blocking = bool(kwargs["blocking"])
del kwargs["blocking"]

# Migrate args and kwargs to communication protocol
args, kwargs = migrate_args_and_kwargs(
to_protocol=self.communication_protocol, args=args, kwargs=kwargs
)

res = validate_callable_args_and_kwargs(args, kwargs, self.signature)

res = self.prepare_args_and_kwargs(args, kwargs)
if isinstance(res, SyftError):
return res
_valid_args, _valid_kwargs = res

_valid_args, _valid_kwargs = res
if self.pre_kwargs:
_valid_kwargs.update(self.pre_kwargs)

Expand Down Expand Up @@ -289,6 +298,33 @@ class RemoteUserCodeFunction(RemoteFunction):
__version__ = SYFT_OBJECT_VERSION_1
__repr_attrs__ = RemoteFunction.__repr_attrs__ + ["user_code_id"]

api: SyftAPI

def prepare_args_and_kwargs(
self, args: List[Any], kwargs: Dict[str, Any]
) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]:
# relative
from ..service.action.action_object import convert_to_pointers

# Validate and migrate args and kwargs
res = validate_callable_args_and_kwargs(args, kwargs, self.signature)
if isinstance(res, SyftError):
return res
args, kwargs = res

args, kwargs = convert_to_pointers(
api=self.api,
node_uid=self.node_uid,
args=args,
kwargs=kwargs,
)

args, kwargs = migrate_args_and_kwargs(
to_protocol=self.communication_protocol, args=args, kwargs=kwargs
)

return args, kwargs

@property
def user_code_id(self) -> Optional[UID]:
return self.pre_kwargs.get("uid", None)
Expand All @@ -308,6 +344,7 @@ def jobs(self) -> Union[List[Job], SyftError]:


def generate_remote_function(
api: SyftAPI,
node_uid: UID,
signature: Signature,
path: str,
Expand All @@ -324,6 +361,7 @@ def generate_remote_function(
# UserCodes are always code.call with a user_code_id
if path == "code.call" and pre_kwargs is not None and "uid" in pre_kwargs:
remote_function = RemoteUserCodeFunction(
api=api,
node_uid=node_uid,
signature=signature,
path=path,
Expand Down Expand Up @@ -742,6 +780,7 @@ def build_endpoint_tree(endpoints, communication_protocol):
signature = signature_remove_context(signature)
if isinstance(v, APIEndpoint):
endpoint_function = generate_remote_function(
self,
self.node_uid,
signature,
v.service_path,
Expand Down Expand Up @@ -1007,3 +1046,7 @@ def validate_callable_args_and_kwargs(args, kwargs, signature: Signature):
_valid_args.append(arg)

return _valid_args, _valid_kwargs


RemoteFunction.update_forward_refs()
RemoteUserCodeFunction.update_forward_refs()
31 changes: 25 additions & 6 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,21 +1247,33 @@ def add_queueitem_to_queue(
return job

def _get_existing_user_code_jobs(
self, user_code_id: UID, credentials: SyftVerifyKey
self, context: AuthedServiceContext, user_code_id: UID
) -> Union[List[Job], SyftError]:
role = self.get_role_for_credentials(credentials=credentials)
context = AuthedServiceContext(node=self, credentials=credentials, role=role)

job_service = self.get_service("jobservice")
return job_service.get_by_user_code_id(
context=context, user_code_id=user_code_id
)

def _is_usercode_call_on_owned_kwargs(
self, context: AuthedServiceContext, api_call: SyftAPICall
) -> bool:
if api_call.path != "code.call":
return False
user_code_service = self.get_service("usercodeservice")
return user_code_service.is_execution_on_owned_args(api_call.kwargs, context)

def add_api_call_to_queue(self, api_call, parent_job_id=None):
unsigned_call = api_call
if isinstance(api_call, SignedSyftAPICall):
unsigned_call = api_call.message

credentials = api_call.credentials
context = AuthedServiceContext(
node=self,
credentials=credentials,
role=self.get_role_for_credentials(credentials=credentials),
)

is_user_code = unsigned_call.path == "code.call"

service_str, method_str = unsigned_call.path.split(".")
Expand All @@ -1270,9 +1282,16 @@ def add_api_call_to_queue(self, api_call, parent_job_id=None):
if is_user_code:
action = Action.from_api_call(unsigned_call)

if self.node_side_type == NodeSideType.LOW_SIDE:
is_usercode_call_on_owned_kwargs = self._is_usercode_call_on_owned_kwargs(
context, unsigned_call
)
# Low side does not execute jobs, unless this is a mock execution
if (
not is_usercode_call_on_owned_kwargs
and self.node_side_type == NodeSideType.LOW_SIDE
):
existing_jobs = self._get_existing_user_code_jobs(
action.user_code_id, api_call.credentials
context, action.user_code_id
)
if isinstance(existing_jobs, SyftError):
return existing_jobs
Expand Down
48 changes: 48 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,54 @@
"hash": "cf26eeac3d9254dfa439917493b816341f8a379a77d182bbecba3b7ed2c1d00a",
"action": "add"
}
},
"User": {
"1": {
"version": 1,
"hash": "078636e64f737e60245b39cf348d30fb006531e80c12b70aa7cf98254e1bb37a",
"action": "remove"
},
"2": {
"version": 2,
"hash": "ded970c92f202716ed33a2117cf541789f35fad66bd4b1db39da5026b1d7d0e7",
"action": "add"
}
},
"UserUpdate": {
"1": {
"version": 1,
"hash": "839dd90aeb611e1dc471c8fd6daf230e913465c0625c6a297079cb7f0a271195",
"action": "remove"
},
"2": {
"version": 2,
"hash": "32cba8fbd786c575f92e26c31384d282e68e3ebfe5c4b0a0e793820b1228d246",
"action": "add"
}
},
"UserCreate": {
"1": {
"version": 1,
"hash": "dab78b63544ae91c09f9843c323cb237c0a6fcfeb71c1acf5f738e2fcf5c277f",
"action": "remove"
},
"2": {
"version": 2,
"hash": "2540188c5aaea866914dccff459df6e0f4727108a503414bb1567ff6297d4646",
"action": "add"
}
},
"UserView": {
"1": {
"version": 1,
"hash": "63289383fe7e7584652f242a4362ce6e2f0ade52f6416ab6149b326a506b0675",
"action": "remove"
},
"2": {
"version": 2,
"hash": "e410de583bb15bc5af57acef7be55ea5fc56b5b0fc169daa3869f4203c4d7473",
"action": "add"
}
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,14 @@ def convert_to_pointers(
args: Optional[List] = None,
kwargs: Optional[Dict] = None,
) -> Tuple[List, Dict]:
# relative
from ..dataset.dataset import Asset

arg_list = []
kwarg_dict = {}
if args is not None:
for arg in args:
if not isinstance(arg, ActionObject):
if not isinstance(arg, (ActionObject, Asset, UID)):
arg = ActionObject.from_obj(
syft_action_data=arg,
syft_client_verify_key=api.signing_key.verify_key,
Expand All @@ -443,7 +446,7 @@ def convert_to_pointers(

if kwargs is not None:
for k, arg in kwargs.items():
if not isinstance(arg, ActionObject):
if not isinstance(arg, (ActionObject, Asset, UID)):
arg = ActionObject.from_obj(
syft_action_data=arg,
syft_client_verify_key=api.signing_key.verify_key,
Expand Down
20 changes: 15 additions & 5 deletions packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def np_array(self, context: AuthedServiceContext, data: Any) -> Any:
if isinstance(blob_store_result, SyftError):
return blob_store_result

np_pointer = self.set(context, np_obj)
np_pointer = self._set(context, np_obj)
return np_pointer

@service_method(
Expand All @@ -80,6 +80,14 @@ def set(
self,
context: AuthedServiceContext,
action_object: Union[ActionObject, TwinObject],
) -> Result[ActionObject, str]:
return self._set(context, action_object, has_result_read_permission=True)

def _set(
self,
context: AuthedServiceContext,
action_object: Union[ActionObject, TwinObject],
has_result_read_permission: bool = False,
) -> Result[ActionObject, str]:
"""Save an object to the action store"""
# 🟡 TODO 9: Create some kind of type checking / protocol for SyftSerializable
Expand All @@ -90,8 +98,10 @@ def set(
action_object.private_obj.syft_created_at = DateTime.now()
action_object.mock_obj.syft_created_at = DateTime.now()

has_result_read_permission = context.extra_kwargs.get(
"has_result_read_permission", False
# If either context or argument is True, has_result_read_permission is True
has_result_read_permission = (
context.extra_kwargs.get("has_result_read_permission", False)
or has_result_read_permission
)

result = self.store.set(
Expand Down Expand Up @@ -404,7 +414,7 @@ def set_result_to_store(self, result_action_object, context, output_policy=None)
# pass permission information to the action store as extra kwargs
context.extra_kwargs = {"has_result_read_permission": True}

set_result = self.set(context, result_action_object)
set_result = self._set(context, result_action_object)

if set_result.is_err():
return set_result
Expand Down Expand Up @@ -664,7 +674,7 @@ def execute(
"has_result_read_permission": has_result_read_permission
}

set_result = self.set(context, result_action_object)
set_result = self._set(context, result_action_object)
if set_result.is_err():
return Err(
f"Failed executing action {action}, set result is an error: {set_result.err()}"
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ def launch_job(func: UserCode, **kwargs):
kw2id = {}
for k, v in kwargs.items():
value = ActionObject.from_obj(v)
ptr = action_service.set(context, value)
ptr = action_service._set(context, value)
ptr = ptr.ok()
kw2id[k] = ptr.id
try:
Expand Down
Loading

0 comments on commit 9affd9d

Please sign in to comment.