diff --git a/poetry.lock b/poetry.lock index 1868163ed..938b69e33 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1273,6 +1273,17 @@ docs = ["IPython", "bump2version", "furo", "sphinx", "sphinx-argparse", "towncri lint = ["black", "check-manifest", "flake8", "isort", "mypy"] test = ["Cython", "greenlet", "ipython", "pytest", "pytest-cov", "setuptools"] +[[package]] +name = "mergedeep" +version = "1.3.4" +description = "A deep merge function for 🐍." +optional = false +python-versions = ">=3.6" +files = [ + {file = "mergedeep-1.3.4-py3-none-any.whl", hash = "sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307"}, + {file = "mergedeep-1.3.4.tar.gz", hash = "sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8"}, +] + [[package]] name = "msgpack" version = "1.0.8" @@ -2795,4 +2806,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "d656bab99c2e5a911ee1003db9e0682141328ae3ef1e1620945f8479451425bf" +content-hash = "641a3685dcb9a044e49d903bdb0d8911d410f7a65e7835dda087a194c47e3c64" diff --git a/pyproject.toml b/pyproject.toml index af7dd1ca0..5841ebd50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ cryptography = "40.0.2" executing = "1.2.0" pydantic = "< 2" ipywidgets = "8.1.2" +mergedeep = "1.3.4" [tool.poetry.group.docs] optional = true diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 015f15eda..149a5d480 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -18,11 +18,9 @@ cluster setup queue, a list of all existing clusters, and the user's working namespace. """ -import re from time import sleep from typing import List, Optional, Tuple, Dict -from kubernetes import config from ray.job_submission import JobSubmissionClient from .auth import config_check, api_config_handler @@ -41,13 +39,11 @@ RayCluster, RayClusterStatus, ) -from kubernetes import client, config -from kubernetes.utils import parse_quantity import yaml import os import requests -from kubernetes import config +from kubernetes import client, config from kubernetes.client.rest import ApiException @@ -145,6 +141,8 @@ def create_app_wrapper(self): gpu = self.config.num_gpus workers = self.config.num_workers template = self.config.template + head_template = self.config.head_template + worker_template = self.config.worker_template image = self.config.image appwrapper = self.config.appwrapper env = self.config.envs @@ -167,6 +165,8 @@ def create_app_wrapper(self): gpu=gpu, workers=workers, template=template, + head_template=head_template, + worker_template=worker_template, image=image, appwrapper=appwrapper, env=env, diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index 970673652..66bf7a13c 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -22,6 +22,8 @@ import pathlib import typing +import kubernetes + dir = pathlib.Path(__file__).parent.parent.resolve() @@ -46,6 +48,8 @@ class ClusterConfiguration: max_memory: typing.Union[int, str] = 2 num_gpus: int = 0 template: str = f"{dir}/templates/base-template.yaml" + head_template: kubernetes.client.V1PodTemplateSpec = None + worker_template: kubernetes.client.V1PodTemplateSpec = None appwrapper: bool = False envs: dict = field(default_factory=dict) image: str = "" diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index 3192ae1bc..56b79352b 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -20,16 +20,12 @@ from typing import Optional import typing import yaml -import sys import os -import argparse import uuid from kubernetes import client, config from .kube_api_helpers import _kube_api_error_handling from ..cluster.auth import api_config_handler, config_check -from os import urandom -from base64 import b64encode -from urllib3.util import parse_url +from mergedeep import merge, Strategy def read_template(template): @@ -278,6 +274,16 @@ def write_user_yaml(user_yaml, output_file_name): print(f"Written to: {output_file_name}") +def apply_head_template(cluster_yaml: dict, head_template: client.V1PodTemplateSpec): + head = cluster_yaml.get("spec").get("headGroupSpec") + merge(head["template"], head_template.to_dict(), strategy=Strategy.ADDITIVE) + + +def apply_worker_template(cluster_yaml: dict, worker_template: client.V1PodTemplateSpec): + worker = cluster_yaml.get("spec").get("workerGroupSpecs")[0] + merge(worker["template"], worker_template.to_dict(), strategy=Strategy.ADDITIVE) + + def generate_appwrapper( name: str, namespace: str, @@ -291,6 +297,8 @@ def generate_appwrapper( gpu: int, workers: int, template: str, + head_template: client.V1PodTemplateSpec, + worker_template: client.V1PodTemplateSpec, image: str, appwrapper: bool, env, @@ -302,6 +310,12 @@ def generate_appwrapper( volume_mounts: list[client.V1VolumeMount], ): cluster_yaml = read_template(template) + + if head_template: + apply_head_template(cluster_yaml, head_template) + if worker_template: + apply_worker_template(cluster_yaml, worker_template) + appwrapper_name, cluster_name = gen_names(name) update_names(cluster_yaml, cluster_name, namespace) update_nodes( diff --git a/tests/unit_test.py b/tests/unit_test.py index db908df60..31c25a061 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -20,8 +20,6 @@ import re import uuid -from codeflare_sdk.cluster import cluster - parent = Path(__file__).resolve().parents[1] aw_dir = os.path.expanduser("~/.codeflare/resources/") sys.path.append(str(parent) + "/src") @@ -69,17 +67,18 @@ createClusterConfig, ) -import codeflare_sdk.utils.kube_api_helpers from codeflare_sdk.utils.generate_yaml import ( gen_names, is_openshift_cluster, ) import openshift -from openshift.selector import Selector import ray import pytest import yaml + +from kubernetes.client import V1PodTemplateSpec, V1PodSpec, V1Toleration + from unittest.mock import MagicMock from pytest_mock import MockerFixture from ray.job_submission import JobSubmissionClient @@ -268,6 +267,41 @@ def test_config_creation(): assert config.appwrapper == True +def test_cluster_config_with_worker_template(mocker): + mocker.patch("kubernetes.client.ApisApi.get_api_versions") + mocker.patch( + "kubernetes.client.CustomObjectsApi.list_namespaced_custom_object", + return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"), + ) + + cluster = Cluster(ClusterConfiguration( + name="unit-test-cluster", + namespace="ns", + num_workers=2, + min_cpus=3, + max_cpus=4, + min_memory=5, + max_memory=6, + num_gpus=7, + image="test/ray:2.20.0-py39-cu118", + worker_template=V1PodTemplateSpec( + spec=V1PodSpec( + containers=[], + tolerations=[V1Toleration( + key="nvidia.com/gpu", + operator="Exists", + effect="NoSchedule", + )], + node_selector={ + "nvidia.com/gpu.present": "true", + }, + ) + ), + )) + + assert cluster + + def test_cluster_creation(mocker): # Create AppWrapper containing a Ray Cluster with no local queue specified mocker.patch("kubernetes.client.ApisApi.get_api_versions")