From 63d1653727488280325d8b63d9cdf2bd53e67e9e Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 7 Feb 2024 11:05:23 +0700 Subject: [PATCH 1/7] [refactor] fixing mypy issues in `syft/service/worker` --- packages/syft/src/syft/serde/third_party.py | 4 +- .../syft/src/syft/service/worker/utils.py | 132 ++++++++++-------- .../service/worker/worker_image_service.py | 15 +- .../service/worker/worker_pool_service.py | 6 +- .../src/syft/service/worker/worker_service.py | 7 +- 5 files changed, 95 insertions(+), 69 deletions(-) diff --git a/packages/syft/src/syft/serde/third_party.py b/packages/syft/src/syft/serde/third_party.py index 1abfe2d9cdc..2d70250cbe6 100644 --- a/packages/syft/src/syft/serde/third_party.py +++ b/packages/syft/src/syft/serde/third_party.py @@ -96,8 +96,8 @@ def deserialize_dataframe(buf: bytes) -> DataFrame: def deserialize_series(blob: bytes) -> Series: - df = DataFrame.from_dict(deserialize(blob, from_bytes=True)) - return df[df.columns[0]] + df: DataFrame = DataFrame.from_dict(deserialize(blob, from_bytes=True)) + return Series(df[df.columns[0]]) recursive_serde_register( diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index 14b5799d825..a3dcc87d5bd 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -5,6 +5,8 @@ import socket import socketserver import sys +from typing import Any +from typing import Dict from typing import List from typing import Optional from typing import Tuple @@ -385,17 +387,22 @@ def run_workers_in_kubernetes( runner = KubernetesRunner() if start_idx == 0: - pool_pods = create_kubernetes_pool( - runner=runner, - tag=worker_image.image_identifier.full_name_with_tag, - pool_name=pool_name, - replicas=worker_count, - queue_port=queue_port, - debug=debug, - reg_username=reg_username, - reg_password=reg_password, - reg_url=reg_url, - ) + if worker_image.image_identifier is not None: + pool_pods = create_kubernetes_pool( + runner=runner, + tag=worker_image.image_identifier.full_name_with_tag, + pool_name=pool_name, + replicas=worker_count, + queue_port=queue_port, + debug=debug, + reg_username=reg_username, + reg_password=reg_password, + reg_url=reg_url, + ) + else: + return SyftError( + message=f"image with uid {worker_image.id} does not have an image identifier" + ) else: pool_pods = scale_kubernetes_pool(runner, pool_name, worker_count) @@ -584,28 +591,35 @@ def _get_healthcheck_based_on_status(status: WorkerStatus) -> WorkerHealth: return WorkerHealth.UNHEALTHY -def image_build(image: SyftWorkerImage, **kwargs) -> Union[ImageBuildResult, SyftError]: - full_tag = image.image_identifier.full_name_with_tag - try: - builder = CustomWorkerBuilder() - return builder.build_image( - config=image.config, - tag=full_tag, - rm=True, - forcerm=True, - **kwargs, - ) - except docker.errors.APIError as e: - return SyftError( - message=f"Docker API error when building '{full_tag}'. Reason - {e}" - ) - except docker.errors.DockerException as e: - return SyftError( - message=f"Docker exception when building '{full_tag}'. Reason - {e}" - ) - except Exception as e: +def image_build( + image: SyftWorkerImage, **kwargs: Dict[str, Any] +) -> Union[ImageBuildResult, SyftError]: + if image.image_identifier is not None: + full_tag = image.image_identifier.full_name_with_tag + try: + builder = CustomWorkerBuilder() + return builder.build_image( + config=image.config, + tag=full_tag, + rm=True, + forcerm=True, + **kwargs, + ) + except docker.errors.APIError as e: + return SyftError( + message=f"Docker API error when building '{full_tag}'. Reason - {e}" + ) + except docker.errors.DockerException as e: + return SyftError( + message=f"Docker exception when building '{full_tag}'. Reason - {e}" + ) + except Exception as e: + return SyftError( + message=f"Unknown exception when building '{full_tag}'. Reason - {e}" + ) + else: return SyftError( - message=f"Unknown exception when building '{full_tag}'. Reason - {e}" + message=f"image with uid {image.id} does not have an image identifier" ) @@ -614,34 +628,40 @@ def image_push( username: Optional[str] = None, password: Optional[str] = None, ) -> Union[ImagePushResult, SyftError]: - full_tag = image.image_identifier.full_name_with_tag - try: - builder = CustomWorkerBuilder() - result = builder.push_image( - # this should be consistent with docker build command - tag=image.image_identifier.full_name_with_tag, - registry_url=image.image_identifier.registry_host, - username=username, - password=password, - ) + if image.image_identifier is not None: + full_tag = image.image_identifier.full_name_with_tag + try: + builder = CustomWorkerBuilder() + result = builder.push_image( + # this should be consistent with docker build command + tag=image.image_identifier.full_name_with_tag, + registry_url=image.image_identifier.registry_host, + username=username, + password=password, + ) - if "error" in result.logs.lower() or result.exit_code: + if "error" in result.logs.lower() or result.exit_code: + return SyftError( + message=f"Failed to push {full_tag}. " + f"Exit code: {result.exit_code}. " + f"Logs:\n{result.logs}" + ) + + return result + except docker.errors.APIError as e: + return SyftError(message=f"Docker API error when pushing {full_tag}. {e}") + except docker.errors.DockerException as e: return SyftError( - message=f"Failed to push {full_tag}. " - f"Exit code: {result.exit_code}. " - f"Logs:\n{result.logs}" + message=f"Docker exception when pushing {full_tag}. Reason - {e}" ) - - return result - except docker.errors.APIError as e: - return SyftError(message=f"Docker API error when pushing {full_tag}. {e}") - except docker.errors.DockerException as e: - return SyftError( - message=f"Docker exception when pushing {full_tag}. Reason - {e}" - ) - except Exception as e: + except Exception as e: + return SyftError( + message=f"Unknown exception when pushing {image.image_identifier}. Reason - {e}" + ) + else: return SyftError( - message=f"Unknown exception when pushing {image.image_identifier}. Reason - {e}" + message=f"image with uid {image.id} does not have an " + "image identifier and tag, hence we can't push it." ) diff --git a/packages/syft/src/syft/service/worker/worker_image_service.py b/packages/syft/src/syft/service/worker/worker_image_service.py index 78b16a67395..8422341b009 100644 --- a/packages/syft/src/syft/service/worker/worker_image_service.py +++ b/packages/syft/src/syft/service/worker/worker_image_service.py @@ -75,7 +75,7 @@ def build( registry_uid: Optional[UID] = None, pull: bool = True, ) -> Union[SyftSuccess, SyftError]: - registry: SyftImageRegistry = None + registry: Optional[SyftImageRegistry] = None if IN_KUBERNETES and registry_uid is None: return SyftError(message="Registry UID is required in Kubernetes mode.") @@ -96,7 +96,7 @@ def build( registry_result = image_registry_service.get_by_id(context, registry_uid) if registry_result.is_err(): return registry_result - registry: SyftImageRegistry = registry_result.ok() + registry = registry_result.ok() try: if registry: @@ -204,12 +204,13 @@ def get_all( images: List[SyftWorkerImage] = result.ok() res = {} - # if image is built index by full_name_with_tag - res.update( - {im.image_identifier.full_name_with_tag: im for im in images if im.is_built} - ) + # if image is built, index it by full_name_with_tag + for im in images: + if im.is_built and im.image_identifier is not None: + res[im.image_identifier.full_name_with_tag] = im # and then index all images by id - # TODO: jupyter repr needs to be updated to show unique values (even if multiple keys point to same value) + # TODO: jupyter repr needs to be updated to show unique values + # (even if multiple keys point to same value) res.update({im.id.to_string(): im for im in images if not im.is_built}) return DictTuple(res) diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 806881547da..d2cd7db17c7 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -566,6 +566,10 @@ def _create_workers_in_pool( number=worker_cnt + existing_worker_cnt, ) else: + if worker_image.image_identifier is not None: + registry_host = worker_image.image_identifier.registry_host + else: + registry_host = None result = run_containers( pool_name=pool_name, worker_image=worker_image, @@ -576,7 +580,7 @@ def _create_workers_in_pool( dev_mode=context.node.dev_mode, reg_username=reg_username, reg_password=reg_password, - reg_url=worker_image.image_identifier.registry_host, + reg_url=registry_host, ) if isinstance(result, SyftError): return result diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 9871d7fa18a..af9b9ba7070 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -255,7 +255,7 @@ def refresh_worker_status( workers: List[SyftWorker], worker_stash: WorkerStash, credentials: SyftVerifyKey, -): +) -> List[SyftWorker]: if IN_KUBERNETES: result = refresh_status_kubernetes(workers) else: @@ -277,7 +277,7 @@ def refresh_worker_status( return result -def refresh_status_kubernetes(workers: List[SyftWorker]): +def refresh_status_kubernetes(workers: List[SyftWorker]) -> List[SyftWorker]: updated_workers = [] runner = KubernetesRunner() for worker in workers: @@ -292,7 +292,7 @@ def refresh_status_kubernetes(workers: List[SyftWorker]): return updated_workers -def refresh_status_docker(workers: List[SyftWorker]): +def refresh_status_docker(workers: List[SyftWorker]) -> List[SyftWorker]: updated_workers = [] with contextlib.closing(docker.from_env()) as client: @@ -317,6 +317,7 @@ def _stop_worker_container( container.stop() # Remove the container and its volumes _remove_worker_container(container, force=force, v=True) + return None except Exception as e: return SyftError( message=f"Failed to delete worker with id: {worker.id}. Error: {e}" From db1ff85149b8eb55a2ff9f6122e8c1ac098edf7d Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 7 Feb 2024 15:16:34 +0700 Subject: [PATCH 2/7] [refactor] done fixing mypy issues in `syft/service/worker` --- .pre-commit-config.yaml | 2 +- .../syft/service/worker/image_identifier.py | 2 +- .../src/syft/service/worker/image_registry.py | 4 +- .../service/worker/image_registry_service.py | 2 +- .../syft/src/syft/service/worker/utils.py | 53 ++++++++++++------- .../syft/src/syft/service/worker/worker.py | 10 ++-- .../syft/service/worker/worker_image_stash.py | 3 +- .../src/syft/service/worker/worker_pool.py | 12 ++++- .../src/syft/service/worker/worker_stash.py | 4 +- 9 files changed, 60 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81b29307e8a..0e1aecdd0ce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,7 +172,7 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util/env.py|^packages/syft/src/syft/util/logger.py|^packages/syft/src/syft/util/markdown.py|^packages/syft/src/syft/util/notebook_ui/notebook_addons.py" + files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util/env.py|^packages/syft/src/syft/util/logger.py|^packages/syft/src/syft/util/markdown.py|^packages/syft/src/syft/util/notebook_ui/notebook_addons.py|^packages/syft/src/syft/service/worker" #files: "^packages/syft/src/syft/serde" args: [ "--follow-imports=skip", diff --git a/packages/syft/src/syft/service/worker/image_identifier.py b/packages/syft/src/syft/service/worker/image_identifier.py index 43623f44d36..4651c3f4f2e 100644 --- a/packages/syft/src/syft/service/worker/image_identifier.py +++ b/packages/syft/src/syft/service/worker/image_identifier.py @@ -53,7 +53,7 @@ def from_str(cls, tag: str) -> Self: return cls(repo=repo, registry=registry, tag=tag) @property - def repo_with_tag(self) -> str: + def repo_with_tag(self) -> Optional[str]: if self.repo or self.tag: return f"{self.repo}:{self.tag}" return None diff --git a/packages/syft/src/syft/service/worker/image_registry.py b/packages/syft/src/syft/service/worker/image_registry.py index 806a0946d2b..7fb9dfcb770 100644 --- a/packages/syft/src/syft/service/worker/image_registry.py +++ b/packages/syft/src/syft/service/worker/image_registry.py @@ -28,7 +28,7 @@ class SyftImageRegistry(SyftObject): url: str @validator("url") - def validate_url(cls, val: str): + def validate_url(cls, val: str) -> str: if not val: raise ValueError("Invalid Registry URL. Must not be empty") @@ -38,7 +38,7 @@ def validate_url(cls, val: str): return val @classmethod - def from_url(cls, full_str: str): + def from_url(cls, full_str: str) -> "SyftImageRegistry": # this is only for urlparse if "://" not in full_str: full_str = f"http://{full_str}" diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index 8c628f84486..81aa276559c 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -62,7 +62,7 @@ def delete( self, context: AuthedServiceContext, uid: UID = None, - url: str = None, + url: str = "", ) -> Union[SyftSuccess, SyftError]: # TODO - we need to make sure that there are no workers running an image bound to this registry diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index a3dcc87d5bd..fafe95f3112 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -14,6 +14,8 @@ # third party import docker +from docker.models.containers import Container +from kr8s.objects import Pod # relative from ...abstract_node import AbstractNode @@ -49,7 +51,9 @@ def backend_container_name() -> str: return f"{hostname}-{service_name}-1" -def get_container(docker_client: docker.DockerClient, container_name: str): +def get_container( + docker_client: docker.DockerClient, container_name: str +) -> Optional[Container]: try: existing_container = docker_client.containers.get(container_name) except docker.errors.NotFound: @@ -58,14 +62,20 @@ def get_container(docker_client: docker.DockerClient, container_name: str): return existing_container -def extract_config_from_backend(worker_name: str, docker_client: docker.DockerClient): +def extract_config_from_backend( + worker_name: str, docker_client: docker.DockerClient +) -> Dict[str, Any]: # Existing main backend container backend_container = get_container( docker_client, container_name=backend_container_name() ) # Config with defaults - extracted_config = {"volume_binds": {}, "network_mode": None, "environment": {}} + extracted_config: Dict[str, Any] = { + "volume_binds": {}, + "network_mode": None, + "environment": {}, + } if backend_container is None: return extracted_config @@ -96,7 +106,7 @@ def extract_config_from_backend(worker_name: str, docker_client: docker.DockerCl return extracted_config -def get_free_tcp_port(): +def get_free_tcp_port() -> int: with socketserver.TCPServer(("localhost", 0), None) as s: free_port = s.server_address[1] return free_port @@ -115,7 +125,7 @@ def run_container_using_docker( registry_url: Optional[str] = None, ) -> ContainerSpawnStatus: if not worker_image.is_built: - raise Exception("Image must be built before running it.") + raise ValueError("Image must be built before running it.") # Get hostname hostname = socket.gethostname() @@ -167,8 +177,11 @@ def run_container_using_docker( environment["QUEUE_PORT"] = queue_port environment["CONTAINER_HOST"] = "docker" + if worker_image.image_identifier is None: + raise ValueError(f"Image {worker_image} does not have an identifier") + container = docker_client.containers.run( - worker_image.image_identifier.full_name_with_tag, + image=worker_image.image_identifier.full_name_with_tag, name=f"{hostname}-{worker_name}", detach=True, auto_remove=True, @@ -259,20 +272,22 @@ def run_workers_in_threads( return results -def prepare_kubernetes_pool_env(runner: KubernetesRunner, env_vars: dict): +def prepare_kubernetes_pool_env( + runner: KubernetesRunner, env_vars: dict +) -> Tuple[List, Dict]: # get current backend pod name backend_pod_name = os.getenv("K8S_POD_NAME") if not backend_pod_name: - raise ValueError(message="Pod name not provided in environment variable") + raise ValueError("Pod name not provided in environment variable") # get current backend's credentials path - creds_path = os.getenv("CREDENTIALS_PATH") + creds_path: Union[str, None, Path] = os.getenv("CREDENTIALS_PATH") if not creds_path: - raise ValueError(message="Credentials path not provided") + raise ValueError("Credentials path not provided") creds_path = Path(creds_path) - if not creds_path.exists(): - raise ValueError(message="Credentials file does not exist") + if creds_path is not None and not creds_path.exists(): + raise ValueError("Credentials file does not exist") # create a secret for the node credentials owned by the backend, not the pool. node_secret = KubeUtils.create_secret( @@ -285,7 +300,7 @@ def prepare_kubernetes_pool_env(runner: KubernetesRunner, env_vars: dict): # clone and patch backend environment variables backend_env = runner.get_pod_env_vars(backend_pod_name) or [] - env_vars = KubeUtils.patch_env_vars(backend_env, env_vars) + env_vars_list = KubeUtils.patch_env_vars(backend_env, env_vars) mount_secrets = { node_secret.metadata.name: { "mountPath": str(creds_path), @@ -293,7 +308,7 @@ def prepare_kubernetes_pool_env(runner: KubernetesRunner, env_vars: dict): }, } - return env_vars, mount_secrets + return env_vars_list, mount_secrets def create_kubernetes_pool( @@ -306,8 +321,8 @@ def create_kubernetes_pool( reg_username: Optional[str] = None, reg_password: Optional[str] = None, reg_url: Optional[str] = None, - **kwargs, -): + **kwargs: Dict[str, Any], +) -> Union[SyftError, List[Pod]]: pool = None error = False @@ -357,7 +372,7 @@ def scale_kubernetes_pool( runner: KubernetesRunner, pool_name: str, replicas: int, -): +) -> Union[SyftError, List[Pod]]: pool = runner.get_pool(pool_name) if not pool: return SyftError(message=f"Pool does not exist. name={pool_name}") @@ -376,12 +391,12 @@ def run_workers_in_kubernetes( worker_count: int, pool_name: str, queue_port: int, - start_idx=0, + start_idx: int = 0, debug: bool = False, reg_username: Optional[str] = None, reg_password: Optional[str] = None, reg_url: Optional[str] = None, - **kwargs, + **kwargs: Dict[str, Any], ) -> Union[List[ContainerSpawnStatus], SyftError]: spawn_status = [] runner = KubernetesRunner() diff --git a/packages/syft/src/syft/service/worker/worker.py b/packages/syft/src/syft/service/worker/worker.py index 0242dba177a..ef3fc4aec5d 100644 --- a/packages/syft/src/syft/service/worker/worker.py +++ b/packages/syft/src/syft/service/worker/worker.py @@ -1,4 +1,8 @@ # stdlib +from typing import Any +from typing import Callable +from typing import Dict +from typing import List # relative from ...serde.serializable import serializable @@ -39,7 +43,7 @@ class DockerWorker(SyftObject): container_id: str created_at: DateTime = DateTime.now() - def _coll_repr_(self): + def _coll_repr_(self) -> Dict[str, Any]: return { "container_name": self.container_name, "container_id": self.container_id, @@ -48,10 +52,10 @@ def _coll_repr_(self): @migrate(DockerWorker, DockerWorkerV1) -def downgrade_job_v2_to_v1(): +def downgrade_job_v2_to_v1() -> List[Callable]: return [drop(["container_name"])] @migrate(DockerWorkerV1, DockerWorker) -def upgrade_job_v2_to_v3(): +def upgrade_job_v2_to_v3() -> List[Callable]: return [make_set_default("job_consumer_id", None)] diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py index 49eb6fa1802..a1580076104 100644 --- a/packages/syft/src/syft/service/worker/worker_image_stash.py +++ b/packages/syft/src/syft/service/worker/worker_image_stash.py @@ -1,5 +1,6 @@ # stdlib from typing import List +from typing import Optional from typing import Union # third party @@ -59,6 +60,6 @@ def set( def get_by_docker_config( self, credentials: SyftVerifyKey, config: DockerWorkerConfig - ): + ) -> Result[Optional[SyftWorkerImage], str]: qks = QueryKeys(qks=[WorkerConfigPK.with_obj(config)]) return self.query_one(credentials=credentials, qks=qks) diff --git a/packages/syft/src/syft/service/worker/worker_pool.py b/packages/syft/src/syft/service/worker/worker_pool.py index 377ee118b55..14f9ff0a7dd 100644 --- a/packages/syft/src/syft/service/worker/worker_pool.py +++ b/packages/syft/src/syft/service/worker/worker_pool.py @@ -87,7 +87,7 @@ def logs(self) -> Union[str, SyftError]: return api.services.worker.logs(uid=self.id) - def get_job_repr(self): + def get_job_repr(self) -> str: if self.job_id is not None: api = APIRegistry.api_for( node_uid=self.syft_node_location, @@ -117,14 +117,20 @@ def refresh_status(self) -> None: def _coll_repr_(self) -> Dict[str, Any]: self.refresh_status() + if self.image and self.image.image_identifier: image_name_with_tag = self.image.image_identifier.full_name_with_tag else: image_name_with_tag = "In Memory Worker" + + healthcheck: str = "" + if self.healthcheck is not None: + healthcheck = self.healthcheck.value + return { "Name": self.name, "Image": image_name_with_tag, - "Healthcheck (health / unhealthy)": f"{self.healthcheck.value}", + "Healthcheck (health / unhealthy)": f"{healthcheck}", "Status": f"{self.status.value}", "Job": self.get_job_repr(), "Created at": str(self.created_at), @@ -166,6 +172,8 @@ def image(self) -> Optional[Union[SyftWorkerImage, SyftError]]: ) if api is not None: return api.services.worker_image.get_by_uid(uid=self.image_id) + else: + return None @property def running_workers(self) -> Union[List[UID], SyftError]: diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index 1e82dcfccdf..4a3abbbeaa1 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -59,13 +59,13 @@ def get_worker_by_name( def update_consumer_state( self, credentials: SyftVerifyKey, worker_uid: UID, consumer_state: ConsumerState - ): + ) -> Result[Ok, Err]: res = self.get_by_uid(credentials=credentials, uid=worker_uid) if res.is_err(): return Err( f"Failed to retrieve Worker with id: {worker_uid}. Error: {res.err()}" ) - worker: SyftWorker = res.ok() + worker: Optional[SyftWorker] = res.ok() if worker is None: return Err(f"Worker with id: {worker_uid} not found") worker.consumer_state = consumer_state From bcf2aa5117d59aa097910b041de452b4b38bfb1c Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 12 Feb 2024 09:57:24 +0700 Subject: [PATCH 3/7] [refactor] fix some issues according to comments Co-authored-by: Kien Dang --- .../src/syft/service/worker/image_registry_service.py | 3 ++- packages/syft/src/syft/service/worker/utils.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index 81aa276559c..8acbd9e6f4b 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -1,5 +1,6 @@ # stdlib from typing import List +from typing import Optional from typing import Union # relative @@ -62,7 +63,7 @@ def delete( self, context: AuthedServiceContext, uid: UID = None, - url: str = "", + url: Optional[str] = None, ) -> Union[SyftSuccess, SyftError]: # TODO - we need to make sure that there are no workers running an image bound to this registry diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index fafe95f3112..4c50e3dd764 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -281,7 +281,7 @@ def prepare_kubernetes_pool_env( raise ValueError("Pod name not provided in environment variable") # get current backend's credentials path - creds_path: Union[str, None, Path] = os.getenv("CREDENTIALS_PATH") + creds_path: Optional[Union[str, Path]] = os.getenv("CREDENTIALS_PATH") if not creds_path: raise ValueError("Credentials path not provided") @@ -300,7 +300,7 @@ def prepare_kubernetes_pool_env( # clone and patch backend environment variables backend_env = runner.get_pod_env_vars(backend_pod_name) or [] - env_vars_list = KubeUtils.patch_env_vars(backend_env, env_vars) + env_vars_: List = KubeUtils.patch_env_vars(backend_env, env_vars) mount_secrets = { node_secret.metadata.name: { "mountPath": str(creds_path), @@ -308,7 +308,7 @@ def prepare_kubernetes_pool_env( }, } - return env_vars_list, mount_secrets + return env_vars_, mount_secrets def create_kubernetes_pool( @@ -321,7 +321,7 @@ def create_kubernetes_pool( reg_username: Optional[str] = None, reg_password: Optional[str] = None, reg_url: Optional[str] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> Union[SyftError, List[Pod]]: pool = None error = False @@ -396,7 +396,7 @@ def run_workers_in_kubernetes( reg_username: Optional[str] = None, reg_password: Optional[str] = None, reg_url: Optional[str] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> Union[List[ContainerSpawnStatus], SyftError]: spawn_status = [] runner = KubernetesRunner() From cd817a755c85fea22b1db7111141ec0060606122 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 14 Feb 2024 14:25:56 +0700 Subject: [PATCH 4/7] [refactor] refactor some functions and their return types --- .../syft/src/syft/service/worker/worker_pool_service.py | 8 +++++--- packages/syft/src/syft/service/worker/worker_stash.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index db3465e57e5..3680ab20346 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -443,7 +443,7 @@ def scale( number: int, pool_id: Optional[UID] = None, pool_name: Optional[str] = None, - ) -> Union[SyftError, SyftSuccess, List[ContainerSpawnStatus]]: + ) -> Union[SyftError, SyftSuccess]: """ Scale the worker pool to the given number of workers in Kubernetes. Allows both scaling up and down the worker pool. @@ -466,7 +466,7 @@ def scale( return SyftSuccess(message=f"Worker pool already has {number} workers") elif number > current_worker_count: workers_to_add = number - current_worker_count - return self.add_workers( + result = self.add_workers( context=context, number=workers_to_add, pool_id=pool_id, @@ -475,6 +475,8 @@ def scale( reg_username=None, reg_password=None, ) + if isinstance(result, SyftError): + return result else: # scale down at kubernetes control plane runner = KubernetesRunner() @@ -514,7 +516,7 @@ def scale( return SyftError( message=( f"Pool {worker_pool.name} was scaled down, " - f"but failed update the stash with err: {result.err()}" + f"but failed update the stash with err: {update_result.err()}" ) ) diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index 4a3abbbeaa1..cb7a914ed9b 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -59,7 +59,7 @@ def get_worker_by_name( def update_consumer_state( self, credentials: SyftVerifyKey, worker_uid: UID, consumer_state: ConsumerState - ) -> Result[Ok, Err]: + ) -> Result[str, str]: res = self.get_by_uid(credentials=credentials, uid=worker_uid) if res.is_err(): return Err( From c6185b16d2089459cb21b68fa29e68f77d21a100 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Wed, 14 Feb 2024 15:50:01 +0800 Subject: [PATCH 5/7] Switch success/error type order --- packages/syft/src/syft/service/worker/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index 8681cebd7f6..f7443139027 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -322,7 +322,7 @@ def create_kubernetes_pool( reg_password: Optional[str] = None, reg_url: Optional[str] = None, **kwargs: Any, -) -> Union[SyftError, List[Pod]]: +) -> Union[List[Pod], SyftError]: pool = None error = False @@ -372,7 +372,7 @@ def scale_kubernetes_pool( runner: KubernetesRunner, pool_name: str, replicas: int, -) -> Union[SyftError, List[Pod]]: +) -> Union[List[Pod], SyftError]: pool = runner.get_pool(pool_name) if not pool: return SyftError(message=f"Pool does not exist. name={pool_name}") From 558edb423519d4299e63bc90e456e66e4a3ae3ef Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Wed, 14 Feb 2024 15:51:41 +0800 Subject: [PATCH 6/7] Use typing.Self instead of hard-coded class --- packages/syft/src/syft/service/worker/image_registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/worker/image_registry.py b/packages/syft/src/syft/service/worker/image_registry.py index 7fb9dfcb770..7292273c605 100644 --- a/packages/syft/src/syft/service/worker/image_registry.py +++ b/packages/syft/src/syft/service/worker/image_registry.py @@ -4,6 +4,7 @@ # third party from pydantic import validator +from typing_extensions import Self # relative from ...serde.serializable import serializable @@ -38,7 +39,7 @@ def validate_url(cls, val: str) -> str: return val @classmethod - def from_url(cls, full_str: str) -> "SyftImageRegistry": + def from_url(cls, full_str: str) -> Self: # this is only for urlparse if "://" not in full_str: full_str = f"http://{full_str}" From df7cfef3bc09c8c98640990b18a70270ad87f821 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Wed, 14 Feb 2024 17:36:35 +0800 Subject: [PATCH 7/7] Use ternary for simplicity --- packages/syft/src/syft/service/worker/worker_pool.py | 4 +--- .../syft/src/syft/service/worker/worker_pool_service.py | 9 +++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/service/worker/worker_pool.py b/packages/syft/src/syft/service/worker/worker_pool.py index 14f9ff0a7dd..151addcaac0 100644 --- a/packages/syft/src/syft/service/worker/worker_pool.py +++ b/packages/syft/src/syft/service/worker/worker_pool.py @@ -123,9 +123,7 @@ def _coll_repr_(self) -> Dict[str, Any]: else: image_name_with_tag = "In Memory Worker" - healthcheck: str = "" - if self.healthcheck is not None: - healthcheck = self.healthcheck.value + healthcheck = self.healthcheck.value if self.healthcheck is not None else "" return { "Name": self.name, diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 3680ab20346..38f223332ea 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -658,10 +658,11 @@ def _create_workers_in_pool( number=worker_cnt + existing_worker_cnt, ) else: - if worker_image.image_identifier is not None: - registry_host = worker_image.image_identifier.registry_host - else: - registry_host = None + registry_host = ( + worker_image.image_identifier.registry_host + if worker_image.image_identifier is not None + else None + ) result = run_containers( pool_name=pool_name, worker_image=worker_image,