Skip to content

Commit

Permalink
don't require consecutive node IDs from 1 to n
Browse files Browse the repository at this point in the history
  • Loading branch information
visr committed Oct 23, 2023
1 parent 097053c commit b04e489
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 50 deletions.
15 changes: 0 additions & 15 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def validate_model_node_field_IDs(self):

# Check node IDs of node fields
node_IDs_all = []
n_nodes = len(self.node.static)

for name in self.fields():
if name in node_names_all_snake_case:
Expand All @@ -240,20 +239,6 @@ def validate_model_node_field_IDs(self):
f"These node IDs were assigned to multiple node types: {node_IDs_unique[(node_ID_counts > 1)]}."
)

if not np.array_equal(node_IDs_unique, np.arange(n_nodes) + 1):
node_IDs_missing = set(np.arange(n_nodes) + 1) - set(node_IDs_unique)
node_IDs_over = set(node_IDs_unique) - set(np.arange(n_nodes) + 1)
msg = [
f"Expected node IDs from 1 to {n_nodes} (the number of rows in self.node.static)."
]
if len(node_IDs_missing) > 0:
msg.append(f"These node IDs are missing: {node_IDs_missing}.")

if len(node_IDs_over) > 0:
msg.append(f"These node IDs are unexpected: {node_IDs_over}.")

raise ValueError(" ".join(msg))

def validate_model_node_IDs(self):
"""Check whether the node IDs in the node field correspond to the node IDs on the node type fields."""

Expand Down
25 changes: 0 additions & 25 deletions python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re

import pandas as pd
import pytest
from pydantic import ValidationError
from ribasim import Model, Solver
Expand Down Expand Up @@ -82,30 +81,6 @@ def test_node_ids_misassigned(basic):
model.validate_model_node_IDs()


def test_node_ids_unsequential(basic):
model = basic

basin = model.basin

basin.profile = pd.DataFrame(
data={
"node_id": [1, 1, 3, 3, 6, 6, 1000, 1000],
"area": [0.01, 1000.0] * 4,
"level": [0.0, 1.0] * 4,
}
)

basin.static["node_id"] = [1, 3, 6, 1000]

with pytest.raises(ValueError) as excinfo:
model.validate_model_node_field_IDs()

assert (
"Expected node IDs from 1 to 17 (the number of rows in self.node.static). These node IDs are missing: {9}. These node IDs are unexpected: {1000}."
in str(excinfo.value)
)


def test_tabulated_rating_curve_model(tabulated_rating_curve, tmp_path):
model_orig = tabulated_rating_curve
model_orig.write(tmp_path / "tabulated_rating_curve")
Expand Down
20 changes: 10 additions & 10 deletions python/ribasim_testmodels/ribasim_testmodels/trivial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ def trivial_model() -> ribasim.Model:
# Set up the nodes:
xy = np.array(
[
(400.0, 200.0), # 1: Basin
(450.0, 200.0), # 2: TabulatedRatingCurve
(500.0, 200.0), # 3: Terminal
(400.0, 200.0), # 2: Basin
(450.0, 200.0), # 4: TabulatedRatingCurve
(500.0, 200.0), # 6: Terminal
]
)
node_xy = gpd.points_from_xy(x=xy[:, 0], y=xy[:, 1])
Expand All @@ -25,15 +25,15 @@ def trivial_model() -> ribasim.Model:
node = ribasim.Node(
static=gpd.GeoDataFrame(
data={"type": node_type},
index=pd.Index(np.arange(len(xy)) + 1, name="fid"),
index=pd.Index([2, 4, 6], name="fid"),
geometry=node_xy,
crs="EPSG:28992",
)
)

# Setup the edges:
from_id = np.array([1, 2], dtype=np.int64)
to_id = np.array([2, 3], dtype=np.int64)
from_id = np.array([2, 4], dtype=np.int64)
to_id = np.array([4, 6], dtype=np.int64)
lines = ribasim.utils.geometry_from_connectivity(node, from_id, to_id)
edge = ribasim.Edge(
static=gpd.GeoDataFrame(
Expand All @@ -50,7 +50,7 @@ def trivial_model() -> ribasim.Model:
# Setup the basins:
profile = pd.DataFrame(
data={
"node_id": [1, 1],
"node_id": [2, 2],
"area": [0.01, 1000.0],
"level": [0.0, 1.0],
}
Expand All @@ -64,7 +64,7 @@ def trivial_model() -> ribasim.Model:

static = pd.DataFrame(
data={
"node_id": [1],
"node_id": [2],
"drainage": [0.0],
"potential_evaporation": [evaporation],
"infiltration": [0.0],
Expand All @@ -81,7 +81,7 @@ def trivial_model() -> ribasim.Model:
rating_curve = ribasim.TabulatedRatingCurve(
static=pd.DataFrame(
data={
"node_id": [2, 2],
"node_id": [4, 4],
"level": [0.0, 1.0],
"discharge": [0.0, q1000],
}
Expand All @@ -91,7 +91,7 @@ def trivial_model() -> ribasim.Model:
terminal = ribasim.Terminal(
static=pd.DataFrame(
data={
"node_id": [3],
"node_id": [6],
}
)
)
Expand Down

0 comments on commit b04e489

Please sign in to comment.