Skip to content

Commit

Permalink
♻️ Use a separate argument for the node name.
Browse files Browse the repository at this point in the history
  • Loading branch information
dalonsoa committed Sep 10, 2024
1 parent dc5c216 commit 8082f1f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 33 deletions.
18 changes: 9 additions & 9 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@ def test_register_node_patch(temp_extension_registry):
from wsimod.extensions import extensions_registry, register_node_patch

# Define a dummy function to patch a node method
@register_node_patch("node_name.method_name")
@register_node_patch("node_name", "method_name")
def dummy_patch():
print("Patched method")

# Check if the patch is registered correctly
assert extensions_registry[("node_name.method_name", None, False)] == dummy_patch
assert extensions_registry[("node_name", "method_name", None, False)] == dummy_patch

# Another function with other arguments
@register_node_patch("node_name.method_name", item="default", is_attr=True)
@register_node_patch("node_name", "method_name", item="default", is_attr=True)
def another_dummy_patch():
print("Another patched method")

# Check if this other patch is registered correctly
assert (
extensions_registry[("node_name.method_name", "default", True)]
extensions_registry[("node_name", "method_name", "default", True)]
== another_dummy_patch
)

Expand All @@ -52,22 +52,22 @@ def test_apply_patches(temp_extension_registry):
model.nodes[node.name] = node

# 1. Patch a method
@register_node_patch("dummy_node.apply_overrides")
@register_node_patch("dummy_node", "apply_overrides")
def dummy_patch():
pass

# 2. Patch an attribute
@register_node_patch("dummy_node.t", is_attr=True)
@register_node_patch("dummy_node", "t", is_attr=True)
def another_dummy_patch(node):
return f"A pathced attribute for {node.name}"

# 3. Patch a method with an item
@register_node_patch("dummy_node.pull_set_handler", item="default")
@register_node_patch("dummy_node", "pull_set_handler", item="default")
def yet_another_dummy_patch():
pass

# 4. Path a method of an attribute
@register_node_patch("dummy_node.dummy_arc.arc_mass_balance")
@register_node_patch("dummy_node", "dummy_arc.arc_mass_balance")
def arc_dummy_patch():
pass

Expand Down Expand Up @@ -131,7 +131,7 @@ def test_path_method_with_reuse(temp_extension_registry):
model = Model()
model.nodes[node.name] = node

@register_node_patch("dummy_node.pull_distributed")
@register_node_patch("dummy_node", "pull_distributed")
def new_pull_distributed(self, vqip, of_type=None, tag="default"):
return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag)

Expand Down
41 changes: 17 additions & 24 deletions wsimod/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
Example of patching a method:
`empty_distributed` will be called instead of `my_node.pull_distributed`:
`empty_distributed` will be called instead of `pull_distributed` of "my_node":
>>> from wsimod.extensions import register_node_patch, apply_patches
>>> @register_node_patch("my_node.pull_distributed")
>>> @register_node_patch("my_node", "pull_distributed")
>>> def empty_distributed(self, vqip):
>>> return {}
Expand All @@ -18,24 +18,24 @@
Example of patching an attribute:
`10` will be assigned to `my_node.t`:
`10` will be assigned to `t`:
>>> @register_node_patch("my_node.t", is_attr=True)
>>> @register_node_patch("my_node", "t", is_attr=True)
>>> def patch_t(node):
>>> return 10
Example of patching an attribute item:
`patch_default_pull_set_handler` will be assigned to
`my_node.pull_set_handler["default"]`:
`pull_set_handler["default"]`:
>>> @register_node_patch("my_node.pull_set_handler", item="default")
>>> @register_node_patch("my_node", "pull_set_handler", item="default")
>>> def patch_default_pull_set_handler(self, vqip):
>>> return {}
If patching a method of an attribute, the `is_attr` argument should be set to `True` and
the target should include the node name and the attribute name and the method name, all
separated by periods, eg. `node_name.attribute_name.method_name`.
the target should include the attribute name and the method name, all separated by
periods, eg. `attribute_name.method_name`.
It should be noted that the patched function should have the same signature as the
original method or attribute, and the return type should be the same as well, otherwise
Expand All @@ -61,28 +61,29 @@


def register_node_patch(
target: str, item: Hashable = None, is_attr: bool = False
node_name: str, target: str, item: Hashable = None, is_attr: bool = False
) -> Callable:
"""Register a function to patch a node method or any of its attributes.
Args:
target (str): The target of the object to patch as a string with the node name
attribute, sub-attribute, etc. and finally method (or attribue) to replace,
sepparated with period, eg. `node_name.make_discharge` or
`node_name.sewer_tank.pull_storage_exact`.
node_name (str): The name of the node to patch.
target (str): The target of the object to patch in the form of a string with the
attribute, sub-attribute, etc. and finally method (or attribute) to replace,
sepparated with period, eg. `make_discharge` or
`sewer_tank.pull_storage_exact`.
item (Hashable): Typically a string or an integer indicating the item to replace
in the selected attribue, which should be a list or a dictionary.
is_attr (bool): If True, the decorated function will be called when applying
the patch and the result assigned to the target, instead of assigning the
function itself. In this case, the only argument passed to the function is
the node object.
"""
target_id = (target, item, is_attr)
target_id = (node_name, target, item, is_attr)
if target_id in extensions_registry:
raise ValueError(f"Patch for {target} already registered.")

def decorator(func):
extensions_registry[(target, item, is_attr)] = func
extensions_registry[target_id] = func
return func

return decorator
Expand All @@ -96,16 +97,8 @@ def apply_patches(model: Model) -> None:
Args:
model (Model): The model to apply the patches to.
"""
for (target, item, is_attr), func in extensions_registry.items():
# Process the target string
for (node_name, target, item, is_attr), func in extensions_registry.items():
starget = target.split(".")
if len(starget) < 2:
raise ValueError(
f"Invalid target {target}. At least two elements are required separated"
"by a period, indicating the node name and the method/attribute to "
"patch."
)
node_name = starget.pop(0)
method = starget.pop()

# Get the member to patch
Expand Down

0 comments on commit 8082f1f

Please sign in to comment.