From a99d45c56438b0281393c6d203b41264aa56d745 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Fri, 23 Aug 2024 14:51:44 +0530 Subject: [PATCH] spin up in memory consumers on startup --- packages/syft/src/syft/server/server.py | 25 ++++++++++++++++++++++++ packages/syft/src/syft/server/uvicorn.py | 17 +++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 66bb446fbf5..eca1b56a8d8 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -608,6 +608,31 @@ def init_queue_manager(self, queue_config: QueueConfig) -> None: message_handler=message_handler, ) + if self.in_memory_workers: + self.start_in_memory_workers( + address=address, message_handler=message_handler + ) + + def start_in_memory_workers( + self, address: str, message_handler: type[AbstractMessageHandler] + ) -> None: + """Starts in-memory workers for the server.""" + + worker_pools = self.pool_stash.get_all(credentials=self.verify_key).ok() + for worker_pool in worker_pools: + # Skip the default worker pool + if worker_pool.name == DEFAULT_WORKER_POOL_NAME: + continue + + # Create consumers for each worker pool + for linked_worker in worker_pool.worker_list: + self.add_consumer_for_service( + service_name=worker_pool.name, + syft_worker_id=linked_worker.object_uid, + address=address, + message_handler=message_handler, + ) + def add_consumer_for_service( self, service_name: str, diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 953d19a4c2e..3549d7a5987 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -1,5 +1,6 @@ # stdlib from collections.abc import Callable +from contextlib import asynccontextmanager import logging import multiprocessing import multiprocessing.synchronize @@ -31,6 +32,7 @@ from .enclave import Enclave from .gateway import Gateway from .routes import make_routes +from .server import Server from .server import ServerType from .utils import get_named_server_uid from .utils import remove_temp_dir_for_server @@ -61,6 +63,17 @@ class AppSettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") +def get_lifetime(worker: Server) -> Callable: + @asynccontextmanager + async def lifespan(app: FastAPI) -> Any: + try: + yield + finally: + worker.stop() + + return lifespan + + def app_factory() -> FastAPI: settings = AppSettings() @@ -85,7 +98,9 @@ def app_factory() -> FastAPI: else: worker = worker_class(**kwargs) - app = FastAPI(title=settings.name) + worker_lifespan = get_lifetime(worker=worker) + + app = FastAPI(title=settings.name, lifespan=worker_lifespan) router = make_routes(worker=worker) api_router = APIRouter() api_router.include_router(router)