From 81d567a503b802ddaf6bb69a47d589f5fe3f20cb Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 13 Feb 2024 15:15:25 +0100 Subject: [PATCH 1/4] add get_dependents --- .../syft/src/syft/service/code/user_code.py | 66 +++++++++++++++++++ .../syft/src/syft/service/job/job_stash.py | 37 +++++++++++ .../syft/src/syft/service/request/request.py | 20 ++++++ 3 files changed, 123 insertions(+) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index a449845115d..b8d0cb86b54 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -61,6 +61,7 @@ from ...util.markdown import as_markdown_code from ..action.action_object import Action from ..action.action_object import ActionObject +from ..action.action_object import TwinMode from ..context import AuthedServiceContext from ..dataset.dataset import Asset from ..job.job_stash import Job @@ -554,6 +555,71 @@ def assets(self) -> List[Asset]: all_assets += assets return all_assets + def get_dependents( + self, visited: Optional[List[str]] = None + ) -> Dict[str, List[Any]]: + # Usercode dependents are: input_policy inputs, output_policy outputs, nested_codes + + visited = visited or [] + visited = visited + [self.id] + dependents = {self.id: []} + + # NOTE input and output policy are stored directly on the code object, + # so dependents are on the code object as well + api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) + action_service = api.services.action + input_policy = self.input_policy + if input_policy is not None: + all_input_ids = [] + for _, inputs in input_policy.inputs.items(): + all_input_ids.extend(inputs.values()) + + for input_id in all_input_ids: + dependents[self.id].append( + action_service.get(input_id, twin_mode=TwinMode.NONE) + ) + + output_policy = self.output_policy + if output_policy is not None: + all_output_ids = [] + for output in output_policy.output_history: + if isinstance(output.outputs, list): + all_output_ids.extend(output.outputs) + else: + all_output_ids.extend(output.outputs.values()) + + for output_id in all_output_ids: + dependents[self.id].append( + action_service.get(output_id, twin_mode=TwinMode.NONE) + ) + + if self.nested_codes is not None: + for obj_link, _ in self.nested_codes.values(): + if visited and obj_link.id in visited: + continue + obj = obj_link.resolve + deps = obj.get_dependents(visited=visited) + + visited.extend(list(deps.keys())) + dependents.update(deps) + + job_service = api.services.job + user_code_jobs = job_service.get_by_user_code_id(self.id) + for job in user_code_jobs: + dependents[self.id].append(job) + if job.id not in visited: + visited.append(job.id) + job_dependents = job.get_dependents(visited=visited) + for k, v in job_dependents.items(): + if k not in dependents: + dependents[k] = v + + return dependents + + # remove empty values + dependents = {k: v for k, v in dependents.items() if len(v) > 0} + return dependents + @property def unsafe_function(self) -> Optional[Callable]: warning = SyftWarning( diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index efa03deac94..fcdd9d8e0e8 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -453,6 +453,43 @@ def resolve(self) -> Union[Any, SyftNotReady]: return self.result return SyftNotReady(message=f"{self.id} not ready yet.") + def get_dependents( + self, visited: Optional[List[UID]] = None + ) -> Dict[UID, List[Any]]: + # result, usercode, logs, subjobs + visited = visited or [] + visited = visited + [self.id] + + api = APIRegistry.api_for( + node_uid=self.node_uid, + user_verify_key=self.syft_client_verify_key, + ) + + dependents = {self.id: []} + result_id = self.result.id + if result_id not in visited: + result_obj = api.services.action.get(result_id, resolve_nested=False) + dependents[self.id].append(result_obj) + + if self.user_code_id not in visited: + user_code_obj = api.services.code.get_by_id(self.user_code_id) + dependents[self.id].append(user_code_obj) + + if self.log_id not in visited: + log_obj = api.services.log.get(self.log_id) + print("log_obj", type(log_obj)) + dependents[self.id].append(log_obj) + + for subjob in self.subjobs: + if subjob.id not in visited: + dependents[self.id].append(subjob) + sub_dependents = subjob.get_dependents(visited=visited) + for key, value in sub_dependents.items(): + if key not in dependents: + dependents[key] = value + + return dependents + @serializable() class JobInfo(SyftObject): diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 699e577e2b9..0f6577d6f78 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -756,6 +756,26 @@ def sync_job(self, job_info: JobInfo, **kwargs) -> Result[SyftSuccess, SyftError job.apply_info(job_info) return job_service.update(job) + def get_dependents( + self, visited: Optional[List[UID]] = None + ) -> Dict[str, SyftObject]: + visited = visited or [] + visited = visited + [self.id] + dependents = {self.id: []} + if not isinstance(self.codes, SyftError): + dependents[self.id].extend(self.codes) + + for dep in dependents[self.id]: + code_deps = dep.get_dependents(visited=visited) + for k, v in code_deps.items(): + if k not in dependents: + dependents[k] = v + visited.append(k) + return dependents + + def get_dependencies(self): + return [] + @serializable() class RequestInfo(SyftObject): From 4d5d9dc1b937613db0fb6821ed9e86c062f0708c Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 14 Feb 2024 09:56:14 +0100 Subject: [PATCH 2/4] update get_dependencies --- .../syft/src/syft/service/code/user_code.py | 38 +++++++------------ .../syft/src/syft/service/job/job_stash.py | 32 +++++++--------- packages/syft/src/syft/service/log/log.py | 7 +++- .../syft/src/syft/service/request/request.py | 36 +++++++++--------- 4 files changed, 48 insertions(+), 65 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b8d0cb86b54..ff9144479ea 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -20,6 +20,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Set from typing import Tuple from typing import Type from typing import Union @@ -555,14 +556,14 @@ def assets(self) -> List[Asset]: all_assets += assets return all_assets - def get_dependents( - self, visited: Optional[List[str]] = None + def get_dependencies( + self, visited: Optional[Set[str]] = None ) -> Dict[str, List[Any]]: # Usercode dependents are: input_policy inputs, output_policy outputs, nested_codes - visited = visited or [] - visited = visited + [self.id] - dependents = {self.id: []} + visited = visited or set() + visited.add(self.id) + dependencies = {self.id: []} # NOTE input and output policy are stored directly on the code object, # so dependents are on the code object as well @@ -575,7 +576,7 @@ def get_dependents( all_input_ids.extend(inputs.values()) for input_id in all_input_ids: - dependents[self.id].append( + dependencies[self.id].append( action_service.get(input_id, twin_mode=TwinMode.NONE) ) @@ -589,7 +590,7 @@ def get_dependents( all_output_ids.extend(output.outputs.values()) for output_id in all_output_ids: - dependents[self.id].append( + dependencies[self.id].append( action_service.get(output_id, twin_mode=TwinMode.NONE) ) @@ -598,27 +599,14 @@ def get_dependents( if visited and obj_link.id in visited: continue obj = obj_link.resolve - deps = obj.get_dependents(visited=visited) + deps = obj.get_dependencies(visited=visited) - visited.extend(list(deps.keys())) - dependents.update(deps) - - job_service = api.services.job - user_code_jobs = job_service.get_by_user_code_id(self.id) - for job in user_code_jobs: - dependents[self.id].append(job) - if job.id not in visited: - visited.append(job.id) - job_dependents = job.get_dependents(visited=visited) - for k, v in job_dependents.items(): - if k not in dependents: - dependents[k] = v - - return dependents + visited.update(deps.keys()) + dependencies.update(deps) # remove empty values - dependents = {k: v for k, v in dependents.items() if len(v) > 0} - return dependents + dependencies = {k: v for k, v in dependencies.items() if len(v) > 0} + return dependencies @property def unsafe_function(self) -> Optional[Callable]: diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 1a8dacf288c..c83cf68c03d 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -6,6 +6,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Set from typing import Union # third party @@ -334,7 +335,6 @@ def _get_log_objs(self): ) return api.services.log.get(self.log_id) - def logs(self, stdout=True, stderr=True, _print=True): api = APIRegistry.api_for( node_uid=self.node_uid, @@ -461,42 +461,36 @@ def resolve(self) -> Union[Any, SyftNotReady]: return self.result return SyftNotReady(message=f"{self.id} not ready yet.") - def get_dependents( - self, visited: Optional[List[UID]] = None + def get_dependencies( + self, visited: Optional[Set[UID]] = None ) -> Dict[UID, List[Any]]: # result, usercode, logs, subjobs - visited = visited or [] - visited = visited + [self.id] + visited = visited or set() + visited.add(self.id) api = APIRegistry.api_for( node_uid=self.node_uid, user_verify_key=self.syft_client_verify_key, ) - dependents = {self.id: []} + dependencies = {self.id: []} result_id = self.result.id if result_id not in visited: result_obj = api.services.action.get(result_id, resolve_nested=False) - dependents[self.id].append(result_obj) - - if self.user_code_id not in visited: - user_code_obj = api.services.code.get_by_id(self.user_code_id) - dependents[self.id].append(user_code_obj) + dependencies[self.id].append(result_obj) if self.log_id not in visited: log_obj = api.services.log.get(self.log_id) - print("log_obj", type(log_obj)) - dependents[self.id].append(log_obj) + dependencies[self.id].append(log_obj) for subjob in self.subjobs: if subjob.id not in visited: - dependents[self.id].append(subjob) - sub_dependents = subjob.get_dependents(visited=visited) - for key, value in sub_dependents.items(): - if key not in dependents: - dependents[key] = value + dependencies[self.id].append(subjob) + sub_dependents = subjob.get_dependencies(visited=visited) + dependencies.update(sub_dependents) + visited.update(sub_dependents.keys()) - return dependents + return dependencies @serializable() diff --git a/packages/syft/src/syft/service/log/log.py b/packages/syft/src/syft/service/log/log.py index bd7f03810f5..2cb87024721 100644 --- a/packages/syft/src/syft/service/log/log.py +++ b/packages/syft/src/syft/service/log/log.py @@ -1,10 +1,11 @@ # relative -from syft.types.syft_migration import migrate -from syft.types.transforms import drop, make_set_default from ...serde.serializable import serializable +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject +from ...types.transforms import drop +from ...types.transforms import make_set_default @serializable() @@ -33,12 +34,14 @@ def restart(self) -> None: self.stderr = "" self.stdout = "" + @migrate(SyftLogV1, SyftLog) def upgrade_syftlog_v1_to_v2(): return [ make_set_default("stderr", ""), ] + @migrate(SyftLog, SyftLogV1) def downgrade_syftlog_v2_to_v1(): return [ diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 0f3a3c1c6ed..1dec107c6cc 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -7,6 +7,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Set from typing import Type from typing import Union @@ -343,8 +344,10 @@ class Request(SyftObject): request_hash: str changes: List[Change] history: List[ChangeStatus] = [] - - __dependent_objects__ = [".usercode.something", ] + + __dependent_objects__ = [ + ".usercode.something", + ] __attr_searchable__ = [ "requesting_user_verify_key", @@ -758,25 +761,20 @@ def sync_job(self, job_info: JobInfo, **kwargs) -> Result[SyftSuccess, SyftError job.apply_info(job_info) return job_service.update(job) - def get_dependents( - self, visited: Optional[List[UID]] = None + def get_dependencies( + self, visited: Optional[Set[UID]] = None ) -> Dict[str, SyftObject]: - visited = visited or [] - visited = visited + [self.id] - dependents = {self.id: []} + dependencies = {} if not isinstance(self.codes, SyftError): - dependents[self.id].extend(self.codes) - - for dep in dependents[self.id]: - code_deps = dep.get_dependents(visited=visited) - for k, v in code_deps.items(): - if k not in dependents: - dependents[k] = v - visited.append(k) - return dependents - - def get_dependencies(self): - return [] + dependencies[self.id] = self.codes + + visited = visited or set() + visited.add(self.id) + for dep in dependencies.get(self.id, []): + code_deps = dep.get_dependencies(visited=visited) + dependencies.update(code_deps) + visited.update(code_deps.keys()) + return dependencies @serializable() From 536c59eb0cf9f20c97192c11b0a4efe5be5ea1dc Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 14 Feb 2024 09:57:17 +0100 Subject: [PATCH 3/4] cleanup --- packages/syft/src/syft/service/request/request.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 1dec107c6cc..a1553a897f2 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -345,10 +345,6 @@ class Request(SyftObject): changes: List[Change] history: List[ChangeStatus] = [] - __dependent_objects__ = [ - ".usercode.something", - ] - __attr_searchable__ = [ "requesting_user_verify_key", "approving_user_verify_key", From 643d91b2d0e8f4bdd36bde57050c4c4be1c2987f Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 14 Feb 2024 12:16:33 +0100 Subject: [PATCH 4/4] fix input policy before approve --- .../src/syft/service/action/action_object.py | 4 +++- .../syft/src/syft/service/code/user_code.py | 17 ++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index cdf9ee483e0..ded7b9c3a87 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -567,7 +567,7 @@ def debox_args_and_kwargs(args: Any, kwargs: Any) -> Tuple[Any, Any]: "syft_action_data_cache", "reload_cache", "syft_resolved", - "refresh_object" + "refresh_object", ] @@ -1091,7 +1091,9 @@ def get_from(self, client: SyftClient) -> Any: return res.syft_action_data def refresh_object(self): + # relative from ...client.api import APIRegistry + api = APIRegistry.api_for( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index ff9144479ea..4954c13eb8e 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -569,16 +569,15 @@ def get_dependencies( # so dependents are on the code object as well api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) action_service = api.services.action - input_policy = self.input_policy - if input_policy is not None: - all_input_ids = [] - for _, inputs in input_policy.inputs.items(): - all_input_ids.extend(inputs.values()) - for input_id in all_input_ids: - dependencies[self.id].append( - action_service.get(input_id, twin_mode=TwinMode.NONE) - ) + all_input_ids = [] + for _, inputs in self.input_policy_init_kwargs.items(): + all_input_ids.extend(inputs.values()) + + for input_id in all_input_ids: + dependencies[self.id].append( + action_service.get(input_id, twin_mode=TwinMode.NONE) + ) output_policy = self.output_policy if output_policy is not None: