Skip to content

Commit

Permalink
Clarify validate_model_node_ids (#867)
Browse files Browse the repository at this point in the history
Fixes #865.

Net output:

```
For LevelBoundary, the node IDs in the data tables don't match the node IDs in the network.
    Node IDs only in the data tables: {181, 182, 183, 184, 185, 186, 187, 188, 189}.
    Node IDs only in the network: set().
```
  • Loading branch information
visr authored Dec 7, 2023
1 parent 4606007 commit 725bca5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
23 changes: 16 additions & 7 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,28 @@ def validate_model_node_field_ids(self):
raise ValueError(" ".join(msg))

def validate_model_node_ids(self):
"""Check whether the node IDs in the node field correspond to the node IDs on the node type fields."""
"""Check whether the node IDs in the data tables correspond to the node IDs in the network."""

error_messages = []

for node in self.nodes().values():
node_ids_field = node.node_ids()
node_ids_from_node_field = self.network.node.df.loc[
self.network.node.df["type"] == node.get_input_type()
].index
nodetype = node.get_input_type()
if nodetype == "Network":
# skip the reference
continue
node_ids_data = set(node.node_ids())
node_ids_network = set(
self.network.node.df.loc[self.network.node.df["type"] == nodetype].index
)

if not set(node_ids_from_node_field) == set(node_ids_field):
if not node_ids_network == node_ids_data:
extra_in_network = node_ids_network.difference(node_ids_data)
extra_in_data = node_ids_data.difference(node_ids_network)
error_messages.append(
f"The node IDs in the field {node} {node_ids_field} do not correspond with the node IDs in the field node {node_ids_from_node_field.tolist()}."
f"""For {nodetype}, the node IDs in the data tables don't match the node IDs in the network.
Node IDs only in the data tables: {extra_in_data}.
Node IDs only in the network: {extra_in_network}.
"""
)

if len(error_messages) > 0:
Expand Down
5 changes: 4 additions & 1 deletion python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def test_node_ids_misassigned(basic):
model.pump.static.df.loc[0, "node_id"] = 8
model.fractional_flow.static.df.loc[1, "node_id"] = 7

with pytest.raises(ValueError, match="The node IDs in the field static.+"):
with pytest.raises(
ValueError,
match="For FractionalFlow, the node IDs in the data tables don't match the node IDs in the network.+",
):
model.validate_model_node_ids()


Expand Down

0 comments on commit 725bca5

Please sign in to comment.