Skip to content

Commit

Permalink
Merge branch 'dev' into more-datatypes-for-assets
Browse files Browse the repository at this point in the history
  • Loading branch information
koenvanderveen authored Jun 20, 2024
2 parents 1f51ef7 + df2859b commit f9ec721
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 42 deletions.
25 changes: 16 additions & 9 deletions packages/syft/src/syft/service/action/action_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,7 @@ def get_from(self, client: SyftClient) -> Any:
else:
return res.syft_action_data

def refresh_object(self, resolve_nested: bool = True) -> ActionObject:
def refresh_object(self, resolve_nested: bool = True) -> ActionObject | SyftError:
# relative
from ...client.api import APIRegistry

Expand Down Expand Up @@ -1275,9 +1275,10 @@ def get(self, block: bool = False) -> Any:
self.wait()

res = self.refresh_object()

if not isinstance(res, ActionObject):
return SyftError(message=f"{res}") # type: ignore
elif issubclass(res.syft_action_data_type, Err):
return SyftError(message=f"{res.syft_action_data.err()}")
else:
if not self.has_storage_permission():
prompt_warning_message(
Expand Down Expand Up @@ -1415,7 +1416,7 @@ def remove_trace_hook(cls) -> bool:
def as_empty_data(self) -> ActionDataEmpty:
return ActionDataEmpty(syft_internal_type=self.syft_internal_type)

def wait(self, timeout: int | None = None) -> ActionObject:
def wait(self, timeout: int | None = None) -> ActionObject | SyftError:
# relative
from ...client.api import APIRegistry

Expand All @@ -1429,12 +1430,18 @@ def wait(self, timeout: int | None = None) -> ActionObject:
obj_id = self.id

counter = 0
while api and not api.services.action.is_resolved(obj_id):
time.sleep(1)
if timeout is not None:
counter += 1
if counter > timeout:
return SyftError(message="Reached Timeout!")
while api:
obj_resolved: bool | str = api.services.action.is_resolved(obj_id)
if isinstance(obj_resolved, str):
return SyftError(message=obj_resolved)
if obj_resolved:
break
if not obj_resolved:
time.sleep(1)
if timeout is not None:
counter += 1
if counter > timeout:
return SyftError(message="Reached Timeout!")

return self

Expand Down
4 changes: 0 additions & 4 deletions packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,13 @@ def is_resolved(
uid: UID,
) -> Result[Ok[bool], Err[str]]:
"""Get an object from the action store"""
# relative

result = self._get(context, uid)
if result.is_ok():
obj = result.ok()
if obj.is_link:
result = self.resolve_links(
context, obj.syft_action_data.action_object_id.id
)

# Checking in case any error occurred
if result.is_err():
return result
Expand All @@ -205,7 +202,6 @@ def is_resolved(

# If it's not an action data link or non resolved (empty). It's resolved
return Ok(True)

# If it's not in the store or permission error, return the error
return result

Expand Down
35 changes: 18 additions & 17 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def fetch(self) -> None:
)
job: Job | None = api.make_call(call)
if job is None:
return
return None
self.resolved = job.resolved
if job.resolved:
self.result = job.result
Expand Down Expand Up @@ -640,7 +640,7 @@ def _repr_html_(self) -> str:

def wait(
self, job_only: bool = False, timeout: int | None = None
) -> Any | SyftNotReady:
) -> Any | SyftNotReady | SyftError:
self.fetch()
if self.resolved:
return self.resolve
Expand All @@ -652,29 +652,28 @@ def wait(

if api is None:
raise ValueError(
f"Can't access Syft API. You must login to {self.syft_node_location}"
f"Can't access Syft API. You must login to node with id '{self.syft_node_location}'"
)

workers = api.services.worker.get_all()
if not isinstance(workers, SyftError) and len(workers) == 0:
return SyftError(
message="This node has no workers. "
"You need to start a worker to run jobs "
"by setting n_consumers > 0."
message=f"Node {self.syft_node_location} has no workers. "
f"You need to start a worker to run jobs "
f"by setting n_consumers > 0."
)

if not job_only and self.result is not None:
self.result.wait(timeout)

print_warning = True
counter = 0
while True:
self.fetch()
if isinstance(self.result, SyftError | Err) or self.status in [
JobStatus.ERRORED,
JobStatus.INTERRUPTED,
]:
return self.result
if self.resolved:
if isinstance(self.result, SyftError | Err) or self.status in [ # type: ignore[unreachable]
JobStatus.ERRORED,
JobStatus.INTERRUPTED,
]:
return self.result
break
if print_warning and self.result is not None:
result_obj = api.services.action.get( # type: ignore[unreachable]
self.result.id, resolve_nested=False
Expand All @@ -686,15 +685,17 @@ def wait(
"Use job.wait().get() instead to wait for the linked result."
)
print_warning = False

sleep(1)
if self.resolved:
break # type: ignore[unreachable]
# TODO: fix the mypy issue

if timeout is not None:
counter += 1
if counter > timeout:
return SyftError(message="Reached Timeout!")

# if self.resolve returns self.result as error, then we
# return SyftError and not wait for the result
# otherwise if a job is resolved and not errored out, we wait for the result
if not job_only and self.result is not None: # type: ignore[unreachable]
self.result.wait(timeout)

Expand Down
8 changes: 1 addition & 7 deletions packages/syft/src/syft/service/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def handle_message_multiprocessing(

try:
call_method = getattr(worker.get_service(queue_item.service), queue_item.method)

role = worker.get_role_for_credentials(credentials=credentials)

context = AuthedServiceContext(
Expand All @@ -205,7 +204,6 @@ def handle_message_multiprocessing(
)

result: Any = call_method(context, *queue_item.args, **queue_item.kwargs)

status = Status.COMPLETED
job_status = JobStatus.COMPLETED

Expand All @@ -227,11 +225,7 @@ def handle_message_multiprocessing(
job_status = JobStatus.ERRORED
# stdlib

raise e
# result = SyftError(
# message=f"Failed with exception: {e}, {traceback.format_exc()}"
# )
# print("HAD AN ERROR WHILE HANDLING MESSAGE", result.message)
logger.error(f"Error while handle message multiprocessing: {e}")

queue_item.result = result
queue_item.resolved = True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# third party
from result import Err

# syft absolute
import syft
Expand Down Expand Up @@ -142,7 +141,7 @@ def compute() -> int:

client_low_ds.refresh()
res = client_low_ds.code.compute(blocking=True)
assert isinstance(res.get(), Err)
assert isinstance(res.get(), SyftError)


def test_ignore_unignore_single(low_worker, high_worker):
Expand Down
8 changes: 5 additions & 3 deletions tests/integration/local/twin_api_sync_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

# third party
import pytest
from result import Err

# syft absolute
import syft
import syft as sy
from syft.client.domain_client import DomainClient
from syft.client.syncing import compare_clients
from syft.client.syncing import resolve
from syft.service.action.action_object import ActionObject
from syft.service.job.job_stash import JobStatus
from syft.service.response import SyftError
from syft.service.response import SyftSuccess
Expand Down Expand Up @@ -149,9 +149,11 @@ def compute_sum():

users[-1].allow_mock_execution()
result = ds_client.api.services.code.compute_sum(blocking=True)
assert isinstance(result.get(), Err)
assert isinstance(result, ActionObject)
assert isinstance(result.get(), SyftError)

job_info = ds_client.api.services.code.compute_sum(blocking=False)
result = job_info.wait(timeout=10)
assert isinstance(result.get(), Err)
assert isinstance(result, ActionObject)
assert isinstance(result.get(), SyftError)
assert job_info.status == JobStatus.ERRORED

0 comments on commit f9ec721

Please sign in to comment.