Skip to content

Commit

Permalink
fix(framework) Make load_app thread safe (#4687)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Dec 12, 2024
1 parent b3cf211 commit 879e8c7
Showing 1 changed file with 57 additions and 54 deletions.
111 changes: 57 additions & 54 deletions src/py/flwr/common/object_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from importlib.util import find_spec
from logging import WARN
from pathlib import Path
from threading import Lock
from typing import Any, Optional, Union

from .logger import log
Expand All @@ -34,6 +35,7 @@


_current_sys_path: Optional[str] = None
_import_lock = Lock()


def validate(
Expand Down Expand Up @@ -146,60 +148,61 @@ def load_app( # pylint: disable= too-many-branches
- This function will modify `sys.path` by inserting the provided `project_dir`
and removing the previously inserted `project_dir`.
"""
valid, error_msg = validate(module_attribute_str, check_module=False)
if not valid and error_msg:
raise error_type(error_msg) from None

module_str, _, attributes_str = module_attribute_str.partition(":")

try:
# Initialize project path
if project_dir is None:
project_dir = Path.cwd()
project_dir = Path(project_dir).absolute()

# Unload modules if the project directory has changed
if _current_sys_path and _current_sys_path != str(project_dir):
_unload_modules(Path(_current_sys_path))

# Set the system path
_set_sys_path(project_dir)

# Import the module
if module_str not in sys.modules:
module = importlib.import_module(module_str)
# Hack: `tabnet` does not work with `importlib.reload`
elif "tabnet" in sys.modules:
log(
WARN,
"Cannot reload module `%s` from disk due to compatibility issues "
"with the `tabnet` library. The module will be loaded from the "
"cache instead. If you experience issues, consider restarting "
"the application.",
module_str,
)
module = sys.modules[module_str]
else:
module = sys.modules[module_str]
_reload_modules(project_dir)

except ModuleNotFoundError as err:
raise error_type(
f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
) from err

# Recursively load attribute
attribute = module
try:
for attribute_str in attributes_str.split("."):
attribute = getattr(attribute, attribute_str)
except AttributeError as err:
raise error_type(
f"Unable to load attribute {attributes_str} from module {module_str}"
f"{OBJECT_REF_HELP_STR}",
) from err

return attribute
with _import_lock:
valid, error_msg = validate(module_attribute_str, check_module=False)
if not valid and error_msg:
raise error_type(error_msg) from None

module_str, _, attributes_str = module_attribute_str.partition(":")

try:
# Initialize project path
if project_dir is None:
project_dir = Path.cwd()
project_dir = Path(project_dir).absolute()

# Unload modules if the project directory has changed
if _current_sys_path and _current_sys_path != str(project_dir):
_unload_modules(Path(_current_sys_path))

# Set the system path
_set_sys_path(project_dir)

# Import the module
if module_str not in sys.modules:
module = importlib.import_module(module_str)
# Hack: `tabnet` does not work with `importlib.reload`
elif "tabnet" in sys.modules:
log(
WARN,
"Cannot reload module `%s` from disk due to compatibility issues "
"with the `tabnet` library. The module will be loaded from the "
"cache instead. If you experience issues, consider restarting "
"the application.",
module_str,
)
module = sys.modules[module_str]
else:
module = sys.modules[module_str]
_reload_modules(project_dir)

except ModuleNotFoundError as err:
raise error_type(
f"Unable to load module {module_str}{OBJECT_REF_HELP_STR}",
) from err

# Recursively load attribute
attribute = module
try:
for attribute_str in attributes_str.split("."):
attribute = getattr(attribute, attribute_str)
except AttributeError as err:
raise error_type(
f"Unable to load attribute {attributes_str} from module {module_str}"
f"{OBJECT_REF_HELP_STR}",
) from err

return attribute


def _unload_modules(project_dir: Path) -> None:
Expand Down

0 comments on commit 879e8c7

Please sign in to comment.