Skip to content

Commit

Permalink
Update used ids on reading an existing model. (#1818)
Browse files Browse the repository at this point in the history
Fixes #1806 and #1804
  • Loading branch information
evetion authored Sep 10, 2024
1 parent a04ef5b commit 9514b30
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
9 changes: 8 additions & 1 deletion python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pandera.dtypes import Int32
from pandera.typing import Index, Series
from pandera.typing.geopandas import GeoDataFrame, GeoSeries
from pydantic import NonNegativeInt, PrivateAttr
from pydantic import NonNegativeInt, PrivateAttr, model_validator
from shapely.geometry import LineString, MultiLineString, Point

from ribasim.input_base import SpatialTableModel
Expand Down Expand Up @@ -60,6 +60,13 @@ class EdgeTable(SpatialTableModel[EdgeSchema]):

_used_edge_ids: UsedIDs = PrivateAttr(default_factory=UsedIDs)

@model_validator(mode="after")
def _update_used_ids(self) -> "EdgeTable":
if self.df is not None and len(self.df.index) > 0:
self._used_edge_ids.node_ids.update(self.df.index)
self._used_edge_ids.max_node_id = self.df.index.max()
return self

def add(
self,
from_node: NodeData,
Expand Down
17 changes: 16 additions & 1 deletion python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ def _ensure_edge_table_is_present(self) -> "Model":
self.edge.df.set_geometry("geometry", inplace=True, crs=self.crs)
return self

@model_validator(mode="after")
def _update_used_ids(self) -> "Model":
# Only update the used node IDs if we read from a database
if "database" in context_file_loading.get():
df = self.node_table().df
assert df is not None
if len(df.index) > 0:
self._used_node_ids.node_ids.update(df.index)
self._used_node_ids.max_node_id = df.index.max()
return self

@field_serializer("input_dir", "results_dir")
def _serialize_path(self, path: Path) -> str:
return str(path)
Expand Down Expand Up @@ -224,7 +235,11 @@ def _apply_crs_function(self, function_name: str, crs: str) -> None:
def node_table(self) -> NodeTable:
"""Compute the full sorted NodeTable from all node types."""
df_chunks = [node.node.df for node in self._nodes()]
df = pd.concat(df_chunks)
df = (
pd.concat(df_chunks)
if df_chunks
else pd.DataFrame(index=pd.Index([], name="node_id"))
)
node_table = NodeTable(df=df)
node_table.sort()
assert node_table.df is not None
Expand Down
13 changes: 13 additions & 0 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,19 @@ def test_node_autoincrement():
assert nbasin.node_id == 101


def test_node_autoincrement_existing_model(basic, tmp_path):
model = basic

model.write(tmp_path / "ribasim.toml")
nmodel = Model.read(tmp_path / "ribasim.toml")

assert nmodel._used_node_ids.max_node_id == 17
assert nmodel._used_node_ids.node_ids == set(range(1, 18)) - {13}

assert nmodel.edge._used_edge_ids.max_node_id == 16
assert nmodel.edge._used_edge_ids.node_ids == set(range(1, 17))


def test_node_empty_geometry():
model = Model(
starttime="2020-01-01",
Expand Down

0 comments on commit 9514b30

Please sign in to comment.