Skip to content

Commit

Permalink
Merge pull request #8482 from OpenMined/eelco/sync-dependencies
Browse files Browse the repository at this point in the history
dependencies for sync
  • Loading branch information
eelcovdw authored Feb 14, 2024
2 parents 7f02029 + 643d91b commit 6319a5e
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 6 deletions.
4 changes: 3 additions & 1 deletion packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def debox_args_and_kwargs(args: Any, kwargs: Any) -> Tuple[Any, Any]:
"syft_action_data_cache",
"reload_cache",
"syft_resolved",
"refresh_object"
"refresh_object",
]


Expand Down Expand Up @@ -1091,7 +1091,9 @@ def get_from(self, client: SyftClient) -> Any:
return res.syft_action_data

def refresh_object(self):
# relative
from ...client.api import APIRegistry

api = APIRegistry.api_for(
node_uid=self.syft_node_location,
user_verify_key=self.syft_client_verify_key,
Expand Down
53 changes: 53 additions & 0 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
from typing import Union
Expand Down Expand Up @@ -61,6 +62,7 @@
from ...util.markdown import as_markdown_code
from ..action.action_object import Action
from ..action.action_object import ActionObject
from ..action.action_object import TwinMode
from ..context import AuthedServiceContext
from ..dataset.dataset import Asset
from ..job.job_stash import Job
Expand Down Expand Up @@ -554,6 +556,57 @@ def assets(self) -> List[Asset]:
all_assets += assets
return all_assets

def get_dependencies(
self, visited: Optional[Set[str]] = None
) -> Dict[str, List[Any]]:
# Usercode dependents are: input_policy inputs, output_policy outputs, nested_codes

visited = visited or set()
visited.add(self.id)
dependencies = {self.id: []}

# NOTE input and output policy are stored directly on the code object,
# so dependents are on the code object as well
api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key)
action_service = api.services.action

all_input_ids = []
for _, inputs in self.input_policy_init_kwargs.items():
all_input_ids.extend(inputs.values())

for input_id in all_input_ids:
dependencies[self.id].append(
action_service.get(input_id, twin_mode=TwinMode.NONE)
)

output_policy = self.output_policy
if output_policy is not None:
all_output_ids = []
for output in output_policy.output_history:
if isinstance(output.outputs, list):
all_output_ids.extend(output.outputs)
else:
all_output_ids.extend(output.outputs.values())

for output_id in all_output_ids:
dependencies[self.id].append(
action_service.get(output_id, twin_mode=TwinMode.NONE)
)

if self.nested_codes is not None:
for obj_link, _ in self.nested_codes.values():
if visited and obj_link.id in visited:
continue
obj = obj_link.resolve
deps = obj.get_dependencies(visited=visited)

visited.update(deps.keys())
dependencies.update(deps)

# remove empty values
dependencies = {k: v for k, v in dependencies.items() if len(v) > 0}
return dependencies

@property
def unsafe_function(self) -> Optional[Callable]:
warning = SyftWarning(
Expand Down
33 changes: 32 additions & 1 deletion packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Union

# third party
Expand Down Expand Up @@ -334,7 +335,6 @@ def _get_log_objs(self):
)
return api.services.log.get(self.log_id)


def logs(self, stdout=True, stderr=True, _print=True):
api = APIRegistry.api_for(
node_uid=self.node_uid,
Expand Down Expand Up @@ -461,6 +461,37 @@ def resolve(self) -> Union[Any, SyftNotReady]:
return self.result
return SyftNotReady(message=f"{self.id} not ready yet.")

def get_dependencies(
self, visited: Optional[Set[UID]] = None
) -> Dict[UID, List[Any]]:
# result, usercode, logs, subjobs
visited = visited or set()
visited.add(self.id)

api = APIRegistry.api_for(
node_uid=self.node_uid,
user_verify_key=self.syft_client_verify_key,
)

dependencies = {self.id: []}
result_id = self.result.id
if result_id not in visited:
result_obj = api.services.action.get(result_id, resolve_nested=False)
dependencies[self.id].append(result_obj)

if self.log_id not in visited:
log_obj = api.services.log.get(self.log_id)
dependencies[self.id].append(log_obj)

for subjob in self.subjobs:
if subjob.id not in visited:
dependencies[self.id].append(subjob)
sub_dependents = subjob.get_dependencies(visited=visited)
dependencies.update(sub_dependents)
visited.update(sub_dependents.keys())

return dependencies


@serializable()
class JobInfo(SyftObject):
Expand Down
7 changes: 5 additions & 2 deletions packages/syft/src/syft/service/log/log.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# relative
from syft.types.syft_migration import migrate
from syft.types.transforms import drop, make_set_default
from ...serde.serializable import serializable
from ...types.syft_migration import migrate
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SYFT_OBJECT_VERSION_2
from ...types.syft_object import SyftObject
from ...types.transforms import drop
from ...types.transforms import make_set_default


@serializable()
Expand Down Expand Up @@ -33,12 +34,14 @@ def restart(self) -> None:
self.stderr = ""
self.stdout = ""


@migrate(SyftLogV1, SyftLog)
def upgrade_syftlog_v1_to_v2():
return [
make_set_default("stderr", ""),
]


@migrate(SyftLog, SyftLogV1)
def downgrade_syftlog_v2_to_v1():
return [
Expand Down
18 changes: 16 additions & 2 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Type
from typing import Union

Expand Down Expand Up @@ -343,8 +344,6 @@ class Request(SyftObject):
request_hash: str
changes: List[Change]
history: List[ChangeStatus] = []

__dependent_objects__ = [".usercode.something", ]

__attr_searchable__ = [
"requesting_user_verify_key",
Expand Down Expand Up @@ -758,6 +757,21 @@ def sync_job(self, job_info: JobInfo, **kwargs) -> Result[SyftSuccess, SyftError
job.apply_info(job_info)
return job_service.update(job)

def get_dependencies(
self, visited: Optional[Set[UID]] = None
) -> Dict[str, SyftObject]:
dependencies = {}
if not isinstance(self.codes, SyftError):
dependencies[self.id] = self.codes

visited = visited or set()
visited.add(self.id)
for dep in dependencies.get(self.id, []):
code_deps = dep.get_dependencies(visited=visited)
dependencies.update(code_deps)
visited.update(code_deps.keys())
return dependencies


@serializable()
class RequestInfo(SyftObject):
Expand Down

0 comments on commit 6319a5e

Please sign in to comment.