diff --git a/pytest.ini b/pytest.ini index 2046d89f4..ca71dc5cc 100644 --- a/pytest.ini +++ b/pytest.ini @@ -5,3 +5,5 @@ addopts = --allow-hosts=127.0.0.1,localhost ; unix socket for Docker/testcontainers --allow-unix-socket +markers = + shim_version diff --git a/requirements_dev.txt b/requirements_dev.txt index cba7f37fb..78ad4b818 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -5,6 +5,7 @@ pytest~=7.2 pytest-asyncio>=0.21 pytest-httpbin==2.1.0 pytest-socket>=0.7.0 +requests-mock>=1.12.1 openai>=1.53.0,<2.0.0 freezegun>=1.2.0 ruff==0.5.3 # Should match .pre-commit-config.yaml diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index baa60b37e..5f95e7a94 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -452,7 +452,7 @@ def _process_provisioning_with_shim( # the previous job container now, making the shim available (state=running->pending) # for the next try. logger.warning( - "%s: failed to sumbit, shim is already running a job, stopping it now, retry later", + "%s: failed to submit, shim is already running a job, stopping it now, retry later", fmt(job_model), ) shim_client.stop(force=True) diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 66ff30143..a38141871 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -1,4 +1,5 @@ from base64 import b64decode +from enum import Enum from typing import Dict, List, Optional, Union from pydantic import Field, validator @@ -94,13 +95,55 @@ class ShimVolumeInfo(CoreModel): device_name: Optional[str] = None -class TaskConfigBody(CoreModel): +class TaskStatus(str, Enum): + PENDING = "pending" + PREPARING = "preparing" + PULLING = "pulling" + CREATING = "creating" + RUNNING = "running" + TERMINATED = "terminated" + + +class TaskInfoResponse(CoreModel): + id: str + status: TaskStatus + termination_reason: str + termination_message: str + + +class TaskSubmitRequest(CoreModel): + id: str + name: str + registry_username: str + registry_password: str + image_name: str + container_user: str + privileged: bool + gpu: int + cpu: float + memory: int + shm_size: int + volumes: list[ShimVolumeInfo] + volume_mounts: list[VolumeMountPoint] + instance_mounts: list[InstanceMountPoint] + host_ssh_user: str + host_ssh_keys: list[str] + container_ssh_keys: list[str] + + +class TaskTerminateRequest(CoreModel): + termination_reason: str + termination_message: str + timeout: int + + +class LegacySubmitBody(CoreModel): username: str password: str image_name: str privileged: bool container_name: str - container_user: Optional[str] + container_user: str shm_size: int public_keys: List[str] ssh_user: str @@ -110,7 +153,7 @@ class TaskConfigBody(CoreModel): instance_mounts: List[InstanceMountPoint] -class StopBody(CoreModel): +class LegacyStopBody(CoreModel): force: bool = False @@ -119,6 +162,6 @@ class JobResult(CoreModel): reason_message: str -class PullBody(CoreModel): +class LegacyPullResponse(CoreModel): state: str result: Optional[JobResult] diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index aafe50753..22f4c346e 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -1,10 +1,12 @@ from dataclasses import dataclass from http import HTTPStatus -from typing import BinaryIO, Dict, List, Optional, Union +from typing import BinaryIO, Dict, List, Optional, TypeVar, Union +import packaging.version import requests import requests.exceptions +from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.envs import Env from dstack._internal.core.models.repos.remote import RemoteRepoCreds from dstack._internal.core.models.resources import Memory @@ -12,20 +14,28 @@ from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint from dstack._internal.server.schemas.runner import ( HealthcheckResponse, + JobResult, + LegacyPullResponse, + LegacyStopBody, + LegacySubmitBody, MetricsResponse, - PullBody, PullResponse, ShimVolumeInfo, - StopBody, SubmitBody, - TaskConfigBody, + TaskInfoResponse, + TaskStatus, + TaskSubmitRequest, + TaskTerminateRequest, ) from dstack._internal.utils.common import get_or_error +from dstack._internal.utils.logging import get_logger REMOTE_SHIM_PORT = 10998 REMOTE_RUNNER_PORT = 10999 REQUEST_TIMEOUT = 15 +logger = get_logger(__name__) + @dataclass class HealthStatus: @@ -120,24 +130,33 @@ def _url(self, path: str) -> str: class ShimClient: + # API v2 (a.k.a. Future API) — `/api/tasks/[:id[/{terminate,remove}]]` + # API v1 (a.k.a. Legacy API) — `/api/{submit,pull,stop}` + _API_V2_MIN_SHIM_VERSION = (0, 18, 34) + + # A surrogate task ID for API-v1-over-v2 emulation (`_v2_compat_*` methods) + _LEGACY_TASK_ID = "00000000-0000-0000-0000-000000000000" + + _shim_version: Optional["_Version"] + _api_version: int + _negotiated: bool = False + def __init__( self, port: int, hostname: str = "localhost", ): - self.secure = False - self.hostname = hostname - self.port = port + self._session = requests.Session() + self._base_url = f"http://{hostname}:{port}" def healthcheck(self, unmask_exeptions: bool = False) -> Optional[HealthcheckResponse]: try: - resp = requests.get(self._url("/api/healthcheck"), timeout=REQUEST_TIMEOUT) - resp.raise_for_status() - return HealthcheckResponse.__response__.parse_obj(resp.json()) + resp = self._request("GET", "/api/healthcheck", raise_for_status=True) except requests.exceptions.RequestException: if unmask_exeptions: raise return None + return self._response(HealthcheckResponse, resp) def submit( self, @@ -146,7 +165,7 @@ def submit( image_name: str, privileged: bool, container_name: str, - container_user: Optional[str], + container_user: str, shm_size: Optional[Memory], public_keys: List[str], ssh_user: str, @@ -154,50 +173,190 @@ def submit( mounts: List[VolumeMountPoint], volumes: List[Volume], instance_mounts: List[InstanceMountPoint], - ): + ) -> bool: """ Returns `True` if submitted and `False` if the shim already has a job (`409 Conflict`). Other error statuses raise an exception. """ - _shm_size = int(shm_size * 1024 * 1024 * 1024) if shm_size else 0 - volume_infos = [_volume_to_shim_volume_info(v) for v in volumes] - post_body = TaskConfigBody( + if self._is_api_v2_supported(): + return self._v2_compat_submit( + username=username, + password=password, + image_name=image_name, + privileged=privileged, + container_name=container_name, + container_user=container_user, + shm_size=shm_size, + public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=ssh_key, + mounts=mounts, + volumes=volumes, + instance_mounts=instance_mounts, + ) + body = LegacySubmitBody( username=username, password=password, image_name=image_name, privileged=privileged, container_name=container_name, container_user=container_user, - shm_size=_shm_size, + shm_size=int(shm_size * 1024**3) if shm_size else 0, public_keys=public_keys, ssh_user=ssh_user, ssh_key=ssh_key, mounts=mounts, - volumes=volume_infos, + volumes=[_volume_to_shim_volume_info(v) for v in volumes], instance_mounts=instance_mounts, - ).dict() - resp = requests.post( - self._url("/api/submit"), - json=post_body, - timeout=REQUEST_TIMEOUT, ) + resp = self._request("POST", "/api/submit", body) if resp.status_code == HTTPStatus.CONFLICT: return False resp.raise_for_status() return True - def stop(self, force: bool = False): - body = StopBody(force=force) - resp = requests.post(self._url("/api/stop"), json=body.dict(), timeout=REQUEST_TIMEOUT) + def stop(self, force: bool = False) -> None: + if self._is_api_v2_supported(): + return self._v2_compat_stop(force) + body = LegacyStopBody(force=force) + self._request("POST", "/api/stop", body, raise_for_status=True) + + def pull(self) -> LegacyPullResponse: + if self._is_api_v2_supported(): + return self._v2_compat_pull() + resp = self._request("GET", "/api/pull", raise_for_status=True) + return self._response(LegacyPullResponse, resp) + + def _v2_compat_submit( + self, + username: str, + password: str, + image_name: str, + privileged: bool, + container_name: str, + container_user: str, + shm_size: Optional[Memory], + public_keys: list[str], + ssh_user: str, + ssh_key: str, + mounts: list[VolumeMountPoint], + volumes: list[Volume], + instance_mounts: List[InstanceMountPoint], + ) -> bool: + task_id = self._LEGACY_TASK_ID + resp = self._request("GET", f"/api/tasks/{task_id}") + if resp.status_code != HTTPStatus.NOT_FOUND: + resp.raise_for_status() + task = self._response(TaskInfoResponse, resp) + if task.status != TaskStatus.TERMINATED: + return False + self._request("POST", f"/api/tasks/{task_id}/remove", raise_for_status=True) + body = TaskSubmitRequest( + id=task_id, + name=container_name, + registry_username=username, + registry_password=password, + image_name=image_name, + container_user=container_user, + privileged=privileged, + gpu=-1, + cpu=0, + memory=0, + shm_size=int(shm_size * 1024**3) if shm_size else 0, + volumes=[_volume_to_shim_volume_info(v) for v in volumes], + volume_mounts=mounts, + instance_mounts=instance_mounts, + host_ssh_user=ssh_user, + host_ssh_keys=[ssh_key], + container_ssh_keys=public_keys, + ) + resp = self._request("POST", "/api/tasks", body, raise_for_status=True) + return True + + def _v2_compat_stop(self, force: bool = False) -> None: + task_id = self._LEGACY_TASK_ID + body = TaskTerminateRequest( + termination_reason="", + termination_message="", + timeout=0 if force else 10, + ) + resp = self._request("POST", f"/api/tasks/{task_id}/terminate", body) + if resp.status_code == HTTPStatus.NOT_FOUND: + return resp.raise_for_status() - def pull(self) -> PullBody: - resp = requests.get(self._url("/api/pull"), timeout=REQUEST_TIMEOUT) + def _v2_compat_pull(self) -> LegacyPullResponse: + task_id = self._LEGACY_TASK_ID + resp = self._request("GET", f"/api/tasks/{task_id}") + if resp.status_code == HTTPStatus.NOT_FOUND: + return LegacyPullResponse( + state="pending", + result=JobResult(reason="", reason_message=""), + ) resp.raise_for_status() - return PullBody.__response__.parse_obj(resp.json()) + task = self._response(TaskInfoResponse, resp) + if task.status in [TaskStatus.PENDING, TaskStatus.PREPARING, TaskStatus.PULLING]: + state = "pulling" + elif task.status == TaskStatus.CREATING: + state = "creating" + elif task.status == TaskStatus.RUNNING: + state = "running" + elif task.status == TaskStatus.TERMINATED: + state = "pending" + else: + assert False, f"should not reach here: {task.status}" + return LegacyPullResponse( + state=state, + result=JobResult( + reason=task.termination_reason, reason_message=task.termination_message + ), + ) - def _url(self, path: str) -> str: - return f"{'https' if self.secure else 'http'}://{self.hostname}:{self.port}/{path.lstrip('/')}" + def _request( + self, + method: str, + path: str, + body: Optional[CoreModel] = None, + *, + raise_for_status: bool = False, + ) -> requests.Response: + url = f"{self._base_url}/{path.lstrip('/')}" + if body is not None: + json = body.dict() + else: + json = None + resp = self._session.request(method, url, json=json, timeout=REQUEST_TIMEOUT) + if raise_for_status: + resp.raise_for_status() + return resp + + _M = TypeVar("_M", bound=CoreModel) + + def _response(self, model_cls: type[_M], response: requests.Response) -> _M: + return model_cls.__response__.parse_obj(response.json()) + + def _is_api_v2_supported(self) -> bool: + if not self._negotiated: + self._negotiate() + return self._api_version >= 2 + + def _negotiate(self) -> None: + resp = self._request("GET", "/api/healthcheck", raise_for_status=True) + raw_version = self._response(HealthcheckResponse, resp).version + version = _parse_version(raw_version) + if version is None or version >= self._API_V2_MIN_SHIM_VERSION: + api_version = 2 + else: + api_version = 1 + logger.debug( + "shim version: %s %s (API v%s)", + raw_version, + version or "(latest)", + api_version, + ) + self._shim_version = version + self._api_version = api_version + self._negotiated = True def health_response_to_health_status(data: HealthcheckResponse) -> HealthStatus: @@ -221,3 +380,30 @@ def _volume_to_shim_volume_info(volume: Volume) -> ShimVolumeInfo: init_fs=not volume.external, device_name=device_name, ) + + +_Version = tuple[int, int, int] + + +def _parse_version(version_string: str) -> Optional[_Version]: + """ + Returns a (major, minor, micro) tuple if the version if final. + Returns `None`, which means "latest", if: + * the version is prerelease or dev build -- assuming that in most cases it's a build based on + the latest final release + * the version consists of only major part or not valid at all, e.g., staging builds have + GitHub run number (e.g., 1234) instead of the version -- assuming that it's a "bleeding edge", + not yet released version + """ + try: + version = packaging.version.parse(version_string) + except packaging.version.InvalidVersion: + return None + if version.is_prerelease or version.is_devrelease: + return None + release = version.release + if len(release) <= 1: + return None + if len(release) == 2: + return (*release, 0) + return release[:3] diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index fbfb4afd1..9021a8519 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -599,11 +599,15 @@ def get_volume( volume_id: Optional[str] = None, provisioning_data: Optional[VolumeProvisioningData] = None, attachment_data: Optional[VolumeAttachmentData] = None, + device_name: Optional[str] = None, ) -> Volume: if id_ is None: id_ = uuid.uuid4() if configuration is None: configuration = get_volume_configuration() + if device_name is not None: + assert attachment_data is None, "attachment_data and device_name are mutually exclusive" + attachment_data = VolumeAttachmentData(device_name=device_name) return Volume( id=id_, name=name, diff --git a/src/tests/_internal/server/services/runner/__init__.py b/src/tests/_internal/server/services/runner/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/tests/_internal/server/services/runner/test_client.py b/src/tests/_internal/server/services/runner/test_client.py new file mode 100644 index 000000000..9e8cc2d2b --- /dev/null +++ b/src/tests/_internal/server/services/runner/test_client.py @@ -0,0 +1,522 @@ +from collections.abc import Generator +from typing import Optional + +import pytest +import requests_mock + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.resources import Memory +from dstack._internal.core.models.volumes import InstanceMountPoint, VolumeMountPoint +from dstack._internal.server.schemas.runner import ( + HealthcheckResponse, + JobResult, + LegacyPullResponse, +) +from dstack._internal.server.services.runner.client import ShimClient, _parse_version +from dstack._internal.server.testing.common import get_volume, get_volume_configuration + + +class BaseShimClientTest: + @pytest.fixture + def adapter(self) -> Generator[requests_mock.Adapter, None, None]: + adapter = requests_mock.Adapter() + with requests_mock.Mocker(adapter=adapter): + yield adapter + return + + @pytest.fixture + def client(self, request: pytest.FixtureRequest, adapter: requests_mock.Adapter) -> ShimClient: + shim_version_marker = request.node.get_closest_marker("shim_version") + if shim_version_marker is not None: + healthcheck_resp = {"service": "dstack-shim", "version": shim_version_marker.args[0]} + adapter.register_uri("GET", "/api/healthcheck", json=healthcheck_resp) + return ShimClient(port=10998, hostname="localhost") + + def assert_request( + self, + adapter: requests_mock.Adapter, + index: int, + method: str, + path: str, + json: Optional[dict] = None, + ): + history = adapter.request_history + assert index < len(history), "index out of history bounds" + req = history[index] + assert req.method == method + assert req.path == path + if json is not None: + assert req.json() == json + + +class TestShimClientNegotiate(BaseShimClientTest): + @pytest.mark.parametrize( + ["expected_shim_version", "expected_api_version"], + [ + # final versions with optional build metadata ("local segment" according to PEP 440); + # boundary-value cases + pytest.param((0, 18, 33), 1, marks=pytest.mark.shim_version("0.18.33")), + pytest.param((0, 18, 33), 1, marks=pytest.mark.shim_version("0.18.33+build.1")), + pytest.param((0, 18, 34), 2, marks=pytest.mark.shim_version("0.18.34")), + pytest.param((0, 18, 34), 2, marks=pytest.mark.shim_version("0.18.34+build.1")), + # looks like major-only version, but not a version at all (stgn build), + # assuming the latest version + pytest.param(None, 2, marks=pytest.mark.shim_version("1494")), + # invalid versions, assuming local builds with the latest version + pytest.param(None, 2, marks=pytest.mark.shim_version("latest")), + pytest.param(None, 2, marks=pytest.mark.shim_version("0.17.0-next")), + # even though this version is less than _FUTURE_API_MIN_VERSION, for the sake of + # simplicity we assume that any non-final version is the latest; normally, users + # should not use non-latest RC versions + pytest.param(None, 2, marks=pytest.mark.shim_version("0.17.0rc1")), + ], + ) + def test( + self, + client: ShimClient, + adapter: requests_mock.Adapter, + expected_shim_version: Optional[tuple[int, int, int]], + expected_api_version: int, + ): + assert not hasattr(client, "_shim_version") + assert not hasattr(client, "_api_version") + + client._negotiate() + + assert client._shim_version == expected_shim_version + assert client._api_version == expected_api_version + assert adapter.call_count == 1 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + + +@pytest.mark.shim_version("0.18.30") +class TestShimClientV1(BaseShimClientTest): + def test_healthcheck(self, client: ShimClient, adapter: requests_mock.Adapter): + resp = client.healthcheck() + + assert resp == HealthcheckResponse(service="dstack-shim", version="0.18.30") + assert adapter.call_count == 1 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + + def test_submit(self, client: ShimClient, adapter: requests_mock.Adapter): + adapter.register_uri("POST", "/api/submit", json={"state": "pulling"}) + volume = get_volume( + name="vol", + volume_id="vol-id", + configuration=get_volume_configuration(backend=BackendType.GCP), + external=False, + device_name="/dev/sdv", + ) + + submitted = client.submit( + username="", + password="", + image_name="debian", + privileged=False, + container_name="test-0-0", + container_user="root", + shm_size=None, + public_keys=["project_key", "user_key"], + ssh_user="dstack", + ssh_key="host_key", + mounts=[VolumeMountPoint(name="vol", path="/vol")], + volumes=[volume], + instance_mounts=[InstanceMountPoint(instance_path="/mnt/nfs/home", path="/home")], + ) + + assert submitted is True + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + expected_request = { + "username": "", + "password": "", + "image_name": "debian", + "privileged": False, + "container_name": "test-0-0", + "container_user": "root", + "shm_size": 0, + "public_keys": ["project_key", "user_key"], + "ssh_user": "dstack", + "ssh_key": "host_key", + "mounts": [{"name": "vol", "path": "/vol"}], + "volumes": [ + { + "backend": "gcp", + "name": "vol", + "volume_id": "vol-id", + "init_fs": True, + "device_name": "/dev/sdv", + } + ], + "instance_mounts": [{"instance_path": "/mnt/nfs/home", "path": "/home"}], + } + self.assert_request(adapter, 1, "POST", "/api/submit", expected_request) + + def test_submit_conflict(self, client: ShimClient, adapter: requests_mock.Adapter): + adapter.register_uri("POST", "/api/submit", status_code=409) + + submitted = client.submit( + username="", + password="", + image_name="debian", + privileged=False, + container_name="test-0-0", + container_user="root", + shm_size=None, + public_keys=["project_key", "user_key"], + ssh_user="dstack", + ssh_key="host_key", + mounts=[], + volumes=[], + instance_mounts=[], + ) + + assert submitted is False + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "POST", "/api/submit") + + def test_stop(self, client: ShimClient, adapter: requests_mock.Adapter): + adapter.register_uri("POST", "/api/stop", json={"state": "pending"}) + + client.stop() + + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "POST", "/api/stop", {"force": False}) + + def test_stop_force(self, client: ShimClient, adapter: requests_mock.Adapter): + adapter.register_uri("POST", "/api/stop", json={"state": "pending"}) + + client.stop(force=True) + + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "POST", "/api/stop", {"force": True}) + + def test_pull(self, client: ShimClient, adapter: requests_mock.Adapter): + adapter.register_uri( + "GET", + "/api/pull", + json={ + "state": "pending", + "result": {"reason": "CONTAINER_EXITED_WITH_ERROR", "reason_message": "killed"}, + }, + ) + + resp = client.pull() + + assert resp == LegacyPullResponse( + state="pending", + result=JobResult(reason="CONTAINER_EXITED_WITH_ERROR", reason_message="killed"), + ) + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "GET", "/api/pull") + + +@pytest.mark.shim_version("0.18.40") +class TestShimClientV2Compat(BaseShimClientTest): + def test_healthcheck(self, client: ShimClient, adapter: requests_mock.Adapter): + resp = client.healthcheck() + + assert resp == HealthcheckResponse(service="dstack-shim", version="0.18.40") + assert adapter.call_count == 1 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + + def test_submit(self, client: ShimClient, adapter: requests_mock.Adapter): + tasks_url = "/api/tasks" + legacy_task_url = f"{tasks_url}/00000000-0000-0000-0000-000000000000" + remove_legacy_task_url = f"{legacy_task_url}/remove" + adapter.register_uri( + "GET", + legacy_task_url, + json={ + "id": "00000000-0000-0000-0000-000000000000", + "status": "terminated", + "termination_reason": "CONTAINER_EXITED_WITH_ERROR", + "termination_message": "killed", + "container_name": "horrible-mule-1-0-0-44f7cb95", # ignored + }, + ) + adapter.register_uri("POST", remove_legacy_task_url) + adapter.register_uri("POST", tasks_url) + volume = get_volume( + name="vol", + volume_id="vol-id", + configuration=get_volume_configuration(backend=BackendType.GCP), + external=False, + device_name="/dev/sdv", + ) + + submitted = client.submit( + username="user", + password="pass", + image_name="debian", + privileged=True, + container_name="test-0-0", + container_user="root", + shm_size=Memory.parse("512MB"), + public_keys=["project_key", "user_key"], + ssh_user="dstack", + ssh_key="host_key", + mounts=[VolumeMountPoint(name="vol", path="/vol")], + volumes=[volume], + instance_mounts=[InstanceMountPoint(instance_path="/mnt/nfs/home", path="/home")], + ) + + assert submitted is True + assert adapter.call_count == 4 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "GET", legacy_task_url) + self.assert_request(adapter, 2, "POST", remove_legacy_task_url) + expected_request = { + "id": "00000000-0000-0000-0000-000000000000", + "name": "test-0-0", + "registry_username": "user", + "registry_password": "pass", + "image_name": "debian", + "container_user": "root", + "privileged": True, + "gpu": -1, + "cpu": 0, + "memory": 0, + "shm_size": 536870912, + "volumes": [ + { + "backend": "gcp", + "name": "vol", + "volume_id": "vol-id", + "init_fs": True, + "device_name": "/dev/sdv", + } + ], + "volume_mounts": [{"name": "vol", "path": "/vol"}], + "instance_mounts": [{"instance_path": "/mnt/nfs/home", "path": "/home"}], + "host_ssh_user": "dstack", + "host_ssh_keys": ["host_key"], + "container_ssh_keys": ["project_key", "user_key"], + } + self.assert_request(adapter, 3, "POST", tasks_url, expected_request) + + def test_submit_no_task(self, client: ShimClient, adapter: requests_mock.Adapter): + tasks_url = "/api/tasks" + legacy_task_url = f"{tasks_url}/00000000-0000-0000-0000-000000000000" + adapter.register_uri("GET", legacy_task_url, status_code=404) + adapter.register_uri("POST", tasks_url) + volume = get_volume( + name="vol", + volume_id="vol-id", + configuration=get_volume_configuration(backend=BackendType.GCP), + external=False, + device_name="/dev/sdv", + ) + + submitted = client.submit( + username="user", + password="pass", + image_name="debian", + privileged=True, + container_name="test-0-0", + container_user="root", + shm_size=Memory.parse("512MB"), + public_keys=["project_key", "user_key"], + ssh_user="dstack", + ssh_key="host_key", + mounts=[VolumeMountPoint(name="vol", path="/vol")], + volumes=[volume], + instance_mounts=[InstanceMountPoint(instance_path="/mnt/nfs/home", path="/home")], + ) + + assert submitted is True + assert adapter.call_count == 3 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "GET", legacy_task_url) + expected_request = { + "id": "00000000-0000-0000-0000-000000000000", + "name": "test-0-0", + "registry_username": "user", + "registry_password": "pass", + "image_name": "debian", + "container_user": "root", + "privileged": True, + "gpu": -1, + "cpu": 0, + "memory": 0, + "shm_size": 536870912, + "volumes": [ + { + "backend": "gcp", + "name": "vol", + "volume_id": "vol-id", + "init_fs": True, + "device_name": "/dev/sdv", + } + ], + "volume_mounts": [{"name": "vol", "path": "/vol"}], + "instance_mounts": [{"instance_path": "/mnt/nfs/home", "path": "/home"}], + "host_ssh_user": "dstack", + "host_ssh_keys": ["host_key"], + "container_ssh_keys": ["project_key", "user_key"], + } + self.assert_request(adapter, 2, "POST", tasks_url, expected_request) + + def test_submit_conflict(self, client: ShimClient, adapter: requests_mock.Adapter): + tasks_url = "/api/tasks" + legacy_task_url = f"{tasks_url}/00000000-0000-0000-0000-000000000000" + adapter.register_uri( + "GET", + legacy_task_url, + json={ + "id": "00000000-0000-0000-0000-000000000000", + "status": "running", + "termination_reason": "", + "termination_message": "", + }, + ) + adapter.register_uri("POST", tasks_url) + volume = get_volume( + name="vol", + volume_id="vol-id", + configuration=get_volume_configuration(backend=BackendType.GCP), + external=False, + device_name="/dev/sdv", + ) + + submitted = client.submit( + username="user", + password="pass", + image_name="debian", + privileged=True, + container_name="test-0-0", + container_user="root", + shm_size=Memory.parse("512MB"), + public_keys=["project_key", "user_key"], + ssh_user="dstack", + ssh_key="host_key", + mounts=[VolumeMountPoint(name="vol", path="/vol")], + volumes=[volume], + instance_mounts=[InstanceMountPoint(instance_path="/mnt/nfs/home", path="/home")], + ) + + assert submitted is False + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "GET", legacy_task_url) + + def test_stop(self, client: ShimClient, adapter: requests_mock.Adapter): + url = "/api/tasks/00000000-0000-0000-0000-000000000000/terminate" + adapter.register_uri( + "POST", + url, + json={ + "id": "00000000-0000-0000-0000-000000000000", + "status": "terminated", + "termination_reason": "", + "termination_message": "", + }, + ) + + client.stop() + + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + expected_request = {"termination_reason": "", "termination_message": "", "timeout": 10} + self.assert_request(adapter, 1, "POST", url, expected_request) + + def test_stop_no_task(self, client: ShimClient, adapter: requests_mock.Adapter): + url = "/api/tasks/00000000-0000-0000-0000-000000000000/terminate" + adapter.register_uri("POST", url, status_code=404) + + client.stop() + + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + expected_request = {"termination_reason": "", "termination_message": "", "timeout": 10} + self.assert_request(adapter, 1, "POST", url, expected_request) + + def test_stop_force(self, client: ShimClient, adapter: requests_mock.Adapter): + url = "/api/tasks/00000000-0000-0000-0000-000000000000/terminate" + adapter.register_uri( + "POST", + url, + json={ + "id": "00000000-0000-0000-0000-000000000000", + "status": "terminated", + "termination_reason": "", + "termination_message": "", + }, + ) + + client.stop(force=True) + + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + expected_request = {"termination_reason": "", "termination_message": "", "timeout": 0} + self.assert_request(adapter, 1, "POST", url, expected_request) + + def test_pull(self, client: ShimClient, adapter: requests_mock.Adapter): + adapter.register_uri( + "GET", + "/api/tasks/00000000-0000-0000-0000-000000000000", + json={ + "id": "00000000-0000-0000-0000-000000000000", + "status": "terminated", + "termination_reason": "CONTAINER_EXITED_WITH_ERROR", + "termination_message": "killed", + "container_name": "horrible-mule-1-0-0-44f7cb95", # ignored + }, + ) + + resp = client.pull() + + assert resp == LegacyPullResponse( + state="pending", + result=JobResult(reason="CONTAINER_EXITED_WITH_ERROR", reason_message="killed"), + ) + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "GET", "/api/tasks/00000000-0000-0000-0000-000000000000") + + def test_pull_no_task(self, client: ShimClient, adapter: requests_mock.Adapter): + adapter.register_uri( + "GET", + "/api/tasks/00000000-0000-0000-0000-000000000000", + status_code=404, + ) + + resp = client.pull() + + assert resp == LegacyPullResponse( + state="pending", + result=JobResult(reason="", reason_message=""), + ) + assert adapter.call_count == 2 + self.assert_request(adapter, 0, "GET", "/api/healthcheck") + self.assert_request(adapter, 1, "GET", "/api/tasks/00000000-0000-0000-0000-000000000000") + + +class TestParseVersion: + @pytest.mark.parametrize( + ["value", "expected"], + [ + ["1.12", (1, 12, 0)], + ["1.12.3", (1, 12, 3)], + ["1.12.3.1", (1, 12, 3)], + ["1.12.3+build.1", (1, 12, 3)], # local builds are OK + ], + ) + def test_valid_final(self, value: str, expected: tuple[int, int, int]): + assert _parse_version(value) == expected + + @pytest.mark.parametrize("value", ["1.12alpha1", "1.12.3rc1", "1.12.3.dev0"]) + def test_valid_pre_dev_local(self, value: str): + assert _parse_version(value) is None + + @pytest.mark.parametrize("value", ["1", "1234"]) + def test_valid_major_only(self, value: str): + assert _parse_version(value) is None + + @pytest.mark.parametrize("value", ["", "foo", "1.12.3-next.20241231"]) + def test_invalid(self, value: str): + assert _parse_version(value) is None