Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the extensions patch functionality #101

Merged
merged 12 commits into from
Sep 11, 2024
84 changes: 84 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest


@pytest.fixture
def temp_extension_registry():
from wsimod.extensions import extensions_registry

bkp = extensions_registry.copy()
extensions_registry.clear()
yield
extensions_registry.clear()
extensions_registry.update(bkp)


def test_register_node_patch(temp_extension_registry):
from wsimod.extensions import extensions_registry, register_node_patch

# Define a dummy function to patch a node method
@register_node_patch("node_name.method_name")
def dummy_patch():
print("Patched method")

# Check if the patch is registered correctly
assert extensions_registry[("node_name.method_name", None, False)] == dummy_patch

# Another function with other arguments
@register_node_patch("node_name.method_name", item="default", is_attr=True)
def another_dummy_patch():
print("Another patched method")

# Check if this other patch is registered correctly
assert (
extensions_registry[("node_name.method_name", "default", True)]
== another_dummy_patch
)


def test_apply_patches(temp_extension_registry):
from wsimod.arcs.arcs import Arc
from wsimod.extensions import (
apply_patches,
extensions_registry,
register_node_patch,
)
from wsimod.nodes import Node
from wsimod.orchestration.model import Model

# Create a dummy model
node = Node("dummy_node")
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)
model = Model()
model.nodes[node.name] = node

# 1. Patch a method
@register_node_patch("dummy_node.apply_overrides")
def dummy_patch():
pass

# 2. Patch an attribute
@register_node_patch("dummy_node.t", is_attr=True)
def another_dummy_patch(node):
return f"A pathced attribute for {node.name}"

# 3. Patch a method with an item
@register_node_patch("dummy_node.pull_set_handler", item="default")
def yet_another_dummy_patch():
pass

# 4. Path a method of an attribute
@register_node_patch("dummy_node.dummy_arc.arc_mass_balance")
def arc_dummy_patch():
pass

# Check if all patches are registered
assert len(extensions_registry) == 4

# Apply the patches
apply_patches(model)

# Verify that the patches are applied correctly
assert model.nodes[node.name].apply_overrides == dummy_patch
assert model.nodes[node.name].t == another_dummy_patch(node)
assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch
assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can include an example that behaves like a conventional decorator (since this will be a common use case)?

Suggested change
# Check if all patches are registered
assert len(extensions_registry) == 4
# Apply the patches
apply_patches(model)
# Verify that the patches are applied correctly
assert model.nodes[node.name].apply_overrides == dummy_patch
assert model.nodes[node.name].t == another_dummy_patch(node)
assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch
assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch
# 5. Patch a decorator
@register_node_patch("dummy_node.pull_distributed")
def a_dummy_decorator(node, vqip):
#Only pull from Reservoir
return node.pull_distributed(vqip, of_type=['Reservoir'])
# Check if all patches are registered
assert len(extensions_registry) == 5
# Apply the patches
apply_patches(model)
# Verify that the patches are applied correctly
assert model.nodes[node.name].apply_overrides == dummy_patch
assert model.nodes[node.name].t == another_dummy_patch(node)
assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch
assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch
assert model.nodes[node.name].dummy_arc.pull_distributed == a_dummy_decorator

Copy link
Collaborator

@barneydobson barneydobson Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think I've misunderstood the decorator example.

If I add the line:

_ = model.nodes[node.name].pull_distributed(node.empty_vqip())

To the test (which would be the normal use case) - it fails. doesn't seem to help if I set is_attr=True... so definitely an example to cover that would be helpful ;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean by "a conventional decorator". And where are you putting that line?

This comment was marked as outdated.

This comment was marked as outdated.

Copy link
Collaborator

@barneydobson barneydobson Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sorry for all the messages there - I was just trying to figure out how to use this properly.

The below test passes, but is it the correct way to extend an existing function (while still calling it)? If so it should be in tests and in the documentation as it will be one of the more common uses of extensions.

def assertDictAlmostEqual(d1, d2, accuracy=19):
    """

    Args:
        d1:
        d2:
        accuracy:
    """
    for d in [d1, d2]:
        for key, item in d.items():
            d[key] = round(item, accuracy)
    assert d1 == d2


def test_apply_dec(temp_extension_registry):
    from wsimod.arcs.arcs import Arc
    from wsimod.extensions import (
        apply_patches,
        extensions_registry,
        register_node_patch,
    )
    from wsimod.nodes.storage import Reservoir
    from wsimod.orchestration.model import Model

    # Create a dummy model
    node = Reservoir(name="dummy_node", initial_storage=10,capacity = 10)
    node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)

    vq = node.pull_distributed({'volume' : 5})
    assertDictAlmostEqual(vq, node.v_change_vqip(node.empty_vqip(),5))

    model = Model()
    model.nodes[node.name] = node

    # 5. Patch a decorator
    @register_node_patch("dummy_node.pull_distributed", is_attr=True)
    def extend_function(node):
        def wrapper(f_old):
            def f(vqip, *args, **kw):
                return f_old(vqip, of_type = ['Node'], *args, **kw)
            return f
        #Only pull from Reservoir
        return wrapper(node.pull_distributed)
        
    
    # Apply the patches
    apply_patches(model)

    # Check appropriate result
    assert node.tank.storage['volume'] == 5
    vq = model.nodes[node.name].pull_distributed({'volume' : 5})
    assertDictAlmostEqual(vq, node.empty_vqip())
    assert node.tank.storage['volume'] == 5

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, now I get what you want. So a common use case is to use the old function you are overriding in the override itself, something like calling super().some_function in a child class, right?

Ok, let's see if I can figure out the most elegant way of doing it, so the user doesn't need to deal with functions, within functions, within functions...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - while operational my approach is not the most elegant ;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a look now. It might be useful to re-read the docstrings to make sure the explanations are clear.

116 changes: 116 additions & 0 deletions wsimod/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""This module contains the utilities to extend WSMOD with new features.

The `register_node_patch` decorator is used to register a function that will be used
instead of a method or attribute of a node. The `apply_patches` function applies all
registered patches to a model.

Example of patching a method:

`empty_distributed` will be called instead of `my_node.pull_distributed`:

>>> from wsimod.extensions import register_node_patch, apply_patches
>>> @register_node_patch("my_node.pull_distributed")
>>> def empty_distributed(self, vqip):
>>> return {}

Attributes, methods of the node, and sub-attributes can be patched. Also, an item of a
list or a dictionary can be patched if the item argument is provided.

Example of patching an attribute:

`10` will be assigned to `my_node.t`:

>>> @register_node_patch("my_node.t", is_attr=True)
>>> def patch_t(node):
>>> return 10

Example of patching an attribute item:

`patch_default_pull_set_handler` will be assigned to
`my_node.pull_set_handler["default"]`:

>>> @register_node_patch("my_node.pull_set_handler", item="default")
>>> def patch_default_pull_set_handler(self, vqip):
>>> return {}

If patching a method of an attribute, the `is_attr` argument should be set to `True` and
the target should include the node name and the attribute name and the method name, all
separated by periods, eg. `node_name.attribute_name.method_name`.

It should be noted that the patched function should have the same signature as the
original method or attribute, and the return type should be the same as well, otherwise
there will be a runtime error.

Finally, the `apply_patches` is called within the `Model.load` method and will apply all
patches in the order they were registered. This means that users need to be careful with
the order of the patches in their extensions files, as they may have interdependencies.

TODO: Update documentation on extensions files.
"""

from typing import Callable, Hashable

from .orchestration.model import Model

extensions_registry: dict[tuple[str, Hashable, bool], Callable] = {}


def register_node_patch(
target: str, item: Hashable = None, is_attr: bool = False
) -> Callable:
"""Register a function to patch a node method or any of its attributes.

Args:
target (str): The target of the object to patch as a string with the node name
attribute, sub-attribute, etc. and finally method (or attribue) to replace,
sepparated with period, eg. `node_name.make_discharge` or
`node_name.sewer_tank.pull_storage_exact`.
item (Hashable): Typically a string or an integer indicating the item to replace
in the selected attribue, which should be a list or a dictionary.
is_attr (bool): If True, the decorated function will be called when applying
the patch and the result assigned to the target, instead of assigning the
function itself. In this case, the only argument passed to the function is
the node object.
"""
target_id = (target, item, is_attr)
if target_id in extensions_registry:
raise ValueError(f"Patch for {target} already registered.")

def decorator(func):
extensions_registry[(target, item, is_attr)] = func
return func

return decorator


def apply_patches(model: Model) -> None:
"""Apply all registered patches to the model.

TODO: Validate signature of the patched methods and type of patched attributes.

Args:
model (Model): The model to apply the patches to.
"""
for (target, item, is_attr), func in extensions_registry.items():
# Process the target string
starget = target.split(".")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there isn't anything to stop nodes have a . in the node name.. It's not in my default model setup - but I wouldn't be surprised if others have introduced this unthinkingly.

Perhaps could we have target_node as a separate argument - since the name could be anything, and then any sub-attributes can by . delimited since they will follow python syntax?

Not sure - what do you think? If it's too awful then at least we validate to ensure no . in name during model.load

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a problem. That's an easy fix. I just put it all together in a single line because I felt it was easier to understand and to cover more cases - in particular the sub-attributes one - in one, consistent approach.

About users using . for the node names... it might be a bit opinionated, but you are the developer, so you tell the users how they should use the tool. If you tell them not to use . but _ or something else, they won't use .. It is not the other way around.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you're right. There's no point on artificially restricting what a node name can be. What about changing the decorator signature to?

def register_node_patch(
    node_name: str, target: str, item: Hashable = None, is_attr: bool = False
) -> Callable:

So node_name is provided independently and can therefore be anything?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works for me!

if len(starget) < 2:
raise ValueError(
f"Invalid target {target}. At least two elements are required separated"
"by a period, indicating the node name and the method/attribute to "
"patch."
)
node_name = starget.pop(0)
method = starget.pop()

# Get the member to patch
node = obj = model.nodes[node_name]
for attr in starget:
obj = getattr(obj, attr)

# Apply the patch
if item is not None:
obj = getattr(obj, method)
obj[item] = func(node) if is_attr else func
else:
setattr(obj, method, func(node) if is_attr else func)
4 changes: 4 additions & 0 deletions wsimod/orchestration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def load(self, address, config_name="config.yml", overrides={}):
config_name:
overrides:
"""
from ..extensions import apply_patches

with open(os.path.join(address, config_name), "r") as file:
data = yaml.safe_load(file)

Expand Down Expand Up @@ -191,6 +193,8 @@ def load(self, address, config_name="config.yml", overrides={}):
if "dates" in data.keys():
self.dates = [to_datetime(x) for x in data["dates"]]

apply_patches(self)

def save(self, address, config_name="config.yml", compress=False):
"""Save the model object to a yaml file and input data to csv.gz format in the
directory specified.
Expand Down
Loading