From 790174e0bc9161908d7d14b4a467b02ad954ad1f Mon Sep 17 00:00:00 2001 From: Tauquir <30658453+itstauq@users.noreply.github.com> Date: Wed, 10 Jul 2024 20:45:33 +0530 Subject: [PATCH] Prevent server data reset due to uvicorn hotreload --- packages/syft/src/syft/node/node.py | 16 +++++------ packages/syft/src/syft/node/server.py | 15 ++++++++--- packages/syft/src/syft/node/utils.py | 38 +++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 13 deletions(-) create mode 100644 packages/syft/src/syft/node/utils.py diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index f13187852b1..e34e5090cd2 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -12,10 +12,8 @@ import logging import os from pathlib import Path -import shutil import subprocess # nosec import sys -import tempfile from time import sleep import traceback from typing import Any @@ -122,6 +120,9 @@ from .credentials import SyftSigningKey from .credentials import SyftVerifyKey from .service_registry import ServiceRegistry +from .utils import get_named_node_uid +from .utils import get_temp_dir_for_node +from .utils import remove_temp_dir_for_node from .worker_settings import WorkerSettings logger = logging.getLogger(__name__) @@ -655,7 +656,7 @@ def named( association_request_auto_approval: bool = False, background_tasks: bool = False, ) -> Node: - uid = UID.with_seed(name) + uid = get_named_node_uid(name) name_hash = hashlib.sha256(name.encode("utf8")).digest() key = SyftSigningKey(signing_key=SigningKey(name_hash)) blob_storage_config = None @@ -952,18 +953,13 @@ def get_temp_dir(self, dir_name: str = "") -> Path: Get a temporary directory unique to the node. Provide all dbs, blob dirs, and locks using this directory. """ - root = os.getenv("SYFT_TEMP_ROOT", "syft") - p = Path(tempfile.gettempdir(), root, str(self.id), dir_name) - p.mkdir(parents=True, exist_ok=True) - return p + return get_temp_dir_for_node(self.id, dir_name) def remove_temp_dir(self) -> None: """ Remove the temporary directory for this node. """ - rootdir = self.get_temp_dir() - if rootdir.exists(): - shutil.rmtree(rootdir, ignore_errors=True) + remove_temp_dir_for_node(self.id) def update_self(self, settings: NodeSettings) -> None: updateable_attrs = ( diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 4628cae6005..a3451f304a7 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -31,6 +31,8 @@ from .gateway import Gateway from .node import NodeType from .routes import make_routes +from .utils import get_named_node_uid +from .utils import remove_temp_dir_for_node if os_name() == "macOS": # needed on MacOS to prevent [__NSCFConstantString initialize] may have been in @@ -73,12 +75,11 @@ def app_factory() -> FastAPI: kwargs = settings.model_dump() if settings.dev_mode: print( - f"\nWARNING: private key is based on node name: {settings.name} in dev_mode. " + f"WARN: private key is based on node name: {settings.name} in dev_mode. " "Don't run this in production." ) worker = worker_class.named(**kwargs) else: - del kwargs["reset"] # Explicitly remove reset from kwargs for non-dev mode worker = worker_class(**kwargs) app = FastAPI(title=settings.name) @@ -119,7 +120,15 @@ def run_uvicorn( starting_uvicorn_event: multiprocessing.synchronize.Event, **kwargs: Any, ) -> None: - if kwargs.get("reset"): + should_reset = kwargs.get("dev_mode") and kwargs.get("reset") + + if should_reset: + print("Found `reset=True` in the launch configuration. Resetting the node...") + named_node_uid = get_named_node_uid(kwargs.get("name")) + remove_temp_dir_for_node(named_node_uid) + # Explicitly set `reset` to False to prevent multiple resets during hot-reload + kwargs["reset"] = False + # Kill all old python processes try: python_pids = find_python_processes_on_port(port) for pid in python_pids: diff --git a/packages/syft/src/syft/node/utils.py b/packages/syft/src/syft/node/utils.py new file mode 100644 index 00000000000..3048fd3fa94 --- /dev/null +++ b/packages/syft/src/syft/node/utils.py @@ -0,0 +1,38 @@ +# future +from __future__ import annotations + +# stdlib +import os +from pathlib import Path +import shutil +import tempfile + +# relative +from ..types.uid import UID + + +def get_named_node_uid(name: str) -> UID: + """ + Get a unique identifier for a named node. + """ + return UID.with_seed(name) + + +def get_temp_dir_for_node(node_uid: UID, dir_name: str = "") -> Path: + """ + Get a temporary directory unique to the node. + Provide all dbs, blob dirs, and locks using this directory. + """ + root = os.getenv("SYFT_TEMP_ROOT", "syft") + p = Path(tempfile.gettempdir(), root, str(node_uid), dir_name) + p.mkdir(parents=True, exist_ok=True) + return p + + +def remove_temp_dir_for_node(node_uid: UID) -> None: + """ + Remove the temporary directory for this node. + """ + rootdir = get_temp_dir_for_node(node_uid) + if rootdir.exists(): + shutil.rmtree(rootdir, ignore_errors=True)