diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 89c7beb2..612a6104 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -48,7 +48,7 @@ jobs: pip list - name: Tests - timeout-minutes: 10 + timeout-minutes: 30 run: | python -m pytest --cov=litserve src/ tests/ -v -s diff --git a/src/litserve/server.py b/src/litserve/server.py index dc98974c..a49946f5 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -26,9 +26,12 @@ from collections import deque from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager +from multiprocessing.context import Process +from threading import Thread from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import uvicorn +import uvicorn.server from fastapi import Depends, FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security import APIKeyHeader @@ -244,6 +247,7 @@ def __init__( self.model_metadata = model_metadata self._connector = _Connector(accelerator=accelerator, devices=devices) self._callback_runner = CallbackRunner(callbacks) + self._uvicorn_servers: List[uvicorn.Server] = [] self.use_zmq = fast_queue specs = spec if spec is not None else [] @@ -583,19 +587,40 @@ def run( elif api_server_worker_type is None: api_server_worker_type = "process" - manager, litserve_workers = self.launch_inference_worker(num_api_servers) + manager, inference_workers = self.launch_inference_worker(num_api_servers) self.verify_worker_status() try: - servers = self._start_server(port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs) + uvicorn_workers = self._start_server( + port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs + ) print(f"Swagger UI is available at http://0.0.0.0:{port}/docs") - for s in servers: - s.join() + if sys.platform != "win32": + # On Linux, kill signal will be captured by uvicorn. + # => They will join and raise a KeyboardInterrupt, allowing to Shutdown server. + for uw in uvicorn_workers: + uw: Union[Process, Thread] + uw.join() + else: + # On Windows, kill signal is captured by inference workers. + # => They will join and raise a KeyboardInterrupt, allowing to Shutdown Server + for iw in inference_workers: + iw: Process + iw.join() + except KeyboardInterrupt: + # KeyboardInterruption + if sys.platform == "win32": + # We kindly ask uvicorn servers to exit. + # It will properly end threads on windows. + for us in self._uvicorn_servers: + us: uvicorn.Server + us.should_exit = True finally: print("Shutting down LitServe") - for w in litserve_workers: - w.terminate() - w.join() + for iw in inference_workers: + iw: Process + iw.terminate() + iw.join() manager.shutdown() def _prepare_app_run(self, app: FastAPI): @@ -605,7 +630,7 @@ def _prepare_app_run(self, app: FastAPI): app.add_middleware(RequestCountMiddleware, active_counter=active_counter) def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_worker_type, **kwargs): - servers = [] + workers = [] for response_queue_id in range(num_uvicorn_servers): self.app.response_queue_id = response_queue_id if self.lit_spec: @@ -613,8 +638,16 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w app: FastAPI = copy.copy(self.app) self._prepare_app_run(app) - config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) + if sys.platform == "win32" and num_uvicorn_servers > 1: + logger.debug("Enable Windows explicit socket sharing...") + # We make sure sockets is listening... + # It prevents further [WinError 10022] + for sock in sockets: + sock.listen(config.backlog) + # We add worker to say unicorn to use a shared socket (win32) + # https://github.com/encode/uvicorn/pull/802 + config.workers = num_uvicorn_servers server = uvicorn.Server(config=config) if uvicorn_worker_type == "process": ctx = mp.get_context("fork") @@ -624,8 +657,9 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w else: raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'") w.start() - servers.append(w) - return servers + workers.append(w) + self._uvicorn_servers.append(server) + return workers def setup_auth(self): if hasattr(self.lit_api, "authorize") and callable(self.lit_api.authorize):