From e84414b330b7b78f5e6a344dccd955410b86637f Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Mon, 23 Sep 2024 11:04:13 +0200 Subject: [PATCH] Use closing everywhere. Add commit explicitely. --- python/ribasim/ribasim/input_base.py | 4 ++-- python/ribasim/ribasim/styles.py | 1 + python/ribasim/tests/test_io.py | 9 +++++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index a2b412e51..ed797553f 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -290,7 +290,7 @@ def _write_arrow(self, filepath: Path, directory: Path, input_dir: Path) -> None @classmethod def _from_db(cls, path: Path, table: str) -> pd.DataFrame | None: - with connect(path) as connection: + with closing(connect(path)) as connection: if exists(connection, table): query = f"select * from {esc_id(table)}" df = pd.read_sql_query( @@ -371,7 +371,7 @@ def sort(self): @classmethod def _from_db(cls, path: Path, table: str): - with connect(path) as connection: + with closing(connect(path)) as connection: if exists(connection, table): # pyogrio hardcodes fid name on reading df = gpd.read_file(path, layer=table, fid_as_index=True) diff --git a/python/ribasim/ribasim/styles.py b/python/ribasim/ribasim/styles.py index b74abb91f..492b2eb65 100644 --- a/python/ribasim/ribasim/styles.py +++ b/python/ribasim/ribasim/styles.py @@ -121,5 +121,6 @@ def _add_styles_to_geopackage(gpkg_path: Path, layer: str): "update_date_time": update_date_time, }, ) + connection.commit() else: logging.warning(f"Style not found for layer: {layer}") diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index d73492998..65ab9aca3 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -323,3 +323,12 @@ def test_minimal_toml(): (toml_path.parent / "database.gpkg").touch() # database file must exist for `read` model = ribasim.Model.read(toml_path) assert model.crs == "EPSG:28992" + + +def test_closed_model(basic, tmp_path): + # Test whether we can write to a just opened model + # implicitly testing that the database is closed after read + toml_path = tmp_path / "basic/ribasim.toml" + basic.write(toml_path) + model = ribasim.Model.read(toml_path) + model.write(toml_path)