Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Dec 21, 2023
1 parent 042e595 commit 553c956
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 1 deletion.
13 changes: 13 additions & 0 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:
class TableModel(FileModel, Generic[TableT]):
df: DataFrame[TableT] | None = Field(default=None, exclude=True, repr=False)

def __eq__(self, other) -> bool:
if not type(self) == type(other):
return False

Check warning on line 165 in python/ribasim/ribasim/input_base.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/input_base.py#L165

Added line #L165 was not covered by tests
if self.filepath != other.filepath:
return False

Check warning on line 167 in python/ribasim/ribasim/input_base.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/input_base.py#L167

Added line #L167 was not covered by tests

if self.df is None and other.df is None:
return True
elif isinstance(self.df, (pd.DataFrame, gpd.GeoDataFrame)):
return self.df.equals(other.df)
else:
return False

Check warning on line 174 in python/ribasim/ribasim/input_base.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/input_base.py#L174

Added line #L174 was not covered by tests

@field_validator("df")
@classmethod
def prefix_extra_columns(cls, v: DataFrame[TableT]):
Expand Down
5 changes: 4 additions & 1 deletion python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,10 @@ def max_node_id(self) -> int:
return self.network.node.df.index.max()

def max_allocation_network_id(self) -> int:
return self.network.node.df.allocation_network_id.max()
m = self.network.node.df.allocation_network_id.max()
if pd.isna(m):
m = 0
return m

def merge_model(
self,
Expand Down
5 changes: 5 additions & 0 deletions python/ribasim/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ def backwater() -> ribasim.Model:
@pytest.fixture()
def discrete_control_of_pid_control() -> ribasim.Model:
return ribasim_testmodels.discrete_control_of_pid_control_model()


@pytest.fixture()
def subnetwork() -> ribasim.Model:
return ribasim_testmodels.subnetwork_model()
17 changes: 17 additions & 0 deletions python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from copy import deepcopy
from sqlite3 import connect

import pandas as pd
Expand Down Expand Up @@ -143,3 +144,19 @@ def test_write_adds_fid_in_tables(basic, tmp_path):
assert "fid" in df.columns
fids = df.get("fid")
assert fids.equals(pd.Series(range(1, len(fids) + 1)))


def test_model_merging(basic, subnetwork, tmp_path):
model = deepcopy(basic)
model_added = deepcopy(subnetwork)
model.merge_model(model_added)
model.merge_model(model_added)
assert (model.network.node.df.index == range(1, 44)).all()
assert model.max_allocation_network_id() == 2
for node_type, node_added in model_added.nodes().items():
node_subnetwork = getattr(subnetwork, node_type)
for table_added, table_subnetwork in zip(
node_added.tables(), node_subnetwork.tables()
):
assert table_added == table_subnetwork
model.write(tmp_path / "compound_model/ribasim.toml")

0 comments on commit 553c956

Please sign in to comment.