Skip to content

Commit

Permalink
♻️ Add patched method when decorating.
Browse files Browse the repository at this point in the history
  • Loading branch information
dalonsoa committed Sep 10, 2024
1 parent d457baf commit e3c5776
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
44 changes: 44 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,47 @@ def arc_dummy_patch():
assert model.nodes[node.name].t == another_dummy_patch(node)
assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch
assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch


def assert_dict_almost_equal(d1: dict, d2: dict, tol: float | None = None):
"""Check if two dictionaries are almost equal.
Args:
d1 (dict): The first dictionary.
d2 (dict): The second dictionary.
tol (float | None, optional): Relative tolerance. Defaults to 1e-6,
`pytest.approx` default.
"""
for key in d1.keys():
assert d1[key] == pytest.approx(d2[key], rel=tol)


def test_path_method_with_reuse(temp_extension_registry):
from functools import partial
from wsimod.arcs.arcs import Arc
from wsimod.extensions import apply_patches, register_node_patch
from wsimod.nodes.storage import Reservoir
from wsimod.orchestration.model import Model

# Create a dummy model
node = Reservoir(name="dummy_node", initial_storage=10, capacity=10)
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)

vq = node.pull_distributed({"volume": 5})
assert_dict_almost_equal(vq, node.v_change_vqip(node.empty_vqip(), 5))

model = Model()
model.nodes[node.name] = node

@register_node_patch("dummy_node.pull_distributed")
def new_pull_distributed(self, vqip, of_type=None, tag="default"):
return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag)

# Apply the patches
apply_patches(model)

# Check appropriate result
assert node.tank.storage["volume"] == 5
vq = model.nodes[node.name].pull_distributed({"volume": 5})
assert_dict_almost_equal(vq, node.empty_vqip())
assert node.tank.storage["volume"] == 5
7 changes: 5 additions & 2 deletions wsimod/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
TODO: Update documentation on extensions files.
"""

from typing import Callable, Hashable

from .orchestration.model import Model
Expand Down Expand Up @@ -113,4 +112,8 @@ def apply_patches(model: Model) -> None:
obj = getattr(obj, method)
obj[item] = func(node) if is_attr else func
else:
setattr(obj, method, func(node) if is_attr else func)
if is_attr:
setattr(obj, method, func(node))
else:
setattr(obj, f"_patched_{method}", getattr(obj, method))
setattr(obj, method, func.__get__(obj, obj.__class__))

0 comments on commit e3c5776

Please sign in to comment.