Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chore]: Move connection pool to proxy deps #2235

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/dstack/_internal/proxy/gateway/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from dstack._internal.proxy.gateway.services.server_client import HTTPMultiClient
from dstack._internal.proxy.gateway.services.stats import StatsCollector
from dstack._internal.proxy.lib.routers.model_proxy import router as model_proxy_router
from dstack._internal.proxy.lib.services.service_connection import service_replica_connection_pool
from dstack._internal.utils.common import run_async
from dstack.version import __version__

Expand All @@ -42,12 +41,13 @@ async def lifespan(app: FastAPI):
injector = get_gateway_injector_from_app(app)
repo = await get_gateway_proxy_repo(await injector.get_repo().__anext__())
nginx = injector.get_nginx()
service_conn_pool = await injector.get_service_connection_pool()
await run_async(nginx.write_global_conf)
await apply_all(repo, nginx)
await apply_all(repo, nginx, service_conn_pool)

yield

await service_replica_connection_pool.remove_all()
await service_conn_pool.remove_all()


def make_app(repo: Optional[GatewayProxyRepo] = None, nginx: Optional[Nginx] = None) -> FastAPI:
Expand Down
12 changes: 10 additions & 2 deletions src/dstack/_internal/proxy/gateway/deps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated
from typing import Annotated, AsyncGenerator

from fastapi import Depends, FastAPI, Request

Expand All @@ -23,10 +23,18 @@ def __init__(
nginx: Nginx,
stats_collector: StatsCollector,
) -> None:
super().__init__(repo=repo, auth=auth)
super().__init__()
self._repo = repo
self._auth = auth
self._nginx = nginx
self._stats_collector = stats_collector

async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]:
yield self._repo

async def get_auth_provider(self) -> AsyncGenerator[BaseProxyAuthProvider, None]:
yield self._auth

def get_nginx(self) -> Nginx:
return self._nginx

Expand Down
10 changes: 10 additions & 0 deletions src/dstack/_internal/proxy/gateway/routers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
RegisterServiceRequest,
)
from dstack._internal.proxy.gateway.services.nginx import Nginx
from dstack._internal.proxy.lib.deps import get_service_connection_pool
from dstack._internal.proxy.lib.services.service_connection import ServiceConnectionPool

router = APIRouter(prefix="/{project_name}")

Expand All @@ -21,6 +23,7 @@ async def register_service(
body: RegisterServiceRequest,
repo: Annotated[GatewayProxyRepo, Depends(get_gateway_proxy_repo)],
nginx: Annotated[Nginx, Depends(get_nginx)],
service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)],
) -> OkResponse:
await registry_services.register_service(
project_name=project_name.lower(),
Expand All @@ -33,6 +36,7 @@ async def register_service(
ssh_private_key=body.ssh_private_key,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
return OkResponse()

Expand All @@ -43,12 +47,14 @@ async def unregister_service(
run_name: str,
repo: Annotated[GatewayProxyRepo, Depends(get_gateway_proxy_repo)],
nginx: Annotated[Nginx, Depends(get_nginx)],
service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)],
) -> OkResponse:
await registry_services.unregister_service(
project_name=project_name.lower(),
run_name=run_name.lower(),
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
return OkResponse()

Expand All @@ -60,6 +66,7 @@ async def register_replica(
body: RegisterReplicaRequest,
repo: Annotated[GatewayProxyRepo, Depends(get_gateway_proxy_repo)],
nginx: Annotated[Nginx, Depends(get_nginx)],
service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)],
) -> OkResponse:
await registry_services.register_replica(
project_name=project_name.lower(),
Expand All @@ -71,6 +78,7 @@ async def register_replica(
ssh_proxy=body.ssh_proxy,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
return OkResponse()

Expand All @@ -82,13 +90,15 @@ async def unregister_replica(
job_id: str,
repo: Annotated[GatewayProxyRepo, Depends(get_gateway_proxy_repo)],
nginx: Annotated[Nginx, Depends(get_nginx)],
service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)],
) -> OkResponse:
await registry_services.unregister_replica(
project_name=project_name.lower(),
run_name=run_name.lower(),
replica_id=job_id,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
return OkResponse()

Expand Down
75 changes: 58 additions & 17 deletions src/dstack/_internal/proxy/gateway/services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.proxy.lib.services.service_connection import (
ServiceReplicaConnection,
service_replica_connection_pool,
ServiceConnection,
ServiceConnectionPool,
)
from dstack._internal.utils.logging import get_logger

Expand All @@ -39,6 +39,7 @@ async def register_service(
ssh_private_key: str,
repo: GatewayProxyRepo,
nginx: Nginx,
service_conn_pool: ServiceConnectionPool,
) -> None:
service = models.Service(
project_name=project_name,
Expand All @@ -64,7 +65,13 @@ async def register_service(

logger.debug("Registering service %s", service.fmt())

await apply_service(service=service, old_service=None, repo=repo, nginx=nginx)
await apply_service(
service=service,
old_service=None,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
await repo.set_service(service)

if model is not None:
Expand All @@ -82,7 +89,11 @@ async def register_service(


async def unregister_service(
project_name: str, run_name: str, repo: GatewayProxyRepo, nginx: Nginx
project_name: str,
run_name: str,
repo: GatewayProxyRepo,
nginx: Nginx,
service_conn_pool: ServiceConnectionPool,
) -> None:
async with lock:
service = await repo.get_service(project_name, run_name)
Expand All @@ -93,7 +104,10 @@ async def unregister_service(

logger.debug("Unregistering service %s", service.fmt())

await stop_replica_connections(r.id for r in service.replicas)
await stop_replica_connections(
ids=(r.id for r in service.replicas),
service_conn_pool=service_conn_pool,
)
await nginx.unregister(service.domain_safe)
await repo.delete_models_by_run(project_name, run_name)
await repo.delete_service(project_name, run_name)
Expand All @@ -111,6 +125,7 @@ async def register_replica(
ssh_proxy: Optional[SSHConnectionParams],
repo: GatewayProxyRepo,
nginx: Nginx,
service_conn_pool: ServiceConnectionPool,
) -> None:
replica = models.Replica(
id=replica_id,
Expand All @@ -134,7 +149,11 @@ async def register_replica(

logger.debug("Registering replica %s in service %s", replica.id, service.fmt())
failures = await apply_service(
service=service, old_service=old_service, repo=repo, nginx=nginx
service=service,
old_service=old_service,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
if replica in failures:
raise ProxyError(
Expand All @@ -152,6 +171,7 @@ async def unregister_replica(
replica_id: str,
repo: GatewayProxyRepo,
nginx: Nginx,
service_conn_pool: ServiceConnectionPool,
) -> None:
async with lock:
old_service = await repo.get_service(project_name, run_name)
Expand All @@ -171,7 +191,13 @@ async def unregister_replica(

logger.debug("Unregistering replica %s in service %s", replica.id, service.fmt())

await apply_service(service=service, old_service=old_service, repo=repo, nginx=nginx)
await apply_service(
service=service,
old_service=old_service,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
await repo.set_service(service)

logger.info("Replica %s in service %s is unregistered now", replica_id, service.fmt())
Expand Down Expand Up @@ -200,6 +226,7 @@ async def apply_service(
old_service: Optional[models.Service],
repo: GatewayProxyRepo,
nginx: Nginx,
service_conn_pool: ServiceConnectionPool,
) -> dict[models.Replica, BaseException]:
if old_service is not None:
if service.domain != old_service.domain:
Expand All @@ -208,9 +235,14 @@ async def apply_service(
f" domain name to change ({old_service.domain} -> {service.domain})"
)
await stop_replica_connections(
replica.id for replica in old_service.replicas if replica not in service.replicas
ids=(
replica.id for replica in old_service.replicas if replica not in service.replicas
),
service_conn_pool=service_conn_pool,
)
replica_conns, replica_failures = await get_or_add_replica_connections(service, repo)
replica_conns, replica_failures = await get_or_add_replica_connections(
service, repo, service_conn_pool
)
replica_configs = [
ReplicaConfig(id=replica.id, socket=conn.app_socket_path)
for replica, conn in replica_conns.items()
Expand All @@ -221,8 +253,8 @@ async def apply_service(


async def get_or_add_replica_connections(
service: models.Service, repo: BaseProxyRepo
) -> tuple[dict[models.Replica, ServiceReplicaConnection], dict[models.Replica, BaseException]]:
service: models.Service, repo: BaseProxyRepo, service_conn_pool: ServiceConnectionPool
) -> tuple[dict[models.Replica, ServiceConnection], dict[models.Replica, BaseException]]:
project = await repo.get_project(service.project_name)
if project is None:
raise UnexpectedProxyError(
Expand All @@ -231,8 +263,7 @@ async def get_or_add_replica_connections(
)
replica_conns, replica_failures = {}, {}
tasks = [
service_replica_connection_pool.get_or_add(project, service, replica)
for replica in service.replicas
service_conn_pool.get_or_add(project, service, replica) for replica in service.replicas
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for replica, conn_or_err in zip(service.replicas, results):
Expand All @@ -249,8 +280,10 @@ async def get_or_add_replica_connections(
return replica_conns, replica_failures


async def stop_replica_connections(ids: Iterable[str]) -> None:
tasks = map(service_replica_connection_pool.remove, ids)
async def stop_replica_connections(
ids: Iterable[str], service_conn_pool: ServiceConnectionPool
) -> None:
tasks = map(service_conn_pool.remove, ids)
results = await asyncio.gather(*tasks, return_exceptions=True)
for replica_id, exc in zip(ids, results):
if isinstance(exc, Exception):
Expand Down Expand Up @@ -284,9 +317,17 @@ async def apply_entrypoint(
await nginx.register(config, acme)


async def apply_all(repo: GatewayProxyRepo, nginx: Nginx) -> None:
async def apply_all(
repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool
) -> None:
service_tasks = [
apply_service(service=service, old_service=None, repo=repo, nginx=nginx)
apply_service(
service=service,
old_service=None,
repo=repo,
nginx=nginx,
service_conn_pool=service_conn_pool,
)
for service in await repo.list_services()
]
entrypoint_tasks = [
Expand Down
23 changes: 17 additions & 6 deletions src/dstack/_internal/proxy/lib/deps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Optional

from fastapi import Depends, FastAPI, Request, Security, status
Expand All @@ -8,6 +8,7 @@
from dstack._internal.proxy.lib.auth import BaseProxyAuthProvider
from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.proxy.lib.services.service_connection import ServiceConnectionPool


class ProxyDependencyInjector(ABC):
Expand All @@ -17,15 +18,19 @@ class ProxyDependencyInjector(ABC):
a specific repo implementation.
"""

def __init__(self, repo: BaseProxyRepo, auth: BaseProxyAuthProvider) -> None:
self._repo = repo
self._auth = auth
def __init__(self) -> None:
self._service_conn_pool = ServiceConnectionPool()

@abstractmethod
async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]:
yield self._repo
pass

@abstractmethod
async def get_auth_provider(self) -> AsyncGenerator[BaseProxyAuthProvider, None]:
yield self._auth
pass

async def get_service_connection_pool(self) -> ServiceConnectionPool:
return self._service_conn_pool


def get_injector_from_app(app: FastAPI) -> ProxyDependencyInjector:
Expand Down Expand Up @@ -53,6 +58,12 @@ async def get_proxy_auth_provider(
yield provider


async def get_service_connection_pool(
injector: Annotated[ProxyDependencyInjector, Depends(get_injector)],
) -> ServiceConnectionPool:
return await injector.get_service_connection_pool()


class ProxyAuthContext:
def __init__(self, project_name: str, token: Optional[str], provider: BaseProxyAuthProvider):
self._project_name = project_name
Expand Down
10 changes: 7 additions & 3 deletions src/dstack/_internal/proxy/lib/routers/model_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi.responses import StreamingResponse
from typing_extensions import Annotated

from dstack._internal.proxy.lib.deps import ProxyAuth, get_proxy_repo
from dstack._internal.proxy.lib.deps import ProxyAuth, get_proxy_repo, get_service_connection_pool
from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError
from dstack._internal.proxy.lib.repo import BaseProxyRepo
from dstack._internal.proxy.lib.schemas.model_proxy import (
Expand All @@ -15,7 +15,10 @@
ModelsResponse,
)
from dstack._internal.proxy.lib.services.model_proxy.model_proxy import get_chat_client
from dstack._internal.proxy.lib.services.service_connection import get_service_replica_client
from dstack._internal.proxy.lib.services.service_connection import (
ServiceConnectionPool,
get_service_replica_client,
)

router = APIRouter(dependencies=[Depends(ProxyAuth(auto_enforce=True))])

Expand All @@ -37,6 +40,7 @@ async def post_chat_completions(
project_name: str,
body: ChatCompletionsRequest,
repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)],
service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)],
):
model = await repo.get_model(project_name, body.model)
if model is None:
Expand All @@ -49,7 +53,7 @@ async def post_chat_completions(
f"Model {model.name} in project {project_name} references run {model.run_name}"
" that does not exist or has no replicas"
)
http_client = await get_service_replica_client(service, repo)
http_client = await get_service_replica_client(service, repo, service_conn_pool)
client = get_chat_client(model, http_client)
if not body.stream:
return await client.generate(body)
Expand Down
Loading
Loading