From 809a7121f7641c12983766507600e250b28d170e Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Thu, 5 Sep 2024 07:01:16 +0100 Subject: [PATCH 01/12] Add node patch --- tests/test_extensions.py | 84 +++++++++++++++++++++++++++++++ wsimod/extensions.py | 106 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 tests/test_extensions.py create mode 100644 wsimod/extensions.py diff --git a/tests/test_extensions.py b/tests/test_extensions.py new file mode 100644 index 00000000..6d698100 --- /dev/null +++ b/tests/test_extensions.py @@ -0,0 +1,84 @@ +import pytest + + +@pytest.fixture +def temp_extension_registry(): + from wsimod.extensions import extensions_registry + + bkp = extensions_registry.copy() + extensions_registry.clear() + yield + extensions_registry.clear() + extensions_registry.update(bkp) + + +def test_register_node_patch(temp_extension_registry): + from wsimod.extensions import extensions_registry, register_node_patch + + # Define a dummy function to patch a node method + @register_node_patch("node_name.method_name") + def dummy_patch(): + print("Patched method") + + # Check if the patch is registered correctly + assert extensions_registry[("node_name.method_name", None, False)] == dummy_patch + + # Another function with other arguments + @register_node_patch("node_name.method_name", item="default", is_attr=True) + def another_dummy_patch(): + print("Another patched method") + + # Check if this other patch is registered correctly + assert ( + extensions_registry[("node_name.method_name", "default", True)] + == another_dummy_patch + ) + + +def test_apply_patches(temp_extension_registry): + from wsimod.arcs.arcs import Arc + from wsimod.extensions import ( + apply_patches, + extensions_registry, + register_node_patch, + ) + from wsimod.nodes import Node + from wsimod.orchestration.model import Model + + # Create a dummy model + node = Node("dummy_node") + node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) + model = Model() + model.nodes[node.name] = node + + # 1. Patch a method + @register_node_patch("dummy_node.apply_overrides") + def dummy_patch(): + pass + + # 2. Patch an attribute + @register_node_patch("dummy_node.t", is_attr=True) + def another_dummy_patch(node): + return f"A pathced attribute for {node.name}" + + # 3. Patch a method with an item + @register_node_patch("dummy_node.pull_set_handler", item="default") + def yet_another_dummy_patch(): + pass + + # 4. Path a method of an attribute + @register_node_patch("dummy_node.dummy_arc.arc_mass_balance") + def arc_dummy_patch(): + pass + + # Check if all patches are registered + assert len(extensions_registry) == 4 + + # Apply the patches + apply_patches(model) + + # Verify that the patches are applied correctly + assert model.nodes[node.name].apply_overrides == dummy_patch + assert model.nodes[node.name].t == another_dummy_patch(node) + assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch + assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch diff --git a/wsimod/extensions.py b/wsimod/extensions.py new file mode 100644 index 00000000..bf56ea12 --- /dev/null +++ b/wsimod/extensions.py @@ -0,0 +1,106 @@ +"""This module contains the utilities to extend WSMOD with new features. + +The `register_node_patch` decorator is used to register a function that will be used +instead of a method or attribute of a node. The `apply_patches` function applies all +registered patches to a model. + +Example of patching a method: + +`empty_distributed` will be called instead of `my_node.pull_distributed`: + + >>> from wsimod.extensions import register_node_patch, apply_patches + >>> @register_node_patch("my_node.pull_distributed") + >>> def empty_distributed(self, vqip): + >>> return {} + +Attributes, methods of the node, and sub-attributes can be patched. Also, an item of a +list or a dictionary can be patched if the item argument is provided. + +Example of patching an attribute: + +`10` will be assigned to `my_node.t`: + + >>> @register_node_patch("my_node.t", is_attr=True) + >>> def patch_t(node): + >>> return 10 + +Example of patching an attribute item: + +`patch_default_pull_set_handler` will be assigned to +`my_node.pull_set_handler["default"]`: + + >>> @register_node_patch("my_node.pull_set_handler", item="default") + >>> def patch_default_pull_set_handler(self, vqip): + >>> return {} + +It should be noted that the patched function should have the same signature as the +original method or attribute, and the return type should be the same as well, otherwise +there will be a runtime error. +""" + +from typing import Callable, Hashable + +from .orchestration.model import Model + +extensions_registry: dict[tuple[str, Hashable, bool], Callable] = {} + + +def register_node_patch( + target: str, item: Hashable = None, is_attr: bool = False +) -> Callable: + """Register a function to patch a node method or any of its attributes. + + Args: + target (str): The target of the object to patch as a string with the node name + attribute, sub-attribute, etc. and finally method (or attribue) to replace, + sepparated with period, eg. `node_name.make_discharge` or + `node_name.sewer_tank.pull_storage_exact`. + item (Hashable): Typically a string or an integer indicating the item to replace + in the selected attribue, which should be a list or a dictionary. + is_attr (bool): If True, the decorated function will be called when applying + the patch and the result assigned to the target, instead of assigning the + function itself. In this case, the only argument passed to the function is + the node object. + """ + target_id = (target, item, is_attr) + if target_id in extensions_registry: + raise ValueError(f"Patch for {target} already registered.") + + def decorator(func): + extensions_registry[(target, item, is_attr)] = func + return func + + return decorator + + +def apply_patches(model: Model) -> None: + """Apply all registered patches to the model. + + TODO: Validate signature of the patched methods and type of patched attributes. + + Args: + model (Model): The model to apply the patches to. + """ + for (target, item, is_attr), func in extensions_registry.items(): + # Process the target string + starget = target.split(".") + if len(starget) < 2: + raise ValueError( + f"Invalid target {target}. At least two elements are required separated" + "by a period, indicating the node name and the method/attribute to " + "patch." + ) + node_name = starget.pop(0) + method = starget.pop() + + # Get the member to patch + node = obj = model.nodes[node_name] + for attr in starget: + obj = getattr(obj, attr) + + # Apply the patch + if item is not None: + obj = getattr(obj, method) + obj[item] = func(node) if is_attr else func + else: + setattr(obj, method, func(node) if is_attr else func) From cf5a5aef4f8d93d5f41101a8c5f0c63433e5745e Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Thu, 5 Sep 2024 07:04:56 +0100 Subject: [PATCH 02/12] Update docs --- wsimod/extensions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wsimod/extensions.py b/wsimod/extensions.py index bf56ea12..eb325b30 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -33,6 +33,10 @@ >>> def patch_default_pull_set_handler(self, vqip): >>> return {} +If patching a method of an attribute, the `is_attr` argument should be set to `True` and +the target should include the node name and the attribute name and the method name, all +separated by periods, eg. `node_name.attribute_name.method_name`. + It should be noted that the patched function should have the same signature as the original method or attribute, and the return type should be the same as well, otherwise there will be a runtime error. From 165d42a675c7ebb789b6435aaeab8ead8dee7d41 Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Thu, 5 Sep 2024 07:53:04 +0100 Subject: [PATCH 03/12] Apply patches after loading config file. --- wsimod/orchestration/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wsimod/orchestration/model.py b/wsimod/orchestration/model.py index 8e890175..50507249 100644 --- a/wsimod/orchestration/model.py +++ b/wsimod/orchestration/model.py @@ -157,6 +157,8 @@ def load(self, address, config_name="config.yml", overrides={}): config_name: overrides: """ + from ..extensions import apply_patches + with open(os.path.join(address, config_name), "r") as file: data = yaml.safe_load(file) @@ -191,6 +193,8 @@ def load(self, address, config_name="config.yml", overrides={}): if "dates" in data.keys(): self.dates = [to_datetime(x) for x in data["dates"]] + apply_patches(self) + def save(self, address, config_name="config.yml", compress=False): """Save the model object to a yaml file and input data to csv.gz format in the directory specified. From d457baf419628f707fed28d9649091b1f904e01c Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Thu, 5 Sep 2024 07:57:11 +0100 Subject: [PATCH 04/12] Update docstring in extensions.py --- wsimod/extensions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/wsimod/extensions.py b/wsimod/extensions.py index eb325b30..940a4f52 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -40,6 +40,12 @@ It should be noted that the patched function should have the same signature as the original method or attribute, and the return type should be the same as well, otherwise there will be a runtime error. + +Finally, the `apply_patches` is called within the `Model.load` method and will apply all +patches in the order they were registered. This means that users need to be careful with +the order of the patches in their extensions files, as they may have interdependencies. + +TODO: Update documentation on extensions files. """ from typing import Callable, Hashable From e3c5776c9392b453ba5d970c28595172ebe8556f Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Tue, 10 Sep 2024 12:37:18 +0100 Subject: [PATCH 05/12] :recycle: Add patched method when decorating. --- tests/test_extensions.py | 44 ++++++++++++++++++++++++++++++++++++++++ wsimod/extensions.py | 7 +++++-- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 6d698100..a92ec4ec 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -82,3 +82,47 @@ def arc_dummy_patch(): assert model.nodes[node.name].t == another_dummy_patch(node) assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch + + +def assert_dict_almost_equal(d1: dict, d2: dict, tol: float | None = None): + """Check if two dictionaries are almost equal. + + Args: + d1 (dict): The first dictionary. + d2 (dict): The second dictionary. + tol (float | None, optional): Relative tolerance. Defaults to 1e-6, + `pytest.approx` default. + """ + for key in d1.keys(): + assert d1[key] == pytest.approx(d2[key], rel=tol) + + +def test_path_method_with_reuse(temp_extension_registry): + from functools import partial + from wsimod.arcs.arcs import Arc + from wsimod.extensions import apply_patches, register_node_patch + from wsimod.nodes.storage import Reservoir + from wsimod.orchestration.model import Model + + # Create a dummy model + node = Reservoir(name="dummy_node", initial_storage=10, capacity=10) + node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) + + vq = node.pull_distributed({"volume": 5}) + assert_dict_almost_equal(vq, node.v_change_vqip(node.empty_vqip(), 5)) + + model = Model() + model.nodes[node.name] = node + + @register_node_patch("dummy_node.pull_distributed") + def new_pull_distributed(self, vqip, of_type=None, tag="default"): + return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag) + + # Apply the patches + apply_patches(model) + + # Check appropriate result + assert node.tank.storage["volume"] == 5 + vq = model.nodes[node.name].pull_distributed({"volume": 5}) + assert_dict_almost_equal(vq, node.empty_vqip()) + assert node.tank.storage["volume"] == 5 diff --git a/wsimod/extensions.py b/wsimod/extensions.py index 940a4f52..721abe60 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -47,7 +47,6 @@ TODO: Update documentation on extensions files. """ - from typing import Callable, Hashable from .orchestration.model import Model @@ -113,4 +112,8 @@ def apply_patches(model: Model) -> None: obj = getattr(obj, method) obj[item] = func(node) if is_attr else func else: - setattr(obj, method, func(node) if is_attr else func) + if is_attr: + setattr(obj, method, func(node)) + else: + setattr(obj, f"_patched_{method}", getattr(obj, method)) + setattr(obj, method, func.__get__(obj, obj.__class__)) From 80416c21bff7831be7f15ad2353340902ef3c065 Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Tue, 10 Sep 2024 12:48:59 +0100 Subject: [PATCH 06/12] :white_check_mark: Fix failing tests --- tests/test_extensions.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index a92ec4ec..0741a223 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -78,10 +78,15 @@ def arc_dummy_patch(): apply_patches(model) # Verify that the patches are applied correctly - assert model.nodes[node.name].apply_overrides == dummy_patch + assert ( + model.nodes[node.name].apply_overrides.__qualname__ == dummy_patch.__qualname__ + ) assert model.nodes[node.name].t == another_dummy_patch(node) assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch - assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch + assert ( + model.nodes[node.name].dummy_arc.arc_mass_balance.__qualname__ + == arc_dummy_patch.__qualname__ + ) def assert_dict_almost_equal(d1: dict, d2: dict, tol: float | None = None): From dc5c216a447339d9e66b2be9ade60779a8448a91 Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Tue, 10 Sep 2024 13:00:49 +0100 Subject: [PATCH 07/12] :memo: Updated docstring --- tests/test_extensions.py | 14 +++++++++++++- wsimod/extensions.py | 19 ++++++++++++------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 0741a223..bdd6cd56 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -81,12 +81,24 @@ def arc_dummy_patch(): assert ( model.nodes[node.name].apply_overrides.__qualname__ == dummy_patch.__qualname__ ) + assert ( + model.nodes[node.name]._patched_apply_overrides.__qualname__ + == "Node.apply_overrides" + ) assert model.nodes[node.name].t == another_dummy_patch(node) - assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch + assert model.nodes[node.name]._patched_t == None + assert ( + model.nodes[node.name].pull_set_handler["default"].__qualname__ + == yet_another_dummy_patch.__qualname__ + ) assert ( model.nodes[node.name].dummy_arc.arc_mass_balance.__qualname__ == arc_dummy_patch.__qualname__ ) + assert ( + model.nodes[node.name].dummy_arc._patched_arc_mass_balance.__qualname__ + == "Arc.arc_mass_balance" + ) def assert_dict_almost_equal(d1: dict, d2: dict, tol: float | None = None): diff --git a/wsimod/extensions.py b/wsimod/extensions.py index 721abe60..aef55c6c 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -39,7 +39,13 @@ It should be noted that the patched function should have the same signature as the original method or attribute, and the return type should be the same as well, otherwise -there will be a runtime error. +there will be a runtime error. In particular, the first argument of the patched function +should be the node object itself, which will typically be named `self`. + +The overridden method or attribute can be accessed within the patched function using the +`_patched_{method_name}` attribute of the object, eg. `self._patched_pull_distributed`. +The exception to this is when patching an item, in which case the original item is no +available to be used within the overriding function. Finally, the `apply_patches` is called within the `Model.load` method and will apply all patches in the order they were registered. This means that users need to be careful with @@ -110,10 +116,9 @@ def apply_patches(model: Model) -> None: # Apply the patch if item is not None: obj = getattr(obj, method) - obj[item] = func(node) if is_attr else func + obj[item] = func(node) if is_attr else func.__get__(obj, obj.__class__) else: - if is_attr: - setattr(obj, method, func(node)) - else: - setattr(obj, f"_patched_{method}", getattr(obj, method)) - setattr(obj, method, func.__get__(obj, obj.__class__)) + setattr(obj, f"_patched_{method}", getattr(obj, method)) + setattr( + obj, method, func(node) if is_attr else func.__get__(obj, obj.__class__) + ) From 8082f1f05b4c41d3d6db9f977a52a1c37d630cae Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Tue, 10 Sep 2024 13:09:43 +0100 Subject: [PATCH 08/12] :recycle: Use a separate argument for the node name. --- tests/test_extensions.py | 18 +++++++++--------- wsimod/extensions.py | 41 +++++++++++++++++----------------------- 2 files changed, 26 insertions(+), 33 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index bdd6cd56..c3356d57 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -16,21 +16,21 @@ def test_register_node_patch(temp_extension_registry): from wsimod.extensions import extensions_registry, register_node_patch # Define a dummy function to patch a node method - @register_node_patch("node_name.method_name") + @register_node_patch("node_name", "method_name") def dummy_patch(): print("Patched method") # Check if the patch is registered correctly - assert extensions_registry[("node_name.method_name", None, False)] == dummy_patch + assert extensions_registry[("node_name", "method_name", None, False)] == dummy_patch # Another function with other arguments - @register_node_patch("node_name.method_name", item="default", is_attr=True) + @register_node_patch("node_name", "method_name", item="default", is_attr=True) def another_dummy_patch(): print("Another patched method") # Check if this other patch is registered correctly assert ( - extensions_registry[("node_name.method_name", "default", True)] + extensions_registry[("node_name", "method_name", "default", True)] == another_dummy_patch ) @@ -52,22 +52,22 @@ def test_apply_patches(temp_extension_registry): model.nodes[node.name] = node # 1. Patch a method - @register_node_patch("dummy_node.apply_overrides") + @register_node_patch("dummy_node", "apply_overrides") def dummy_patch(): pass # 2. Patch an attribute - @register_node_patch("dummy_node.t", is_attr=True) + @register_node_patch("dummy_node", "t", is_attr=True) def another_dummy_patch(node): return f"A pathced attribute for {node.name}" # 3. Patch a method with an item - @register_node_patch("dummy_node.pull_set_handler", item="default") + @register_node_patch("dummy_node", "pull_set_handler", item="default") def yet_another_dummy_patch(): pass # 4. Path a method of an attribute - @register_node_patch("dummy_node.dummy_arc.arc_mass_balance") + @register_node_patch("dummy_node", "dummy_arc.arc_mass_balance") def arc_dummy_patch(): pass @@ -131,7 +131,7 @@ def test_path_method_with_reuse(temp_extension_registry): model = Model() model.nodes[node.name] = node - @register_node_patch("dummy_node.pull_distributed") + @register_node_patch("dummy_node", "pull_distributed") def new_pull_distributed(self, vqip, of_type=None, tag="default"): return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag) diff --git a/wsimod/extensions.py b/wsimod/extensions.py index aef55c6c..1fa67387 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -6,10 +6,10 @@ Example of patching a method: -`empty_distributed` will be called instead of `my_node.pull_distributed`: +`empty_distributed` will be called instead of `pull_distributed` of "my_node": >>> from wsimod.extensions import register_node_patch, apply_patches - >>> @register_node_patch("my_node.pull_distributed") + >>> @register_node_patch("my_node", "pull_distributed") >>> def empty_distributed(self, vqip): >>> return {} @@ -18,24 +18,24 @@ Example of patching an attribute: -`10` will be assigned to `my_node.t`: +`10` will be assigned to `t`: - >>> @register_node_patch("my_node.t", is_attr=True) + >>> @register_node_patch("my_node", "t", is_attr=True) >>> def patch_t(node): >>> return 10 Example of patching an attribute item: `patch_default_pull_set_handler` will be assigned to -`my_node.pull_set_handler["default"]`: +`pull_set_handler["default"]`: - >>> @register_node_patch("my_node.pull_set_handler", item="default") + >>> @register_node_patch("my_node", "pull_set_handler", item="default") >>> def patch_default_pull_set_handler(self, vqip): >>> return {} If patching a method of an attribute, the `is_attr` argument should be set to `True` and -the target should include the node name and the attribute name and the method name, all -separated by periods, eg. `node_name.attribute_name.method_name`. +the target should include the attribute name and the method name, all separated by +periods, eg. `attribute_name.method_name`. It should be noted that the patched function should have the same signature as the original method or attribute, and the return type should be the same as well, otherwise @@ -61,15 +61,16 @@ def register_node_patch( - target: str, item: Hashable = None, is_attr: bool = False + node_name: str, target: str, item: Hashable = None, is_attr: bool = False ) -> Callable: """Register a function to patch a node method or any of its attributes. Args: - target (str): The target of the object to patch as a string with the node name - attribute, sub-attribute, etc. and finally method (or attribue) to replace, - sepparated with period, eg. `node_name.make_discharge` or - `node_name.sewer_tank.pull_storage_exact`. + node_name (str): The name of the node to patch. + target (str): The target of the object to patch in the form of a string with the + attribute, sub-attribute, etc. and finally method (or attribute) to replace, + sepparated with period, eg. `make_discharge` or + `sewer_tank.pull_storage_exact`. item (Hashable): Typically a string or an integer indicating the item to replace in the selected attribue, which should be a list or a dictionary. is_attr (bool): If True, the decorated function will be called when applying @@ -77,12 +78,12 @@ def register_node_patch( function itself. In this case, the only argument passed to the function is the node object. """ - target_id = (target, item, is_attr) + target_id = (node_name, target, item, is_attr) if target_id in extensions_registry: raise ValueError(f"Patch for {target} already registered.") def decorator(func): - extensions_registry[(target, item, is_attr)] = func + extensions_registry[target_id] = func return func return decorator @@ -96,16 +97,8 @@ def apply_patches(model: Model) -> None: Args: model (Model): The model to apply the patches to. """ - for (target, item, is_attr), func in extensions_registry.items(): - # Process the target string + for (node_name, target, item, is_attr), func in extensions_registry.items(): starget = target.split(".") - if len(starget) < 2: - raise ValueError( - f"Invalid target {target}. At least two elements are required separated" - "by a period, indicating the node name and the method/attribute to " - "patch." - ) - node_name = starget.pop(0) method = starget.pop() # Get the member to patch From d10bbb1f61dc1030479cc6b518bab4761378835e Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Tue, 10 Sep 2024 13:13:07 +0100 Subject: [PATCH 09/12] :rotating_light: Fix use of annotations incompatible with py3.9 --- tests/test_extensions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index c3356d57..0c598923 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -1,3 +1,5 @@ +from typing import Optional + import pytest @@ -101,7 +103,7 @@ def arc_dummy_patch(): ) -def assert_dict_almost_equal(d1: dict, d2: dict, tol: float | None = None): +def assert_dict_almost_equal(d1: dict, d2: dict, tol: Optional[float] = None): """Check if two dictionaries are almost equal. Args: From 6c7168d2e3dfad9ca4d6d366c9dd45d71a91042b Mon Sep 17 00:00:00 2001 From: Dobson Date: Tue, 10 Sep 2024 16:30:11 +0100 Subject: [PATCH 10/12] Update test_extensions.py Add handler behaviour tests --- tests/test_extensions.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 0c598923..4faa8379 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -145,3 +145,36 @@ def new_pull_distributed(self, vqip, of_type=None, tag="default"): vq = model.nodes[node.name].pull_distributed({"volume": 5}) assert_dict_almost_equal(vq, node.empty_vqip()) assert node.tank.storage["volume"] == 5 + +def test_handler_extensions(temp_extension_registry): + from wsimod.arcs.arcs import Arc + from wsimod.extensions import ( + apply_patches, + extensions_registry, + register_node_patch, + ) + from wsimod.nodes import Node + from wsimod.orchestration.model import Model + + # Create a dummy model + node = Node("dummy_node") + node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) + model = Model() + model.nodes[node.name] = node + + # 1. Patch a handler + @register_node_patch("dummy_node", "pull_check_handler", item="default") + def dummy_patch(*args, **kwargs): + return 'dummy_patch' + + # 2. Patch a handler with access to self + @register_node_patch("dummy_node", "pull_set_handler", item="default", is_attr=True) + def dummy_patch(self): + def f(vqip, *args, **kwargs): + return f"{self.name} - {vqip['volume']}" + return f + + apply_patches(model) + + assert node.pull_check() == 'dummy_patch' + assert node.pull_set({'volume': 1}) == 'dummy_node - 1' \ No newline at end of file From a6636cedc7d854d7533355d27514c5c51997ea31 Mon Sep 17 00:00:00 2001 From: Dobson Date: Tue, 10 Sep 2024 16:36:12 +0100 Subject: [PATCH 11/12] Fix precommit --- tests/test_extensions.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 4faa8379..752de934 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -146,6 +146,7 @@ def new_pull_distributed(self, vqip, of_type=None, tag="default"): assert_dict_almost_equal(vq, node.empty_vqip()) assert node.tank.storage["volume"] == 5 + def test_handler_extensions(temp_extension_registry): from wsimod.arcs.arcs import Arc from wsimod.extensions import ( @@ -165,16 +166,17 @@ def test_handler_extensions(temp_extension_registry): # 1. Patch a handler @register_node_patch("dummy_node", "pull_check_handler", item="default") def dummy_patch(*args, **kwargs): - return 'dummy_patch' - + return "dummy_patch" + # 2. Patch a handler with access to self @register_node_patch("dummy_node", "pull_set_handler", item="default", is_attr=True) def dummy_patch(self): def f(vqip, *args, **kwargs): return f"{self.name} - {vqip['volume']}" + return f - + apply_patches(model) - assert node.pull_check() == 'dummy_patch' - assert node.pull_set({'volume': 1}) == 'dummy_node - 1' \ No newline at end of file + assert node.pull_check() == "dummy_patch" + assert node.pull_set({"volume": 1}) == "dummy_node - 1" From 566976ee6d905424ac5f34f7b2ac874ca551c087 Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Tue, 10 Sep 2024 16:48:01 +0100 Subject: [PATCH 12/12] :bug: Fix node not being passed as first argument when item is present --- tests/test_extensions.py | 18 +++++------------- wsimod/extensions.py | 2 +- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 752de934..7f76f681 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -117,7 +117,6 @@ def assert_dict_almost_equal(d1: dict, d2: dict, tol: Optional[float] = None): def test_path_method_with_reuse(temp_extension_registry): - from functools import partial from wsimod.arcs.arcs import Arc from wsimod.extensions import apply_patches, register_node_patch from wsimod.nodes.storage import Reservoir @@ -149,11 +148,7 @@ def new_pull_distributed(self, vqip, of_type=None, tag="default"): def test_handler_extensions(temp_extension_registry): from wsimod.arcs.arcs import Arc - from wsimod.extensions import ( - apply_patches, - extensions_registry, - register_node_patch, - ) + from wsimod.extensions import apply_patches, register_node_patch from wsimod.nodes import Node from wsimod.orchestration.model import Model @@ -165,16 +160,13 @@ def test_handler_extensions(temp_extension_registry): # 1. Patch a handler @register_node_patch("dummy_node", "pull_check_handler", item="default") - def dummy_patch(*args, **kwargs): + def dummy_patch(self, *args, **kwargs): return "dummy_patch" # 2. Patch a handler with access to self - @register_node_patch("dummy_node", "pull_set_handler", item="default", is_attr=True) - def dummy_patch(self): - def f(vqip, *args, **kwargs): - return f"{self.name} - {vqip['volume']}" - - return f + @register_node_patch("dummy_node", "pull_set_handler", item="default") + def dummy_patch(self, vqip, *args, **kwargs): + return f"{self.name} - {vqip['volume']}" apply_patches(model) diff --git a/wsimod/extensions.py b/wsimod/extensions.py index 1fa67387..6d4a1be4 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -109,7 +109,7 @@ def apply_patches(model: Model) -> None: # Apply the patch if item is not None: obj = getattr(obj, method) - obj[item] = func(node) if is_attr else func.__get__(obj, obj.__class__) + obj[item] = func(node) if is_attr else func.__get__(node, node.__class__) else: setattr(obj, f"_patched_{method}", getattr(obj, method)) setattr(