Skip to content

Commit

Permalink
Use pydantic_settings to manage FastAPI app settings
Browse files Browse the repository at this point in the history
  • Loading branch information
itstauq committed Jun 29, 2024
1 parent c726060 commit 7705464
Showing 1 changed file with 48 additions and 101 deletions.
149 changes: 48 additions & 101 deletions packages/syft/src/syft/node/server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# stdlib
import base64
from collections.abc import Callable
from enum import Enum
import json
import multiprocessing
import os
from pathlib import Path
import platform
import signal
import subprocess # nosec
import time
from typing import Any

# third party
from fastapi import APIRouter
from fastapi import FastAPI
from pydantic_settings import BaseSettings
from pydantic_settings import SettingsConfigDict
import requests
from starlette.middleware.cors import CORSMiddleware

Expand All @@ -36,102 +36,64 @@
WAIT_TIME_SECONDS = 20


def make_app(name: str, router: APIRouter) -> FastAPI:
app = FastAPI(
title=name,
)

api_router = APIRouter()

api_router.include_router(router)
app.include_router(api_router, prefix="/api/v2")

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AppSettings(BaseSettings):
name: str
node_type: NodeType = NodeType.DOMAIN
node_side_type: NodeSideType = NodeSideType.HIGH_SIDE
processes: int = 1
reset: bool = False
dev_mode: bool = False
enable_warnings: bool = False
in_memory_workers: bool = True
queue_port: int | None = None
create_producer: bool = False
n_consumers: int = 0
association_request_auto_approval: bool = False
background_tasks: bool = False

return app
model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None")


def app_factory() -> FastAPI:
try:
kwargs_encoded = os.environ["APP_FACTORY_KWARGS"]
kwargs_json = base64.b64decode(kwargs_encoded)
kwargs = json.loads(kwargs_json)
name = kwargs["name"]
node_type = kwargs["node_type"]
node_side_type = kwargs["node_side_type"]
processes = kwargs["processes"]
reset = kwargs["reset"]
dev_mode = kwargs["dev_mode"]
enable_warnings = kwargs["enable_warnings"]
in_memory_workers = kwargs["in_memory_workers"]
queue_port = kwargs["queue_port"]
create_producer = kwargs["create_producer"]
n_consumers = kwargs["n_consumers"]
association_request_auto_approval = kwargs["association_request_auto_approval"]
background_tasks = kwargs["background_tasks"]
except KeyError as e:
raise KeyError(f"Missing required environment variable: {e}")
settings = AppSettings()

worker_classes = {
NodeType.DOMAIN: Domain,
NodeType.GATEWAY: Gateway,
NodeType.ENCLAVE: Enclave,
}
if node_type not in worker_classes:
raise NotImplementedError(f"node_type: {node_type} is not supported")
worker_class = worker_classes[node_type]
kwargs = {
"name": name,
"processes": processes,
"local_db": True,
"node_type": node_type,
"node_side_type": node_side_type,
"enable_warnings": enable_warnings,
"migrate": True,
"in_memory_workers": in_memory_workers,
"queue_port": queue_port,
"create_producer": create_producer,
"n_consumers": n_consumers,
"association_request_auto_approval": association_request_auto_approval,
"background_tasks": background_tasks,
}
if dev_mode:
if settings.node_type not in worker_classes:
raise NotImplementedError(f"node_type: {settings.node_type} is not supported")
worker_class = worker_classes[settings.node_type]

kwargs = settings.model_dump()
if settings.dev_mode:
print(
f"\nWARNING: private key is based on node name: {name} in dev_mode. "
f"\nWARNING: private key is based on node name: {settings.name} in dev_mode. "
"Don't run this in production."
)
kwargs["reset"] = reset
worker = worker_class.named(**kwargs)
else:
del kwargs["reset"] # Explicitly remove reset from kwargs for non-dev mode
worker = worker_class(**kwargs)

worker = worker_class.named(**kwargs) if dev_mode else worker_class(**kwargs)
app = FastAPI(title=settings.name)
router = make_routes(worker=worker)
app = make_app(worker.name, router=router)
api_router = APIRouter()
api_router.include_router(router)
app.include_router(api_router, prefix="/api/v2")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return app


def run_uvicorn(
name: str,
node_type: Enum,
host: str,
port: int,
processes: int,
reset: bool,
dev_mode: bool,
node_side_type: str,
enable_warnings: bool,
in_memory_workers: bool,
queue_port: int | None,
create_producer: bool,
association_request_auto_approval: bool,
n_consumers: int,
background_tasks: bool,
) -> None:
if reset:
def run_uvicorn(host: str, port: int, **kwargs: Any) -> None:
if kwargs.get("reset"):
try:
python_pids = find_python_processes_on_port(port)
for pid in python_pids:
Expand All @@ -141,33 +103,18 @@ def run_uvicorn(
except Exception: # nosec
print(f"Failed to kill python process on port: {port}")

kwargs = {
"name": name,
"node_type": node_type,
"node_side_type": node_side_type,
"processes": processes,
"reset": reset,
"dev_mode": dev_mode,
"enable_warnings": enable_warnings,
"in_memory_workers": in_memory_workers,
"queue_port": queue_port,
"create_producer": create_producer,
"n_consumers": n_consumers,
"association_request_auto_approval": association_request_auto_approval,
"background_tasks": background_tasks,
}
kwargs_json = json.dumps(kwargs)
kwargs_encoded = base64.b64encode(kwargs_json.encode()).decode()
env_prefix = AppSettings.model_config.get("env_prefix", "")
env_variables = " ".join(f"{env_prefix}{k.upper()}={v}" for k, v in kwargs.items())

uvicorn_cmd = (
f"APP_FACTORY_KWARGS={kwargs_encoded}"
f"{env_variables}"
" uvicorn syft.node.server:app_factory"
" --factory"
f" --host {host}"
f" --port {port}"
)
if dev_mode:
if kwargs.get("dev_mode"):
uvicorn_cmd += f" --reload --reload-dir {Path(__file__).parent.parent}"
print(f"{uvicorn_cmd=}")
os.system(uvicorn_cmd)


Expand Down

0 comments on commit 7705464

Please sign in to comment.