diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index a2b412e51..ed797553f 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -290,7 +290,7 @@ def _write_arrow(self, filepath: Path, directory: Path, input_dir: Path) -> None @classmethod def _from_db(cls, path: Path, table: str) -> pd.DataFrame | None: - with connect(path) as connection: + with closing(connect(path)) as connection: if exists(connection, table): query = f"select * from {esc_id(table)}" df = pd.read_sql_query( @@ -371,7 +371,7 @@ def sort(self): @classmethod def _from_db(cls, path: Path, table: str): - with connect(path) as connection: + with closing(connect(path)) as connection: if exists(connection, table): # pyogrio hardcodes fid name on reading df = gpd.read_file(path, layer=table, fid_as_index=True) diff --git a/python/ribasim/ribasim/styles.py b/python/ribasim/ribasim/styles.py index 450da9977..492b2eb65 100644 --- a/python/ribasim/ribasim/styles.py +++ b/python/ribasim/ribasim/styles.py @@ -1,7 +1,8 @@ import logging -import sqlite3 +from contextlib import closing from datetime import datetime from pathlib import Path +from sqlite3 import connect STYLES_DIR = Path(__file__).parent / "styles" @@ -98,19 +99,19 @@ def _no_existing_style(conn, style_name): def _add_styles_to_geopackage(gpkg_path: Path, layer: str): - with sqlite3.connect(gpkg_path) as conn: - if not conn.execute(SQL_STYLES_EXIST).fetchone()[0]: - conn.execute(CREATE_TABLE_SQL) - conn.execute(INSERT_CONTENTS_SQL) + with closing(connect(gpkg_path)) as connection: + if not connection.execute(SQL_STYLES_EXIST).fetchone()[0]: + connection.execute(CREATE_TABLE_SQL) + connection.execute(INSERT_CONTENTS_SQL) style_name = f"{layer.replace(' / ', '_')}Style" style_qml = STYLES_DIR / f"{style_name}.qml" - if style_qml.exists() and _no_existing_style(conn, style_name): + if style_qml.exists() and _no_existing_style(connection, style_name): description = f"Ribasim style for layer: {layer}" update_date_time = f"{datetime.now().isoformat()}Z" - conn.execute( + connection.execute( INSERT_ROW_SQL, { "layer": layer, @@ -120,5 +121,6 @@ def _add_styles_to_geopackage(gpkg_path: Path, layer: str): "update_date_time": update_date_time, }, ) + connection.commit() else: logging.warning(f"Style not found for layer: {layer}") diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index d73492998..65ab9aca3 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -323,3 +323,12 @@ def test_minimal_toml(): (toml_path.parent / "database.gpkg").touch() # database file must exist for `read` model = ribasim.Model.read(toml_path) assert model.crs == "EPSG:28992" + + +def test_closed_model(basic, tmp_path): + # Test whether we can write to a just opened model + # implicitly testing that the database is closed after read + toml_path = tmp_path / "basic/ribasim.toml" + basic.write(toml_path) + model = ribasim.Model.read(toml_path) + model.write(toml_path)