Skip to content

Commit

Permalink
Merge pull request #8519 from OpenMined/eelco/sync-hierarchical-resolve
Browse files Browse the repository at this point in the history
hierarchical resolve for sync
  • Loading branch information
eelcovdw authored Feb 27, 2024
2 parents 9a0bf38 + f861ca0 commit ef4b8b4
Show file tree
Hide file tree
Showing 15 changed files with 1,757 additions and 588 deletions.
1,858 changes: 1,410 additions & 448 deletions notebooks/node syncing/syncing.ipynb

Large diffs are not rendered by default.

24 changes: 18 additions & 6 deletions packages/syft/src/syft/client/domain_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..service.dataset.dataset import CreateDataset
from ..service.response import SyftError
from ..service.response import SyftSuccess
from ..service.sync.diff_state import ResolvedSyncState
from ..service.user.roles import Roles
from ..service.user.user_roles import ServiceRole
from ..types.blob_storage import BlobFile
Expand Down Expand Up @@ -140,9 +141,6 @@ def upload_dataset(self, dataset: CreateDataset) -> Union[SyftSuccess, SyftError
return tuple(valid.err())
return valid.err()

def apply_state(self, resolved_sync_state):
self._sync_items(resolved_sync_state.low_side_state)

def create_actionobject(self, action_object):
print("syncing obj with blob id", action_object.syft_blob_storage_entry_id)
action_object = action_object.refresh_object()
Expand All @@ -155,8 +153,8 @@ def create_actionobject(self, action_object):

def get_permissions_for_other_node(self, items):
if len(items) > 0:
assert len(set([i.syft_node_location for i in items])) == 1
assert len(set([i.syft_client_verify_key for i in items])) == 1
assert len({i.syft_node_location for i in items}) == 1
assert len({i.syft_client_verify_key for i in items}) == 1
item = items[0]
api = APIRegistry.api_for(
item.syft_node_location, item.syft_client_verify_key
Expand All @@ -165,13 +163,27 @@ def get_permissions_for_other_node(self, items):
else:
return {}

def _sync_items(self, items):
def apply_state(
self, resolved_state: ResolvedSyncState
) -> Union[SyftSuccess, SyftError]:
if len(resolved_state.delete_objs):
raise NotImplementedError("TODO implement delete")
items = resolved_state.create_objs + resolved_state.update_objs

action_objects = [x for x in items if isinstance(x, ActionObject)]
permissions = self.get_permissions_for_other_node(items)
for action_object in action_objects:
self.create_actionobject(action_object)

res = self.api.services.sync.sync_items(items, permissions)
if isinstance(res, SyftError):
return res

# Add updated node state to store to have a previous_state for next sync
new_state = self.api.services.sync.get_state(add_to_store=True)
if isinstance(new_state, SyftError):
return new_state

self._fetch_api(self.credentials)
return res

Expand Down
91 changes: 44 additions & 47 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,53 @@
# third party
from IPython.display import Markdown
from IPython.display import display
# stdlib
from typing import Optional

# relative
from ..service.sync.diff_state import DiffState
from ..service.sync.diff_state import ResolvedSyncState
from ..service.sync.diff_state import display_diff_hierarchy
from ..service.sync.diff_state import resolve_diff


def compare_states(low_state, high_state) -> DiffState:
return DiffState.from_sync_state(low_state=low_state, high_state=high_state)


def resolve(state: DiffState, force_approve: bool = False):
low_new_objs = []
high_new_objs = []
# new_objs = state.objs_to_sync()
for new_obj in state.diffs:
if new_obj.merge_state == "NEW":
if new_obj.low_obj is None:
state_list = low_new_objs
source = "LOW"
destination = "HIGH"
obj_to_sync = new_obj.high_obj
if new_obj.high_obj is None:
state_list = high_new_objs
source = "HIGH"
destination = "LOW"
obj_to_sync = new_obj.low_obj
if hasattr(obj_to_sync, "_repr_markdown_"):
display(Markdown(obj_to_sync._repr_markdown_()))
else:
display(obj_to_sync)

if force_approve:
state_list.append(obj_to_sync)
else:
print(
f"Do you approve moving this object from the {source} side to the {destination} side (approve/deny): ",
flush=True,
)
while True:
decision = input()
if decision == "approve":
state_list.append(obj_to_sync)
break
elif decision == "deny":
break
else:
print("Please write `approve` or `deny`:", flush=True)
if new_obj.merge_state == "DIFF":
# TODO: this is a shortcut
state_list = low_new_objs
state_list.append(new_obj.high_obj)
# pass

return low_new_objs, high_new_objs
def get_user_input_for_resolve():
print(
"Do you want to keep the low state or the high state for these objects? choose 'low' or 'high'"
)

while True:
decision = input()
decision = decision.lower()

if decision in ["low", "high"]:
return decision
else:
print("Please choose between `low` or `high`")


def resolve(state: DiffState, decision: Optional[str] = None):
resolved_state_low = ResolvedSyncState()
resolved_state_high = ResolvedSyncState()

for diff_hierarchy in state.hierarchies:
if all(item.merge_state == "SAME" for item, _ in diff_hierarchy):
# Hierarchy has no diffs
continue

display_diff_hierarchy(diff_hierarchy)

if decision is None:
decision = get_user_input_for_resolve()
else:
print(f"Decision: Syncing all objects from {decision} side")

for diff, _ in diff_hierarchy:
low_resolved_diff, high_resolved_diff = resolve_diff(
diff, decision=decision
)
resolved_state_low.add(low_resolved_diff)
resolved_state_high.add(high_resolved_diff)

return resolved_state_low, resolved_state_high
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -1279,7 +1279,7 @@
"SyncStateItem": {
"1": {
"version": 1,
"hash": "3846e3429925f4009a9bf65ce8a8fa85751fee9937d200d947131822ef6f1533",
"hash": "7e1f22d0e24bb615b077d76feae7bed96a49a998358bd842aba18e8d69a22481",
"action": "add"
}
},
Expand Down
10 changes: 10 additions & 0 deletions packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..service import TYPE_TO_SERVICE
from ..service import UserLibConfigRegistry
from ..service import service_method
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import GUEST_ROLE_LEVEL
from ..user.user_roles import ServiceRole
from .action_object import Action
Expand Down Expand Up @@ -704,6 +705,15 @@ def exists(
else:
return SyftError(message=f"Object: {obj_id} does not exist")

@service_method(path="action.delete", name="delete", roles=ADMIN_ROLE_LEVEL)
def delete(
self, context: AuthedServiceContext, uid: UID
) -> Union[SyftSuccess, SyftError]:
res = self.store.delete(context.credentials, uid)
if res.is_err():
return SyftError(message=res.err())
return SyftSuccess(message="Great Success!")


def resolve_action_args(
action: Action, context: AuthedServiceContext, service: ActionService
Expand Down
11 changes: 11 additions & 0 deletions packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..service import SERVICE_TO_TYPES
from ..service import TYPE_TO_SERVICE
from ..service import service_method
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from ..user.user_roles import GUEST_ROLE_LEVEL
from ..user.user_roles import ServiceRole
Expand Down Expand Up @@ -74,6 +75,16 @@ def _submit(
result = self.stash.set(context.credentials, code)
return result

@service_method(path="code.delete", name="delete", roles=ADMIN_ROLE_LEVEL)
def delete(
self, context: AuthedServiceContext, uid: UID
) -> Union[SyftSuccess, SyftError]:
"""Delete User Code"""
result = self.stash.delete_by_uid(context.credentials, uid)
if result.is_err():
return SyftError(message=str(result.err()))
return SyftSuccess(message="User Code Deleted")

@service_method(
path="code.sync_code_from_request",
name="sync_code_from_request",
Expand Down
14 changes: 14 additions & 0 deletions packages/syft/src/syft/service/job/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import DATA_OWNER_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from .job_stash import Job
Expand Down Expand Up @@ -76,6 +77,19 @@ def get_by_user_code_id(
res = res.ok()
return res

@service_method(
path="job.delete",
name="delete",
roles=ADMIN_ROLE_LEVEL,
)
def delete(
self, context: AuthedServiceContext, uid: UID
) -> Union[SyftSuccess, SyftError]:
res = self.stash.delete_by_uid(context.credentials, uid)
if res.is_err():
return SyftError(message=res.err())
return SyftSuccess(message="Great Success!")

@service_method(
path="job.restart",
name="restart",
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def get_dependencies(self) -> List[UID]:
if self.user_code_id:
dependencies.append(self.user_code_id)

if self.result:
if self.result is not None:
dependencies.append(self.result.id.id)

if self.log_id:
Expand All @@ -482,7 +482,7 @@ def get_dependencies(self) -> List[UID]:

def get_sync_dependencies(self, api=None) -> List[UID]:
dependencies = []
if self.result:
if self.result is not None:
dependencies.append(self.result.id.id)

if self.log_id:
Expand Down
5 changes: 5 additions & 0 deletions packages/syft/src/syft/service/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ class SyftLog(SyftObject):
__canonical_name__ = "SyftLog"
__version__ = SYFT_OBJECT_VERSION_2

__repr_attrs__ = ["stdout", "stderr"]

stdout: str = ""
stderr: str = ""

def append(self, new_str: str) -> None:
self.stdout += new_str

def append_error(self, new_str: str) -> None:
self.stderr += new_str

Expand Down
26 changes: 26 additions & 0 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,14 @@ def _get_latest_or_create_job(self) -> Union[Job, SyftError]:

return job

def _is_action_object_from_job(self, action_object: ActionObject) -> Optional[Job]:
api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key)
job_service = api.services.job
existing_jobs = job_service.get_by_user_code_id(self.code.id)
for job in existing_jobs:
if job.result and job.result.id == action_object.id:
return job

def accept_by_depositing_result(self, result: Any, force: bool = False):
# this code is extremely brittle because its a work around that relies on
# the type of request being very specifically tied to code which needs approving
Expand All @@ -660,6 +668,16 @@ def accept_by_depositing_result(self, result: Any, force: bool = False):
message="JobInfo should not include result. Use sync_job instead."
)
result = job_info.result
elif isinstance(result, ActionObject):
# Do not allow accepting a result produced by a Job,
# This can cause an inconsistent Job state
if self._is_action_object_from_job(result):
action_object_job = self._is_action_object_from_job(result)
if action_object_job is not None:
return SyftError(
message=f"This ActionObject is the result of Job {action_object_job.id}, "
f"please use the `Job.info` instead."
)
else:
# NOTE result is added at the end of function (once ActionObject is created)
job_info = JobInfo(
Expand Down Expand Up @@ -767,8 +785,16 @@ def accept_by_depositing_result(self, result: Any, force: bool = False):
if isinstance(approved, SyftError):
return approved

print("ActionObject 4", type(action_object), action_object)

job_info.result = action_object
job = self._get_latest_or_create_job()

existing_result = job.result.id if job.result is not None else None
print("New result", action_object)
print(
f"Job({job.id}) Setting new result {existing_result} -> {job_info.result.id}"
)
job.apply_info(job_info)

api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key)
Expand Down
10 changes: 10 additions & 0 deletions packages/syft/src/syft/service/request/request_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..service import TYPE_TO_SERVICE
from ..service import service_method
from ..user.user import UserView
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import GUEST_ROLE_LEVEL
from ..user.user_service import UserService
from .request import Change
Expand Down Expand Up @@ -104,6 +105,15 @@ def submit(
print("Failed to submit Request", e)
raise e

@service_method(path="request.delete", name="delete", roles=ADMIN_ROLE_LEVEL)
def delete(
self, context: AuthedServiceContext, uid: UID
) -> Union[SyftSuccess, SyftError]:
result = self.stash.delete_by_uid(context.credentials, uid)
if result.is_err():
return SyftError(message=str(result.err()))
return result

def expand_node(self, context: AuthedServiceContext, code_obj: UserCode):
user_code_service = context.node.get_service("usercodeservice")
nested_requests = user_code_service.solve_nested_requests(context, code_obj)
Expand Down
Loading

0 comments on commit ef4b8b4

Please sign in to comment.