diff --git a/python/ribasim/ribasim/model.py b/python/ribasim/ribasim/model.py index b22b9a2df..4f1188409 100644 --- a/python/ribasim/ribasim/model.py +++ b/python/ribasim/ribasim/model.py @@ -232,6 +232,7 @@ def _write_toml(self, fn: FilePath): return fn def _save(self, directory: DirectoryPath, input_dir: DirectoryPath): + self.network._save(directory, input_dir) for sub in self.nodes().values(): sub._save(directory, input_dir) @@ -239,7 +240,7 @@ def nodes(self): return { k: getattr(self, k) for k in self.model_fields.keys() - if isinstance(getattr(self, k), NodeModel) + if isinstance(getattr(self, k), NodeModel) and k != "network" } def children(self): @@ -279,9 +280,6 @@ def validate_model_node_ids(self): for node in self.nodes().values(): nodetype = node.get_input_type() - if nodetype == "Network": - # skip the reference - continue node_ids_data = set(node.node_ids()) node_ids_network = set( self.network.node.df.loc[self.network.node.df["type"] == nodetype].index diff --git a/python/ribasim/tests/conftest.py b/python/ribasim/tests/conftest.py index f6dbb730d..2f0f3ae9f 100644 --- a/python/ribasim/tests/conftest.py +++ b/python/ribasim/tests/conftest.py @@ -37,3 +37,8 @@ def discrete_control_of_pid_control() -> ribasim.Model: @pytest.fixture() def level_setpoint_with_minmax() -> ribasim.Model: return ribasim_testmodels.level_setpoint_with_minmax_model() + + +@pytest.fixture() +def trivial() -> ribasim.Model: + return ribasim_testmodels.trivial_model() diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index 570d8297f..b87d5b0af 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -8,14 +8,18 @@ from ribasim import Node, Pump -def __assert_equal(a: DataFrame, b: DataFrame) -> None: +def __assert_equal(a: DataFrame, b: DataFrame, is_network=False) -> None: """Like pandas.testing.assert_frame_equal, but ignoring the index.""" if a is None and b is None: return True - # TODO support assert basic == model, ignoring the index for all but node - a = a.reset_index(drop=True) - b = b.reset_index(drop=True) + if is_network: + # We set this on write, needed for GeoPackage. + a.index.name = "fid" + a.index.name = "fid" + else: + a = a.reset_index(drop=True) + b = b.reset_index(drop=True) # avoid comparing datetime64[ns] with datetime64[ms] if "time" in a: @@ -38,8 +42,12 @@ def test_basic(basic, tmp_path): index_a = model_orig.network.node.df.index.to_numpy(int) index_b = model_loaded.network.node.df.index.to_numpy(int) assert_array_equal(index_a, index_b) - __assert_equal(model_orig.network.node.df, model_loaded.network.node.df) - __assert_equal(model_orig.network.edge.df, model_loaded.network.edge.df) + __assert_equal( + model_orig.network.node.df, model_loaded.network.node.df, is_network=True + ) + __assert_equal( + model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True + ) assert model_loaded.basin.time.df is None @@ -56,8 +64,12 @@ def test_basic_transient(basic_transient, tmp_path): model_orig.write(tmp_path / "basic_transient/ribasim.toml") model_loaded = ribasim.Model(filepath=tmp_path / "basic_transient/ribasim.toml") - __assert_equal(model_orig.network.node.df, model_loaded.network.node.df) - __assert_equal(model_orig.network.edge.df, model_loaded.network.edge.df) + __assert_equal( + model_orig.network.node.df, model_loaded.network.node.df, is_network=True + ) + __assert_equal( + model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True + ) time = model_loaded.basin.time assert model_orig.basin.time.df.time[0] == time.df.time[0] @@ -129,3 +141,27 @@ def test_sort(level_setpoint_with_minmax, tmp_path): table_loaded = model_loaded.discrete_control.condition assert table_loaded.df.iloc[0]["greater_than"] == 5.0 __assert_equal(table.df, table_loaded.df) + + +def test_roundtrip(trivial, tmp_path): + model1 = trivial + model1dir = tmp_path / "model1" + model2dir = tmp_path / "model2" + # read a model and then write it to a different path + model1.write(model1dir / "ribasim.toml") + model2 = ribasim.Model(filepath=model1dir / "ribasim.toml") + model2.write(model2dir / "ribasim.toml") + + assert (model1dir / "database.gpkg").is_file() + assert (model2dir / "database.gpkg").is_file() + + assert (model1dir / "ribasim.toml").read_text() == ( + model2dir / "ribasim.toml" + ).read_text() + + # check if all tables are the same + __assert_equal(model1.network.node.df, model2.network.node.df, is_network=True) + __assert_equal(model1.network.edge.df, model2.network.edge.df, is_network=True) + for node1, node2 in zip(model1.nodes().values(), model2.nodes().values()): + for table1, table2 in zip(node1.tables(), node2.tables()): + __assert_equal(table1.df, table2.df)