-
Notifications
You must be signed in to change notification settings - Fork 164
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement TensorDockAPIClient * Implement TensorDockBackend * Fix tests * Kill instance after 15 seconds of terminating * Update server/config.yml reference * Allow tensordock backend in shim * Don't wait for terminating completion in CLI * TensorDock: raise NoCapacityError if failed to deploy
- Loading branch information
Showing
21 changed files
with
568 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from dstack._internal.core.backends.base import Backend | ||
from dstack._internal.core.backends.tensordock.compute import TensorDockCompute | ||
from dstack._internal.core.backends.tensordock.config import TensorDockConfig | ||
from dstack._internal.core.models.backends.base import BackendType | ||
|
||
|
||
class TensorDockBackend(Backend): | ||
TYPE: BackendType = BackendType.TENSORDOCK | ||
|
||
def __init__(self, config: TensorDockConfig): | ||
self.config = config | ||
self._compute = TensorDockCompute(self.config) | ||
|
||
def compute(self) -> TensorDockCompute: | ||
return self._compute |
92 changes: 92 additions & 0 deletions
92
src/dstack/_internal/core/backends/tensordock/api_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import uuid | ||
|
||
import requests | ||
import yaml | ||
|
||
from dstack._internal.core.models.instances import InstanceType | ||
from dstack._internal.utils.logging import get_logger | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class TensorDockAPIClient: | ||
def __init__(self, api_key: str, api_token: str): | ||
self.api_url = "https://marketplace.tensordock.com/api/v0".rstrip("/") | ||
self.api_key = api_key | ||
self.api_token = api_token | ||
self.s = requests.Session() | ||
|
||
def auth_test(self) -> bool: | ||
resp = self.s.post( | ||
self._url("/auth/test"), data={"api_key": self.api_key, "api_token": self.api_token} | ||
) | ||
resp.raise_for_status() | ||
return resp.json()["success"] | ||
|
||
def get_hostnode(self, hostnode_id: str) -> dict: | ||
logger.debug("Fetching hostnode %s", hostnode_id) | ||
resp = self.s.get(self._url(f"/client/deploy/hostnodes/{hostnode_id}")) | ||
resp.raise_for_status() | ||
data = resp.json() | ||
if not data["success"]: | ||
raise requests.HTTPError(data) | ||
return data["hostnode"] | ||
|
||
def deploy_single(self, instance_name: str, instance: InstanceType, cloudinit: dict) -> dict: | ||
hostnode = self.get_hostnode(instance.name) | ||
gpu = instance.resources.gpus[0] | ||
for gpu_model in hostnode["specs"]["gpu"].keys(): | ||
if gpu_model.endswith(f"-{gpu.memory_mib // 1024}gb"): | ||
if gpu.name.lower() in gpu_model.lower(): | ||
break | ||
else: | ||
raise ValueError(f"Can't find GPU on the hostnode: {gpu.name}") | ||
form = { | ||
"api_key": self.api_key, | ||
"api_token": self.api_token, | ||
"password": uuid.uuid4().hex, # we disable the password auth, but it's required | ||
"name": instance_name, | ||
"gpu_count": len(instance.resources.gpus), | ||
"gpu_model": gpu_model, | ||
"vcpus": instance.resources.cpus, | ||
"ram": instance.resources.memory_mib // 1024, | ||
"external_ports": "{%s}" % hostnode["networking"]["ports"][0], | ||
"internal_ports": "{22}", | ||
"hostnode": instance.name, | ||
"storage": 100, # TODO(egor-s): take from instance.resources | ||
"operating_system": "Ubuntu 22.04 LTS", | ||
"cloudinit_script": yaml.dump(cloudinit).replace("\n", "\\n"), | ||
} | ||
logger.debug( | ||
"Deploying instance hostnode=%s, cpus=%s, memory=%s, gpu=%sx %s", | ||
form["hostnode"], | ||
form["vcpus"], | ||
form["ram"], | ||
form["gpu_count"], | ||
form["gpu_model"], | ||
) | ||
resp = self.s.post(self._url("/client/deploy/single"), data=form) | ||
resp.raise_for_status() | ||
data = resp.json() | ||
if not data["success"]: | ||
raise requests.HTTPError(data) | ||
data["password"] = form["password"] | ||
return data | ||
|
||
def delete_single(self, instance_id: str): | ||
logger.debug("Deleting instance %s", instance_id) | ||
resp = self.s.post( | ||
self._url("/client/delete/single"), | ||
data={ | ||
"api_key": self.api_key, | ||
"api_token": self.api_token, | ||
"server": instance_id, | ||
}, | ||
) | ||
resp.raise_for_status() | ||
data = resp.json() | ||
if not data["success"]: | ||
raise requests.HTTPError(data) | ||
|
||
def _url(self, path): | ||
return f"{self.api_url}/{path.lstrip('/')}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from typing import List, Optional | ||
|
||
import requests | ||
|
||
from dstack._internal.core.backends.base import Compute | ||
from dstack._internal.core.backends.base.compute import get_shim_commands | ||
from dstack._internal.core.backends.base.offers import get_catalog_offers | ||
from dstack._internal.core.backends.tensordock.api_client import TensorDockAPIClient | ||
from dstack._internal.core.backends.tensordock.config import TensorDockConfig | ||
from dstack._internal.core.errors import NoCapacityError | ||
from dstack._internal.core.models.backends.base import BackendType | ||
from dstack._internal.core.models.instances import ( | ||
InstanceAvailability, | ||
InstanceOfferWithAvailability, | ||
LaunchedInstanceInfo, | ||
) | ||
from dstack._internal.core.models.runs import Job, Requirements, Run | ||
|
||
|
||
class TensorDockCompute(Compute): | ||
def __init__(self, config: TensorDockConfig): | ||
self.config = config | ||
self.api_client = TensorDockAPIClient(config.creds.api_key, config.creds.api_token) | ||
|
||
def get_offers( | ||
self, requirements: Optional[Requirements] = None | ||
) -> List[InstanceOfferWithAvailability]: | ||
offers = get_catalog_offers( | ||
provider=BackendType.TENSORDOCK.value, | ||
requirements=requirements, | ||
) | ||
offers = [ | ||
InstanceOfferWithAvailability( | ||
**offer.dict(), availability=InstanceAvailability.AVAILABLE | ||
) | ||
for offer in offers | ||
] | ||
return offers | ||
|
||
def run_job( | ||
self, | ||
run: Run, | ||
job: Job, | ||
instance_offer: InstanceOfferWithAvailability, | ||
project_ssh_public_key: str, | ||
project_ssh_private_key: str, | ||
) -> LaunchedInstanceInfo: | ||
commands = get_shim_commands( | ||
backend=BackendType.TENSORDOCK, | ||
image_name=job.job_spec.image_name, | ||
authorized_keys=[ | ||
run.run_spec.ssh_key_pub.strip(), | ||
project_ssh_public_key.strip(), | ||
], | ||
registry_auth_required=job.job_spec.registry_auth is not None, | ||
) | ||
try: | ||
resp = self.api_client.deploy_single( | ||
instance_name=job.job_spec.job_name, | ||
instance=instance_offer.instance, | ||
cloudinit={ | ||
"ssh_pwauth": False, # disable password auth | ||
"users": [ | ||
"default", | ||
{ | ||
"name": "user", | ||
"ssh_authorized_keys": [ | ||
run.run_spec.ssh_key_pub.strip(), | ||
project_ssh_public_key.strip(), | ||
], | ||
}, | ||
], | ||
"runcmd": [ | ||
["sh", "-c", " && ".join(commands)], | ||
], | ||
}, | ||
) | ||
except requests.HTTPError: | ||
raise NoCapacityError() | ||
return LaunchedInstanceInfo( | ||
instance_id=resp["server"], | ||
ip_address=resp["ip"], | ||
region=instance_offer.region, | ||
username="user", | ||
ssh_port={v: k for k, v in resp["port_forwards"].items()}["22"], | ||
dockerized=True, | ||
) | ||
|
||
def terminate_instance(self, instance_id: str, region: str): | ||
try: | ||
self.api_client.delete_single(instance_id) | ||
except requests.HTTPError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from dstack._internal.core.backends.base.config import BackendConfig | ||
from dstack._internal.core.models.backends.tensordock import ( | ||
AnyTensorDockCreds, | ||
TensorDockStoredConfig, | ||
) | ||
|
||
|
||
class TensorDockConfig(TensorDockStoredConfig, BackendConfig): | ||
creds: AnyTensorDockCreds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import List, Optional, Union | ||
|
||
from pydantic import BaseModel | ||
from typing_extensions import Literal | ||
|
||
from dstack._internal.core.models.backends.base import ConfigMultiElement | ||
from dstack._internal.core.models.common import ForbidExtra | ||
|
||
|
||
class TensorDockConfigInfo(BaseModel): | ||
type: Literal["tensordock"] = "tensordock" | ||
regions: Optional[List[str]] = None | ||
|
||
|
||
class TensorDockAPIKeyCreds(ForbidExtra): | ||
type: Literal["api_key"] = "api_key" | ||
api_key: str | ||
api_token: str | ||
|
||
|
||
AnyTensorDockCreds = TensorDockAPIKeyCreds | ||
|
||
|
||
TensorDockCreds = AnyTensorDockCreds | ||
|
||
|
||
class TensorDockConfigInfoWithCreds(TensorDockConfigInfo): | ||
creds: AnyTensorDockCreds | ||
|
||
|
||
AnyTensorDockConfigInfo = Union[TensorDockConfigInfo, TensorDockConfigInfoWithCreds] | ||
|
||
|
||
class TensorDockConfigInfoWithCredsPartial(BaseModel): | ||
type: Literal["tensordock"] = "tensordock" | ||
creds: Optional[AnyTensorDockCreds] | ||
regions: Optional[List[str]] | ||
|
||
|
||
class TensorDockConfigValues(BaseModel): | ||
type: Literal["tensordock"] = "tensordock" | ||
regions: Optional[ConfigMultiElement] | ||
|
||
|
||
class TensorDockStoredConfig(TensorDockConfigInfo): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.