Skip to content

Commit

Permalink
Add roundtrip test for model files (#1007)
Browse files Browse the repository at this point in the history
Fixes #682. It was already fixed, but this ensures it with an extra
test.
  • Loading branch information
visr authored Feb 9, 2024
1 parent bde1e0c commit 88f6fde
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
6 changes: 2 additions & 4 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,15 @@ def _write_toml(self, fn: FilePath):
return fn

def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
self.network._save(directory, input_dir)
for sub in self.nodes().values():
sub._save(directory, input_dir)

def nodes(self):
return {
k: getattr(self, k)
for k in self.model_fields.keys()
if isinstance(getattr(self, k), NodeModel)
if isinstance(getattr(self, k), NodeModel) and k != "network"
}

def children(self):
Expand Down Expand Up @@ -279,9 +280,6 @@ def validate_model_node_ids(self):

for node in self.nodes().values():
nodetype = node.get_input_type()
if nodetype == "Network":
# skip the reference
continue
node_ids_data = set(node.node_ids())
node_ids_network = set(
self.network.node.df.loc[self.network.node.df["type"] == nodetype].index
Expand Down
5 changes: 5 additions & 0 deletions python/ribasim/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ def discrete_control_of_pid_control() -> ribasim.Model:
@pytest.fixture()
def level_setpoint_with_minmax() -> ribasim.Model:
return ribasim_testmodels.level_setpoint_with_minmax_model()


@pytest.fixture()
def trivial() -> ribasim.Model:
return ribasim_testmodels.trivial_model()
52 changes: 44 additions & 8 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
from ribasim import Node, Pump


def __assert_equal(a: DataFrame, b: DataFrame) -> None:
def __assert_equal(a: DataFrame, b: DataFrame, is_network=False) -> None:
"""Like pandas.testing.assert_frame_equal, but ignoring the index."""
if a is None and b is None:
return True

# TODO support assert basic == model, ignoring the index for all but node
a = a.reset_index(drop=True)
b = b.reset_index(drop=True)
if is_network:
# We set this on write, needed for GeoPackage.
a.index.name = "fid"
a.index.name = "fid"
else:
a = a.reset_index(drop=True)
b = b.reset_index(drop=True)

# avoid comparing datetime64[ns] with datetime64[ms]
if "time" in a:
Expand All @@ -38,8 +42,12 @@ def test_basic(basic, tmp_path):
index_a = model_orig.network.node.df.index.to_numpy(int)
index_b = model_loaded.network.node.df.index.to_numpy(int)
assert_array_equal(index_a, index_b)
__assert_equal(model_orig.network.node.df, model_loaded.network.node.df)
__assert_equal(model_orig.network.edge.df, model_loaded.network.edge.df)
__assert_equal(
model_orig.network.node.df, model_loaded.network.node.df, is_network=True
)
__assert_equal(
model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True
)
assert model_loaded.basin.time.df is None


Expand All @@ -56,8 +64,12 @@ def test_basic_transient(basic_transient, tmp_path):
model_orig.write(tmp_path / "basic_transient/ribasim.toml")
model_loaded = ribasim.Model(filepath=tmp_path / "basic_transient/ribasim.toml")

__assert_equal(model_orig.network.node.df, model_loaded.network.node.df)
__assert_equal(model_orig.network.edge.df, model_loaded.network.edge.df)
__assert_equal(
model_orig.network.node.df, model_loaded.network.node.df, is_network=True
)
__assert_equal(
model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True
)

time = model_loaded.basin.time
assert model_orig.basin.time.df.time[0] == time.df.time[0]
Expand Down Expand Up @@ -129,3 +141,27 @@ def test_sort(level_setpoint_with_minmax, tmp_path):
table_loaded = model_loaded.discrete_control.condition
assert table_loaded.df.iloc[0]["greater_than"] == 5.0
__assert_equal(table.df, table_loaded.df)


def test_roundtrip(trivial, tmp_path):
model1 = trivial
model1dir = tmp_path / "model1"
model2dir = tmp_path / "model2"
# read a model and then write it to a different path
model1.write(model1dir / "ribasim.toml")
model2 = ribasim.Model(filepath=model1dir / "ribasim.toml")
model2.write(model2dir / "ribasim.toml")

assert (model1dir / "database.gpkg").is_file()
assert (model2dir / "database.gpkg").is_file()

assert (model1dir / "ribasim.toml").read_text() == (
model2dir / "ribasim.toml"
).read_text()

# check if all tables are the same
__assert_equal(model1.network.node.df, model2.network.node.df, is_network=True)
__assert_equal(model1.network.edge.df, model2.network.edge.df, is_network=True)
for node1, node2 in zip(model1.nodes().values(), model2.nodes().values()):
for table1, table2 in zip(node1.tables(), node2.tables()):
__assert_equal(table1.df, table2.df)

0 comments on commit 88f6fde

Please sign in to comment.