Skip to content

Commit

Permalink
Support arrow input.
Browse files Browse the repository at this point in the history
  • Loading branch information
evetion committed Nov 16, 2023
1 parent 548c6eb commit 35671dc
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 1 deletion.
8 changes: 8 additions & 0 deletions core/test/run_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ end
end
end

@testset "basic arrow model" begin
toml_path = normpath(@__DIR__, "../../generated_testmodels/basic_arrow/ribasim.toml")
@test ispath(toml_path)
model = Ribasim.run(toml_path)
@test model isa Ribasim.Model
@test successful_retcode(model)
end

@testset "basic transient model" begin
toml_path =
normpath(@__DIR__, "../../generated_testmodels/basic_transient/ribasim.toml")
Expand Down
28 changes: 27 additions & 1 deletion python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ def check_filepath(cls, value: Any) -> Any:
else:
return value

def set_filepath(self, filepath: Path) -> None:
"""Set the filepath of this instance.
Args:
filepath (Path): The filepath to set.
"""
# Disable assignment validation, which would
# otherwise trigger check_filepath() and _load() again.
self.model_config["validate_assignment"] = False
self.filepath = filepath
self.model_config["validate_assignment"] = True

Check warning on line 127 in python/ribasim/ribasim/input_base.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/input_base.py#L125-L127

Added lines #L125 - L127 were not covered by tests

@abstractmethod
def _save(self, directory: DirectoryPath) -> None:
"""Save this instance to disk.
Expand Down Expand Up @@ -209,7 +221,10 @@ def _save(
) -> None:
# TODO directory could be used to save an arrow file
db_path = context_file_loading.get().get("database")
if self.df is not None and db_path is not None:
if self.df is not None and self.filepath is not None:
self.sort(sort_keys)
self._write_arrow(self.filepath, directory)

Check warning on line 226 in python/ribasim/ribasim/input_base.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/input_base.py#L225-L226

Added lines #L225 - L226 were not covered by tests
elif self.df is not None and db_path is not None:
self.sort(sort_keys)
self._write_table(db_path)

Expand All @@ -233,6 +248,17 @@ def _write_table(self, temp_path: Path) -> None:
cursor.execute(sql, (table, "attributes", table))
connection.commit()

def _write_arrow(self, filepath: Path, directory: Path) -> None:
"""Write the contents of the input to a an arrow file."""
if self.df is not None: # double check to make mypy happy
path = directory / filepath # TODO: Handle relative dir?
path.parent.mkdir(parents=True, exist_ok=True)
self.df.to_feather(

Check warning on line 256 in python/ribasim/ribasim/input_base.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/input_base.py#L253-L256

Added lines #L253 - L256 were not covered by tests
path,
compression="zstd",
compression_level=6,
)

@classmethod
def _from_db(cls, path: FilePath, table: str) -> pd.DataFrame | None:
with connect(path) as connection:
Expand Down
2 changes: 2 additions & 0 deletions python/ribasim_testmodels/ribasim_testmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from ribasim_testmodels.backwater import backwater_model
from ribasim_testmodels.basic import (
basic_arrow_model,
basic_model,
basic_transient_model,
outlet_model,
Expand Down Expand Up @@ -50,6 +51,7 @@
__all__ = [
"backwater_model",
"basic_model",
"basic_arrow_model",
"basic_transient_model",
"bucket_model",
"pump_discrete_control_model",
Expand Down
6 changes: 6 additions & 0 deletions python/ribasim_testmodels/ribasim_testmodels/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def basic_model() -> ribasim.Model:
return model


def basic_arrow_model() -> ribasim.Model:
model = basic_model()
model.basin.profile.set_filepath("input/profile.arrow")
return model


def basic_transient_model() -> ribasim.Model:
"""Update the basic model with transient forcing"""

Expand Down

0 comments on commit 35671dc

Please sign in to comment.