Skip to content

Commit

Permalink
Merge pull request #9200 from OpenMined/shubham/workers-launch-at-start
Browse files Browse the repository at this point in the history
spin up in memory consumers on startup
  • Loading branch information
IonesioJunior authored Aug 23, 2024
2 parents 5b71773 + a99d45c commit 4893271
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
25 changes: 25 additions & 0 deletions packages/syft/src/syft/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,31 @@ def init_queue_manager(self, queue_config: QueueConfig) -> None:
message_handler=message_handler,
)

if self.in_memory_workers:
self.start_in_memory_workers(
address=address, message_handler=message_handler
)

def start_in_memory_workers(
self, address: str, message_handler: type[AbstractMessageHandler]
) -> None:
"""Starts in-memory workers for the server."""

worker_pools = self.pool_stash.get_all(credentials=self.verify_key).ok()
for worker_pool in worker_pools:
# Skip the default worker pool
if worker_pool.name == DEFAULT_WORKER_POOL_NAME:
continue

# Create consumers for each worker pool
for linked_worker in worker_pool.worker_list:
self.add_consumer_for_service(
service_name=worker_pool.name,
syft_worker_id=linked_worker.object_uid,
address=address,
message_handler=message_handler,
)

def add_consumer_for_service(
self,
service_name: str,
Expand Down
17 changes: 16 additions & 1 deletion packages/syft/src/syft/server/uvicorn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
from collections.abc import Callable
from contextlib import asynccontextmanager
import logging
import multiprocessing
import multiprocessing.synchronize
Expand Down Expand Up @@ -31,6 +32,7 @@
from .enclave import Enclave
from .gateway import Gateway
from .routes import make_routes
from .server import Server
from .server import ServerType
from .utils import get_named_server_uid
from .utils import remove_temp_dir_for_server
Expand Down Expand Up @@ -61,6 +63,17 @@ class AppSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None")


def get_lifetime(worker: Server) -> Callable:
@asynccontextmanager
async def lifespan(app: FastAPI) -> Any:
try:
yield
finally:
worker.stop()

return lifespan


def app_factory() -> FastAPI:
settings = AppSettings()

Expand All @@ -85,7 +98,9 @@ def app_factory() -> FastAPI:
else:
worker = worker_class(**kwargs)

app = FastAPI(title=settings.name)
worker_lifespan = get_lifetime(worker=worker)

app = FastAPI(title=settings.name, lifespan=worker_lifespan)
router = make_routes(worker=worker)
api_router = APIRouter()
api_router.include_router(router)
Expand Down

0 comments on commit 4893271

Please sign in to comment.