Skip to content

Commit

Permalink
Merge pull request #8397 from kiendang/worker-get
Browse files Browse the repository at this point in the history
Move `worker_pool_service.get_worker` to `worker_service.get`
  • Loading branch information
shubham3121 authored Jan 16, 2024
2 parents 9295685 + 2b2823c commit 341c5db
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 173 deletions.
53 changes: 0 additions & 53 deletions packages/syft/src/syft/service/worker/worker_pool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from .worker_pool import WorkerPool
from .worker_pool import WorkerStatus
from .worker_pool import _get_worker_container
from .worker_pool import _get_worker_container_status
from .worker_pool_stash import SyftWorkerPoolStash
from .worker_service import WorkerService
from .worker_stash import WorkerStash
Expand Down Expand Up @@ -288,58 +287,6 @@ def filter_by_image_id(

return result.ok()

@service_method(
path="worker_pool.get_worker",
name="get_worker",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def get_worker(
self, context: AuthedServiceContext, worker_pool_id: UID, worker_id: UID
) -> Union[SyftWorker, SyftError]:
worker_pool_worker = self._get_worker_pool_and_worker(
context, worker_pool_id, worker_id
)
if isinstance(worker_pool_worker, SyftError):
return worker_pool_worker

_, linked_worker = worker_pool_worker

result = linked_worker.resolve_with_context(context=context)

if result.is_err():
return SyftError(
message=f"Failed to retrieve Linked SyftWorker {linked_worker.object_uid}"
)

worker = result.ok()

if context.node.in_memory_workers:
return worker

with contextlib.closing(docker.from_env()) as client:
worker_status = _get_worker_container_status(client, worker)

if isinstance(worker_status, SyftError):
return worker_status

if worker_status != WorkerStatus.PENDING:
worker.status = worker_status

result = self.worker_stash.update(
credentials=context.credentials,
obj=worker,
)

return (
SyftError(
message=f"Failed to update worker status. Error: {result.err()}"
)
if result.is_err()
else worker
)

return worker

@service_method(
path="worker_pool.get_worker_status",
name="get_worker_status",
Expand Down
168 changes: 49 additions & 119 deletions packages/syft/src/syft/service/worker/worker_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# stdlib
import contextlib
import socket
from typing import List
from typing import Tuple
from typing import Union

# third party
import docker
from result import Result

# relative
from ...node.credentials import SyftVerifyKey
Expand All @@ -31,93 +29,6 @@
from .worker_pool import _get_worker_container_status
from .worker_stash import WorkerStash

WORKER_NUM = 0


def get_main_backend():
hostname = socket.gethostname()
return f"{hostname}-backend-1"


def start_worker_container(
worker_num: int, context: AuthedServiceContext, syft_worker_uid
):
client = docker.from_env()
existing_container_name = get_main_backend()
hostname = socket.gethostname()
worker_name = f"{hostname}-worker-{worker_num}"
return create_new_container_from_existing(
worker_name=worker_name,
client=client,
existing_container_name=existing_container_name,
syft_worker_uid=syft_worker_uid,
)


def create_new_container_from_existing(
worker_name: str,
client: docker.client.DockerClient,
existing_container_name: str,
syft_worker_uid,
) -> docker.models.containers.Container:
# Get the existing container
existing_container = client.containers.get(existing_container_name)

# Inspect the existing container
details = existing_container.attrs

# Extract relevant settings
image = details["Config"]["Image"]
command = details["Config"]["Cmd"]
environment = details["Config"]["Env"]
ports = details["NetworkSettings"]["Ports"]
host_config = details["HostConfig"]

volumes = {}
for vol in host_config["Binds"]:
parts = vol.split(":")
key = parts[0]
bind = parts[1]
mode = parts[2]
if "/storage" in bind:
# we need this because otherwise we are using the same node private key
# which will make account creation fail
worker_postfix = worker_name.split("-", 1)[1]
key = f"{key}-{worker_postfix}"
volumes[key] = {"bind": bind, "mode": mode}

# we need this because otherwise we are using the same node private key
# which will make account creation fail

environment = dict([e.split("=", 1) for e in environment])
environment["CREATE_PRODUCER"] = "false"
environment["N_CONSUMERS"] = 1
environment["DOCKER_WORKER_NAME"] = worker_name
environment["DEFAULT_ROOT_USERNAME"] = worker_name
environment["DEFAULT_ROOT_EMAIL"] = f"{worker_name}@openmined.org"
environment["PORT"] = str(8003 + WORKER_NUM)
environment["HTTP_PORT"] = str(88 + WORKER_NUM)
environment["HTTPS_PORT"] = str(446 + WORKER_NUM)
environment["SYFT_WORKER_UID"] = str(syft_worker_uid)

environment.pop("NODE_PRIVATE_KEY", None)

new_container = client.containers.create(
name=worker_name,
image=image,
command=command,
environment=environment,
ports=ports,
detach=True,
volumes=volumes,
tty=True,
stdin_open=True,
network_mode=f"container:{existing_container.id}",
)

new_container.start()
return new_container


@instrument
@serializable()
Expand Down Expand Up @@ -160,17 +71,19 @@ def list(self, context: AuthedServiceContext) -> Union[SyftSuccess, SyftError]:
return workers

# If container workers, check their statuses
for idx, worker in enumerate(workers):
result = check_and_update_status_for_worker(
worker=worker,
worker_stash=self.stash,
credentials=context.credentials,
)
if result.is_err():
return SyftError(
message=f"Failed to update status for worker: {worker.id}. Error: {result.err()}"
with contextlib.closing(docker.from_env()) as client:
for idx, worker in enumerate(workers):
worker_ = _check_and_update_status_for_worker(
client=client,
worker=worker,
worker_stash=self.stash,
credentials=context.credentials,
)
workers[idx] = worker

if not isinstance(worker_, SyftWorker):
return worker_

workers[idx] = worker_

return workers

Expand All @@ -182,37 +95,48 @@ def status(
context: AuthedServiceContext,
uid: UID,
) -> Union[Tuple[WorkerStatus, WorkerHealth], SyftError]:
result = self.stash.get_by_uid(credentials=context.credentials, uid=uid)
if result.is_err():
return SyftError(message=f"Failed to retrieve worker with UID {uid}")
worker: SyftWorker = result.ok()
worker = self.get(context=context, uid=uid)

if context.node.in_memory_workers:
return worker.status, worker.healthcheck
if not isinstance(worker, SyftWorker):
return worker

result = check_and_update_status_for_worker(
worker=worker,
worker_stash=self.stash,
credentials=context.credentials,
)
return worker.status, worker.healthcheck

@service_method(
path="worker.get",
name="get",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def get(
self, context: AuthedServiceContext, uid: UID
) -> Union[SyftWorker, SyftError]:
result = self.stash.get_by_uid(credentials=context.credentials, uid=uid)
if result.is_err():
return SyftError(
message=f"Failed to update status for worker: {worker.id}. Error: {result.err()}"
)
return SyftError(message=f"Failed to retrieve worker with UID {uid}")

worker = result.ok()
if worker is None:
return SyftError(message=f"Worker does not exist for UID {uid}")

return worker.status, worker.healthcheck
if context.node.in_memory_workers:
return worker

with contextlib.closing(docker.from_env()) as client:
return _check_and_update_status_for_worker(
client=client,
worker=worker,
worker_stash=self.stash,
credentials=context.credentials,
)


def check_and_update_status_for_worker(
def _check_and_update_status_for_worker(
client: docker.DockerClient,
worker: SyftWorker,
worker_stash: WorkerStash,
credentials: SyftVerifyKey,
) -> Result[SyftWorker, str]:
with contextlib.closing(docker.from_env()) as client:
worker_status = _get_worker_container_status(client, worker)
) -> Union[SyftWorker, SyftError]:
worker_status = _get_worker_container_status(client, worker)

if isinstance(worker_status, SyftError):
return worker_status
Expand All @@ -226,4 +150,10 @@ def check_and_update_status_for_worker(
obj=worker,
)

return result
return (
SyftError(
message=f"Failed to update status for worker: {worker.id}. Error: {result.err()}"
)
if result.is_err()
else result.ok()
)
2 changes: 1 addition & 1 deletion packages/syft/src/syft/store/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def update(
credentials: SyftVerifyKey,
obj: BaseStash.object_type,
has_permission=False,
) -> Optional[Result[BaseStash.object_type, str]]:
) -> Result[BaseStash.object_type, str]:
qk = self.partition.store_query_key(obj)
return self.partition.update(
credentials=credentials, qk=qk, obj=obj, has_permission=has_permission
Expand Down

0 comments on commit 341c5db

Please sign in to comment.