From 879e8c74a50a1be29dcb4b7bed6cf23258cd8edb Mon Sep 17 00:00:00 2001
From: Javier <jafermarq@users.noreply.github.com>
Date: Thu, 12 Dec 2024 14:58:24 +0000
Subject: [PATCH] fix(framework) Make `load_app` thread safe (#4687)

---
 src/py/flwr/common/object_ref.py | 111 ++++++++++++++++---------------
 1 file changed, 57 insertions(+), 54 deletions(-)

diff --git a/src/py/flwr/common/object_ref.py b/src/py/flwr/common/object_ref.py
index 91414ef210f8..bf34bf5f639c 100644
--- a/src/py/flwr/common/object_ref.py
+++ b/src/py/flwr/common/object_ref.py
@@ -21,6 +21,7 @@
 from importlib.util import find_spec
 from logging import WARN
 from pathlib import Path
+from threading import Lock
 from typing import Any, Optional, Union
 
 from .logger import log
@@ -34,6 +35,7 @@
 
 
 _current_sys_path: Optional[str] = None
+_import_lock = Lock()
 
 
 def validate(
@@ -146,60 +148,61 @@ def load_app(  # pylint: disable= too-many-branches
     - This function will modify `sys.path` by inserting the provided `project_dir`
       and removing the previously inserted `project_dir`.
     """
-    valid, error_msg = validate(module_attribute_str, check_module=False)
-    if not valid and error_msg:
-        raise error_type(error_msg) from None
-
-    module_str, _, attributes_str = module_attribute_str.partition(":")
-
-    try:
-        # Initialize project path
-        if project_dir is None:
-            project_dir = Path.cwd()
-        project_dir = Path(project_dir).absolute()
-
-        # Unload modules if the project directory has changed
-        if _current_sys_path and _current_sys_path != str(project_dir):
-            _unload_modules(Path(_current_sys_path))
-
-        # Set the system path
-        _set_sys_path(project_dir)
-
-        # Import the module
-        if module_str not in sys.modules:
-            module = importlib.import_module(module_str)
-        # Hack: `tabnet` does not work with `importlib.reload`
-        elif "tabnet" in sys.modules:
-            log(
-                WARN,
-                "Cannot reload module `%s` from disk due to compatibility issues "
-                "with the `tabnet` library. The module will be loaded from the "
-                "cache instead. If you experience issues, consider restarting "
-                "the application.",
-                module_str,
-            )
-            module = sys.modules[module_str]
-        else:
-            module = sys.modules[module_str]
-            _reload_modules(project_dir)
-
-    except ModuleNotFoundError as err:
-        raise error_type(
-            f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
-        ) from err
-
-    # Recursively load attribute
-    attribute = module
-    try:
-        for attribute_str in attributes_str.split("."):
-            attribute = getattr(attribute, attribute_str)
-    except AttributeError as err:
-        raise error_type(
-            f"Unable to load attribute {attributes_str} from module {module_str}"
-            f"{OBJECT_REF_HELP_STR}",
-        ) from err
-
-    return attribute
+    with _import_lock:
+        valid, error_msg = validate(module_attribute_str, check_module=False)
+        if not valid and error_msg:
+            raise error_type(error_msg) from None
+
+        module_str, _, attributes_str = module_attribute_str.partition(":")
+
+        try:
+            # Initialize project path
+            if project_dir is None:
+                project_dir = Path.cwd()
+            project_dir = Path(project_dir).absolute()
+
+            # Unload modules if the project directory has changed
+            if _current_sys_path and _current_sys_path != str(project_dir):
+                _unload_modules(Path(_current_sys_path))
+
+            # Set the system path
+            _set_sys_path(project_dir)
+
+            # Import the module
+            if module_str not in sys.modules:
+                module = importlib.import_module(module_str)
+            # Hack: `tabnet` does not work with `importlib.reload`
+            elif "tabnet" in sys.modules:
+                log(
+                    WARN,
+                    "Cannot reload module `%s` from disk due to compatibility issues "
+                    "with the `tabnet` library. The module will be loaded from the "
+                    "cache instead. If you experience issues, consider restarting "
+                    "the application.",
+                    module_str,
+                )
+                module = sys.modules[module_str]
+            else:
+                module = sys.modules[module_str]
+                _reload_modules(project_dir)
+
+        except ModuleNotFoundError as err:
+            raise error_type(
+                f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
+            ) from err
+
+        # Recursively load attribute
+        attribute = module
+        try:
+            for attribute_str in attributes_str.split("."):
+                attribute = getattr(attribute, attribute_str)
+        except AttributeError as err:
+            raise error_type(
+                f"Unable to load attribute {attributes_str} from module {module_str}"
+                f"{OBJECT_REF_HELP_STR}",
+            ) from err
+
+        return attribute
 
 
 def _unload_modules(project_dir: Path) -> None: