Skip to content

Commit

Permalink
Merge pull request #9034 from OpenMined/tauquir/prevent-reset-during-…
Browse files Browse the repository at this point in the history
…hotreload

Prevent server data reset due to Uvicorn hot-reload
  • Loading branch information
rasswanth-s authored Jul 11, 2024
2 parents ea5e612 + 790174e commit fe31b28
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 13 deletions.
16 changes: 6 additions & 10 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
import logging
import os
from pathlib import Path
import shutil
import subprocess # nosec
import sys
import tempfile
from time import sleep
import traceback
from typing import Any
Expand Down Expand Up @@ -122,6 +120,9 @@
from .credentials import SyftSigningKey
from .credentials import SyftVerifyKey
from .service_registry import ServiceRegistry
from .utils import get_named_node_uid
from .utils import get_temp_dir_for_node
from .utils import remove_temp_dir_for_node
from .worker_settings import WorkerSettings

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -655,7 +656,7 @@ def named(
association_request_auto_approval: bool = False,
background_tasks: bool = False,
) -> Node:
uid = UID.with_seed(name)
uid = get_named_node_uid(name)
name_hash = hashlib.sha256(name.encode("utf8")).digest()
key = SyftSigningKey(signing_key=SigningKey(name_hash))
blob_storage_config = None
Expand Down Expand Up @@ -893,18 +894,13 @@ def get_temp_dir(self, dir_name: str = "") -> Path:
Get a temporary directory unique to the node.
Provide all dbs, blob dirs, and locks using this directory.
"""
root = os.getenv("SYFT_TEMP_ROOT", "syft")
p = Path(tempfile.gettempdir(), root, str(self.id), dir_name)
p.mkdir(parents=True, exist_ok=True)
return p
return get_temp_dir_for_node(self.id, dir_name)

def remove_temp_dir(self) -> None:
"""
Remove the temporary directory for this node.
"""
rootdir = self.get_temp_dir()
if rootdir.exists():
shutil.rmtree(rootdir, ignore_errors=True)
remove_temp_dir_for_node(self.id)

def update_self(self, settings: NodeSettings) -> None:
updateable_attrs = (
Expand Down
15 changes: 12 additions & 3 deletions packages/syft/src/syft/node/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from .gateway import Gateway
from .node import NodeType
from .routes import make_routes
from .utils import get_named_node_uid
from .utils import remove_temp_dir_for_node

if os_name() == "macOS":
# needed on MacOS to prevent [__NSCFConstantString initialize] may have been in
Expand Down Expand Up @@ -73,12 +75,11 @@ def app_factory() -> FastAPI:
kwargs = settings.model_dump()
if settings.dev_mode:
print(
f"\nWARNING: private key is based on node name: {settings.name} in dev_mode. "
f"WARN: private key is based on node name: {settings.name} in dev_mode. "
"Don't run this in production."
)
worker = worker_class.named(**kwargs)
else:
del kwargs["reset"] # Explicitly remove reset from kwargs for non-dev mode
worker = worker_class(**kwargs)

app = FastAPI(title=settings.name)
Expand Down Expand Up @@ -119,7 +120,15 @@ def run_uvicorn(
starting_uvicorn_event: multiprocessing.synchronize.Event,
**kwargs: Any,
) -> None:
if kwargs.get("reset"):
should_reset = kwargs.get("dev_mode") and kwargs.get("reset")

if should_reset:
print("Found `reset=True` in the launch configuration. Resetting the node...")
named_node_uid = get_named_node_uid(kwargs.get("name"))
remove_temp_dir_for_node(named_node_uid)
# Explicitly set `reset` to False to prevent multiple resets during hot-reload
kwargs["reset"] = False
# Kill all old python processes
try:
python_pids = find_python_processes_on_port(port)
for pid in python_pids:
Expand Down
38 changes: 38 additions & 0 deletions packages/syft/src/syft/node/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# future
from __future__ import annotations

# stdlib
import os
from pathlib import Path
import shutil
import tempfile

# relative
from ..types.uid import UID


def get_named_node_uid(name: str) -> UID:
"""
Get a unique identifier for a named node.
"""
return UID.with_seed(name)


def get_temp_dir_for_node(node_uid: UID, dir_name: str = "") -> Path:
"""
Get a temporary directory unique to the node.
Provide all dbs, blob dirs, and locks using this directory.
"""
root = os.getenv("SYFT_TEMP_ROOT", "syft")
p = Path(tempfile.gettempdir(), root, str(node_uid), dir_name)
p.mkdir(parents=True, exist_ok=True)
return p


def remove_temp_dir_for_node(node_uid: UID) -> None:
"""
Remove the temporary directory for this node.
"""
rootdir = get_temp_dir_for_node(node_uid)
if rootdir.exists():
shutil.rmtree(rootdir, ignore_errors=True)

0 comments on commit fe31b28

Please sign in to comment.