-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
190 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def temp_extension_registry(): | ||
from wsimod.extensions import extensions_registry | ||
|
||
bkp = extensions_registry.copy() | ||
extensions_registry.clear() | ||
yield | ||
extensions_registry.clear() | ||
extensions_registry.update(bkp) | ||
|
||
|
||
def test_register_node_patch(temp_extension_registry): | ||
from wsimod.extensions import extensions_registry, register_node_patch | ||
|
||
# Define a dummy function to patch a node method | ||
@register_node_patch("node_name.method_name") | ||
def dummy_patch(): | ||
print("Patched method") | ||
|
||
# Check if the patch is registered correctly | ||
assert extensions_registry[("node_name.method_name", None, False)] == dummy_patch | ||
|
||
# Another function with other arguments | ||
@register_node_patch("node_name.method_name", item="default", is_attr=True) | ||
def another_dummy_patch(): | ||
print("Another patched method") | ||
|
||
# Check if this other patch is registered correctly | ||
assert ( | ||
extensions_registry[("node_name.method_name", "default", True)] | ||
== another_dummy_patch | ||
) | ||
|
||
|
||
def test_apply_patches(temp_extension_registry): | ||
from wsimod.arcs.arcs import Arc | ||
from wsimod.extensions import ( | ||
apply_patches, | ||
extensions_registry, | ||
register_node_patch, | ||
) | ||
from wsimod.nodes import Node | ||
from wsimod.orchestration.model import Model | ||
|
||
# Create a dummy model | ||
node = Node("dummy_node") | ||
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) | ||
model = Model() | ||
model.nodes[node.name] = node | ||
|
||
# 1. Patch a method | ||
@register_node_patch("dummy_node.apply_overrides") | ||
def dummy_patch(): | ||
pass | ||
|
||
# 2. Patch an attribute | ||
@register_node_patch("dummy_node.t", is_attr=True) | ||
def another_dummy_patch(node): | ||
return f"A pathced attribute for {node.name}" | ||
|
||
# 3. Patch a method with an item | ||
@register_node_patch("dummy_node.pull_set_handler", item="default") | ||
def yet_another_dummy_patch(): | ||
pass | ||
|
||
# 4. Path a method of an attribute | ||
@register_node_patch("dummy_node.dummy_arc.arc_mass_balance") | ||
def arc_dummy_patch(): | ||
pass | ||
|
||
# Check if all patches are registered | ||
assert len(extensions_registry) == 4 | ||
|
||
# Apply the patches | ||
apply_patches(model) | ||
|
||
# Verify that the patches are applied correctly | ||
assert model.nodes[node.name].apply_overrides == dummy_patch | ||
assert model.nodes[node.name].t == another_dummy_patch(node) | ||
assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch | ||
assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
"""This module contains the utilities to extend WSMOD with new features. | ||
The `register_node_patch` decorator is used to register a function that will be used | ||
instead of a method or attribute of a node. The `apply_patches` function applies all | ||
registered patches to a model. | ||
Example of patching a method: | ||
`empty_distributed` will be called instead of `my_node.pull_distributed`: | ||
>>> from wsimod.extensions import register_node_patch, apply_patches | ||
>>> @register_node_patch("my_node.pull_distributed") | ||
>>> def empty_distributed(self, vqip): | ||
>>> return {} | ||
Attributes, methods of the node, and sub-attributes can be patched. Also, an item of a | ||
list or a dictionary can be patched if the item argument is provided. | ||
Example of patching an attribute: | ||
`10` will be assigned to `my_node.t`: | ||
>>> @register_node_patch("my_node.t", is_attr=True) | ||
>>> def patch_t(node): | ||
>>> return 10 | ||
Example of patching an attribute item: | ||
`patch_default_pull_set_handler` will be assigned to | ||
`my_node.pull_set_handler["default"]`: | ||
>>> @register_node_patch("my_node.pull_set_handler", item="default") | ||
>>> def patch_default_pull_set_handler(self, vqip): | ||
>>> return {} | ||
It should be noted that the patched function should have the same signature as the | ||
original method or attribute, and the return type should be the same as well, otherwise | ||
there will be a runtime error. | ||
""" | ||
|
||
from typing import Callable, Hashable | ||
|
||
from .orchestration.model import Model | ||
|
||
extensions_registry: dict[tuple[str, Hashable, bool], Callable] = {} | ||
|
||
|
||
def register_node_patch( | ||
target: str, item: Hashable = None, is_attr: bool = False | ||
) -> Callable: | ||
"""Register a function to patch a node method or any of its attributes. | ||
Args: | ||
target (str): The target of the object to patch as a string with the node name | ||
attribute, sub-attribute, etc. and finally method (or attribue) to replace, | ||
sepparated with period, eg. `node_name.make_discharge` or | ||
`node_name.sewer_tank.pull_storage_exact`. | ||
item (Hashable): Typically a string or an integer indicating the item to replace | ||
in the selected attribue, which should be a list or a dictionary. | ||
is_attr (bool): If True, the decorated function will be called when applying | ||
the patch and the result assigned to the target, instead of assigning the | ||
function itself. In this case, the only argument passed to the function is | ||
the node object. | ||
""" | ||
target_id = (target, item, is_attr) | ||
if target_id in extensions_registry: | ||
raise ValueError(f"Patch for {target} already registered.") | ||
|
||
def decorator(func): | ||
extensions_registry[(target, item, is_attr)] = func | ||
return func | ||
|
||
return decorator | ||
|
||
|
||
def apply_patches(model: Model) -> None: | ||
"""Apply all registered patches to the model. | ||
TODO: Validate signature of the patched methods and type of patched attributes. | ||
Args: | ||
model (Model): The model to apply the patches to. | ||
""" | ||
for (target, item, is_attr), func in extensions_registry.items(): | ||
# Process the target string | ||
starget = target.split(".") | ||
if len(starget) < 2: | ||
raise ValueError( | ||
f"Invalid target {target}. At least two elements are required separated" | ||
"by a period, indicating the node name and the method/attribute to " | ||
"patch." | ||
) | ||
node_name = starget.pop(0) | ||
method = starget.pop() | ||
|
||
# Get the member to patch | ||
node = obj = model.nodes[node_name] | ||
for attr in starget: | ||
obj = getattr(obj, attr) | ||
|
||
# Apply the patch | ||
if item is not None: | ||
obj = getattr(obj, method) | ||
obj[item] = func(node) if is_attr else func | ||
else: | ||
setattr(obj, method, func(node) if is_attr else func) |