diff --git a/src/dstack/_internal/proxy/gateway/app.py b/src/dstack/_internal/proxy/gateway/app.py index 82457e337..f43ebada0 100644 --- a/src/dstack/_internal/proxy/gateway/app.py +++ b/src/dstack/_internal/proxy/gateway/app.py @@ -28,7 +28,6 @@ from dstack._internal.proxy.gateway.services.server_client import HTTPMultiClient from dstack._internal.proxy.gateway.services.stats import StatsCollector from dstack._internal.proxy.lib.routers.model_proxy import router as model_proxy_router -from dstack._internal.proxy.lib.services.service_connection import service_replica_connection_pool from dstack._internal.utils.common import run_async from dstack.version import __version__ @@ -42,12 +41,13 @@ async def lifespan(app: FastAPI): injector = get_gateway_injector_from_app(app) repo = await get_gateway_proxy_repo(await injector.get_repo().__anext__()) nginx = injector.get_nginx() + service_conn_pool = await injector.get_service_connection_pool() await run_async(nginx.write_global_conf) - await apply_all(repo, nginx) + await apply_all(repo, nginx, service_conn_pool) yield - await service_replica_connection_pool.remove_all() + await service_conn_pool.remove_all() def make_app(repo: Optional[GatewayProxyRepo] = None, nginx: Optional[Nginx] = None) -> FastAPI: diff --git a/src/dstack/_internal/proxy/gateway/deps.py b/src/dstack/_internal/proxy/gateway/deps.py index d7ea22e3a..009ce8da1 100644 --- a/src/dstack/_internal/proxy/gateway/deps.py +++ b/src/dstack/_internal/proxy/gateway/deps.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, AsyncGenerator from fastapi import Depends, FastAPI, Request @@ -23,10 +23,18 @@ def __init__( nginx: Nginx, stats_collector: StatsCollector, ) -> None: - super().__init__(repo=repo, auth=auth) + super().__init__() + self._repo = repo + self._auth = auth self._nginx = nginx self._stats_collector = stats_collector + async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]: + yield self._repo + + async def get_auth_provider(self) -> AsyncGenerator[BaseProxyAuthProvider, None]: + yield self._auth + def get_nginx(self) -> Nginx: return self._nginx diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index c1fa5af96..c8e2c55b7 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -11,6 +11,8 @@ RegisterServiceRequest, ) from dstack._internal.proxy.gateway.services.nginx import Nginx +from dstack._internal.proxy.lib.deps import get_service_connection_pool +from dstack._internal.proxy.lib.services.service_connection import ServiceConnectionPool router = APIRouter(prefix="/{project_name}") @@ -21,6 +23,7 @@ async def register_service( body: RegisterServiceRequest, repo: Annotated[GatewayProxyRepo, Depends(get_gateway_proxy_repo)], nginx: Annotated[Nginx, Depends(get_nginx)], + service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)], ) -> OkResponse: await registry_services.register_service( project_name=project_name.lower(), @@ -33,6 +36,7 @@ async def register_service( ssh_private_key=body.ssh_private_key, repo=repo, nginx=nginx, + service_conn_pool=service_conn_pool, ) return OkResponse() @@ -43,12 +47,14 @@ async def unregister_service( run_name: str, repo: Annotated[GatewayProxyRepo, Depends(get_gateway_proxy_repo)], nginx: Annotated[Nginx, Depends(get_nginx)], + service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)], ) -> OkResponse: await registry_services.unregister_service( project_name=project_name.lower(), run_name=run_name.lower(), repo=repo, nginx=nginx, + service_conn_pool=service_conn_pool, ) return OkResponse() @@ -60,6 +66,7 @@ async def register_replica( body: RegisterReplicaRequest, repo: Annotated[GatewayProxyRepo, Depends(get_gateway_proxy_repo)], nginx: Annotated[Nginx, Depends(get_nginx)], + service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)], ) -> OkResponse: await registry_services.register_replica( project_name=project_name.lower(), @@ -71,6 +78,7 @@ async def register_replica( ssh_proxy=body.ssh_proxy, repo=repo, nginx=nginx, + service_conn_pool=service_conn_pool, ) return OkResponse() @@ -82,6 +90,7 @@ async def unregister_replica( job_id: str, repo: Annotated[GatewayProxyRepo, Depends(get_gateway_proxy_repo)], nginx: Annotated[Nginx, Depends(get_nginx)], + service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)], ) -> OkResponse: await registry_services.unregister_replica( project_name=project_name.lower(), @@ -89,6 +98,7 @@ async def unregister_replica( replica_id=job_id, repo=repo, nginx=nginx, + service_conn_pool=service_conn_pool, ) return OkResponse() diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 0ef104997..fa935a147 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -18,8 +18,8 @@ from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.proxy.lib.repo import BaseProxyRepo from dstack._internal.proxy.lib.services.service_connection import ( - ServiceReplicaConnection, - service_replica_connection_pool, + ServiceConnection, + ServiceConnectionPool, ) from dstack._internal.utils.logging import get_logger @@ -39,6 +39,7 @@ async def register_service( ssh_private_key: str, repo: GatewayProxyRepo, nginx: Nginx, + service_conn_pool: ServiceConnectionPool, ) -> None: service = models.Service( project_name=project_name, @@ -64,7 +65,13 @@ async def register_service( logger.debug("Registering service %s", service.fmt()) - await apply_service(service=service, old_service=None, repo=repo, nginx=nginx) + await apply_service( + service=service, + old_service=None, + repo=repo, + nginx=nginx, + service_conn_pool=service_conn_pool, + ) await repo.set_service(service) if model is not None: @@ -82,7 +89,11 @@ async def register_service( async def unregister_service( - project_name: str, run_name: str, repo: GatewayProxyRepo, nginx: Nginx + project_name: str, + run_name: str, + repo: GatewayProxyRepo, + nginx: Nginx, + service_conn_pool: ServiceConnectionPool, ) -> None: async with lock: service = await repo.get_service(project_name, run_name) @@ -93,7 +104,10 @@ async def unregister_service( logger.debug("Unregistering service %s", service.fmt()) - await stop_replica_connections(r.id for r in service.replicas) + await stop_replica_connections( + ids=(r.id for r in service.replicas), + service_conn_pool=service_conn_pool, + ) await nginx.unregister(service.domain_safe) await repo.delete_models_by_run(project_name, run_name) await repo.delete_service(project_name, run_name) @@ -111,6 +125,7 @@ async def register_replica( ssh_proxy: Optional[SSHConnectionParams], repo: GatewayProxyRepo, nginx: Nginx, + service_conn_pool: ServiceConnectionPool, ) -> None: replica = models.Replica( id=replica_id, @@ -134,7 +149,11 @@ async def register_replica( logger.debug("Registering replica %s in service %s", replica.id, service.fmt()) failures = await apply_service( - service=service, old_service=old_service, repo=repo, nginx=nginx + service=service, + old_service=old_service, + repo=repo, + nginx=nginx, + service_conn_pool=service_conn_pool, ) if replica in failures: raise ProxyError( @@ -152,6 +171,7 @@ async def unregister_replica( replica_id: str, repo: GatewayProxyRepo, nginx: Nginx, + service_conn_pool: ServiceConnectionPool, ) -> None: async with lock: old_service = await repo.get_service(project_name, run_name) @@ -171,7 +191,13 @@ async def unregister_replica( logger.debug("Unregistering replica %s in service %s", replica.id, service.fmt()) - await apply_service(service=service, old_service=old_service, repo=repo, nginx=nginx) + await apply_service( + service=service, + old_service=old_service, + repo=repo, + nginx=nginx, + service_conn_pool=service_conn_pool, + ) await repo.set_service(service) logger.info("Replica %s in service %s is unregistered now", replica_id, service.fmt()) @@ -200,6 +226,7 @@ async def apply_service( old_service: Optional[models.Service], repo: GatewayProxyRepo, nginx: Nginx, + service_conn_pool: ServiceConnectionPool, ) -> dict[models.Replica, BaseException]: if old_service is not None: if service.domain != old_service.domain: @@ -208,9 +235,14 @@ async def apply_service( f" domain name to change ({old_service.domain} -> {service.domain})" ) await stop_replica_connections( - replica.id for replica in old_service.replicas if replica not in service.replicas + ids=( + replica.id for replica in old_service.replicas if replica not in service.replicas + ), + service_conn_pool=service_conn_pool, ) - replica_conns, replica_failures = await get_or_add_replica_connections(service, repo) + replica_conns, replica_failures = await get_or_add_replica_connections( + service, repo, service_conn_pool + ) replica_configs = [ ReplicaConfig(id=replica.id, socket=conn.app_socket_path) for replica, conn in replica_conns.items() @@ -221,8 +253,8 @@ async def apply_service( async def get_or_add_replica_connections( - service: models.Service, repo: BaseProxyRepo -) -> tuple[dict[models.Replica, ServiceReplicaConnection], dict[models.Replica, BaseException]]: + service: models.Service, repo: BaseProxyRepo, service_conn_pool: ServiceConnectionPool +) -> tuple[dict[models.Replica, ServiceConnection], dict[models.Replica, BaseException]]: project = await repo.get_project(service.project_name) if project is None: raise UnexpectedProxyError( @@ -231,8 +263,7 @@ async def get_or_add_replica_connections( ) replica_conns, replica_failures = {}, {} tasks = [ - service_replica_connection_pool.get_or_add(project, service, replica) - for replica in service.replicas + service_conn_pool.get_or_add(project, service, replica) for replica in service.replicas ] results = await asyncio.gather(*tasks, return_exceptions=True) for replica, conn_or_err in zip(service.replicas, results): @@ -249,8 +280,10 @@ async def get_or_add_replica_connections( return replica_conns, replica_failures -async def stop_replica_connections(ids: Iterable[str]) -> None: - tasks = map(service_replica_connection_pool.remove, ids) +async def stop_replica_connections( + ids: Iterable[str], service_conn_pool: ServiceConnectionPool +) -> None: + tasks = map(service_conn_pool.remove, ids) results = await asyncio.gather(*tasks, return_exceptions=True) for replica_id, exc in zip(ids, results): if isinstance(exc, Exception): @@ -284,9 +317,17 @@ async def apply_entrypoint( await nginx.register(config, acme) -async def apply_all(repo: GatewayProxyRepo, nginx: Nginx) -> None: +async def apply_all( + repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool +) -> None: service_tasks = [ - apply_service(service=service, old_service=None, repo=repo, nginx=nginx) + apply_service( + service=service, + old_service=None, + repo=repo, + nginx=nginx, + service_conn_pool=service_conn_pool, + ) for service in await repo.list_services() ] entrypoint_tasks = [ diff --git a/src/dstack/_internal/proxy/lib/deps.py b/src/dstack/_internal/proxy/lib/deps.py index 45a23f45e..ae10be7ab 100644 --- a/src/dstack/_internal/proxy/lib/deps.py +++ b/src/dstack/_internal/proxy/lib/deps.py @@ -1,4 +1,4 @@ -from abc import ABC +from abc import ABC, abstractmethod from typing import AsyncGenerator, Optional from fastapi import Depends, FastAPI, Request, Security, status @@ -8,6 +8,7 @@ from dstack._internal.proxy.lib.auth import BaseProxyAuthProvider from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.proxy.lib.repo import BaseProxyRepo +from dstack._internal.proxy.lib.services.service_connection import ServiceConnectionPool class ProxyDependencyInjector(ABC): @@ -17,15 +18,19 @@ class ProxyDependencyInjector(ABC): a specific repo implementation. """ - def __init__(self, repo: BaseProxyRepo, auth: BaseProxyAuthProvider) -> None: - self._repo = repo - self._auth = auth + def __init__(self) -> None: + self._service_conn_pool = ServiceConnectionPool() + @abstractmethod async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]: - yield self._repo + pass + @abstractmethod async def get_auth_provider(self) -> AsyncGenerator[BaseProxyAuthProvider, None]: - yield self._auth + pass + + async def get_service_connection_pool(self) -> ServiceConnectionPool: + return self._service_conn_pool def get_injector_from_app(app: FastAPI) -> ProxyDependencyInjector: @@ -53,6 +58,12 @@ async def get_proxy_auth_provider( yield provider +async def get_service_connection_pool( + injector: Annotated[ProxyDependencyInjector, Depends(get_injector)], +) -> ServiceConnectionPool: + return await injector.get_service_connection_pool() + + class ProxyAuthContext: def __init__(self, project_name: str, token: Optional[str], provider: BaseProxyAuthProvider): self._project_name = project_name diff --git a/src/dstack/_internal/proxy/lib/routers/model_proxy.py b/src/dstack/_internal/proxy/lib/routers/model_proxy.py index 8f49c2e6e..e5a5c4cee 100644 --- a/src/dstack/_internal/proxy/lib/routers/model_proxy.py +++ b/src/dstack/_internal/proxy/lib/routers/model_proxy.py @@ -4,7 +4,7 @@ from fastapi.responses import StreamingResponse from typing_extensions import Annotated -from dstack._internal.proxy.lib.deps import ProxyAuth, get_proxy_repo +from dstack._internal.proxy.lib.deps import ProxyAuth, get_proxy_repo, get_service_connection_pool from dstack._internal.proxy.lib.errors import ProxyError, UnexpectedProxyError from dstack._internal.proxy.lib.repo import BaseProxyRepo from dstack._internal.proxy.lib.schemas.model_proxy import ( @@ -15,7 +15,10 @@ ModelsResponse, ) from dstack._internal.proxy.lib.services.model_proxy.model_proxy import get_chat_client -from dstack._internal.proxy.lib.services.service_connection import get_service_replica_client +from dstack._internal.proxy.lib.services.service_connection import ( + ServiceConnectionPool, + get_service_replica_client, +) router = APIRouter(dependencies=[Depends(ProxyAuth(auto_enforce=True))]) @@ -37,6 +40,7 @@ async def post_chat_completions( project_name: str, body: ChatCompletionsRequest, repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)], + service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)], ): model = await repo.get_model(project_name, body.model) if model is None: @@ -49,7 +53,7 @@ async def post_chat_completions( f"Model {model.name} in project {project_name} references run {model.run_name}" " that does not exist or has no replicas" ) - http_client = await get_service_replica_client(service, repo) + http_client = await get_service_replica_client(service, repo, service_conn_pool) client = get_chat_client(model, http_client) if not body.stream: return await client.generate(body) diff --git a/src/dstack/_internal/proxy/lib/services/service_connection.py b/src/dstack/_internal/proxy/lib/services/service_connection.py index 37f54a58a..df6ef14c0 100644 --- a/src/dstack/_internal/proxy/lib/services/service_connection.py +++ b/src/dstack/_internal/proxy/lib/services/service_connection.py @@ -26,13 +26,13 @@ HTTP_TIMEOUT = 60 # Same as default Nginx proxy timeout -class ServiceReplicaClient(httpx.AsyncClient): +class ServiceClient(httpx.AsyncClient): def build_request(self, *args, **kwargs) -> httpx.Request: self.cookies.clear() # the client is shared by all users, don't leak cookies return super().build_request(*args, **kwargs) -class ServiceReplicaConnection: +class ServiceConnection: def __init__(self, project: Project, service: Service, replica: Replica) -> None: self._temp_dir = TemporaryDirectory() options = { @@ -58,7 +58,7 @@ def __init__(self, project: Project, service: Service, replica: Replica) -> None ], options=options, ) - self._client = ServiceReplicaClient( + self._client = ServiceClient( transport=AsyncHTTPTransport(uds=str(self._app_socket_path)), # The hostname in base_url is there for troubleshooting, as it may appear in # logs and in the Host header. The actual destination is the Unix socket. @@ -80,26 +80,26 @@ async def close(self) -> None: await self._client.aclose() await self._tunnel.aclose() - async def client(self) -> ServiceReplicaClient: + async def client(self) -> ServiceClient: await asyncio.wait_for(self._is_open.wait(), timeout=OPEN_TUNNEL_TIMEOUT) return self._client -class ServiceReplicaConnectionPool: +class ServiceConnectionPool: def __init__(self) -> None: # TODO(#1595): remove connections to stopped replicas in-server - self.connections: Dict[str, ServiceReplicaConnection] = {} + self.connections: Dict[str, ServiceConnection] = {} - async def get(self, replica_id: str) -> Optional[ServiceReplicaConnection]: + async def get(self, replica_id: str) -> Optional[ServiceConnection]: return self.connections.get(replica_id) async def get_or_add( self, project: Project, service: Service, replica: Replica - ) -> ServiceReplicaConnection: + ) -> ServiceConnection: connection = self.connections.get(replica.id) if connection is not None: return connection - connection = ServiceReplicaConnection(project, service, replica) + connection = ServiceConnection(project, service, replica) self.connections[replica.id] = connection try: await connection.open() @@ -125,7 +125,9 @@ async def remove_all(self) -> None: ) -async def get_service_replica_client(service: Service, repo: BaseProxyRepo) -> httpx.AsyncClient: +async def get_service_replica_client( + service: Service, repo: BaseProxyRepo, service_conn_pool: ServiceConnectionPool +) -> httpx.AsyncClient: """ `service` must have at least one replica """ @@ -139,16 +141,12 @@ async def get_service_replica_client(service: Service, repo: BaseProxyRepo) -> h # Nginx not available, forward directly to the tunnel # TODO(#1595): consider trying different replicas, e.g. using HTTPMultiClient replica = random.choice(service.replicas) - connection = await service_replica_connection_pool.get(replica.id) + connection = await service_conn_pool.get(replica.id) if connection is None: project = await repo.get_project(service.project_name) if project is None: raise UnexpectedProxyError( f"Expected to find project {service.project_name} but could not" ) - connection = await service_replica_connection_pool.get_or_add(project, service, replica) + connection = await service_conn_pool.get_or_add(project, service, replica) return await connection.client() - - -# TODO(#1595): do not use a global variable, it's shared by tests -service_replica_connection_pool: ServiceReplicaConnectionPool = ServiceReplicaConnectionPool() diff --git a/src/dstack/_internal/proxy/lib/testing/common.py b/src/dstack/_internal/proxy/lib/testing/common.py index de6f6e46d..6b335a339 100644 --- a/src/dstack/_internal/proxy/lib/testing/common.py +++ b/src/dstack/_internal/proxy/lib/testing/common.py @@ -1,6 +1,22 @@ -from typing import Optional +from typing import AsyncGenerator, Optional +from dstack._internal.proxy.lib.auth import BaseProxyAuthProvider +from dstack._internal.proxy.lib.deps import ProxyDependencyInjector from dstack._internal.proxy.lib.models import Project, Replica, Service +from dstack._internal.proxy.lib.repo import BaseProxyRepo + + +class ProxyTestDependencyInjector(ProxyDependencyInjector): + def __init__(self, repo: BaseProxyRepo, auth: BaseProxyAuthProvider) -> None: + super().__init__() + self._repo = repo + self._auth = auth + + async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]: + yield self._repo + + async def get_auth_provider(self) -> AsyncGenerator[BaseProxyAuthProvider, None]: + yield self._auth def make_project(name: str) -> Project: diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 0da0d9f9b..ca718f484 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -15,8 +15,8 @@ from dstack._internal.cli.utils.common import console from dstack._internal.core.errors import ForbiddenError, ServerClientError from dstack._internal.core.services.configs import update_default_project +from dstack._internal.proxy.lib.deps import get_injector_from_app from dstack._internal.proxy.lib.routers import model_proxy -from dstack._internal.proxy.lib.services.service_connection import service_replica_connection_pool from dstack._internal.server import settings from dstack._internal.server.background import start_background_tasks from dstack._internal.server.db import get_db, get_session_ctx, migrate @@ -142,7 +142,8 @@ async def lifespan(app: FastAPI): yield scheduler.shutdown() await gateway_connections_pool.remove_all() - await service_replica_connection_pool.remove_all() + service_conn_pool = await get_injector_from_app(app).get_service_connection_pool() + await service_conn_pool.remove_all() await get_db().engine.dispose() # Let checked-out DB connections close as dispose() only closes checked-in connections await asyncio.sleep(3) diff --git a/src/dstack/_internal/server/services/proxy/deps.py b/src/dstack/_internal/server/services/proxy/deps.py index ab0d73b67..558b143ab 100644 --- a/src/dstack/_internal/server/services/proxy/deps.py +++ b/src/dstack/_internal/server/services/proxy/deps.py @@ -9,9 +9,6 @@ class ServerProxyDependencyInjector(ProxyDependencyInjector): - def __init__(self) -> None: - pass - async def get_repo(self) -> AsyncGenerator[BaseProxyRepo, None]: async with get_session_ctx() as session: yield ServerProxyRepo(session) diff --git a/src/dstack/_internal/server/services/proxy/routers/service_proxy.py b/src/dstack/_internal/server/services/proxy/routers/service_proxy.py index ae7fbaa64..efc1bca7d 100644 --- a/src/dstack/_internal/server/services/proxy/routers/service_proxy.py +++ b/src/dstack/_internal/server/services/proxy/routers/service_proxy.py @@ -3,8 +3,14 @@ from fastapi.responses import RedirectResponse, Response from typing_extensions import Annotated -from dstack._internal.proxy.lib.deps import ProxyAuth, ProxyAuthContext, get_proxy_repo +from dstack._internal.proxy.lib.deps import ( + ProxyAuth, + ProxyAuthContext, + get_proxy_repo, + get_service_connection_pool, +) from dstack._internal.proxy.lib.repo import BaseProxyRepo +from dstack._internal.proxy.lib.services.service_connection import ServiceConnectionPool from dstack._internal.server.services.proxy.services import service_proxy REDIRECTED_HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"] @@ -29,8 +35,11 @@ async def service_reverse_proxy( request: Request, auth: Annotated[ProxyAuthContext, Depends(ProxyAuth(auto_enforce=False))], repo: Annotated[BaseProxyRepo, Depends(get_proxy_repo)], + service_conn_pool: Annotated[ServiceConnectionPool, Depends(get_service_connection_pool)], ) -> Response: - return await service_proxy.proxy(project_name, run_name, path, request, auth, repo) + return await service_proxy.proxy( + project_name, run_name, path, request, auth, repo, service_conn_pool + ) # TODO(#1595): support websockets diff --git a/src/dstack/_internal/server/services/proxy/services/service_proxy.py b/src/dstack/_internal/server/services/proxy/services/service_proxy.py index 27b6390cc..fc7836feb 100644 --- a/src/dstack/_internal/server/services/proxy/services/service_proxy.py +++ b/src/dstack/_internal/server/services/proxy/services/service_proxy.py @@ -8,7 +8,10 @@ from dstack._internal.proxy.lib.deps import ProxyAuthContext from dstack._internal.proxy.lib.errors import ProxyError from dstack._internal.proxy.lib.repo import BaseProxyRepo -from dstack._internal.proxy.lib.services.service_connection import get_service_replica_client +from dstack._internal.proxy.lib.services.service_connection import ( + ServiceConnectionPool, + get_service_replica_client, +) from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -21,6 +24,7 @@ async def proxy( request: fastapi.Request, auth: ProxyAuthContext, repo: BaseProxyRepo, + service_conn_pool: ServiceConnectionPool, ) -> fastapi.responses.Response: # TODO(#1595): enforce client_max_body_size @@ -33,7 +37,7 @@ async def proxy( if service.auth: await auth.enforce() - client = await get_service_replica_client(service, repo) + client = await get_service_replica_client(service, repo, service_conn_pool) try: upstream_request = await build_upstream_request(request, path, client) diff --git a/src/tests/_internal/proxy/gateway/conftest.py b/src/tests/_internal/proxy/gateway/conftest.py index bc57d97e5..008bf9460 100644 --- a/src/tests/_internal/proxy/gateway/conftest.py +++ b/src/tests/_internal/proxy/gateway/conftest.py @@ -14,8 +14,8 @@ def system_mocks() -> Generator[Mocks, None, None]: patch(f"{nginx}.sudo") as sudo, patch(f"{nginx}.Nginx.reload") as reload_nginx, patch(f"{nginx}.Nginx.run_certbot") as run_certbot, - patch(f"{connection}.ServiceReplicaConnection.open") as open_conn, - patch(f"{connection}.ServiceReplicaConnection.close") as close_conn, + patch(f"{connection}.ServiceConnection.open") as open_conn, + patch(f"{connection}.ServiceConnection.close") as close_conn, ): sudo.return_value = [] yield Mocks( diff --git a/src/tests/_internal/proxy/lib/routers/test_model_proxy.py b/src/tests/_internal/proxy/lib/routers/test_model_proxy.py index 1777184fc..c8bb1b286 100644 --- a/src/tests/_internal/proxy/lib/routers/test_model_proxy.py +++ b/src/tests/_internal/proxy/lib/routers/test_model_proxy.py @@ -9,7 +9,6 @@ from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo from dstack._internal.proxy.lib.auth import BaseProxyAuthProvider -from dstack._internal.proxy.lib.deps import ProxyDependencyInjector from dstack._internal.proxy.lib.models import ChatModel, OpenAIChatModelFormat from dstack._internal.proxy.lib.repo import BaseProxyRepo from dstack._internal.proxy.lib.routers.model_proxy import router @@ -24,7 +23,11 @@ ) from dstack._internal.proxy.lib.services.model_proxy.clients.base import ChatCompletionsClient from dstack._internal.proxy.lib.testing.auth import ProxyTestAuthProvider -from dstack._internal.proxy.lib.testing.common import make_project, make_service +from dstack._internal.proxy.lib.testing.common import ( + ProxyTestDependencyInjector, + make_project, + make_service, +) SAMPLE_RESPONSE = "Hello there, how may I assist you today?" @@ -87,7 +90,7 @@ def make_model( def make_http_client(repo: BaseProxyRepo, auth: BaseProxyAuthProvider) -> httpx.AsyncClient: app = FastAPI() - app.state.proxy_dependency_injector = ProxyDependencyInjector(repo=repo, auth=auth) + app.state.proxy_dependency_injector = ProxyTestDependencyInjector(repo=repo, auth=auth) app.include_router(router, prefix="/proxy/models") return httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) @@ -110,7 +113,7 @@ def make_openai_client( def mock_chat_client() -> Generator[None, None, None]: with ( patch( - "dstack._internal.proxy.lib.services.service_connection.ServiceReplicaConnectionPool.get_or_add" + "dstack._internal.proxy.lib.services.service_connection.ServiceConnectionPool.get_or_add" ), patch("dstack._internal.proxy.lib.routers.model_proxy.get_chat_client") as get_client_mock, ): diff --git a/src/tests/_internal/server/services/proxy/routers/test_service_proxy.py b/src/tests/_internal/server/services/proxy/routers/test_service_proxy.py index cfb4992f2..7892c9b70 100644 --- a/src/tests/_internal/server/services/proxy/routers/test_service_proxy.py +++ b/src/tests/_internal/server/services/proxy/routers/test_service_proxy.py @@ -7,11 +7,14 @@ from dstack._internal.proxy.gateway.repo.repo import GatewayProxyRepo from dstack._internal.proxy.lib.auth import BaseProxyAuthProvider -from dstack._internal.proxy.lib.deps import ProxyDependencyInjector from dstack._internal.proxy.lib.repo import BaseProxyRepo -from dstack._internal.proxy.lib.services.service_connection import ServiceReplicaClient +from dstack._internal.proxy.lib.services.service_connection import ServiceClient from dstack._internal.proxy.lib.testing.auth import ProxyTestAuthProvider -from dstack._internal.proxy.lib.testing.common import make_project, make_service +from dstack._internal.proxy.lib.testing.common import ( + ProxyTestDependencyInjector, + make_project, + make_service, +) from dstack._internal.server.services.proxy.routers.service_proxy import router MOCK_REPLICA_CLIENT_TIMEOUT = 8 @@ -23,9 +26,9 @@ @pytest.fixture def mock_replica_client_httpbin(httpbin) -> Generator[None, None, None]: with patch( - "dstack._internal.proxy.lib.services.service_connection.ServiceReplicaConnectionPool.get_or_add" + "dstack._internal.proxy.lib.services.service_connection.ServiceConnectionPool.get_or_add" ) as add_connection_mock: - add_connection_mock.return_value.client.return_value = ServiceReplicaClient( + add_connection_mock.return_value.client.return_value = ServiceClient( base_url=httpbin.url, timeout=MOCK_REPLICA_CLIENT_TIMEOUT ) yield @@ -35,7 +38,7 @@ def make_app( repo: BaseProxyRepo, auth: BaseProxyAuthProvider = ProxyTestAuthProvider() ) -> FastAPI: app = FastAPI() - app.state.proxy_dependency_injector = ProxyDependencyInjector(repo=repo, auth=auth) + app.state.proxy_dependency_injector = ProxyTestDependencyInjector(repo=repo, auth=auth) app.include_router(router, prefix="/proxy/services") return app