diff --git a/core/test/docs.toml b/core/test/docs.toml index fadc0ebbb..42d38f7c1 100644 --- a/core/test/docs.toml +++ b/core/test/docs.toml @@ -5,6 +5,8 @@ endtime = 2021-01-01 # required # input files database = "database.gpkg" # required +input_dir = "." # optional, default "." +results_dir = "results" # optional, default "." # Specific tables can also go into Arrow files rather than the database. # For large tables this can benefit from better compressed file sizes. @@ -37,9 +39,9 @@ timing = false # optional, whether to log debug timing statements [results] # These results files are always written -flow = "results/flow.arrow" # optional, default "results/flow.arrow" -basin = "results/basin.arrow" # optional, default "results/basin.arrow" -control = "results/control.arrow" # optional, default "results/control.arrow" +flow = "results/flow.arrow" # optional, default "results/flow.arrow" +basin = "results/basin.arrow" # optional, default "results/basin.arrow" +control = "results/control.arrow" # optional, default "results/control.arrow" allocation = "results/allocation.arrow" # optional, default "results/allocation.arrow" -compression = "zstd" # optional, default "zstd", also supports "lz4" -compression_level = 6 # optional, default 6 +compression = "zstd" # optional, default "zstd", also supports "lz4" +compression_level = 6 # optional, default 6 diff --git a/core/test/run_models.jl b/core/test/run_models.jl index 293a623f1..6bdb6eaf0 100644 --- a/core/test/run_models.jl +++ b/core/test/run_models.jl @@ -66,6 +66,14 @@ end end end +@testset "basic arrow model" begin + toml_path = normpath(@__DIR__, "../../generated_testmodels/basic_arrow/ribasim.toml") + @test ispath(toml_path) + model = Ribasim.run(toml_path) + @test model isa Ribasim.Model + @test successful_retcode(model) +end + @testset "basic transient model" begin toml_path = normpath(@__DIR__, "../../generated_testmodels/basic_transient/ribasim.toml") diff --git a/docs/python/examples.ipynb b/docs/python/examples.ipynb index 49130ac37..41a303701 100644 --- a/docs/python/examples.ipynb +++ b/docs/python/examples.ipynb @@ -471,7 +471,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = ribasim.Model.from_toml(datadir / \"basic/ribasim.toml\")" + "model = ribasim.Model(filepath=datadir / \"basic/ribasim.toml\")" ] }, { diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index c7346ad7c..d2b2b5f86 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -23,6 +23,7 @@ field_validator, model_serializer, model_validator, + validate_call, ) from ribasim.types import FilePath @@ -104,8 +105,8 @@ def check_filepath(cls, value: Any) -> Any: if filepath is not None: filepath = Path(filepath) data = cls._load(filepath) - value.update(data) - return value + data.update(value) + return data elif isinstance(value, Path | str): # Pydantic Model init requires a dict data = cls._load(Path(value)) @@ -114,15 +115,25 @@ def check_filepath(cls, value: Any) -> Any: else: return value + @validate_call + def set_filepath(self, filepath: Path) -> None: + """Set the filepath of this instance. + + Args: + filepath (Path): The filepath to set. + """ + # Disable assignment validation, which would + # otherwise trigger check_filepath() and _load() again. + self.model_config["validate_assignment"] = False + self.filepath = filepath + self.model_config["validate_assignment"] = True + @abstractmethod - def _save(self, directory: DirectoryPath) -> None: + def _save(self, directory: DirectoryPath, input_dir: DirectoryPath) -> None: """Save this instance to disk. This method needs to be implemented by any class deriving from FileModel. - - Args: - save_settings (ModelSaveSettings): The model save settings. """ raise NotImplementedError() @@ -159,8 +170,8 @@ def prefix_extra_columns(cls, v: DataFrame[TableT]): return v @model_serializer - def set_model(self) -> Path | None: - return self.filepath + def set_model(self) -> str | None: + return str(self.filepath.name) if self.filepath is not None else None @classmethod def tablename(cls) -> str: @@ -205,11 +216,17 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: return {} def _save( - self, directory: DirectoryPath, sort_keys: list[str] = ["node_id"] + self, + directory: DirectoryPath, + input_dir: DirectoryPath, + sort_keys: list[str] = ["node_id"], ) -> None: # TODO directory could be used to save an arrow file db_path = context_file_loading.get().get("database") - if self.df is not None and db_path is not None: + if self.df is not None and self.filepath is not None: + self.sort(sort_keys) + self._write_arrow(self.filepath, directory, input_dir) + elif self.df is not None and db_path is not None: self.sort(sort_keys) self._write_table(db_path) @@ -223,15 +240,26 @@ def _write_table(self, temp_path: Path) -> None: SQLite connection to the database. """ table = self.tablename() - if self.df is not None: # double check to make mypy happy - with closing(connect(temp_path)) as connection: - self.df.to_sql(table, connection, index=False, if_exists="replace") - - # Set geopackage attribute table - with closing(connection.cursor()) as cursor: - sql = "INSERT INTO gpkg_contents (table_name, data_type, identifier) VALUES (?, ?, ?)" - cursor.execute(sql, (table, "attributes", table)) - connection.commit() + assert self.df is not None + with closing(connect(temp_path)) as connection: + self.df.to_sql(table, connection, index=False, if_exists="replace") + + # Set geopackage attribute table + with closing(connection.cursor()) as cursor: + sql = "INSERT INTO gpkg_contents (table_name, data_type, identifier) VALUES (?, ?, ?)" + cursor.execute(sql, (table, "attributes", table)) + connection.commit() + + def _write_arrow(self, filepath: Path, directory: Path, input_dir: Path) -> None: + """Write the contents of the input to a an arrow file.""" + assert self.df is not None + path = directory / input_dir / filepath + path.parent.mkdir(parents=True, exist_ok=True) + self.df.to_feather( + path, + compression="zstd", + compression_level=6, + ) @classmethod def _from_db(cls, path: FilePath, table: str) -> pd.DataFrame | None: @@ -246,7 +274,8 @@ def _from_db(cls, path: FilePath, table: str) -> pd.DataFrame | None: @classmethod def _from_arrow(cls, path: FilePath) -> pd.DataFrame: - return pd.read_feather(path) + directory = context_file_loading.get().get("directory", Path(".")) + return pd.read_feather(directory / path) def sort(self, sort_keys: list[str] = ["node_id"]): """Sort all input tables as required. @@ -358,8 +387,10 @@ def node_ids_and_types(self) -> tuple[list[int], list[str]]: ids = self.node_ids() return list(ids), len(ids) * [self.get_input_type()] - def _save(self, directory: DirectoryPath): + def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs): for field in self.fields(): getattr(self, field)._save( - directory, sort_keys=self._sort_keys.get("field", ["node_id"]) + directory, + input_dir, + sort_keys=self._sort_keys.get("field", ["node_id"]), ) diff --git a/python/ribasim/ribasim/model.py b/python/ribasim/ribasim/model.py index 923078a3b..86c8cf4a0 100644 --- a/python/ribasim/ribasim/model.py +++ b/python/ribasim/ribasim/model.py @@ -8,7 +8,14 @@ import pandas as pd import tomli import tomli_w -from pydantic import DirectoryPath, Field, model_serializer, model_validator +from pydantic import ( + DirectoryPath, + Field, + field_serializer, + field_validator, + model_serializer, + model_validator, +) from ribasim.config import ( Allocation, @@ -36,6 +43,10 @@ class Network(FileModel, NodeModel): + filepath: Path | None = Field( + default=Path("database.gpkg"), exclude=True, repr=False + ) + node: Node = Field(default_factory=Node) edge: Edge = Field(default_factory=Edge) @@ -50,35 +61,40 @@ def n_nodes(self): @classmethod def _load(cls, filepath: Path | None) -> dict[str, Any]: if filepath is not None: - context_file_loading.get()["database"] = filepath + directory = context_file_loading.get().get("directory", Path(".")) + context_file_loading.get()["database"] = directory / filepath return {} @classmethod def _layername(cls, field: str) -> str: return field.capitalize() - def _save(self, directory): + def _save(self, directory, input_dir=Path(".")): # We write all tables to a temporary database with a dot prefix, # and at the end move this over the target file. # This does not throw a PermissionError if the file is open in QGIS. directory = Path(directory) - db_path = directory / "database.gpkg" + db_path = directory / input_dir / "database.gpkg" db_path = db_path.resolve() + db_path.parent.mkdir(parents=True, exist_ok=True) temp_path = db_path.with_stem(".database") # avoid adding tables to existing model temp_path.unlink(missing_ok=True) context_file_loading.get()["database"] = temp_path - self.node._save(directory) - self.edge._save(directory) + self.node._save(directory, input_dir) + self.edge._save(directory, input_dir) shutil.move(temp_path, db_path) context_file_loading.get()["database"] = db_path @model_serializer def set_modelname(self) -> str: - return "database.gpkg" + if self.filepath is not None: + return str(self.filepath.name) + else: + return str(self.model_fields["filepath"].default) class Model(FileModel): @@ -95,13 +111,13 @@ class Model(FileModel): endtime : datetime.datetime End time of the simulation. - update_timestep: float = 86400 + update_timestep: datetime.timedelta = timedelta(seconds=86400) The output time step of the simulation in seconds (default of 1 day) - relative_dir: str = "." + relative_dir: Path = Path(".") The relative directory of the input files. - input_dir: str = "." + input_dir: Path = Path(".") The directory of the input files. - results_dir: str = "." + results_dir: Path = Path(".") The directory of the results files. network: Network @@ -147,10 +163,10 @@ class Model(FileModel): starttime: datetime.datetime endtime: datetime.datetime - update_timestep: float = 86400 - relative_dir: str = "." - input_dir: str = "." - results_dir: str = "." + update_timestep: datetime.timedelta = datetime.timedelta(seconds=86400) + relative_dir: Path = Path(".") + input_dir: Path = Path(".") + results_dir: Path = Path("results") network: Network = Field(default_factory=Network, alias="database") results: Results = Results() @@ -174,6 +190,21 @@ class Model(FileModel): pid_control: PidControl = Field(default_factory=PidControl) user: User = Field(default_factory=User) + @field_validator("update_timestep") + @classmethod + def timestep_in_seconds(cls, v: Any) -> datetime.timedelta: + if not isinstance(v, datetime.timedelta): + v = datetime.timedelta(seconds=v) + return v + + @field_serializer("update_timestep") + def serialize_dt(self, td: datetime.timedelta) -> int: + return int(td.total_seconds()) + + @field_serializer("relative_dir", "input_dir", "results_dir") + def serialize_path(self, path: Path) -> str: + return str(path) + def __repr__(self) -> str: first = [] second = [] @@ -196,15 +227,14 @@ def _write_toml(self, directory: FilePath): content = self.model_dump(exclude_unset=True, exclude_none=True, by_alias=True) # Filter empty dicts (default Nodes) content = dict(filter(lambda x: x[1], content.items())) - fn = directory / "ribasim.toml" with open(fn, "wb") as f: tomli_w.dump(content, f) return fn - def _save(self, directory: DirectoryPath): + def _save(self, directory: DirectoryPath, input_dir: DirectoryPath): for sub in self.nodes().values(): - sub._save(directory) + sub._save(directory, input_dir) def nodes(self): return { @@ -283,6 +313,11 @@ def validate_model(self): self.validate_model_node_field_ids() self.validate_model_node_ids() + @classmethod + def read(cls, filepath: FilePath) -> "Model": + """Read model from TOML file.""" + return cls(filepath=filepath) # type: ignore + def write(self, directory: FilePath) -> Path: """ Write the contents of the model to a database and a TOML configuration file. @@ -297,7 +332,7 @@ def write(self, directory: FilePath) -> Path: context_file_loading.set({}) directory = Path(directory) directory.mkdir(parents=True, exist_ok=True) - self._save(directory) + self._save(directory, self.input_dir) fn = self._write_toml(directory) context_file_loading.set({}) @@ -311,8 +346,10 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: with open(filepath, "rb") as f: config = tomli.load(f) - # Convert relative path to absolute path - config["database"] = filepath.parent / config["database"] + context_file_loading.get()["directory"] = filepath.parent / config.get( + "input_dir", "." + ) + return config else: return {} @@ -323,23 +360,6 @@ def reset_contextvar(self) -> "Model": context_file_loading.set({}) return self - @classmethod - def from_toml(cls, path: Path | str) -> "Model": - """ - Initialize a model from the TOML configuration file. - - Parameters - ---------- - path : FilePath - Path to the configuration TOML file. - - Returns - ------- - model : Model - """ - kwargs = cls._load(Path(path)) - return cls(**kwargs) - def plot_control_listen(self, ax): x_start, x_end = [], [] y_start, y_end = [], [] diff --git a/python/ribasim/tests/conftest.py b/python/ribasim/tests/conftest.py index c2e23b670..c9acf17cf 100644 --- a/python/ribasim/tests/conftest.py +++ b/python/ribasim/tests/conftest.py @@ -9,6 +9,11 @@ def basic() -> ribasim.Model: return ribasim_testmodels.basic_model() +@pytest.fixture() +def basic_arrow() -> ribasim.Model: + return ribasim_testmodels.basic_arrow_model() + + @pytest.fixture() def basic_transient() -> ribasim.Model: return ribasim_testmodels.basic_transient_model() diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index e06972e12..129a01963 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -37,6 +37,14 @@ def test_basic(basic, tmp_path): assert model_loaded.basin.time.df is None +def test_basic_arrow(basic_arrow, tmp_path): + model_orig = basic_arrow + model_orig.write(tmp_path / "basic_arrow") + model_loaded = ribasim.Model(filepath=tmp_path / "basic_arrow/ribasim.toml") + + assert_equal(model_orig.basin.profile.df, model_loaded.basin.profile.df) + + def test_basic_transient(basic_transient, tmp_path): model_orig = basic_transient model_orig.write(tmp_path / "basic_transient") diff --git a/python/ribasim/tests/test_model.py b/python/ribasim/tests/test_model.py index 05d26c355..e8f5baa77 100644 --- a/python/ribasim/tests/test_model.py +++ b/python/ribasim/tests/test_model.py @@ -116,4 +116,4 @@ def test_node_ids_unsequential(basic): def test_tabulated_rating_curve_model(tabulated_rating_curve, tmp_path): model_orig = tabulated_rating_curve model_orig.write(tmp_path / "tabulated_rating_curve") - Model.from_toml(tmp_path / "tabulated_rating_curve/ribasim.toml") + Model.read(tmp_path / "tabulated_rating_curve/ribasim.toml") diff --git a/python/ribasim_testmodels/ribasim_testmodels/__init__.py b/python/ribasim_testmodels/ribasim_testmodels/__init__.py index cf011a35d..04d1a2e54 100644 --- a/python/ribasim_testmodels/ribasim_testmodels/__init__.py +++ b/python/ribasim_testmodels/ribasim_testmodels/__init__.py @@ -13,6 +13,7 @@ ) from ribasim_testmodels.backwater import backwater_model from ribasim_testmodels.basic import ( + basic_arrow_model, basic_model, basic_transient_model, outlet_model, @@ -50,6 +51,7 @@ __all__ = [ "backwater_model", "basic_model", + "basic_arrow_model", "basic_transient_model", "bucket_model", "pump_discrete_control_model", diff --git a/python/ribasim_testmodels/ribasim_testmodels/basic.py b/python/ribasim_testmodels/ribasim_testmodels/basic.py index eceea3ce1..26a5931cb 100644 --- a/python/ribasim_testmodels/ribasim_testmodels/basic.py +++ b/python/ribasim_testmodels/ribasim_testmodels/basic.py @@ -1,3 +1,5 @@ +from pathlib import Path + import geopandas as gpd import numpy as np import pandas as pd @@ -211,6 +213,13 @@ def basic_model() -> ribasim.Model: return model +def basic_arrow_model() -> ribasim.Model: + model = basic_model() + model.basin.profile.set_filepath(Path("profile.arrow")) + model.input_dir = Path("input") + return model + + def basic_transient_model() -> ribasim.Model: """Update the basic model with transient forcing"""