diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index dc9eb40e81e..181380ef87f 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -1240,7 +1240,7 @@ def get_from(self, client: SyftClient) -> Any: else: return res.syft_action_data - def refresh_object(self, resolve_nested: bool = True) -> ActionObject: + def refresh_object(self, resolve_nested: bool = True) -> ActionObject | SyftError: # relative from ...client.api import APIRegistry @@ -1275,9 +1275,10 @@ def get(self, block: bool = False) -> Any: self.wait() res = self.refresh_object() - if not isinstance(res, ActionObject): return SyftError(message=f"{res}") # type: ignore + elif issubclass(res.syft_action_data_type, Err): + return SyftError(message=f"{res.syft_action_data.err()}") else: if not self.has_storage_permission(): prompt_warning_message( @@ -1415,7 +1416,7 @@ def remove_trace_hook(cls) -> bool: def as_empty_data(self) -> ActionDataEmpty: return ActionDataEmpty(syft_internal_type=self.syft_internal_type) - def wait(self, timeout: int | None = None) -> ActionObject: + def wait(self, timeout: int | None = None) -> ActionObject | SyftError: # relative from ...client.api import APIRegistry @@ -1429,12 +1430,18 @@ def wait(self, timeout: int | None = None) -> ActionObject: obj_id = self.id counter = 0 - while api and not api.services.action.is_resolved(obj_id): - time.sleep(1) - if timeout is not None: - counter += 1 - if counter > timeout: - return SyftError(message="Reached Timeout!") + while api: + obj_resolved: bool | str = api.services.action.is_resolved(obj_id) + if isinstance(obj_resolved, str): + return SyftError(message=obj_resolved) + if obj_resolved: + break + if not obj_resolved: + time.sleep(1) + if timeout is not None: + counter += 1 + if counter > timeout: + return SyftError(message="Reached Timeout!") return self diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 0dcb2271a6c..b1fa39e5330 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -183,8 +183,6 @@ def is_resolved( uid: UID, ) -> Result[Ok[bool], Err[str]]: """Get an object from the action store""" - # relative - result = self._get(context, uid) if result.is_ok(): obj = result.ok() @@ -192,7 +190,6 @@ def is_resolved( result = self.resolve_links( context, obj.syft_action_data.action_object_id.id ) - # Checking in case any error occurred if result.is_err(): return result @@ -205,7 +202,6 @@ def is_resolved( # If it's not an action data link or non resolved (empty). It's resolved return Ok(True) - # If it's not in the store or permission error, return the error return result diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index bb85134c387..05ad2ae1549 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -317,7 +317,7 @@ def fetch(self) -> None: ) job: Job | None = api.make_call(call) if job is None: - return + return None self.resolved = job.resolved if job.resolved: self.result = job.result @@ -640,7 +640,7 @@ def _repr_html_(self) -> str: def wait( self, job_only: bool = False, timeout: int | None = None - ) -> Any | SyftNotReady: + ) -> Any | SyftNotReady | SyftError: self.fetch() if self.resolved: return self.resolve @@ -652,29 +652,28 @@ def wait( if api is None: raise ValueError( - f"Can't access Syft API. You must login to {self.syft_node_location}" + f"Can't access Syft API. You must login to node with id '{self.syft_node_location}'" ) workers = api.services.worker.get_all() if not isinstance(workers, SyftError) and len(workers) == 0: return SyftError( - message="This node has no workers. " - "You need to start a worker to run jobs " - "by setting n_consumers > 0." + message=f"Node {self.syft_node_location} has no workers. " + f"You need to start a worker to run jobs " + f"by setting n_consumers > 0." ) - if not job_only and self.result is not None: - self.result.wait(timeout) - print_warning = True counter = 0 while True: self.fetch() - if isinstance(self.result, SyftError | Err) or self.status in [ - JobStatus.ERRORED, - JobStatus.INTERRUPTED, - ]: - return self.result + if self.resolved: + if isinstance(self.result, SyftError | Err) or self.status in [ # type: ignore[unreachable] + JobStatus.ERRORED, + JobStatus.INTERRUPTED, + ]: + return self.result + break if print_warning and self.result is not None: result_obj = api.services.action.get( # type: ignore[unreachable] self.result.id, resolve_nested=False @@ -686,15 +685,17 @@ def wait( "Use job.wait().get() instead to wait for the linked result." ) print_warning = False + sleep(1) - if self.resolved: - break # type: ignore[unreachable] - # TODO: fix the mypy issue + if timeout is not None: counter += 1 if counter > timeout: return SyftError(message="Reached Timeout!") + # if self.resolve returns self.result as error, then we + # return SyftError and not wait for the result + # otherwise if a job is resolved and not errored out, we wait for the result if not job_only and self.result is not None: # type: ignore[unreachable] self.result.wait(timeout) diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index 968e4b7c975..e8e755d450f 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -184,7 +184,6 @@ def handle_message_multiprocessing( try: call_method = getattr(worker.get_service(queue_item.service), queue_item.method) - role = worker.get_role_for_credentials(credentials=credentials) context = AuthedServiceContext( @@ -205,7 +204,6 @@ def handle_message_multiprocessing( ) result: Any = call_method(context, *queue_item.args, **queue_item.kwargs) - status = Status.COMPLETED job_status = JobStatus.COMPLETED @@ -227,11 +225,7 @@ def handle_message_multiprocessing( job_status = JobStatus.ERRORED # stdlib - raise e - # result = SyftError( - # message=f"Failed with exception: {e}, {traceback.format_exc()}" - # ) - # print("HAD AN ERROR WHILE HANDLING MESSAGE", result.message) + logger.error(f"Error while handle message multiprocessing: {e}") queue_item.result = result queue_item.resolved = True diff --git a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py index 0bd022ae604..d68124e9b4d 100644 --- a/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py +++ b/packages/syft/tests/syft/service/sync/sync_resolve_single_test.py @@ -1,5 +1,4 @@ # third party -from result import Err # syft absolute import syft @@ -142,7 +141,7 @@ def compute() -> int: client_low_ds.refresh() res = client_low_ds.code.compute(blocking=True) - assert isinstance(res.get(), Err) + assert isinstance(res.get(), SyftError) def test_ignore_unignore_single(low_worker, high_worker): diff --git a/tests/integration/local/twin_api_sync_test.py b/tests/integration/local/twin_api_sync_test.py index fed1905370f..9212b7d6905 100644 --- a/tests/integration/local/twin_api_sync_test.py +++ b/tests/integration/local/twin_api_sync_test.py @@ -3,7 +3,6 @@ # third party import pytest -from result import Err # syft absolute import syft @@ -11,6 +10,7 @@ from syft.client.domain_client import DomainClient from syft.client.syncing import compare_clients from syft.client.syncing import resolve +from syft.service.action.action_object import ActionObject from syft.service.job.job_stash import JobStatus from syft.service.response import SyftError from syft.service.response import SyftSuccess @@ -149,9 +149,11 @@ def compute_sum(): users[-1].allow_mock_execution() result = ds_client.api.services.code.compute_sum(blocking=True) - assert isinstance(result.get(), Err) + assert isinstance(result, ActionObject) + assert isinstance(result.get(), SyftError) job_info = ds_client.api.services.code.compute_sum(blocking=False) result = job_info.wait(timeout=10) - assert isinstance(result.get(), Err) + assert isinstance(result, ActionObject) + assert isinstance(result.get(), SyftError) assert job_info.status == JobStatus.ERRORED