diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 240dc31ba1a..e83a1c56d8b 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -1435,26 +1435,36 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: print("Building Default Worker Image") - # Build the Image for given tag - result = image_build_method( - context, image_uid=default_image.id, tag=DEFAULT_WORKER_IMAGE_TAG - ) - - if isinstance(result, SyftError): - print("Failed to build default worker image: ", result.message) - return + if not default_image.is_built: + # Build the Image for given tag + result = image_build_method( + context, image_uid=default_image.id, tag=DEFAULT_WORKER_IMAGE_TAG + ) - create_pool_method = node.get_service_method(SyftWorkerPoolService.create_pool) + if isinstance(result, SyftError): + print("Failed to build default worker image: ", result.message) + return + default_worker_pool = node.get_default_worker_pool() worker_count = node.queue_config.client_config.n_consumers - print("Creating default Worker Pool") - result = create_pool_method( - context, - name=DEFAULT_WORKER_POOL_NAME, - image_uid=default_image.id, - number=worker_count, - ) + # Create worker pool if it doesn't exists + if default_worker_pool is None: + create_pool_method = node.get_service_method(SyftWorkerPoolService.create_pool) + print("Creating default Worker Pool") + result = create_pool_method( + context, + name=DEFAULT_WORKER_POOL_NAME, + image_uid=default_image.id, + number=worker_count, + ) + + else: + # Else add a worker to existing worker pool + add_worker_method = node.get_service_method(SyftWorkerPoolService.add_workers) + result = add_worker_method( + context=context, number=worker_count, pool_name=DEFAULT_WORKER_POOL_NAME + ) if isinstance(result, SyftError): print(f"Failed to create Worker for Default workers. Error: {result.message}") @@ -1465,7 +1475,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: if container_status.error: print( f"Failed to create container: Worker: {container_status.worker}," - "Error: {container_status.error}" + f"Error: {container_status.error}" ) return diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 23f9392adf5..a62a8a9e8ba 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -577,6 +577,7 @@ def _run(self): # Call Message Handler try: message = msg.pop() + self.associate_job(message) self.message_handler.handle_message( message=message, syft_worker_id=self.syft_worker_id, diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index 5e8ccdc4e5c..bba383c2ec7 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -142,7 +142,7 @@ def run_container_using_docker( container_name=container_name, ) if existing_container: - existing_container.stop() + existing_container.remove(force=True) # Extract Config from backend container backend_host_config = extract_config_from_backend( diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py index d83edcfcdc6..c2e3a5a1d5d 100644 --- a/packages/syft/src/syft/service/worker/worker_image_stash.py +++ b/packages/syft/src/syft/service/worker/worker_image_stash.py @@ -8,6 +8,7 @@ # relative from ...custom_worker.config import DockerWorkerConfig +from ...custom_worker.config import WorkerConfig from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...store.document_store import BaseUIDStoreStash @@ -19,7 +20,7 @@ from ..action.action_permissions import ActionPermission from .worker_image import SyftWorkerImage -DockerWorkerConfigPK = PartitionKey(key="config", type_=DockerWorkerConfig) +WorkerConfigPK = PartitionKey(key="config", type_=WorkerConfig) @serializable() @@ -59,5 +60,5 @@ def set( def get_by_docker_config( self, credentials: SyftVerifyKey, config: DockerWorkerConfig ): - qks = QueryKeys(qks=[DockerWorkerConfigPK.with_obj(config)]) + qks = QueryKeys(qks=[WorkerConfigPK.with_obj(config)]) return self.query_one(credentials=credentials, qks=qks)