diff --git a/tests/test_extensions.py b/tests/test_extensions.py new file mode 100644 index 00000000..6d698100 --- /dev/null +++ b/tests/test_extensions.py @@ -0,0 +1,84 @@ +import pytest + + +@pytest.fixture +def temp_extension_registry(): + from wsimod.extensions import extensions_registry + + bkp = extensions_registry.copy() + extensions_registry.clear() + yield + extensions_registry.clear() + extensions_registry.update(bkp) + + +def test_register_node_patch(temp_extension_registry): + from wsimod.extensions import extensions_registry, register_node_patch + + # Define a dummy function to patch a node method + @register_node_patch("node_name.method_name") + def dummy_patch(): + print("Patched method") + + # Check if the patch is registered correctly + assert extensions_registry[("node_name.method_name", None, False)] == dummy_patch + + # Another function with other arguments + @register_node_patch("node_name.method_name", item="default", is_attr=True) + def another_dummy_patch(): + print("Another patched method") + + # Check if this other patch is registered correctly + assert ( + extensions_registry[("node_name.method_name", "default", True)] + == another_dummy_patch + ) + + +def test_apply_patches(temp_extension_registry): + from wsimod.arcs.arcs import Arc + from wsimod.extensions import ( + apply_patches, + extensions_registry, + register_node_patch, + ) + from wsimod.nodes import Node + from wsimod.orchestration.model import Model + + # Create a dummy model + node = Node("dummy_node") + node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) + model = Model() + model.nodes[node.name] = node + + # 1. Patch a method + @register_node_patch("dummy_node.apply_overrides") + def dummy_patch(): + pass + + # 2. Patch an attribute + @register_node_patch("dummy_node.t", is_attr=True) + def another_dummy_patch(node): + return f"A pathced attribute for {node.name}" + + # 3. Patch a method with an item + @register_node_patch("dummy_node.pull_set_handler", item="default") + def yet_another_dummy_patch(): + pass + + # 4. Path a method of an attribute + @register_node_patch("dummy_node.dummy_arc.arc_mass_balance") + def arc_dummy_patch(): + pass + + # Check if all patches are registered + assert len(extensions_registry) == 4 + + # Apply the patches + apply_patches(model) + + # Verify that the patches are applied correctly + assert model.nodes[node.name].apply_overrides == dummy_patch + assert model.nodes[node.name].t == another_dummy_patch(node) + assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch + assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch diff --git a/wsimod/extensions.py b/wsimod/extensions.py new file mode 100644 index 00000000..bf56ea12 --- /dev/null +++ b/wsimod/extensions.py @@ -0,0 +1,106 @@ +"""This module contains the utilities to extend WSMOD with new features. + +The `register_node_patch` decorator is used to register a function that will be used +instead of a method or attribute of a node. The `apply_patches` function applies all +registered patches to a model. + +Example of patching a method: + +`empty_distributed` will be called instead of `my_node.pull_distributed`: + + >>> from wsimod.extensions import register_node_patch, apply_patches + >>> @register_node_patch("my_node.pull_distributed") + >>> def empty_distributed(self, vqip): + >>> return {} + +Attributes, methods of the node, and sub-attributes can be patched. Also, an item of a +list or a dictionary can be patched if the item argument is provided. + +Example of patching an attribute: + +`10` will be assigned to `my_node.t`: + + >>> @register_node_patch("my_node.t", is_attr=True) + >>> def patch_t(node): + >>> return 10 + +Example of patching an attribute item: + +`patch_default_pull_set_handler` will be assigned to +`my_node.pull_set_handler["default"]`: + + >>> @register_node_patch("my_node.pull_set_handler", item="default") + >>> def patch_default_pull_set_handler(self, vqip): + >>> return {} + +It should be noted that the patched function should have the same signature as the +original method or attribute, and the return type should be the same as well, otherwise +there will be a runtime error. +""" + +from typing import Callable, Hashable + +from .orchestration.model import Model + +extensions_registry: dict[tuple[str, Hashable, bool], Callable] = {} + + +def register_node_patch( + target: str, item: Hashable = None, is_attr: bool = False +) -> Callable: + """Register a function to patch a node method or any of its attributes. + + Args: + target (str): The target of the object to patch as a string with the node name + attribute, sub-attribute, etc. and finally method (or attribue) to replace, + sepparated with period, eg. `node_name.make_discharge` or + `node_name.sewer_tank.pull_storage_exact`. + item (Hashable): Typically a string or an integer indicating the item to replace + in the selected attribue, which should be a list or a dictionary. + is_attr (bool): If True, the decorated function will be called when applying + the patch and the result assigned to the target, instead of assigning the + function itself. In this case, the only argument passed to the function is + the node object. + """ + target_id = (target, item, is_attr) + if target_id in extensions_registry: + raise ValueError(f"Patch for {target} already registered.") + + def decorator(func): + extensions_registry[(target, item, is_attr)] = func + return func + + return decorator + + +def apply_patches(model: Model) -> None: + """Apply all registered patches to the model. + + TODO: Validate signature of the patched methods and type of patched attributes. + + Args: + model (Model): The model to apply the patches to. + """ + for (target, item, is_attr), func in extensions_registry.items(): + # Process the target string + starget = target.split(".") + if len(starget) < 2: + raise ValueError( + f"Invalid target {target}. At least two elements are required separated" + "by a period, indicating the node name and the method/attribute to " + "patch." + ) + node_name = starget.pop(0) + method = starget.pop() + + # Get the member to patch + node = obj = model.nodes[node_name] + for attr in starget: + obj = getattr(obj, attr) + + # Apply the patch + if item is not None: + obj = getattr(obj, method) + obj[item] = func(node) if is_attr else func + else: + setattr(obj, method, func(node) if is_attr else func)