diff --git a/tests/test_extensions.py b/tests/test_extensions.py index bdd6cd56..c3356d57 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -16,21 +16,21 @@ def test_register_node_patch(temp_extension_registry): from wsimod.extensions import extensions_registry, register_node_patch # Define a dummy function to patch a node method - @register_node_patch("node_name.method_name") + @register_node_patch("node_name", "method_name") def dummy_patch(): print("Patched method") # Check if the patch is registered correctly - assert extensions_registry[("node_name.method_name", None, False)] == dummy_patch + assert extensions_registry[("node_name", "method_name", None, False)] == dummy_patch # Another function with other arguments - @register_node_patch("node_name.method_name", item="default", is_attr=True) + @register_node_patch("node_name", "method_name", item="default", is_attr=True) def another_dummy_patch(): print("Another patched method") # Check if this other patch is registered correctly assert ( - extensions_registry[("node_name.method_name", "default", True)] + extensions_registry[("node_name", "method_name", "default", True)] == another_dummy_patch ) @@ -52,22 +52,22 @@ def test_apply_patches(temp_extension_registry): model.nodes[node.name] = node # 1. Patch a method - @register_node_patch("dummy_node.apply_overrides") + @register_node_patch("dummy_node", "apply_overrides") def dummy_patch(): pass # 2. Patch an attribute - @register_node_patch("dummy_node.t", is_attr=True) + @register_node_patch("dummy_node", "t", is_attr=True) def another_dummy_patch(node): return f"A pathced attribute for {node.name}" # 3. Patch a method with an item - @register_node_patch("dummy_node.pull_set_handler", item="default") + @register_node_patch("dummy_node", "pull_set_handler", item="default") def yet_another_dummy_patch(): pass # 4. Path a method of an attribute - @register_node_patch("dummy_node.dummy_arc.arc_mass_balance") + @register_node_patch("dummy_node", "dummy_arc.arc_mass_balance") def arc_dummy_patch(): pass @@ -131,7 +131,7 @@ def test_path_method_with_reuse(temp_extension_registry): model = Model() model.nodes[node.name] = node - @register_node_patch("dummy_node.pull_distributed") + @register_node_patch("dummy_node", "pull_distributed") def new_pull_distributed(self, vqip, of_type=None, tag="default"): return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag) diff --git a/wsimod/extensions.py b/wsimod/extensions.py index aef55c6c..1fa67387 100644 --- a/wsimod/extensions.py +++ b/wsimod/extensions.py @@ -6,10 +6,10 @@ Example of patching a method: -`empty_distributed` will be called instead of `my_node.pull_distributed`: +`empty_distributed` will be called instead of `pull_distributed` of "my_node": >>> from wsimod.extensions import register_node_patch, apply_patches - >>> @register_node_patch("my_node.pull_distributed") + >>> @register_node_patch("my_node", "pull_distributed") >>> def empty_distributed(self, vqip): >>> return {} @@ -18,24 +18,24 @@ Example of patching an attribute: -`10` will be assigned to `my_node.t`: +`10` will be assigned to `t`: - >>> @register_node_patch("my_node.t", is_attr=True) + >>> @register_node_patch("my_node", "t", is_attr=True) >>> def patch_t(node): >>> return 10 Example of patching an attribute item: `patch_default_pull_set_handler` will be assigned to -`my_node.pull_set_handler["default"]`: +`pull_set_handler["default"]`: - >>> @register_node_patch("my_node.pull_set_handler", item="default") + >>> @register_node_patch("my_node", "pull_set_handler", item="default") >>> def patch_default_pull_set_handler(self, vqip): >>> return {} If patching a method of an attribute, the `is_attr` argument should be set to `True` and -the target should include the node name and the attribute name and the method name, all -separated by periods, eg. `node_name.attribute_name.method_name`. +the target should include the attribute name and the method name, all separated by +periods, eg. `attribute_name.method_name`. It should be noted that the patched function should have the same signature as the original method or attribute, and the return type should be the same as well, otherwise @@ -61,15 +61,16 @@ def register_node_patch( - target: str, item: Hashable = None, is_attr: bool = False + node_name: str, target: str, item: Hashable = None, is_attr: bool = False ) -> Callable: """Register a function to patch a node method or any of its attributes. Args: - target (str): The target of the object to patch as a string with the node name - attribute, sub-attribute, etc. and finally method (or attribue) to replace, - sepparated with period, eg. `node_name.make_discharge` or - `node_name.sewer_tank.pull_storage_exact`. + node_name (str): The name of the node to patch. + target (str): The target of the object to patch in the form of a string with the + attribute, sub-attribute, etc. and finally method (or attribute) to replace, + sepparated with period, eg. `make_discharge` or + `sewer_tank.pull_storage_exact`. item (Hashable): Typically a string or an integer indicating the item to replace in the selected attribue, which should be a list or a dictionary. is_attr (bool): If True, the decorated function will be called when applying @@ -77,12 +78,12 @@ def register_node_patch( function itself. In this case, the only argument passed to the function is the node object. """ - target_id = (target, item, is_attr) + target_id = (node_name, target, item, is_attr) if target_id in extensions_registry: raise ValueError(f"Patch for {target} already registered.") def decorator(func): - extensions_registry[(target, item, is_attr)] = func + extensions_registry[target_id] = func return func return decorator @@ -96,16 +97,8 @@ def apply_patches(model: Model) -> None: Args: model (Model): The model to apply the patches to. """ - for (target, item, is_attr), func in extensions_registry.items(): - # Process the target string + for (node_name, target, item, is_attr), func in extensions_registry.items(): starget = target.split(".") - if len(starget) < 2: - raise ValueError( - f"Invalid target {target}. At least two elements are required separated" - "by a period, indicating the node name and the method/attribute to " - "patch." - ) - node_name = starget.pop(0) method = starget.pop() # Get the member to patch