Skip to content

Commit

Permalink
test initial method
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobson committed Jun 3, 2024
1 parent fd6f068 commit 155271e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 25 deletions.
30 changes: 5 additions & 25 deletions docs/demo/scripts/customise_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,31 +338,11 @@ def custom_handler_function(x):


# %%
def wrap(func):
"""
Args:
func:
Returns:
"""

def pull_distributed_wrapper(x):
"""
Args:
x:
Returns:
"""
return func(x, tag="FWTW")

return pull_distributed_wrapper


my_fwtw.pull_distributed = wrap(my_fwtw.pull_distributed)
from wsimod.extensions import extensions as extend
@extend.model_attribute(obj=my_fwtw, attribute_name="pull_distributed")
def new_distributed(pull_distributed, vqip):
"""pull_distributed with the tag 'FWTW'."""
return pull_distributed(vqip, tag="FWTW")

# %% [markdown]
# Explaining decorators is outside the scope of this tutorial, though you can
Expand Down
33 changes: 33 additions & 0 deletions wsimod/extensions/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
def model_attribute(obj, attribute_name):
"""
Decorator to extend or modify a model attribute.
Args:
obj: The model object whose attribute should be modified.
attribute_name (str): The name of the attribute to modify.
Returns:
A decorator function that takes the extension function as an argument.
"""
def decorator(func):
"""
Decorator function that applies the extension function to the model attribute.
Args:
func: The extension function that modifies the model attribute.
"""
attribute = getattr(obj, attribute_name)

def wrapped_attribute(*args, **kwargs):
return func(attribute, *args, **kwargs)

setattr(obj, attribute_name, wrapped_attribute)
return wrapped_attribute

return decorator


@model_attribute(obj=my_fwtw, attribute_name="pull_distributed")
def new_distributed(pull_distributed, vqip):
"""pull_distributed with the tag 'FWTW'."""
return pull_distributed(vqip, tag="FWTW")

0 comments on commit 155271e

Please sign in to comment.