Skip to content

Commit

Permalink
Merge pull request #8380 from OpenMined/fix-wp-bugs
Browse files Browse the repository at this point in the history
Fix bugs in Default Worker Pool Spinning
  • Loading branch information
shubham3121 authored Jan 10, 2024
2 parents 5254b0d + 04d116c commit 4bede98
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 20 deletions.
44 changes: 27 additions & 17 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,26 +1435,36 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]:

print("Building Default Worker Image")

# Build the Image for given tag
result = image_build_method(
context, image_uid=default_image.id, tag=DEFAULT_WORKER_IMAGE_TAG
)

if isinstance(result, SyftError):
print("Failed to build default worker image: ", result.message)
return
if not default_image.is_built:
# Build the Image for given tag
result = image_build_method(
context, image_uid=default_image.id, tag=DEFAULT_WORKER_IMAGE_TAG
)

create_pool_method = node.get_service_method(SyftWorkerPoolService.create_pool)
if isinstance(result, SyftError):
print("Failed to build default worker image: ", result.message)
return

default_worker_pool = node.get_default_worker_pool()
worker_count = node.queue_config.client_config.n_consumers

print("Creating default Worker Pool")
result = create_pool_method(
context,
name=DEFAULT_WORKER_POOL_NAME,
image_uid=default_image.id,
number=worker_count,
)
# Create worker pool if it doesn't exists
if default_worker_pool is None:
create_pool_method = node.get_service_method(SyftWorkerPoolService.create_pool)
print("Creating default Worker Pool")
result = create_pool_method(
context,
name=DEFAULT_WORKER_POOL_NAME,
image_uid=default_image.id,
number=worker_count,
)

else:
# Else add a worker to existing worker pool
add_worker_method = node.get_service_method(SyftWorkerPoolService.add_workers)
result = add_worker_method(
context=context, number=worker_count, pool_name=DEFAULT_WORKER_POOL_NAME
)

if isinstance(result, SyftError):
print(f"Failed to create Worker for Default workers. Error: {result.message}")
Expand All @@ -1465,7 +1475,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]:
if container_status.error:
print(
f"Failed to create container: Worker: {container_status.worker},"
"Error: {container_status.error}"
f"Error: {container_status.error}"
)
return

Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/queue/zmq_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def _run(self):
# Call Message Handler
try:
message = msg.pop()
self.associate_job(message)
self.message_handler.handle_message(
message=message,
syft_worker_id=self.syft_worker_id,
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def run_container_using_docker(
container_name=container_name,
)
if existing_container:
existing_container.stop()
existing_container.remove(force=True)

# Extract Config from backend container
backend_host_config = extract_config_from_backend(
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/service/worker/worker_image_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# relative
from ...custom_worker.config import DockerWorkerConfig
from ...custom_worker.config import WorkerConfig
from ...node.credentials import SyftVerifyKey
from ...serde.serializable import serializable
from ...store.document_store import BaseUIDStoreStash
Expand All @@ -19,7 +20,7 @@
from ..action.action_permissions import ActionPermission
from .worker_image import SyftWorkerImage

DockerWorkerConfigPK = PartitionKey(key="config", type_=DockerWorkerConfig)
WorkerConfigPK = PartitionKey(key="config", type_=WorkerConfig)


@serializable()
Expand Down Expand Up @@ -59,5 +60,5 @@ def set(
def get_by_docker_config(
self, credentials: SyftVerifyKey, config: DockerWorkerConfig
):
qks = QueryKeys(qks=[DockerWorkerConfigPK.with_obj(config)])
qks = QueryKeys(qks=[WorkerConfigPK.with_obj(config)])
return self.query_one(credentials=credentials, qks=qks)

0 comments on commit 4bede98

Please sign in to comment.