From e3c5776c9392b453ba5d970c28595172ebe8556f Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Tue, 10 Sep 2024 12:37:18 +0100 Subject: [PATCH] :recycle: Add patched method when decorating. --- tests/test_extensions.py | 44 ++++++++++++++++++++++++++++++++++++++++ wsimod/extensions.py | 7 +++++-- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 6d69810..a92ec4e 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -82,3 +82,47 @@ def arc_dummy_patch(): assert model.nodes[node.name].t == another_dummy_patch(node) assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch + + +def assert_dict_almost_equal(d1: dict, d2: dict, tol: float | None = None): + """Check if two dictionaries are almost equal. + + Args: + d1 (dict): The first dictionary. + d2 (dict): The second dictionary. + tol (float | None, optional): Relative tolerance. Defaults to 1e-6, + `pytest.approx` default. + """ + for key in d1.keys(): + assert d1[key] == pytest.approx(d2[key], rel=tol) + + +def test_path_method_with_reuse(temp_extension_registry): + from functools import partial + from wsimod.arcs.arcs import Arc + from wsimod.extensions import apply_patches, register_node_patch + from wsimod.nodes.storage import Reservoir + from wsimod.orchestration.model import Model + + # Create a dummy model + node = Reservoir(name="dummy_node", initial_storage=10, capacity=10) + node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) + + vq = node.pull_distributed({"volume": 5}) + assert_dict_almost_equal(vq, node.v_change_vqip(node.empty_vqip(), 5)) + + model = Model() + model.nodes[node.name] = node + + @register_node_patch("dummy_node.pull_distributed") + def new_pull_distributed(self, vqip, of_type=None, tag="default"): + return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag) + + # Apply the patches + apply_patches(model) + + # Check appropriate result + assert node.tank.storage["volume"] == 5 + vq = model.nodes[node.name].pull_distributed({"volume": 5}) + assert_dict_almost_equal(vq, node.empty_vqip()) + assert node.tank.storage["volume"] == 5 diff --git a/wsimod/extensions.py b/wsimod/extensions.py index 940a4f5..721abe6 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -47,7 +47,6 @@ TODO: Update documentation on extensions files. """ - from typing import Callable, Hashable from .orchestration.model import Model @@ -113,4 +112,8 @@ def apply_patches(model: Model) -> None: obj = getattr(obj, method) obj[item] = func(node) if is_attr else func else: - setattr(obj, method, func(node) if is_attr else func) + if is_attr: + setattr(obj, method, func(node)) + else: + setattr(obj, f"_patched_{method}", getattr(obj, method)) + setattr(obj, method, func.__get__(obj, obj.__class__))