Skip to content

Commit

Permalink
Add TensorDock backend type (#746)
Browse files Browse the repository at this point in the history
* Implement TensorDockAPIClient

* Implement TensorDockBackend

* Fix tests

* Kill instance after 15 seconds of terminating

* Update server/config.yml reference

* Allow tensordock backend in shim

* Don't wait for terminating completion in CLI

* TensorDock: raise NoCapacityError if failed to deploy
  • Loading branch information
Egor-S authored Oct 27, 2023
1 parent 4433017 commit f68efea
Show file tree
Hide file tree
Showing 21 changed files with 568 additions and 40 deletions.
6 changes: 5 additions & 1 deletion docs/docs/reference/server/config.yml.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#SCHEMA# dstack._internal.server.services.config.ProjectConfig
overrides:
backends:
type: 'Union[AWSConfigInfoWithCreds, AzureConfigInfoWithCreds, GCPConfigInfoWithCreds, LambdaConfigInfoWithCreds]'
type: 'Union[AWSConfigInfoWithCreds, AzureConfigInfoWithCreds, GCPConfigInfoWithCreds, LambdaConfigInfoWithCreds, TensorDockInfoWithCreds]'

#SCHEMA# dstack._internal.server.services.config.AWSConfig

Expand All @@ -30,3 +30,7 @@
#SCHEMA# dstack._internal.server.services.config.LambdaConfig

##SCHEMA# dstack._internal.core.models.backends.lambdalabs.LambdaAPIKeyCreds

#SCHEMA# dstack._internal.server.services.config.TensorDockConfig

##SCHEMA# dstack._internal.core.models.backends.tensordock.TensorDockAPIKeyCreds
2 changes: 1 addition & 1 deletion runner/cmd/shim/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
Destination: &backendName,
EnvVars: []string{"DSTACK_BACKEND"},
Action: func(c *cli.Context, s string) error {
for _, backend := range []string{"aws", "azure", "gcp", "lambda", "local"} {
for _, backend := range []string{"aws", "azure", "gcp", "lambda", "tensordock", "local"} {
if s == backend {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/cli/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _command(self, args: argparse.Namespace):
# Gently stop the run and wait for it to finish
with console.status("Stopping..."):
run.stop(abort=False)
while not run.status.is_finished():
while not (run.status.is_finished() or run.status == RunStatus.TERMINATING):
time.sleep(2)
run.refresh()
console.print("Stopped")
Expand Down
31 changes: 16 additions & 15 deletions src/dstack/_internal/core/backends/base/offers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@ def get_catalog_offers(
requirements: Optional[Requirements] = None,
extra_filter: Optional[Callable[[InstanceOffer], bool]] = None,
) -> List[InstanceOffer]:
filters = dict(
provider=[provider],
min_cpu=requirements.cpus,
max_price=requirements.max_price,
spot=requirements.spot,
)
if requirements.memory_mib is not None:
filters["min_memory"] = requirements.memory_mib / 1024
if requirements.gpus is not None:
if requirements.gpus.name is not None:
filters["gpu_name"] = [requirements.gpus.name]
if requirements.gpus.memory_mib is not None:
filters["min_gpu_memory"] = requirements.gpus.memory_mib / 1024
if requirements.gpus.count is not None:
filters["min_gpu_count"] = requirements.gpus.count
filters = dict(provider=[provider])
if requirements is not None:
filters.update(
min_cpu=requirements.cpus,
max_price=requirements.max_price,
spot=requirements.spot,
)
if requirements.memory_mib is not None:
filters["min_memory"] = requirements.memory_mib / 1024
if requirements.gpus is not None:
if requirements.gpus.name is not None:
filters["gpu_name"] = [requirements.gpus.name]
if requirements.gpus.memory_mib is not None:
filters["min_gpu_memory"] = requirements.gpus.memory_mib / 1024
if requirements.gpus.count is not None:
filters["min_gpu_count"] = requirements.gpus.count

offers = []
for item in gpuhunt.query(**filters):
Expand Down
15 changes: 15 additions & 0 deletions src/dstack/_internal/core/backends/tensordock/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dstack._internal.core.backends.base import Backend
from dstack._internal.core.backends.tensordock.compute import TensorDockCompute
from dstack._internal.core.backends.tensordock.config import TensorDockConfig
from dstack._internal.core.models.backends.base import BackendType


class TensorDockBackend(Backend):
TYPE: BackendType = BackendType.TENSORDOCK

def __init__(self, config: TensorDockConfig):
self.config = config
self._compute = TensorDockCompute(self.config)

def compute(self) -> TensorDockCompute:
return self._compute
92 changes: 92 additions & 0 deletions src/dstack/_internal/core/backends/tensordock/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import uuid

import requests
import yaml

from dstack._internal.core.models.instances import InstanceType
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


class TensorDockAPIClient:
def __init__(self, api_key: str, api_token: str):
self.api_url = "https://marketplace.tensordock.com/api/v0".rstrip("/")
self.api_key = api_key
self.api_token = api_token
self.s = requests.Session()

def auth_test(self) -> bool:
resp = self.s.post(
self._url("/auth/test"), data={"api_key": self.api_key, "api_token": self.api_token}
)
resp.raise_for_status()
return resp.json()["success"]

def get_hostnode(self, hostnode_id: str) -> dict:
logger.debug("Fetching hostnode %s", hostnode_id)
resp = self.s.get(self._url(f"/client/deploy/hostnodes/{hostnode_id}"))
resp.raise_for_status()
data = resp.json()
if not data["success"]:
raise requests.HTTPError(data)
return data["hostnode"]

def deploy_single(self, instance_name: str, instance: InstanceType, cloudinit: dict) -> dict:
hostnode = self.get_hostnode(instance.name)
gpu = instance.resources.gpus[0]
for gpu_model in hostnode["specs"]["gpu"].keys():
if gpu_model.endswith(f"-{gpu.memory_mib // 1024}gb"):
if gpu.name.lower() in gpu_model.lower():
break
else:
raise ValueError(f"Can't find GPU on the hostnode: {gpu.name}")
form = {
"api_key": self.api_key,
"api_token": self.api_token,
"password": uuid.uuid4().hex, # we disable the password auth, but it's required
"name": instance_name,
"gpu_count": len(instance.resources.gpus),
"gpu_model": gpu_model,
"vcpus": instance.resources.cpus,
"ram": instance.resources.memory_mib // 1024,
"external_ports": "{%s}" % hostnode["networking"]["ports"][0],
"internal_ports": "{22}",
"hostnode": instance.name,
"storage": 100, # TODO(egor-s): take from instance.resources
"operating_system": "Ubuntu 22.04 LTS",
"cloudinit_script": yaml.dump(cloudinit).replace("\n", "\\n"),
}
logger.debug(
"Deploying instance hostnode=%s, cpus=%s, memory=%s, gpu=%sx %s",
form["hostnode"],
form["vcpus"],
form["ram"],
form["gpu_count"],
form["gpu_model"],
)
resp = self.s.post(self._url("/client/deploy/single"), data=form)
resp.raise_for_status()
data = resp.json()
if not data["success"]:
raise requests.HTTPError(data)
data["password"] = form["password"]
return data

def delete_single(self, instance_id: str):
logger.debug("Deleting instance %s", instance_id)
resp = self.s.post(
self._url("/client/delete/single"),
data={
"api_key": self.api_key,
"api_token": self.api_token,
"server": instance_id,
},
)
resp.raise_for_status()
data = resp.json()
if not data["success"]:
raise requests.HTTPError(data)

def _url(self, path):
return f"{self.api_url}/{path.lstrip('/')}"
93 changes: 93 additions & 0 deletions src/dstack/_internal/core/backends/tensordock/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import List, Optional

import requests

from dstack._internal.core.backends.base import Compute
from dstack._internal.core.backends.base.compute import get_shim_commands
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.backends.tensordock.api_client import TensorDockAPIClient
from dstack._internal.core.backends.tensordock.config import TensorDockConfig
from dstack._internal.core.errors import NoCapacityError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import (
InstanceAvailability,
InstanceOfferWithAvailability,
LaunchedInstanceInfo,
)
from dstack._internal.core.models.runs import Job, Requirements, Run


class TensorDockCompute(Compute):
def __init__(self, config: TensorDockConfig):
self.config = config
self.api_client = TensorDockAPIClient(config.creds.api_key, config.creds.api_token)

def get_offers(
self, requirements: Optional[Requirements] = None
) -> List[InstanceOfferWithAvailability]:
offers = get_catalog_offers(
provider=BackendType.TENSORDOCK.value,
requirements=requirements,
)
offers = [
InstanceOfferWithAvailability(
**offer.dict(), availability=InstanceAvailability.AVAILABLE
)
for offer in offers
]
return offers

def run_job(
self,
run: Run,
job: Job,
instance_offer: InstanceOfferWithAvailability,
project_ssh_public_key: str,
project_ssh_private_key: str,
) -> LaunchedInstanceInfo:
commands = get_shim_commands(
backend=BackendType.TENSORDOCK,
image_name=job.job_spec.image_name,
authorized_keys=[
run.run_spec.ssh_key_pub.strip(),
project_ssh_public_key.strip(),
],
registry_auth_required=job.job_spec.registry_auth is not None,
)
try:
resp = self.api_client.deploy_single(
instance_name=job.job_spec.job_name,
instance=instance_offer.instance,
cloudinit={
"ssh_pwauth": False, # disable password auth
"users": [
"default",
{
"name": "user",
"ssh_authorized_keys": [
run.run_spec.ssh_key_pub.strip(),
project_ssh_public_key.strip(),
],
},
],
"runcmd": [
["sh", "-c", " && ".join(commands)],
],
},
)
except requests.HTTPError:
raise NoCapacityError()
return LaunchedInstanceInfo(
instance_id=resp["server"],
ip_address=resp["ip"],
region=instance_offer.region,
username="user",
ssh_port={v: k for k, v in resp["port_forwards"].items()}["22"],
dockerized=True,
)

def terminate_instance(self, instance_id: str, region: str):
try:
self.api_client.delete_single(instance_id)
except requests.HTTPError:
pass
9 changes: 9 additions & 0 deletions src/dstack/_internal/core/backends/tensordock/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dstack._internal.core.backends.base.config import BackendConfig
from dstack._internal.core.models.backends.tensordock import (
AnyTensorDockCreds,
TensorDockStoredConfig,
)


class TensorDockConfig(TensorDockStoredConfig, BackendConfig):
creds: AnyTensorDockCreds
10 changes: 10 additions & 0 deletions src/dstack/_internal/core/models/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,33 @@
LambdaConfigInfoWithCredsPartial,
LambdaConfigValues,
)
from dstack._internal.core.models.backends.tensordock import (
TensorDockConfigInfo,
TensorDockConfigInfoWithCreds,
TensorDockConfigInfoWithCredsPartial,
TensorDockConfigValues,
)

AnyConfigInfoWithoutCreds = Union[
AWSConfigInfo,
AzureConfigInfo,
GCPConfigInfo,
LambdaConfigInfo,
TensorDockConfigInfo,
]
AnyConfigInfoWithCreds = Union[
AWSConfigInfoWithCreds,
AzureConfigInfoWithCreds,
GCPConfigInfoWithCreds,
LambdaConfigInfoWithCreds,
TensorDockConfigInfoWithCreds,
]
AnyConfigInfoWithCredsPartial = Union[
AWSConfigInfoWithCredsPartial,
AzureConfigInfoWithCredsPartial,
GCPConfigInfoWithCredsPartial,
LambdaConfigInfoWithCredsPartial,
TensorDockConfigInfoWithCredsPartial,
]
AnyConfigInfo = Union[AnyConfigInfoWithoutCreds, AnyConfigInfoWithCreds]

Expand All @@ -53,6 +62,7 @@
AzureConfigValues,
GCPConfigValues,
LambdaConfigValues,
TensorDockConfigValues,
]


Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class BackendType(str, enum.Enum):
GCP = "gcp"
LAMBDA = "lambda"
LOCAL = "local"
TENSORDOCK = "tensordock"


class ConfigElementValue(BaseModel):
Expand Down
46 changes: 46 additions & 0 deletions src/dstack/_internal/core/models/backends/tensordock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import List, Optional, Union

from pydantic import BaseModel
from typing_extensions import Literal

from dstack._internal.core.models.backends.base import ConfigMultiElement
from dstack._internal.core.models.common import ForbidExtra


class TensorDockConfigInfo(BaseModel):
type: Literal["tensordock"] = "tensordock"
regions: Optional[List[str]] = None


class TensorDockAPIKeyCreds(ForbidExtra):
type: Literal["api_key"] = "api_key"
api_key: str
api_token: str


AnyTensorDockCreds = TensorDockAPIKeyCreds


TensorDockCreds = AnyTensorDockCreds


class TensorDockConfigInfoWithCreds(TensorDockConfigInfo):
creds: AnyTensorDockCreds


AnyTensorDockConfigInfo = Union[TensorDockConfigInfo, TensorDockConfigInfoWithCreds]


class TensorDockConfigInfoWithCredsPartial(BaseModel):
type: Literal["tensordock"] = "tensordock"
creds: Optional[AnyTensorDockCreds]
regions: Optional[List[str]]


class TensorDockConfigValues(BaseModel):
type: Literal["tensordock"] = "tensordock"
regions: Optional[ConfigMultiElement]


class TensorDockStoredConfig(TensorDockConfigInfo):
pass
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/background/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
from dstack._internal.server.background.tasks.process_pending_jobs import process_pending_jobs
from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs
from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs
from dstack._internal.server.background.tasks.process_terminating_jobs import (
process_terminating_jobs,
)


def start_background_tasks() -> AsyncIOScheduler:
scheduler = AsyncIOScheduler()
scheduler.add_job(process_submitted_jobs, IntervalTrigger(seconds=2))
scheduler.add_job(process_running_jobs, IntervalTrigger(seconds=2))
scheduler.add_job(process_terminating_jobs, IntervalTrigger(seconds=2))
scheduler.add_job(process_pending_jobs, IntervalTrigger(seconds=10))
scheduler.start()
return scheduler
Loading

0 comments on commit f68efea

Please sign in to comment.