Skip to content

Commit

Permalink
Merge pull request #8884 from OpenMined/rasswanth/add-annotations-to-…
Browse files Browse the repository at this point in the history
…custom-worker-pool

Add ability to pass pod Annotations and Labels during Worker Pool Launch
  • Loading branch information
rasswanth-s authored Jun 4, 2024
2 parents d9302f3 + 216a9bd commit bff337a
Show file tree
Hide file tree
Showing 13 changed files with 373 additions and 130 deletions.
367 changes: 249 additions & 118 deletions notebooks/api/0.8/11-container-images-k8s.ipynb

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion packages/grid/helm/examples/azure/azure.high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ node:
name: syft-azure
side: high
rootEmail: [email protected]
defaultWorkerPoolCount: 1
resourcesPreset: 2xlarge

defaultWorkerPool:
count: 1
podLabels: null
podAnnotations: null

ingress:
# Make sure cluster is created with --enable-app-routing
# az aks create -g group-name -n cluster-name -l region --enable-app-routing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ spec:
- name: DEFAULT_WORKER_POOL_IMAGE
value: "{{ .Values.global.registry }}/openmined/grid-backend:{{ .Values.global.version }}"
- name: DEFAULT_WORKER_POOL_COUNT
value: {{ .Values.node.defaultWorkerPoolCount | quote }}
value: {{ .Values.node.defaultWorkerPool.count | quote }}
- name: DEFAULT_WORKER_POOL_POD_LABELS
value: {{ .Values.node.defaultWorkerPool.podLabels | toJson | quote }}
- name: DEFAULT_WORKER_POOL_POD_ANNOTATIONS
value: {{ .Values.node.defaultWorkerPool.podAnnotations | toJson | quote }}
- name: USE_INTERNAL_REGISTRY
value: {{ .Values.node.useInternalRegistry | quote }}
{{- if .Values.node.defaultBucketName }}
Expand Down
9 changes: 7 additions & 2 deletions packages/grid/helm/syft/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,20 @@ node:
rootEmail: [email protected]
type: domain
side: high
inMemoryWorkers: false
defaultWorkerPoolCount: 1
defaultBucketName: null
inMemoryWorkers: false
queuePort: 5556
logLevel: info
debuggerEnabled: false
associationRequestAutoApproval: false
useInternalRegistry: true

# Default Worker pool settings
defaultWorkerPool:
count: 1
podLabels: null
podAnnotations: null

# SMTP Settings
smtp:
host: smtp.sendgrid.net
Expand Down
6 changes: 5 additions & 1 deletion packages/grid/helm/values.dev.high.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ registry:

node:
rootEmail: [email protected]
defaultWorkerPoolCount: 1
side: high

resourcesPreset: 2xlarge
resources: null

defaultWorkerPool:
count: 1
podLabels: null
podAnnotations: null

secret:
defaultRootPassword: changethis

Expand Down
6 changes: 5 additions & 1 deletion packages/grid/helm/values.dev.low.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ registry:

node:
rootEmail: [email protected]
defaultWorkerPoolCount: 1
side: low

resourcesPreset: 2xlarge
resources: null

defaultWorkerPool:
count: 1
podLabels: null
podAnnotations: null

secret:
defaultRootPassword: changethis

Expand Down
6 changes: 5 additions & 1 deletion packages/grid/helm/values.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ registry:

node:
rootEmail: [email protected]
defaultWorkerPoolCount: 1
associationRequestAutoApproval: true

resourcesPreset: null
resources: null

defaultWorkerPool:
count: 1
podLabels: null
podAnnotations: null

secret:
defaultRootPassword: changethis

Expand Down
22 changes: 18 additions & 4 deletions packages/syft/src/syft/custom_worker/runner_k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def create_pool(
registry_username: str | None = None,
registry_password: str | None = None,
reg_url: str | None = None,
pod_annotations: dict[str, str] | None = None,
pod_labels: dict[str, str] | None = None,
**kwargs: Any,
) -> StatefulSet:
try:
Expand All @@ -52,6 +54,8 @@ def create_pool(
env_vars=env_vars,
mount_secrets=mount_secrets,
pull_secret=pull_secret,
pod_annotations=pod_annotations,
pod_labels=pod_labels,
**kwargs,
)

Expand Down Expand Up @@ -147,6 +151,8 @@ def _create_stateful_set(
env_vars: list[dict] | None = None,
mount_secrets: dict | None = None,
pull_secret: Secret | None = None,
pod_annotations: dict[str, str] | None = None,
pod_labels: dict[str, str] | None = None,
**kwargs: Any,
) -> StatefulSet:
"""Create a stateful set for a pool"""
Expand Down Expand Up @@ -182,6 +188,16 @@ def _create_stateful_set(
}
]

default_pod_labels = {
"app.kubernetes.io/name": KUBERNETES_NAMESPACE,
"app.kubernetes.io/component": pool_name,
}

if isinstance(pod_labels, dict):
pod_labels = {**default_pod_labels, **pod_labels}
else:
pod_labels = default_pod_labels

stateful_set = StatefulSet(
{
"metadata": {
Expand All @@ -201,10 +217,8 @@ def _create_stateful_set(
},
"template": {
"metadata": {
"labels": {
"app.kubernetes.io/name": KUBERNETES_NAMESPACE,
"app.kubernetes.io/component": pool_name,
}
"labels": pod_labels,
"annotations": pod_annotations,
},
"spec": {
# TODO: make this configurable
Expand Down
15 changes: 15 additions & 0 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime
from functools import partial
import hashlib
import json
import os
from pathlib import Path
import shutil
Expand Down Expand Up @@ -217,6 +218,16 @@ def get_default_worker_pool_count(node: Node) -> int:
)


def get_default_worker_pool_pod_annotations() -> dict[str, str] | None:
annotations = get_env("DEFAULT_WORKER_POOL_POD_ANNOTATIONS", "null")
return json.loads(annotations)


def get_default_worker_pool_pod_labels() -> dict[str, str] | None:
labels = get_env("DEFAULT_WORKER_POOL_POD_LABELS", "null")
return json.loads(labels)


def in_kubernetes() -> bool:
return get_container_host() == "k8s"

Expand Down Expand Up @@ -1719,6 +1730,8 @@ def create_default_worker_pool(node: Node) -> SyftError | None:
default_pool_name = node.settings.default_worker_pool
default_worker_pool = node.get_default_worker_pool()
default_worker_tag = get_default_worker_tag_by_env(node.dev_mode)
default_worker_pool_pod_annotations = get_default_worker_pool_pod_annotations()
default_worker_pool_pod_labels = get_default_worker_pool_pod_labels()
worker_count = get_default_worker_pool_count(node)
context = AuthedServiceContext(
node=node,
Expand Down Expand Up @@ -1775,6 +1788,8 @@ def create_default_worker_pool(node: Node) -> SyftError | None:
pool_name=default_pool_name,
image_uid=default_image.id,
num_workers=worker_count,
pod_annotations=default_worker_pool_pod_annotations,
pod_labels=default_worker_pool_pod_labels,
)
else:
# Else add a worker to existing worker pool
Expand Down
7 changes: 7 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,13 @@
"action": "remove"
}
},
"CreateCustomWorkerPoolChange": {
"3": {
"version": 3,
"hash": "e982f2ebcdc6fe23a65a014109e33ba7c487bb7ca5623723cf5ec7642f86828c",
"action": "add"
}
},
"NodePeerUpdate": {
"1": {
"version": 1,
Expand Down
19 changes: 18 additions & 1 deletion packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,14 @@ def __repr_syft_nested__(self) -> str:
@serializable()
class CreateCustomWorkerPoolChange(Change):
__canonical_name__ = "CreateCustomWorkerPoolChange"
__version__ = SYFT_OBJECT_VERSION_2
__version__ = SYFT_OBJECT_VERSION_3

pool_name: str
num_workers: int
image_uid: UID | None = None
config: WorkerConfig | None = None
pod_annotations: dict[str, str] | None = None
pod_labels: dict[str, str] | None = None

__repr_attrs__ = ["pool_name", "num_workers", "image_uid"]

Expand Down Expand Up @@ -337,6 +339,8 @@ def _run(
num_workers=self.num_workers,
registry_username=context.extra_kwargs.get("registry_username", None),
registry_password=context.extra_kwargs.get("registry_password", None),
pod_annotations=self.pod_annotations,
pod_labels=self.pod_labels,
)
if isinstance(result, SyftError):
return Err(result)
Expand All @@ -361,6 +365,19 @@ def __repr_syft_nested__(self) -> str:
)


@serializable()
class CreateCustomWorkerPoolChangeV2(Change):
__canonical_name__ = "CreateCustomWorkerPoolChange"
__version__ = SYFT_OBJECT_VERSION_2

pool_name: str
num_workers: int
image_uid: UID | None = None
config: WorkerConfig | None = None

__repr_attrs__ = ["pool_name", "num_workers", "image_uid"]


@serializable()
class Request(SyncableSyftObject):
__canonical_name__ = "Request"
Expand Down
12 changes: 12 additions & 0 deletions packages/syft/src/syft/service/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def create_kubernetes_pool(
registry_username: str | None = None,
registry_password: str | None = None,
reg_url: str | None = None,
pod_annotations: dict[str, str] | None = None,
pod_labels: dict[str, str] | None = None,
**kwargs: Any,
) -> list[Pod] | SyftError:
pool = None
Expand Down Expand Up @@ -363,6 +365,8 @@ def create_kubernetes_pool(
registry_username=registry_username,
registry_password=registry_password,
reg_url=reg_url,
pod_annotations=pod_annotations,
pod_labels=pod_labels,
)
except Exception as e:
if pool:
Expand Down Expand Up @@ -405,6 +409,8 @@ def run_workers_in_kubernetes(
registry_username: str | None = None,
registry_password: str | None = None,
reg_url: str | None = None,
pod_annotations: dict[str, str] | None = None,
pod_labels: dict[str, str] | None = None,
**kwargs: Any,
) -> list[ContainerSpawnStatus] | SyftError:
spawn_status = []
Expand All @@ -422,6 +428,8 @@ def run_workers_in_kubernetes(
registry_username=registry_username,
registry_password=registry_password,
reg_url=reg_url,
pod_annotations=pod_annotations,
pod_labels=pod_labels,
)
else:
return SyftError(
Expand Down Expand Up @@ -504,6 +512,8 @@ def run_containers(
registry_username: str | None = None,
registry_password: str | None = None,
reg_url: str | None = None,
pod_annotations: dict[str, str] | None = None,
pod_labels: dict[str, str] | None = None,
) -> list[ContainerSpawnStatus] | SyftError:
results = []

Expand Down Expand Up @@ -540,6 +550,8 @@ def run_containers(
registry_username=registry_username,
registry_password=registry_password,
reg_url=reg_url,
pod_annotations=pod_annotations,
pod_labels=pod_labels,
)

return results
Expand Down
Loading

0 comments on commit bff337a

Please sign in to comment.