Skip to content

Commit

Permalink
Add hot reloading capability using uvicorn app factory pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
itstauq committed Jun 27, 2024
1 parent 249735c commit c726060
Showing 1 changed file with 95 additions and 89 deletions.
184 changes: 95 additions & 89 deletions packages/syft/src/syft/node/server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# stdlib
import asyncio
import base64
from collections.abc import Callable
from enum import Enum
import json
import multiprocessing
import os
from pathlib import Path
import platform
import signal
import subprocess # nosec
Expand All @@ -14,7 +16,6 @@
from fastapi import FastAPI
import requests
from starlette.middleware.cors import CORSMiddleware
import uvicorn

# relative
from ..abstract_node import NodeSideType
Expand Down Expand Up @@ -56,11 +57,61 @@ def make_app(name: str, router: APIRouter) -> FastAPI:
return app


worker_classes = {
NodeType.DOMAIN: Domain,
NodeType.GATEWAY: Gateway,
NodeType.ENCLAVE: Enclave,
}
def app_factory() -> FastAPI:
try:
kwargs_encoded = os.environ["APP_FACTORY_KWARGS"]
kwargs_json = base64.b64decode(kwargs_encoded)
kwargs = json.loads(kwargs_json)
name = kwargs["name"]
node_type = kwargs["node_type"]
node_side_type = kwargs["node_side_type"]
processes = kwargs["processes"]
reset = kwargs["reset"]
dev_mode = kwargs["dev_mode"]
enable_warnings = kwargs["enable_warnings"]
in_memory_workers = kwargs["in_memory_workers"]
queue_port = kwargs["queue_port"]
create_producer = kwargs["create_producer"]
n_consumers = kwargs["n_consumers"]
association_request_auto_approval = kwargs["association_request_auto_approval"]
background_tasks = kwargs["background_tasks"]
except KeyError as e:
raise KeyError(f"Missing required environment variable: {e}")

worker_classes = {
NodeType.DOMAIN: Domain,
NodeType.GATEWAY: Gateway,
NodeType.ENCLAVE: Enclave,
}
if node_type not in worker_classes:
raise NotImplementedError(f"node_type: {node_type} is not supported")
worker_class = worker_classes[node_type]
kwargs = {
"name": name,
"processes": processes,
"local_db": True,
"node_type": node_type,
"node_side_type": node_side_type,
"enable_warnings": enable_warnings,
"migrate": True,
"in_memory_workers": in_memory_workers,
"queue_port": queue_port,
"create_producer": create_producer,
"n_consumers": n_consumers,
"association_request_auto_approval": association_request_auto_approval,
"background_tasks": background_tasks,
}
if dev_mode:
print(
f"\nWARNING: private key is based on node name: {name} in dev_mode. "
"Don't run this in production."
)
kwargs["reset"] = reset

worker = worker_class.named(**kwargs) if dev_mode else worker_class(**kwargs)
router = make_routes(worker=worker)
app = make_app(worker.name, router=router)
return app


def run_uvicorn(
Expand All @@ -80,89 +131,44 @@ def run_uvicorn(
n_consumers: int,
background_tasks: bool,
) -> None:
async def _run_uvicorn(
name: str,
node_type: NodeType,
host: str,
port: int,
reset: bool,
dev_mode: bool,
node_side_type: Enum,
) -> None:
if node_type not in worker_classes:
raise NotImplementedError(f"node_type: {node_type} is not supported")
worker_class = worker_classes[node_type]
if dev_mode:
print(
f"\nWARNING: private key is based on node name: {name} in dev_mode. "
"Don't run this in production."
)

worker = worker_class.named(
name=name,
processes=processes,
reset=reset,
local_db=True,
node_type=node_type,
node_side_type=node_side_type,
enable_warnings=enable_warnings,
migrate=True,
in_memory_workers=in_memory_workers,
queue_port=queue_port,
create_producer=create_producer,
n_consumers=n_consumers,
association_request_auto_approval=association_request_auto_approval,
background_tasks=background_tasks,
)
else:
worker = worker_class(
name=name,
processes=processes,
local_db=True,
node_type=node_type,
node_side_type=node_side_type,
enable_warnings=enable_warnings,
migrate=True,
in_memory_workers=in_memory_workers,
queue_port=queue_port,
create_producer=create_producer,
n_consumers=n_consumers,
association_request_auto_approval=association_request_auto_approval,
background_tasks=background_tasks,
)
router = make_routes(worker=worker)
app = make_app(worker.name, router=router)

if reset:
try:
python_pids = find_python_processes_on_port(port)
for pid in python_pids:
print(f"Stopping process on port: {port}")
kill_process(pid)
time.sleep(1)
except Exception: # nosec
print(f"Failed to kill python process on port: {port}")

config = uvicorn.Config(app, host=host, port=port, reload=dev_mode)
server = uvicorn.Server(config)

await server.serve()
asyncio.get_running_loop().stop()

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(
_run_uvicorn(
name,
node_type,
host,
port,
reset,
dev_mode,
node_side_type,
)
if reset:
try:
python_pids = find_python_processes_on_port(port)
for pid in python_pids:
print(f"Stopping process on port: {port}")
kill_process(pid)
time.sleep(1)
except Exception: # nosec
print(f"Failed to kill python process on port: {port}")

kwargs = {
"name": name,
"node_type": node_type,
"node_side_type": node_side_type,
"processes": processes,
"reset": reset,
"dev_mode": dev_mode,
"enable_warnings": enable_warnings,
"in_memory_workers": in_memory_workers,
"queue_port": queue_port,
"create_producer": create_producer,
"n_consumers": n_consumers,
"association_request_auto_approval": association_request_auto_approval,
"background_tasks": background_tasks,
}
kwargs_json = json.dumps(kwargs)
kwargs_encoded = base64.b64encode(kwargs_json.encode()).decode()
uvicorn_cmd = (
f"APP_FACTORY_KWARGS={kwargs_encoded}"
" uvicorn syft.node.server:app_factory"
" --factory"
f" --host {host}"
f" --port {port}"
)
loop.close()
if dev_mode:
uvicorn_cmd += f" --reload --reload-dir {Path(__file__).parent.parent}"
print(f"{uvicorn_cmd=}")
os.system(uvicorn_cmd)


def serve_node(
Expand Down

0 comments on commit c726060

Please sign in to comment.