diff --git a/tests/test_extensions.py b/tests/test_extensions.py new file mode 100644 index 00000000..7f76f681 --- /dev/null +++ b/tests/test_extensions.py @@ -0,0 +1,174 @@ +from typing import Optional + +import pytest + + +@pytest.fixture +def temp_extension_registry(): + from wsimod.extensions import extensions_registry + + bkp = extensions_registry.copy() + extensions_registry.clear() + yield + extensions_registry.clear() + extensions_registry.update(bkp) + + +def test_register_node_patch(temp_extension_registry): + from wsimod.extensions import extensions_registry, register_node_patch + + # Define a dummy function to patch a node method + @register_node_patch("node_name", "method_name") + def dummy_patch(): + print("Patched method") + + # Check if the patch is registered correctly + assert extensions_registry[("node_name", "method_name", None, False)] == dummy_patch + + # Another function with other arguments + @register_node_patch("node_name", "method_name", item="default", is_attr=True) + def another_dummy_patch(): + print("Another patched method") + + # Check if this other patch is registered correctly + assert ( + extensions_registry[("node_name", "method_name", "default", True)] + == another_dummy_patch + ) + + +def test_apply_patches(temp_extension_registry): + from wsimod.arcs.arcs import Arc + from wsimod.extensions import ( + apply_patches, + extensions_registry, + register_node_patch, + ) + from wsimod.nodes import Node + from wsimod.orchestration.model import Model + + # Create a dummy model + node = Node("dummy_node") + node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) + model = Model() + model.nodes[node.name] = node + + # 1. Patch a method + @register_node_patch("dummy_node", "apply_overrides") + def dummy_patch(): + pass + + # 2. Patch an attribute + @register_node_patch("dummy_node", "t", is_attr=True) + def another_dummy_patch(node): + return f"A pathced attribute for {node.name}" + + # 3. Patch a method with an item + @register_node_patch("dummy_node", "pull_set_handler", item="default") + def yet_another_dummy_patch(): + pass + + # 4. Path a method of an attribute + @register_node_patch("dummy_node", "dummy_arc.arc_mass_balance") + def arc_dummy_patch(): + pass + + # Check if all patches are registered + assert len(extensions_registry) == 4 + + # Apply the patches + apply_patches(model) + + # Verify that the patches are applied correctly + assert ( + model.nodes[node.name].apply_overrides.__qualname__ == dummy_patch.__qualname__ + ) + assert ( + model.nodes[node.name]._patched_apply_overrides.__qualname__ + == "Node.apply_overrides" + ) + assert model.nodes[node.name].t == another_dummy_patch(node) + assert model.nodes[node.name]._patched_t == None + assert ( + model.nodes[node.name].pull_set_handler["default"].__qualname__ + == yet_another_dummy_patch.__qualname__ + ) + assert ( + model.nodes[node.name].dummy_arc.arc_mass_balance.__qualname__ + == arc_dummy_patch.__qualname__ + ) + assert ( + model.nodes[node.name].dummy_arc._patched_arc_mass_balance.__qualname__ + == "Arc.arc_mass_balance" + ) + + +def assert_dict_almost_equal(d1: dict, d2: dict, tol: Optional[float] = None): + """Check if two dictionaries are almost equal. + + Args: + d1 (dict): The first dictionary. + d2 (dict): The second dictionary. + tol (float | None, optional): Relative tolerance. Defaults to 1e-6, + `pytest.approx` default. + """ + for key in d1.keys(): + assert d1[key] == pytest.approx(d2[key], rel=tol) + + +def test_path_method_with_reuse(temp_extension_registry): + from wsimod.arcs.arcs import Arc + from wsimod.extensions import apply_patches, register_node_patch + from wsimod.nodes.storage import Reservoir + from wsimod.orchestration.model import Model + + # Create a dummy model + node = Reservoir(name="dummy_node", initial_storage=10, capacity=10) + node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) + + vq = node.pull_distributed({"volume": 5}) + assert_dict_almost_equal(vq, node.v_change_vqip(node.empty_vqip(), 5)) + + model = Model() + model.nodes[node.name] = node + + @register_node_patch("dummy_node", "pull_distributed") + def new_pull_distributed(self, vqip, of_type=None, tag="default"): + return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag) + + # Apply the patches + apply_patches(model) + + # Check appropriate result + assert node.tank.storage["volume"] == 5 + vq = model.nodes[node.name].pull_distributed({"volume": 5}) + assert_dict_almost_equal(vq, node.empty_vqip()) + assert node.tank.storage["volume"] == 5 + + +def test_handler_extensions(temp_extension_registry): + from wsimod.arcs.arcs import Arc + from wsimod.extensions import apply_patches, register_node_patch + from wsimod.nodes import Node + from wsimod.orchestration.model import Model + + # Create a dummy model + node = Node("dummy_node") + node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) + model = Model() + model.nodes[node.name] = node + + # 1. Patch a handler + @register_node_patch("dummy_node", "pull_check_handler", item="default") + def dummy_patch(self, *args, **kwargs): + return "dummy_patch" + + # 2. Patch a handler with access to self + @register_node_patch("dummy_node", "pull_set_handler", item="default") + def dummy_patch(self, vqip, *args, **kwargs): + return f"{self.name} - {vqip['volume']}" + + apply_patches(model) + + assert node.pull_check() == "dummy_patch" + assert node.pull_set({"volume": 1}) == "dummy_node - 1" diff --git a/wsimod/extensions.py b/wsimod/extensions.py new file mode 100644 index 00000000..6d4a1be4 --- /dev/null +++ b/wsimod/extensions.py @@ -0,0 +1,117 @@ +"""This module contains the utilities to extend WSMOD with new features. + +The `register_node_patch` decorator is used to register a function that will be used +instead of a method or attribute of a node. The `apply_patches` function applies all +registered patches to a model. + +Example of patching a method: + +`empty_distributed` will be called instead of `pull_distributed` of "my_node": + + >>> from wsimod.extensions import register_node_patch, apply_patches + >>> @register_node_patch("my_node", "pull_distributed") + >>> def empty_distributed(self, vqip): + >>> return {} + +Attributes, methods of the node, and sub-attributes can be patched. Also, an item of a +list or a dictionary can be patched if the item argument is provided. + +Example of patching an attribute: + +`10` will be assigned to `t`: + + >>> @register_node_patch("my_node", "t", is_attr=True) + >>> def patch_t(node): + >>> return 10 + +Example of patching an attribute item: + +`patch_default_pull_set_handler` will be assigned to +`pull_set_handler["default"]`: + + >>> @register_node_patch("my_node", "pull_set_handler", item="default") + >>> def patch_default_pull_set_handler(self, vqip): + >>> return {} + +If patching a method of an attribute, the `is_attr` argument should be set to `True` and +the target should include the attribute name and the method name, all separated by +periods, eg. `attribute_name.method_name`. + +It should be noted that the patched function should have the same signature as the +original method or attribute, and the return type should be the same as well, otherwise +there will be a runtime error. In particular, the first argument of the patched function +should be the node object itself, which will typically be named `self`. + +The overridden method or attribute can be accessed within the patched function using the +`_patched_{method_name}` attribute of the object, eg. `self._patched_pull_distributed`. +The exception to this is when patching an item, in which case the original item is no +available to be used within the overriding function. + +Finally, the `apply_patches` is called within the `Model.load` method and will apply all +patches in the order they were registered. This means that users need to be careful with +the order of the patches in their extensions files, as they may have interdependencies. + +TODO: Update documentation on extensions files. +""" +from typing import Callable, Hashable + +from .orchestration.model import Model + +extensions_registry: dict[tuple[str, Hashable, bool], Callable] = {} + + +def register_node_patch( + node_name: str, target: str, item: Hashable = None, is_attr: bool = False +) -> Callable: + """Register a function to patch a node method or any of its attributes. + + Args: + node_name (str): The name of the node to patch. + target (str): The target of the object to patch in the form of a string with the + attribute, sub-attribute, etc. and finally method (or attribute) to replace, + sepparated with period, eg. `make_discharge` or + `sewer_tank.pull_storage_exact`. + item (Hashable): Typically a string or an integer indicating the item to replace + in the selected attribue, which should be a list or a dictionary. + is_attr (bool): If True, the decorated function will be called when applying + the patch and the result assigned to the target, instead of assigning the + function itself. In this case, the only argument passed to the function is + the node object. + """ + target_id = (node_name, target, item, is_attr) + if target_id in extensions_registry: + raise ValueError(f"Patch for {target} already registered.") + + def decorator(func): + extensions_registry[target_id] = func + return func + + return decorator + + +def apply_patches(model: Model) -> None: + """Apply all registered patches to the model. + + TODO: Validate signature of the patched methods and type of patched attributes. + + Args: + model (Model): The model to apply the patches to. + """ + for (node_name, target, item, is_attr), func in extensions_registry.items(): + starget = target.split(".") + method = starget.pop() + + # Get the member to patch + node = obj = model.nodes[node_name] + for attr in starget: + obj = getattr(obj, attr) + + # Apply the patch + if item is not None: + obj = getattr(obj, method) + obj[item] = func(node) if is_attr else func.__get__(node, node.__class__) + else: + setattr(obj, f"_patched_{method}", getattr(obj, method)) + setattr( + obj, method, func(node) if is_attr else func.__get__(obj, obj.__class__) + ) diff --git a/wsimod/orchestration/model.py b/wsimod/orchestration/model.py index 8e890175..50507249 100644 --- a/wsimod/orchestration/model.py +++ b/wsimod/orchestration/model.py @@ -157,6 +157,8 @@ def load(self, address, config_name="config.yml", overrides={}): config_name: overrides: """ + from ..extensions import apply_patches + with open(os.path.join(address, config_name), "r") as file: data = yaml.safe_load(file) @@ -191,6 +193,8 @@ def load(self, address, config_name="config.yml", overrides={}): if "dates" in data.keys(): self.dates = [to_datetime(x) for x in data["dates"]] + apply_patches(self) + def save(self, address, config_name="config.yml", compress=False): """Save the model object to a yaml file and input data to csv.gz format in the directory specified.