diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index cdf9ee483e0..ded7b9c3a87 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -567,7 +567,7 @@ def debox_args_and_kwargs(args: Any, kwargs: Any) -> Tuple[Any, Any]: "syft_action_data_cache", "reload_cache", "syft_resolved", - "refresh_object" + "refresh_object", ] @@ -1091,7 +1091,9 @@ def get_from(self, client: SyftClient) -> Any: return res.syft_action_data def refresh_object(self): + # relative from ...client.api import APIRegistry + api = APIRegistry.api_for( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index a449845115d..4954c13eb8e 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -20,6 +20,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Set from typing import Tuple from typing import Type from typing import Union @@ -61,6 +62,7 @@ from ...util.markdown import as_markdown_code from ..action.action_object import Action from ..action.action_object import ActionObject +from ..action.action_object import TwinMode from ..context import AuthedServiceContext from ..dataset.dataset import Asset from ..job.job_stash import Job @@ -554,6 +556,57 @@ def assets(self) -> List[Asset]: all_assets += assets return all_assets + def get_dependencies( + self, visited: Optional[Set[str]] = None + ) -> Dict[str, List[Any]]: + # Usercode dependents are: input_policy inputs, output_policy outputs, nested_codes + + visited = visited or set() + visited.add(self.id) + dependencies = {self.id: []} + + # NOTE input and output policy are stored directly on the code object, + # so dependents are on the code object as well + api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) + action_service = api.services.action + + all_input_ids = [] + for _, inputs in self.input_policy_init_kwargs.items(): + all_input_ids.extend(inputs.values()) + + for input_id in all_input_ids: + dependencies[self.id].append( + action_service.get(input_id, twin_mode=TwinMode.NONE) + ) + + output_policy = self.output_policy + if output_policy is not None: + all_output_ids = [] + for output in output_policy.output_history: + if isinstance(output.outputs, list): + all_output_ids.extend(output.outputs) + else: + all_output_ids.extend(output.outputs.values()) + + for output_id in all_output_ids: + dependencies[self.id].append( + action_service.get(output_id, twin_mode=TwinMode.NONE) + ) + + if self.nested_codes is not None: + for obj_link, _ in self.nested_codes.values(): + if visited and obj_link.id in visited: + continue + obj = obj_link.resolve + deps = obj.get_dependencies(visited=visited) + + visited.update(deps.keys()) + dependencies.update(deps) + + # remove empty values + dependencies = {k: v for k, v in dependencies.items() if len(v) > 0} + return dependencies + @property def unsafe_function(self) -> Optional[Callable]: warning = SyftWarning( diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index ac97b1406e8..c83cf68c03d 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -6,6 +6,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Set from typing import Union # third party @@ -334,7 +335,6 @@ def _get_log_objs(self): ) return api.services.log.get(self.log_id) - def logs(self, stdout=True, stderr=True, _print=True): api = APIRegistry.api_for( node_uid=self.node_uid, @@ -461,6 +461,37 @@ def resolve(self) -> Union[Any, SyftNotReady]: return self.result return SyftNotReady(message=f"{self.id} not ready yet.") + def get_dependencies( + self, visited: Optional[Set[UID]] = None + ) -> Dict[UID, List[Any]]: + # result, usercode, logs, subjobs + visited = visited or set() + visited.add(self.id) + + api = APIRegistry.api_for( + node_uid=self.node_uid, + user_verify_key=self.syft_client_verify_key, + ) + + dependencies = {self.id: []} + result_id = self.result.id + if result_id not in visited: + result_obj = api.services.action.get(result_id, resolve_nested=False) + dependencies[self.id].append(result_obj) + + if self.log_id not in visited: + log_obj = api.services.log.get(self.log_id) + dependencies[self.id].append(log_obj) + + for subjob in self.subjobs: + if subjob.id not in visited: + dependencies[self.id].append(subjob) + sub_dependents = subjob.get_dependencies(visited=visited) + dependencies.update(sub_dependents) + visited.update(sub_dependents.keys()) + + return dependencies + @serializable() class JobInfo(SyftObject): diff --git a/packages/syft/src/syft/service/log/log.py b/packages/syft/src/syft/service/log/log.py index bd7f03810f5..2cb87024721 100644 --- a/packages/syft/src/syft/service/log/log.py +++ b/packages/syft/src/syft/service/log/log.py @@ -1,10 +1,11 @@ # relative -from syft.types.syft_migration import migrate -from syft.types.transforms import drop, make_set_default from ...serde.serializable import serializable +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject +from ...types.transforms import drop +from ...types.transforms import make_set_default @serializable() @@ -33,12 +34,14 @@ def restart(self) -> None: self.stderr = "" self.stdout = "" + @migrate(SyftLogV1, SyftLog) def upgrade_syftlog_v1_to_v2(): return [ make_set_default("stderr", ""), ] + @migrate(SyftLog, SyftLogV1) def downgrade_syftlog_v2_to_v1(): return [ diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 53b21cf221e..a1553a897f2 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -7,6 +7,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Set from typing import Type from typing import Union @@ -343,8 +344,6 @@ class Request(SyftObject): request_hash: str changes: List[Change] history: List[ChangeStatus] = [] - - __dependent_objects__ = [".usercode.something", ] __attr_searchable__ = [ "requesting_user_verify_key", @@ -758,6 +757,21 @@ def sync_job(self, job_info: JobInfo, **kwargs) -> Result[SyftSuccess, SyftError job.apply_info(job_info) return job_service.update(job) + def get_dependencies( + self, visited: Optional[Set[UID]] = None + ) -> Dict[str, SyftObject]: + dependencies = {} + if not isinstance(self.codes, SyftError): + dependencies[self.id] = self.codes + + visited = visited or set() + visited.add(self.id) + for dep in dependencies.get(self.id, []): + code_deps = dep.get_dependencies(visited=visited) + dependencies.update(code_deps) + visited.update(code_deps.keys()) + return dependencies + @serializable() class RequestInfo(SyftObject):