Skip to content

Commit

Permalink
🐛 Fix node not being passed as first argument when item is present
Browse files Browse the repository at this point in the history
  • Loading branch information
dalonsoa committed Sep 10, 2024
1 parent a6636ce commit 566976e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 14 deletions.
18 changes: 5 additions & 13 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def assert_dict_almost_equal(d1: dict, d2: dict, tol: Optional[float] = None):


def test_path_method_with_reuse(temp_extension_registry):
from functools import partial
from wsimod.arcs.arcs import Arc
from wsimod.extensions import apply_patches, register_node_patch
from wsimod.nodes.storage import Reservoir
Expand Down Expand Up @@ -149,11 +148,7 @@ def new_pull_distributed(self, vqip, of_type=None, tag="default"):

def test_handler_extensions(temp_extension_registry):
from wsimod.arcs.arcs import Arc
from wsimod.extensions import (
apply_patches,
extensions_registry,
register_node_patch,
)
from wsimod.extensions import apply_patches, register_node_patch
from wsimod.nodes import Node
from wsimod.orchestration.model import Model

Expand All @@ -165,16 +160,13 @@ def test_handler_extensions(temp_extension_registry):

# 1. Patch a handler
@register_node_patch("dummy_node", "pull_check_handler", item="default")
def dummy_patch(*args, **kwargs):
def dummy_patch(self, *args, **kwargs):
return "dummy_patch"

# 2. Patch a handler with access to self
@register_node_patch("dummy_node", "pull_set_handler", item="default", is_attr=True)
def dummy_patch(self):
def f(vqip, *args, **kwargs):
return f"{self.name} - {vqip['volume']}"

return f
@register_node_patch("dummy_node", "pull_set_handler", item="default")
def dummy_patch(self, vqip, *args, **kwargs):
return f"{self.name} - {vqip['volume']}"

apply_patches(model)

Expand Down
2 changes: 1 addition & 1 deletion wsimod/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def apply_patches(model: Model) -> None:
# Apply the patch
if item is not None:
obj = getattr(obj, method)
obj[item] = func(node) if is_attr else func.__get__(obj, obj.__class__)
obj[item] = func(node) if is_attr else func.__get__(node, node.__class__)
else:
setattr(obj, f"_patched_{method}", getattr(obj, method))
setattr(
Expand Down

0 comments on commit 566976e

Please sign in to comment.