Skip to content

Commit

Permalink
custom class
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobson committed Oct 3, 2024
1 parent 61573a9 commit 388ac40
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
33 changes: 33 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,36 @@ def dummy_patch(self, vqip, *args, **kwargs):

assert node.pull_check() == "dummy_patch"
assert node.pull_set({"volume": 1}) == "dummy_node - 1"


def test_custom_class():
"""Test a custom class."""

import tempfile

from wsimod.nodes.nodes import Node, NODES_REGISTRY
from wsimod.orchestration.model import Model, to_datetime

class CustomNode(Node):
def __init__(self, name):
super().__init__(name)
self.custom_attr = 1

def end_timestep(self):
self.custom_attr += 1
super().end_timestep()

NODES_REGISTRY["CustomNode"] = CustomNode

with tempfile.TemporaryDirectory() as temp_dir:
model = Model()
model.nodes["node_name"] = CustomNode("node_name")
model.save(temp_dir)

del model
model = Model()
model.load(temp_dir)
model.river_dishcarge_order = []
assert model.nodes["node_name"].custom_attr == 1
model.run(dates=[to_datetime("2000-01-01")])
assert model.nodes["node_name"].custom_attr == 2
4 changes: 2 additions & 2 deletions wsimod/orchestration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def load(self, address, config_name="config.yml", overrides={}):
FLAG:
E.G. ADDITION FOR NEW ORCHESTRATION
"""
load_extension_files(data.get("extensions", []))

if "orchestration" in data.keys():
# Update orchestration
Expand Down Expand Up @@ -220,7 +221,6 @@ def load(self, address, config_name="config.yml", overrides={}):
if "dates" in data.keys():
self.dates = [to_datetime(x) for x in data["dates"]]

load_extension_files(data.get("extensions", []))
apply_patches(self)

def save(self, address, config_name="config.yml", compress=False):
Expand Down Expand Up @@ -497,6 +497,7 @@ def add_arcs(self, arclist):
]:
river_arcs[name] = self.arcs[name]

self.river_discharge_order = []
if any(river_arcs):
upstreamness = (
{x: 0 for x in self.nodes_type["Waste"].keys()}
Expand All @@ -505,7 +506,6 @@ def add_arcs(self, arclist):
)
upstreamness = self.assign_upstream(river_arcs, upstreamness)

self.river_discharge_order = []
if "River" in self.nodes_type:
for node in sorted(
upstreamness.items(), key=lambda item: item[1], reverse=True
Expand Down

0 comments on commit 388ac40

Please sign in to comment.