From 566976ee6d905424ac5f34f7b2ac874ca551c087 Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Tue, 10 Sep 2024 16:48:01 +0100 Subject: [PATCH] :bug: Fix node not being passed as first argument when item is present --- tests/test_extensions.py | 18 +++++------------- wsimod/extensions.py | 2 +- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 752de934..7f76f681 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -117,7 +117,6 @@ def assert_dict_almost_equal(d1: dict, d2: dict, tol: Optional[float] = None): def test_path_method_with_reuse(temp_extension_registry): - from functools import partial from wsimod.arcs.arcs import Arc from wsimod.extensions import apply_patches, register_node_patch from wsimod.nodes.storage import Reservoir @@ -149,11 +148,7 @@ def new_pull_distributed(self, vqip, of_type=None, tag="default"): def test_handler_extensions(temp_extension_registry): from wsimod.arcs.arcs import Arc - from wsimod.extensions import ( - apply_patches, - extensions_registry, - register_node_patch, - ) + from wsimod.extensions import apply_patches, register_node_patch from wsimod.nodes import Node from wsimod.orchestration.model import Model @@ -165,16 +160,13 @@ def test_handler_extensions(temp_extension_registry): # 1. Patch a handler @register_node_patch("dummy_node", "pull_check_handler", item="default") - def dummy_patch(*args, **kwargs): + def dummy_patch(self, *args, **kwargs): return "dummy_patch" # 2. Patch a handler with access to self - @register_node_patch("dummy_node", "pull_set_handler", item="default", is_attr=True) - def dummy_patch(self): - def f(vqip, *args, **kwargs): - return f"{self.name} - {vqip['volume']}" - - return f + @register_node_patch("dummy_node", "pull_set_handler", item="default") + def dummy_patch(self, vqip, *args, **kwargs): + return f"{self.name} - {vqip['volume']}" apply_patches(model) diff --git a/wsimod/extensions.py b/wsimod/extensions.py index 1fa67387..6d4a1be4 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -109,7 +109,7 @@ def apply_patches(model: Model) -> None: # Apply the patch if item is not None: obj = getattr(obj, method) - obj[item] = func(node) if is_attr else func.__get__(obj, obj.__class__) + obj[item] = func(node) if is_attr else func.__get__(node, node.__class__) else: setattr(obj, f"_patched_{method}", getattr(obj, method)) setattr(