Skip to content

Commit

Permalink
Implement offers cache (#2197)
Browse files Browse the repository at this point in the history
* Implement get_offers_cached()

* Limit active sessions to prevent possible QueuePool overflow errors

* Fix tests
  • Loading branch information
r4victor authored Jan 20, 2025
1 parent 1fdb91a commit 91bdc80
Show file tree
Hide file tree
Showing 22 changed files with 88 additions and 32 deletions.
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class AWSVolumeBackendData(CoreModel):

class AWSCompute(Compute):
def __init__(self, config: AWSConfig):
super().__init__()
self.config = config
if is_core_model_instance(config.creds, AWSAccessKeyCreds):
self.session = boto3.Session(
Expand Down
17 changes: 14 additions & 3 deletions src/dstack/_internal/core/backends/azure/compute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import enum
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Tuple

from azure.core.credentials import TokenCredential
Expand Down Expand Up @@ -70,6 +71,7 @@

class AzureCompute(Compute):
def __init__(self, config: AzureConfig, credential: TokenCredential):
super().__init__()
self.config = config
self.credential = credential
self._compute_client = compute_mgmt.ComputeManagementClient(
Expand Down Expand Up @@ -391,13 +393,22 @@ def _get_offers_with_availability(
offers = [offer for offer in offers if offer.region in config_locations]
locations = set(offer.region for offer in offers)

has_quota = set()
for location in locations:
def get_location_quotas(location: str) -> List[str]:
quotas = []
resources = compute_client.resource_skus.list(filter=f"location eq '{location}'")
for resource in resources:
if resource.resource_type != "virtualMachines" or not _vm_type_available(resource):
continue
has_quota.add((resource.name, location))
quotas.append((resource.name, location))
return quotas

has_quota = set()
with ThreadPoolExecutor(max_workers=8) as executor:
futures = []
for location in locations:
futures.append(executor.submit(get_location_quotas, location))
for future in as_completed(futures):
has_quota.update(future.result())

offers_with_availability = []
for offer in offers:
Expand Down
23 changes: 23 additions & 0 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json
import os
import re
import threading
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Dict, List, Optional

import git
import requests
import yaml
from cachetools import TTLCache, cachedmethod

from dstack._internal import settings
from dstack._internal.core.consts import (
Expand Down Expand Up @@ -41,6 +44,10 @@


class Compute(ABC):
def __init__(self):
self._offers_cache_lock = threading.Lock()
self._offers_cache = TTLCache(maxsize=5, ttl=30)

@abstractmethod
def get_offers(
self, requirements: Optional[Requirements] = None
Expand Down Expand Up @@ -177,6 +184,22 @@ def detach_volume(self, volume: Volume, instance_id: str):
"""
raise NotImplementedError()

def _get_offers_cached_key(self, requirements: Optional[Requirements] = None) -> int:
# Requirements is not hashable, so we use a hack to get arguments hash
if requirements is None:
return hash(None)
return hash(json.dumps(requirements.dict(), sort_keys=True))

@cachedmethod(
cache=lambda self: self._offers_cache,
key=_get_offers_cached_key,
lock=lambda self: self._offers_cache_lock,
)
def get_offers_cached(
self, requirements: Optional[Requirements] = None
) -> List[InstanceOfferWithAvailability]:
return self.get_offers(requirements)


def get_instance_name(run: Run, job: Job) -> str:
return f"{run.project_name.lower()}-{job.job_spec.job_name}"
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/cudo/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

class CudoCompute(Compute):
def __init__(self, config: CudoConfig):
super().__init__()
self.config = config
self.api_client = CudoApiClient(config.creds.api_key)

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/datacrunch/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

class DataCrunchCompute(Compute):
def __init__(self, config: DataCrunchConfig):
super().__init__()
self.config = config
self.api_client = DataCrunchAPIClient(config.creds.client_id, config.creds.client_secret)

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class GCPVolumeDiskBackendData(CoreModel):

class GCPCompute(Compute):
def __init__(self, config: GCPConfig):
super().__init__()
self.config = config
self.credentials, self.project_id = auth.authenticate(config.creds)
self.instances_client = compute_v1.InstancesClient(credentials=self.credentials)
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

class KubernetesCompute(Compute):
def __init__(self, config: KubernetesConfig):
super().__init__()
self.config = config
self.api = get_api_from_config_data(config.kubeconfig.data)

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/lambdalabs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

class LambdaCompute(Compute):
def __init__(self, config: LambdaConfig):
super().__init__()
self.config = config
self.api_client = LambdaAPIClient(config.creds.api_key)

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/nebius/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

class NebiusCompute(Compute):
def __init__(self, config: NebiusConfig):
super().__init__()
self.config = config
self.api_client = NebiusAPIClient(json.loads(self.config.creds.data))

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/oci/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

class OCICompute(Compute):
def __init__(self, config: OCIConfig):
super().__init__()
self.config = config
self.regions = make_region_clients_map(config.regions or [], config.creds)

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/runpod/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class RunpodCompute(Compute):
_last_cleanup_time = None

def __init__(self, config: RunpodConfig):
super().__init__()
self.config = config
self.api_client = RunpodApiClient(config.creds.api_key)

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/tensordock/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class TensorDockCompute(Compute):
def __init__(self, config: TensorDockConfig):
super().__init__()
self.config = config
self.api_client = TensorDockAPIClient(config.creds.api_key, config.creds.api_token)

Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/vastai/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class VastAICompute(Compute):
def __init__(self, config: VastAIConfig):
super().__init__()
self.config = config
self.api_client = VastAIAPIClient(config.creds.api_key)
self.catalog = gpuhunt.Catalog(balance_resources=False, auto_reload=False)
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/vultr/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

class VultrCompute(Compute):
def __init__(self, config: VultrConfig):
super().__init__()
self.config = config
self.api_client = VultrApiClient(config.creds.api_key)

Expand Down
14 changes: 6 additions & 8 deletions src/dstack/_internal/server/background/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,36 @@ def start_background_tasks() -> AsyncIOScheduler:
# * 150 active instances with up to 2 minutes processing latency
_scheduler.add_job(collect_metrics, IntervalTrigger(seconds=10), max_instances=1)
_scheduler.add_job(delete_metrics, IntervalTrigger(minutes=5), max_instances=1)
# process_submitted_jobs and process_instances processing rate is 75 jobs(instances) per minute.
# Currently limited by cloud rate limits such as AWS ListServiceQuotas requests.
# TODO: Fix unnecessary requests to clouds and increase this.
# process_submitted_jobs and process_instances max processing rate is 75 jobs(instances) per minute.
_scheduler.add_job(
process_submitted_jobs,
IntervalTrigger(seconds=4, jitter=2),
kwargs={"batch_size": 5},
max_instances=5,
max_instances=2,
)
_scheduler.add_job(
process_running_jobs,
IntervalTrigger(seconds=4, jitter=2),
kwargs={"batch_size": 5},
max_instances=5,
max_instances=2,
)
_scheduler.add_job(
process_terminating_jobs,
IntervalTrigger(seconds=4, jitter=2),
kwargs={"batch_size": 5},
max_instances=5,
max_instances=2,
)
_scheduler.add_job(
process_runs,
IntervalTrigger(seconds=2, jitter=1),
kwargs={"batch_size": 5},
max_instances=5,
max_instances=2,
)
_scheduler.add_job(
process_instances,
IntervalTrigger(seconds=4, jitter=2),
kwargs={"batch_size": 5},
max_instances=5,
max_instances=2,
)
_scheduler.add_job(process_fleets, IntervalTrigger(seconds=10, jitter=2))
_scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15))
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def __init__(self, url: str, engine: Optional[AsyncEngine] = None):
self.url,
echo=settings.SQL_ECHO_ENABLED,
poolclass=AsyncAdaptedQueuePool,
pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW,
)
self.session_maker = sessionmaker(
bind=self.engine,
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/services/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ async def get_instance_offers(
Returns list of instances satisfying minimal resource requirements sorted by price
"""
logger.info("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends])
tasks = [run_async(backend.compute().get_offers, requirements) for backend in backends]
tasks = [run_async(backend.compute().get_offers_cached, requirements) for backend in backends]
offers_by_backend = []
for backend, result in zip(backends, await asyncio.gather(*tasks, return_exceptions=True)):
if isinstance(result, BackendError):
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
ALEMBIC_MIGRATIONS_LOCATION = os.getenv(
"DSTACK_ALEMBIC_MIGRATIONS_LOCATION", "dstack._internal.server:migrations"
)
# Users may want to increase pool size to support more concurrent resources
# if their db supports many connections
DB_POOL_SIZE = int(os.getenv("DSTACK_DB_POOL_SIZE", 10))
DB_MAX_OVERFLOW = int(os.getenv("DSTACK_DB_MAX_OVERFLOW", 10))

SERVER_CONFIG_DISABLED = os.getenv("DSTACK_SERVER_CONFIG_DISABLED") is not None
SERVER_CONFIG_ENABLED = not SERVER_CONFIG_DISABLED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ async def test_creates_instance(self, test_db, session: AsyncSession):

backend_mock = Mock()
backend_mock.TYPE = BackendType.AWS
backend_mock.compute.return_value.get_offers.return_value = [offer]
backend_mock.compute.return_value.get_offers_cached.return_value = [offer]
backend_mock.compute.return_value.create_instance.return_value = JobProvisioningData(
backend=offer.backend,
instance_type=offer.instance,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def test_provisions_job(
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = backend
backend_mock.compute.return_value.get_offers.return_value = [offer]
backend_mock.compute.return_value.get_offers_cached.return_value = [offer]
backend_mock.compute.return_value.run_job.return_value = JobProvisioningData(
backend=offer.backend,
instance_type=offer.instance,
Expand All @@ -143,7 +143,7 @@ async def test_provisions_job(
)
await process_submitted_jobs()
m.assert_called_once()
backend_mock.compute.return_value.get_offers.assert_called_once()
backend_mock.compute.return_value.get_offers_cached.assert_called_once()
backend_mock.compute.return_value.run_job.assert_called_once()

await session.refresh(job)
Expand Down Expand Up @@ -199,7 +199,7 @@ async def test_fails_job_when_privileged_true_and_no_offers_with_create_instance
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
backend_mock.compute.return_value.get_offers_cached.return_value = [offer]
backend_mock.compute.return_value.run_job.return_value = JobProvisioningData(
backend=offer.backend,
instance_type=offer.instance,
Expand All @@ -218,7 +218,7 @@ async def test_fails_job_when_privileged_true_and_no_offers_with_create_instance
datetime_mock.return_value = datetime(2023, 1, 2, 3, 30, 0, tzinfo=timezone.utc)
await process_submitted_jobs()
m.assert_called_once()
backend_mock.compute.return_value.get_offers.assert_not_called()
backend_mock.compute.return_value.get_offers_cached.assert_not_called()
backend_mock.compute.return_value.run_job.assert_not_called()

await session.refresh(job)
Expand Down Expand Up @@ -272,7 +272,7 @@ async def test_fails_job_when_instance_mounts_and_no_offers_with_create_instance
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.RUNPOD
backend_mock.compute.return_value.get_offers.return_value = [offer]
backend_mock.compute.return_value.get_offers_cached.return_value = [offer]
backend_mock.compute.return_value.run_job.return_value = JobProvisioningData(
backend=offer.backend,
instance_type=offer.instance,
Expand All @@ -291,7 +291,7 @@ async def test_fails_job_when_instance_mounts_and_no_offers_with_create_instance
datetime_mock.return_value = datetime(2023, 1, 2, 3, 30, 0, tzinfo=timezone.utc)
await process_submitted_jobs()
m.assert_called_once()
backend_mock.compute.return_value.get_offers.assert_not_called()
backend_mock.compute.return_value.get_offers_cached.assert_not_called()
backend_mock.compute.return_value.run_job.assert_not_called()

await session.refresh(job)
Expand Down Expand Up @@ -494,7 +494,7 @@ async def test_creates_new_instance_in_existing_fleet(self, test_db, session: As
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.AWS
backend_mock.compute.return_value.get_offers.return_value = [offer]
backend_mock.compute.return_value.get_offers_cached.return_value = [offer]
backend_mock.compute.return_value.run_job.return_value = JobProvisioningData(
backend=offer.backend,
instance_type=offer.instance,
Expand All @@ -511,7 +511,7 @@ async def test_creates_new_instance_in_existing_fleet(self, test_db, session: As
)
await process_submitted_jobs()
m.assert_called_once()
backend_mock.compute.return_value.get_offers.assert_called_once()
backend_mock.compute.return_value.get_offers_cached.assert_called_once()
backend_mock.compute.return_value.run_job.assert_called_once()

await session.refresh(job)
Expand Down
4 changes: 2 additions & 2 deletions src/tests/_internal/server/routers/test_fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,13 +773,13 @@ async def test_returns_plan(self, test_db, session: AsyncSession, client: AsyncC
backend_mock = Mock()
m.return_value = [backend_mock]
backend_mock.TYPE = BackendType.AWS
backend_mock.compute.return_value.get_offers.return_value = offers
backend_mock.compute.return_value.get_offers_cached.return_value = offers
response = await client.post(
f"/api/project/{project.name}/fleets/get_plan",
headers=get_auth_headers(user.token),
json={"spec": spec.dict()},
)
backend_mock.compute.return_value.get_offers.assert_called_once()
backend_mock.compute.return_value.get_offers_cached.assert_called_once()

assert response.status_code == 200
assert response.json() == {
Expand Down
Loading

0 comments on commit 91bdc80

Please sign in to comment.