diff --git a/docs/demo/scripts/customise_interactions.py b/docs/demo/scripts/customise_interactions.py index bc86bba8..02e70e44 100644 --- a/docs/demo/scripts/customise_interactions.py +++ b/docs/demo/scripts/customise_interactions.py @@ -338,31 +338,11 @@ def custom_handler_function(x): # %% -def wrap(func): - """ - - Args: - func: - - Returns: - - """ - - def pull_distributed_wrapper(x): - """ - - Args: - x: - - Returns: - - """ - return func(x, tag="FWTW") - - return pull_distributed_wrapper - - -my_fwtw.pull_distributed = wrap(my_fwtw.pull_distributed) +from wsimod.extensions import extensions as extend +@extend.model_attribute(obj=my_fwtw, attribute_name="pull_distributed") +def new_distributed(pull_distributed, vqip): + """pull_distributed with the tag 'FWTW'.""" + return pull_distributed(vqip, tag="FWTW") # %% [markdown] # Explaining decorators is outside the scope of this tutorial, though you can diff --git a/wsimod/extensions/extensions.py b/wsimod/extensions/extensions.py new file mode 100644 index 00000000..31a1b2e8 --- /dev/null +++ b/wsimod/extensions/extensions.py @@ -0,0 +1,33 @@ +def model_attribute(obj, attribute_name): + """ + Decorator to extend or modify a model attribute. + + Args: + obj: The model object whose attribute should be modified. + attribute_name (str): The name of the attribute to modify. + + Returns: + A decorator function that takes the extension function as an argument. + """ + def decorator(func): + """ + Decorator function that applies the extension function to the model attribute. + + Args: + func: The extension function that modifies the model attribute. + """ + attribute = getattr(obj, attribute_name) + + def wrapped_attribute(*args, **kwargs): + return func(attribute, *args, **kwargs) + + setattr(obj, attribute_name, wrapped_attribute) + return wrapped_attribute + + return decorator + + +@model_attribute(obj=my_fwtw, attribute_name="pull_distributed") +def new_distributed(pull_distributed, vqip): + """pull_distributed with the tag 'FWTW'.""" + return pull_distributed(vqip, tag="FWTW") \ No newline at end of file