Skip to content

Commit

Permalink
Merge pull request #9053 from OpenMined/tauquir/enclave-model-upload-…
Browse files Browse the repository at this point in the history
…optimization

Enclave model upload optimization
  • Loading branch information
rasswanth-s authored Jul 22, 2024
2 parents b3b03f1 + 80166a5 commit 3304443
Show file tree
Hide file tree
Showing 10 changed files with 672 additions and 40 deletions.
429 changes: 429 additions & 0 deletions notebooks/experimental/Digital Signatures/ds-benchmarks.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions packages/syft/src/syft/client/domain_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def upload_model(self, model: CreateModel) -> SyftSuccess | SyftError:
model_size += get_mb_size(asset.data)
model_ref_action_ids.append(twin.id)

# Clear the Data and Mock , as they are uploaded as twin object
asset.data = None
asset.mock = None

# Update the progress bar and set the dynamic description
pbar.set_description(f"Uploading: {asset.name}")
pbar.update(1)
Expand Down
47 changes: 46 additions & 1 deletion packages/syft/src/syft/node/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import base64
import binascii
from collections.abc import AsyncGenerator
from collections.abc import Callable
from datetime import datetime
import logging
from pathlib import Path
from typing import Annotated

# third party
Expand Down Expand Up @@ -34,12 +37,13 @@
from ..util.telemetry import TRACE_MODE
from .credentials import SyftVerifyKey
from .credentials import UserLoginCredentials
from .server_settings import ServerSettings
from .worker import Worker

logger = logging.getLogger(__name__)


def make_routes(worker: Worker) -> APIRouter:
def make_routes(worker: Worker, settings: ServerSettings | None = None) -> APIRouter:
if TRACE_MODE:
# third party
try:
Expand All @@ -49,6 +53,34 @@ def make_routes(worker: Worker) -> APIRouter:
except Exception as e:
logger.error("Failed to import opentelemetry", exc_info=e)

def _handle_profile(
request: Request, handler_func: Callable, *args: list, **kwargs: dict
) -> Response:
if not settings:
raise Exception("Server Settings are required to enable profiling")
# third party
from pyinstrument import Profiler # Lazy Load

profiles_dir = Path(settings.profile_dir or Path.cwd()) / "profiles"
profiles_dir.mkdir(parents=True, exist_ok=True)

with Profiler(
interval=settings.profile_interval, async_mode="enabled"
) as profiler:
response = handler_func(*args, **kwargs)

timestamp = datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
url_path = request.url.path.replace("/api/v2", "").replace("/", "-")
profile_output_path = (
profiles_dir / f"{settings.name}-{timestamp}{url_path}.html"
)
profiler.write_html(profile_output_path)

logger.info(
f"Request to {request.url.path} took {profiler.last_session.duration:.2f} seconds"
)
return response

router = APIRouter()

async def get_body(request: Request) -> bytes:
Expand Down Expand Up @@ -165,6 +197,13 @@ def syft_new_api(
kind=trace.SpanKind.SERVER,
):
return handle_syft_new_api(user_verify_key, communication_protocol)
elif settings and settings.profile:
return _handle_profile(
request,
handle_syft_new_api,
user_verify_key,
communication_protocol,
)
else:
return handle_syft_new_api(user_verify_key, communication_protocol)

Expand All @@ -188,6 +227,8 @@ def syft_new_api_call(
kind=trace.SpanKind.SERVER,
):
return handle_new_api_call(data)
elif settings and settings.profile:
return _handle_profile(request, handle_new_api_call, data)
else:
return handle_new_api_call(data)

Expand Down Expand Up @@ -253,6 +294,8 @@ def login(
kind=trace.SpanKind.SERVER,
):
return handle_login(email, password, worker)
elif settings and settings.profile:
return _handle_profile(request, handle_login, email, password, worker)
else:
return handle_login(email, password, worker)

Expand All @@ -267,6 +310,8 @@ def register(
kind=trace.SpanKind.SERVER,
):
return handle_register(data, worker)
elif settings and settings.profile:
return _handle_profile(request, handle_register, data, worker)
else:
return handle_register(data, worker)

Expand Down
127 changes: 100 additions & 27 deletions packages/syft/src/syft/node/server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
from collections.abc import Callable
from datetime import datetime
import multiprocessing
import multiprocessing.synchronize
import os
Expand All @@ -14,8 +15,8 @@
# third party
from fastapi import APIRouter
from fastapi import FastAPI
from pydantic_settings import BaseSettings
from pydantic_settings import SettingsConfigDict
from fastapi import Request
from fastapi import Response
import requests
from starlette.middleware.cors import CORSMiddleware
import uvicorn
Expand All @@ -31,6 +32,7 @@
from .gateway import Gateway
from .node import NodeType
from .routes import make_routes
from .server_settings import ServerSettings
from .utils import get_named_node_uid
from .utils import remove_temp_dir_for_node

Expand All @@ -42,26 +44,8 @@
WAIT_TIME_SECONDS = 20


class AppSettings(BaseSettings):
name: str
node_type: NodeType = NodeType.DOMAIN
node_side_type: NodeSideType = NodeSideType.HIGH_SIDE
processes: int = 1
reset: bool = False
dev_mode: bool = False
enable_warnings: bool = False
in_memory_workers: bool = True
queue_port: int | None = None
create_producer: bool = False
n_consumers: int = 0
association_request_auto_approval: bool = False
background_tasks: bool = False

model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None")


def app_factory() -> FastAPI:
settings = AppSettings()
settings = ServerSettings()

worker_classes = {
NodeType.DOMAIN: Domain,
Expand All @@ -72,29 +56,105 @@ def app_factory() -> FastAPI:
raise NotImplementedError(f"node_type: {settings.node_type} is not supported")
worker_class = worker_classes[settings.node_type]

kwargs = settings.model_dump()
worker_kwargs = settings.model_dump()
# Remove Profiling inputs
worker_kwargs.pop("profile")
worker_kwargs.pop("profile_interval")
worker_kwargs.pop("profile_dir")
if settings.dev_mode:
print(
f"WARN: private key is based on node name: {settings.name} in dev_mode. "
"Don't run this in production."
)
worker = worker_class.named(**kwargs)
worker = worker_class.named(**worker_kwargs)
else:
worker = worker_class(**kwargs)
worker = worker_class(**worker_kwargs)

app = FastAPI(title=settings.name)
router = make_routes(worker=worker)
router = make_routes(worker=worker, settings=settings)
api_router = APIRouter()
api_router.include_router(router)
app.include_router(api_router, prefix="/api/v2")

# Register middlewares
_register_middlewares(app, settings)

return app


def _register_middlewares(app: FastAPI, settings: ServerSettings) -> None:
_register_cors_middleware(app)

# As currently sync routes are not supported in pyinstrument
# we are not registering the profiler middleware for sync routes
# as currently most of our routes are sync routes in syft (routes.py)
# ex: syft_new_api, syft_new_api_call, login, register
# we should either convert these routes to async or
# wait until pyinstrument supports sync routes
# The reason we cannot our sync routes to async is because
# we have blocking IO operations, like the requests library, like if one route calls to
# itself, it will block the event loop and the server will hang
# if settings.profile:
# _register_profiler(app, settings)


def _register_cors_middleware(app: FastAPI) -> None:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return app


def _register_profiler(app: FastAPI, settings: ServerSettings) -> None:
# third party
from pyinstrument import Profiler

profiles_dir = (
Path.cwd() / "profiles"
if settings.profile_dir is None
else Path(settings.profile_dir) / "profiles"
)

@app.middleware("http")
async def profile_request(
request: Request, call_next: Callable[[Request], Response]
) -> Response:
with Profiler(
interval=settings.profile_interval, async_mode="enabled"
) as profiler:
response = await call_next(request)

# Profile File Name - Domain Name - Timestamp - URL Path
timestamp = datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
profiles_dir.mkdir(parents=True, exist_ok=True)
url_path = request.url.path.replace("/api/v2", "").replace("/", "-")
profile_output_path = (
profiles_dir / f"{settings.name}-{timestamp}{url_path}.html"
)

# Write the profile to a HTML file
profiler.write_html(profile_output_path)

print(
f"Request to {request.url.path} took {profiler.last_session.duration:.2f} seconds"
)

return response


def _load_pyinstrument_jupyter_extension() -> None:
try:
# third party
from IPython import get_ipython

ipython = get_ipython() # noqa: F821
ipython.run_line_magic("load_ext", "pyinstrument")
print("Pyinstrument Jupyter extension loaded")
except Exception as e:
print(f"Error loading pyinstrument jupyter extension: {e}")


def attach_debugger() -> None:
Expand Down Expand Up @@ -142,7 +202,7 @@ def run_uvicorn(
attach_debugger()

# Set up all kwargs as environment variables so that they can be accessed in the app_factory function.
env_prefix = AppSettings.model_config.get("env_prefix", "")
env_prefix = ServerSettings.model_config.get("env_prefix", "")
for key, value in kwargs.items():
key_with_prefix = f"{env_prefix}{key.upper()}"
os.environ[key_with_prefix] = str(value)
Expand Down Expand Up @@ -187,13 +247,23 @@ def serve_node(
association_request_auto_approval: bool = False,
background_tasks: bool = False,
debug: bool = False,
# Profiling inputs
profile: bool = False,
profile_interval: float = 0.001,
profile_dir: str | None = None,
) -> tuple[Callable, Callable]:
starting_uvicorn_event = multiprocessing.Event()

# Enable IPython autoreload if dev_mode is enabled.
if dev_mode:
enable_autoreload()

# Load the Pyinstrument Jupyter extension if profile is enabled.
if profile:
_load_pyinstrument_jupyter_extension()
if profile_dir is None:
profile_dir = str(Path.cwd())

server_process = multiprocessing.Process(
target=run_uvicorn,
kwargs={
Expand All @@ -214,6 +284,9 @@ def serve_node(
"background_tasks": background_tasks,
"debug": debug,
"starting_uvicorn_event": starting_uvicorn_event,
"profile": profile,
"profile_interval": profile_interval,
"profile_dir": profile_dir,
},
)

Expand Down
30 changes: 30 additions & 0 deletions packages/syft/src/syft/node/server_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# third party
from pydantic_settings import BaseSettings
from pydantic_settings import SettingsConfigDict

# relative
from ..abstract_node import NodeSideType
from ..abstract_node import NodeType


class ServerSettings(BaseSettings):
name: str
node_type: NodeType = NodeType.DOMAIN
node_side_type: NodeSideType = NodeSideType.HIGH_SIDE
processes: int = 1
reset: bool = False
dev_mode: bool = False
enable_warnings: bool = False
in_memory_workers: bool = True
queue_port: int | None = None
create_producer: bool = False
n_consumers: int = 0
association_request_auto_approval: bool = False
background_tasks: bool = False

# Profiling inputs
profile: bool = False
profile_interval: float = 0.001
profile_dir: str | None = None

model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None")
13 changes: 13 additions & 0 deletions packages/syft/src/syft/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def deploy_to_python(
background_tasks: bool = False,
debug: bool = False,
migrate: bool = False,
profile: bool = False,
profile_interval: float = 0.001,
profile_dir: str | None = None,
) -> NodeHandle:
worker_classes = {
NodeType.DOMAIN: Domain,
Expand Down Expand Up @@ -204,6 +207,9 @@ def deploy_to_python(
"background_tasks": background_tasks,
"debug": debug,
"migrate": migrate,
"profile": profile,
"profile_interval": profile_interval,
"profile_dir": profile_dir,
}

if port:
Expand Down Expand Up @@ -305,6 +311,10 @@ def launch(
background_tasks: bool = False,
debug: bool = False,
migrate: bool = False,
# Profiling Related Input for in-memory fastapi server
profile: bool = False,
profile_interval: float = 0.001,
profile_dir: str | None = None,
) -> NodeHandle:
if dev_mode is True:
thread_workers = True
Expand Down Expand Up @@ -343,6 +353,9 @@ def launch(
background_tasks=background_tasks,
debug=debug,
migrate=migrate,
profile=profile,
profile_interval=profile_interval,
profile_dir=profile_dir,
)
display(
SyftInfo(
Expand Down
Loading

0 comments on commit 3304443

Please sign in to comment.