diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 3c303a5d1f7..336fb996f33 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -1092,11 +1092,14 @@ def get_from(self, client: SyftClient) -> Any: else: return res.syft_action_data - def get(self) -> Any: + def get(self, block: bool = False) -> Any: """Get the object from a Syft Client""" # relative from ...client.api import APIRegistry + if block: + self.wait() + api = APIRegistry.api_for( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index a6f2ed4a821..efa03deac94 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -409,7 +409,7 @@ def _repr_markdown_(self) -> str: """ return as_markdown_code(md) - def wait(self): + def wait(self, job_only=False): # stdlib from time import sleep @@ -422,6 +422,9 @@ def wait(self): if self.resolved: return self.resolve + if not job_only: + self.result.wait() + print_warning = True while True: self.fetch() @@ -429,11 +432,11 @@ def wait(self): result_obj = api.services.action.get( self.result.id, resolve_nested=False ) - if isinstance(result_obj.syft_action_data, ActionDataLink): + if isinstance(result_obj.syft_action_data, ActionDataLink) and job_only: print( "You're trying to wait on a job that has a link as a result." "This means that the job may be ready but the linked result may not." - "Use job.result.wait() instead to wait for the linked result." + "Use job.wait().get() instead to wait for the linked result." ) print_warning = False sleep(2) diff --git a/packages/syft/tests/syft/syft_functions/syft_function_test.py b/packages/syft/tests/syft/syft_functions/syft_function_test.py index 44c08bac286..c81ae3d4561 100644 --- a/packages/syft/tests/syft/syft_functions/syft_function_test.py +++ b/packages/syft/tests/syft/syft_functions/syft_function_test.py @@ -95,9 +95,9 @@ def process_all(domain, x): assert len(job.subjobs) == 3 # stdlib + assert job.wait().get() == 5 sub_results = [j.wait().get() for j in job.subjobs] assert set(sub_results) == {2, 3, 5} - assert job.result.wait().get() == 5 job = client.jobs[-1] assert job.job_worker_id is not None