From 879e8c74a50a1be29dcb4b7bed6cf23258cd8edb Mon Sep 17 00:00:00 2001 From: Javier <jafermarq@users.noreply.github.com> Date: Thu, 12 Dec 2024 14:58:24 +0000 Subject: [PATCH] fix(framework) Make `load_app` thread safe (#4687) --- src/py/flwr/common/object_ref.py | 111 ++++++++++++++++--------------- 1 file changed, 57 insertions(+), 54 deletions(-) diff --git a/src/py/flwr/common/object_ref.py b/src/py/flwr/common/object_ref.py index 91414ef210f8..bf34bf5f639c 100644 --- a/src/py/flwr/common/object_ref.py +++ b/src/py/flwr/common/object_ref.py @@ -21,6 +21,7 @@ from importlib.util import find_spec from logging import WARN from pathlib import Path +from threading import Lock from typing import Any, Optional, Union from .logger import log @@ -34,6 +35,7 @@ _current_sys_path: Optional[str] = None +_import_lock = Lock() def validate( @@ -146,60 +148,61 @@ def load_app( # pylint: disable= too-many-branches - This function will modify `sys.path` by inserting the provided `project_dir` and removing the previously inserted `project_dir`. """ - valid, error_msg = validate(module_attribute_str, check_module=False) - if not valid and error_msg: - raise error_type(error_msg) from None - - module_str, _, attributes_str = module_attribute_str.partition(":") - - try: - # Initialize project path - if project_dir is None: - project_dir = Path.cwd() - project_dir = Path(project_dir).absolute() - - # Unload modules if the project directory has changed - if _current_sys_path and _current_sys_path != str(project_dir): - _unload_modules(Path(_current_sys_path)) - - # Set the system path - _set_sys_path(project_dir) - - # Import the module - if module_str not in sys.modules: - module = importlib.import_module(module_str) - # Hack: `tabnet` does not work with `importlib.reload` - elif "tabnet" in sys.modules: - log( - WARN, - "Cannot reload module `%s` from disk due to compatibility issues " - "with the `tabnet` library. The module will be loaded from the " - "cache instead. If you experience issues, consider restarting " - "the application.", - module_str, - ) - module = sys.modules[module_str] - else: - module = sys.modules[module_str] - _reload_modules(project_dir) - - except ModuleNotFoundError as err: - raise error_type( - f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}", - ) from err - - # Recursively load attribute - attribute = module - try: - for attribute_str in attributes_str.split("."): - attribute = getattr(attribute, attribute_str) - except AttributeError as err: - raise error_type( - f"Unable to load attribute {attributes_str} from module {module_str}" - f"{OBJECT_REF_HELP_STR}", - ) from err - - return attribute + with _import_lock: + valid, error_msg = validate(module_attribute_str, check_module=False) + if not valid and error_msg: + raise error_type(error_msg) from None + + module_str, _, attributes_str = module_attribute_str.partition(":") + + try: + # Initialize project path + if project_dir is None: + project_dir = Path.cwd() + project_dir = Path(project_dir).absolute() + + # Unload modules if the project directory has changed + if _current_sys_path and _current_sys_path != str(project_dir): + _unload_modules(Path(_current_sys_path)) + + # Set the system path + _set_sys_path(project_dir) + + # Import the module + if module_str not in sys.modules: + module = importlib.import_module(module_str) + # Hack: `tabnet` does not work with `importlib.reload` + elif "tabnet" in sys.modules: + log( + WARN, + "Cannot reload module `%s` from disk due to compatibility issues " + "with the `tabnet` library. The module will be loaded from the " + "cache instead. If you experience issues, consider restarting " + "the application.", + module_str, + ) + module = sys.modules[module_str] + else: + module = sys.modules[module_str] + _reload_modules(project_dir) + + except ModuleNotFoundError as err: + raise error_type( + f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}", + ) from err + + # Recursively load attribute + attribute = module + try: + for attribute_str in attributes_str.split("."): + attribute = getattr(attribute, attribute_str) + except AttributeError as err: + raise error_type( + f"Unable to load attribute {attributes_str} from module {module_str}" + f"{OBJECT_REF_HELP_STR}", + ) from err + + return attribute def _unload_modules(project_dir: Path) -> None: