diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 0741a223..bdd6cd56 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -81,12 +81,24 @@ def arc_dummy_patch(): assert ( model.nodes[node.name].apply_overrides.__qualname__ == dummy_patch.__qualname__ ) + assert ( + model.nodes[node.name]._patched_apply_overrides.__qualname__ + == "Node.apply_overrides" + ) assert model.nodes[node.name].t == another_dummy_patch(node) - assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch + assert model.nodes[node.name]._patched_t == None + assert ( + model.nodes[node.name].pull_set_handler["default"].__qualname__ + == yet_another_dummy_patch.__qualname__ + ) assert ( model.nodes[node.name].dummy_arc.arc_mass_balance.__qualname__ == arc_dummy_patch.__qualname__ ) + assert ( + model.nodes[node.name].dummy_arc._patched_arc_mass_balance.__qualname__ + == "Arc.arc_mass_balance" + ) def assert_dict_almost_equal(d1: dict, d2: dict, tol: float | None = None): diff --git a/wsimod/extensions.py b/wsimod/extensions.py index 721abe60..aef55c6c 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -39,7 +39,13 @@ It should be noted that the patched function should have the same signature as the original method or attribute, and the return type should be the same as well, otherwise -there will be a runtime error. +there will be a runtime error. In particular, the first argument of the patched function +should be the node object itself, which will typically be named `self`. + +The overridden method or attribute can be accessed within the patched function using the +`_patched_{method_name}` attribute of the object, eg. `self._patched_pull_distributed`. +The exception to this is when patching an item, in which case the original item is no +available to be used within the overriding function. Finally, the `apply_patches` is called within the `Model.load` method and will apply all patches in the order they were registered. This means that users need to be careful with @@ -110,10 +116,9 @@ def apply_patches(model: Model) -> None: # Apply the patch if item is not None: obj = getattr(obj, method) - obj[item] = func(node) if is_attr else func + obj[item] = func(node) if is_attr else func.__get__(obj, obj.__class__) else: - if is_attr: - setattr(obj, method, func(node)) - else: - setattr(obj, f"_patched_{method}", getattr(obj, method)) - setattr(obj, method, func.__get__(obj, obj.__class__)) + setattr(obj, f"_patched_{method}", getattr(obj, method)) + setattr( + obj, method, func(node) if is_attr else func.__get__(obj, obj.__class__) + )