diff --git a/packages/syft/src/syft/custom_worker/config.py b/packages/syft/src/syft/custom_worker/config.py index 5e9522c2b88..34624986923 100644 --- a/packages/syft/src/syft/custom_worker/config.py +++ b/packages/syft/src/syft/custom_worker/config.py @@ -179,4 +179,7 @@ def test_image_build(self, tag: str, **kwargs: Any) -> SyftSuccess | SyftError: ) return SyftSuccess(message=iterator_to_string(iterator=logs)) except Exception as e: - return SyftError(message=f"Failed to build: {e}") + # stdlib + import traceback + + return SyftError(message=f"Failed to build: {e} {traceback.format_exc()}") diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 849e8267004..2c38c54d95e 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -3,6 +3,7 @@ # stdlib from collections import OrderedDict +from collections import defaultdict from collections.abc import Callable from datetime import datetime from functools import partial @@ -463,6 +464,9 @@ def stop(self) -> None: for p in self.queue_manager.producers.values(): p.close() + self.queue_manager.producers.clear() + self.queue_manager.consumers.clear() + NodeRegistry.remove_node(self.id) def close(self) -> None: @@ -567,6 +571,18 @@ def add_consumer_for_service( ) consumer.run() + def remove_consumer_with_id(self, syft_worker_id: UID) -> None: + for _, consumers in self.queue_manager.consumers.items(): + # Grab the list of consumers for the given queue + consumer_to_pop = None + for consumer_idx, consumer in enumerate(consumers): + if consumer.syft_worker_id == syft_worker_id: + consumer.close() + consumer_to_pop = consumer_idx + break + if consumer_to_pop is not None: + consumers.pop(consumer_to_pop) + @classmethod def named( cls, diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index a823f501773..d8b45baf104 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -609,16 +609,17 @@ def _repr_html_(self) -> str: worker_attr = "" if self.job_worker_id: worker = self.worker - worker_pool_id_button = CopyIDButton( - copy_text=str(worker.worker_pool_name), max_width=60 - ) - worker_attr = f""" -
- - Worker Pool: - {worker.name} on worker {worker_pool_id_button.to_html()} -
- """ + if not isinstance(worker, SyftError): + worker_pool_id_button = CopyIDButton( + copy_text=str(worker.worker_pool_name), max_width=60 + ) + worker_attr = f""" +
+ + Worker Pool: + {worker.name} on worker {worker_pool_id_button.to_html()} +
+ """ logs = self.logs(_print=False) logs_lines = logs.split("\n") if logs else [] diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 0f42904356a..43a948b2abf 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -388,6 +388,14 @@ def update_consumer_state_for_worker( return try: + # Check if worker is present in the database + worker = self.worker_stash.get_by_uid( + credentials=self.worker_stash.partition.root_verify_key, + uid=syft_worker_id, + ) + if worker.is_ok() and worker.ok() is None: + return + res = self.worker_stash.update_consumer_state( credentials=self.worker_stash.partition.root_verify_key, worker_uid=syft_worker_id, @@ -395,13 +403,14 @@ def update_consumer_state_for_worker( ) if res.is_err(): logger.error( - "Failed to update consumer state for worker id={} error={}", + "Failed to update consumer state for worker id={} to state: {} error={}", syft_worker_id, + consumer_state, res.err(), ) except Exception as e: logger.error( - f"Failed to update consumer state for worker id: {syft_worker_id}. Error: {e}" + f"Failed to update consumer state for worker id: {syft_worker_id} to state {consumer_state}. Error: {e}" ) def worker_waiting(self, worker: Worker) -> None: @@ -572,9 +581,10 @@ def delete_worker(self, worker: Worker, disconnect: bool) -> None: self.workers.pop(worker.identity, None) - self.update_consumer_state_for_worker( - worker.syft_worker_id, ConsumerState.DETACHED - ) + if worker.syft_worker_id is not None: + self.update_consumer_state_for_worker( + worker.syft_worker_id, ConsumerState.DETACHED + ) @property def alive(self) -> bool: @@ -633,7 +643,11 @@ def post_init(self) -> None: self.producer_ping_t = Timeout(PRODUCER_TIMEOUT_SEC) self.reconnect_to_producer() + def disconnect_from_producer(self) -> None: + self.send_to_producer(QueueMsgProtocol.W_DISCONNECT) + def close(self) -> None: + self.disconnect_from_producer() self._stop.set() try: self.poller.unregister(self.socket) diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 2c4ddf60dd2..6c574b99735 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -208,6 +208,9 @@ def delete( stopped = _stop_worker_container(worker, docker_container, force) if stopped is not None: return stopped + else: + # kill the in memory worker thread + context.node.remove_consumer_with_id(syft_worker_id=worker.id) # remove the worker from the pool try: