Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arrow input. #798

Merged
merged 7 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions core/test/docs.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ endtime = 2021-01-01 # required

# input files
database = "database.gpkg" # required
input_dir = "." # optional, default "."
results_dir = "results" # optional, default "."

# Specific tables can also go into Arrow files rather than the database.
# For large tables this can benefit from better compressed file sizes.
Expand Down Expand Up @@ -37,9 +39,9 @@ timing = false # optional, whether to log debug timing statements

[results]
# These results files are always written
flow = "results/flow.arrow" # optional, default "results/flow.arrow"
basin = "results/basin.arrow" # optional, default "results/basin.arrow"
control = "results/control.arrow" # optional, default "results/control.arrow"
flow = "results/flow.arrow" # optional, default "results/flow.arrow"
evetion marked this conversation as resolved.
Show resolved Hide resolved
basin = "results/basin.arrow" # optional, default "results/basin.arrow"
control = "results/control.arrow" # optional, default "results/control.arrow"
allocation = "results/allocation.arrow" # optional, default "results/allocation.arrow"
compression = "zstd" # optional, default "zstd", also supports "lz4"
compression_level = 6 # optional, default 6
compression = "zstd" # optional, default "zstd", also supports "lz4"
compression_level = 6 # optional, default 6
8 changes: 8 additions & 0 deletions core/test/run_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ end
end
end

@testset "basic arrow model" begin
toml_path = normpath(@__DIR__, "../../generated_testmodels/basic_arrow/ribasim.toml")
@test ispath(toml_path)
model = Ribasim.run(toml_path)
@test model isa Ribasim.Model
@test successful_retcode(model)
end

@testset "basic transient model" begin
toml_path =
normpath(@__DIR__, "../../generated_testmodels/basic_transient/ribasim.toml")
Expand Down
2 changes: 1 addition & 1 deletion docs/python/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = ribasim.Model.from_toml(datadir / \"basic/ribasim.toml\")"
"model = ribasim.Model(filepath=datadir / \"basic/ribasim.toml\")"
evetion marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
Expand Down
75 changes: 53 additions & 22 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
field_validator,
model_serializer,
model_validator,
validate_call,
)

from ribasim.types import FilePath
Expand Down Expand Up @@ -104,8 +105,8 @@ def check_filepath(cls, value: Any) -> Any:
if filepath is not None:
filepath = Path(filepath)
data = cls._load(filepath)
value.update(data)
return value
data.update(value)
return data
elif isinstance(value, Path | str):
# Pydantic Model init requires a dict
data = cls._load(Path(value))
Expand All @@ -114,15 +115,25 @@ def check_filepath(cls, value: Any) -> Any:
else:
return value

@validate_call
def set_filepath(self, filepath: Path) -> None:
"""Set the filepath of this instance.

Args:
filepath (Path): The filepath to set.
"""
# Disable assignment validation, which would
# otherwise trigger check_filepath() and _load() again.
self.model_config["validate_assignment"] = False
self.filepath = filepath
self.model_config["validate_assignment"] = True

@abstractmethod
def _save(self, directory: DirectoryPath) -> None:
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath) -> None:
"""Save this instance to disk.

This method needs to be implemented by any class deriving from
FileModel.

Args:
save_settings (ModelSaveSettings): The model save settings.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -159,8 +170,8 @@ def prefix_extra_columns(cls, v: DataFrame[TableT]):
return v

@model_serializer
def set_model(self) -> Path | None:
return self.filepath
def set_model(self) -> str | None:
return str(self.filepath.name) if self.filepath is not None else None

@classmethod
def tablename(cls) -> str:
Expand Down Expand Up @@ -205,11 +216,17 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:
return {}

def _save(
self, directory: DirectoryPath, sort_keys: list[str] = ["node_id"]
self,
directory: DirectoryPath,
input_dir: DirectoryPath,
sort_keys: list[str] = ["node_id"],
) -> None:
# TODO directory could be used to save an arrow file
db_path = context_file_loading.get().get("database")
if self.df is not None and db_path is not None:
if self.df is not None and self.filepath is not None:
self.sort(sort_keys)
self._write_arrow(self.filepath, directory, input_dir)
elif self.df is not None and db_path is not None:
self.sort(sort_keys)
self._write_table(db_path)

Expand All @@ -223,15 +240,26 @@ def _write_table(self, temp_path: Path) -> None:
SQLite connection to the database.
"""
table = self.tablename()
if self.df is not None: # double check to make mypy happy
with closing(connect(temp_path)) as connection:
self.df.to_sql(table, connection, index=False, if_exists="replace")

# Set geopackage attribute table
with closing(connection.cursor()) as cursor:
sql = "INSERT INTO gpkg_contents (table_name, data_type, identifier) VALUES (?, ?, ?)"
cursor.execute(sql, (table, "attributes", table))
connection.commit()
assert self.df is not None
with closing(connect(temp_path)) as connection:
self.df.to_sql(table, connection, index=False, if_exists="replace")

# Set geopackage attribute table
with closing(connection.cursor()) as cursor:
sql = "INSERT INTO gpkg_contents (table_name, data_type, identifier) VALUES (?, ?, ?)"
cursor.execute(sql, (table, "attributes", table))
connection.commit()

def _write_arrow(self, filepath: Path, directory: Path, input_dir: Path) -> None:
"""Write the contents of the input to a an arrow file."""
assert self.df is not None
path = directory / input_dir / filepath
path.parent.mkdir(parents=True, exist_ok=True)
self.df.to_feather(
path,
compression="zstd",
compression_level=6,
)

@classmethod
def _from_db(cls, path: FilePath, table: str) -> pd.DataFrame | None:
Expand All @@ -246,7 +274,8 @@ def _from_db(cls, path: FilePath, table: str) -> pd.DataFrame | None:

@classmethod
def _from_arrow(cls, path: FilePath) -> pd.DataFrame:
return pd.read_feather(path)
directory = context_file_loading.get().get("directory", Path("."))
return pd.read_feather(directory / path)

def sort(self, sort_keys: list[str] = ["node_id"]):
"""Sort all input tables as required.
Expand Down Expand Up @@ -358,8 +387,10 @@ def node_ids_and_types(self) -> tuple[list[int], list[str]]:
ids = self.node_ids()
return list(ids), len(ids) * [self.get_input_type()]

def _save(self, directory: DirectoryPath):
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs):
for field in self.fields():
getattr(self, field)._save(
directory, sort_keys=self._sort_keys.get("field", ["node_id"])
directory,
input_dir,
sort_keys=self._sort_keys.get("field", ["node_id"]),
)
96 changes: 58 additions & 38 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
import pandas as pd
import tomli
import tomli_w
from pydantic import DirectoryPath, Field, model_serializer, model_validator
from pydantic import (
DirectoryPath,
Field,
field_serializer,
field_validator,
model_serializer,
model_validator,
)

from ribasim.config import (
Allocation,
Expand Down Expand Up @@ -36,6 +43,10 @@


class Network(FileModel, NodeModel):
filepath: Path | None = Field(
default=Path("database.gpkg"), exclude=True, repr=False
)

node: Node = Field(default_factory=Node)
edge: Edge = Field(default_factory=Edge)

Expand All @@ -50,35 +61,40 @@
@classmethod
def _load(cls, filepath: Path | None) -> dict[str, Any]:
if filepath is not None:
context_file_loading.get()["database"] = filepath
directory = context_file_loading.get().get("directory", Path("."))
context_file_loading.get()["database"] = directory / filepath
return {}

@classmethod
def _layername(cls, field: str) -> str:
return field.capitalize()

def _save(self, directory):
def _save(self, directory, input_dir=Path(".")):
# We write all tables to a temporary database with a dot prefix,
# and at the end move this over the target file.
# This does not throw a PermissionError if the file is open in QGIS.
directory = Path(directory)
db_path = directory / "database.gpkg"
db_path = directory / input_dir / "database.gpkg"
db_path = db_path.resolve()
db_path.parent.mkdir(parents=True, exist_ok=True)
temp_path = db_path.with_stem(".database")

# avoid adding tables to existing model
temp_path.unlink(missing_ok=True)
context_file_loading.get()["database"] = temp_path

self.node._save(directory)
self.edge._save(directory)
self.node._save(directory, input_dir)
self.edge._save(directory, input_dir)

shutil.move(temp_path, db_path)
context_file_loading.get()["database"] = db_path

@model_serializer
def set_modelname(self) -> str:
return "database.gpkg"
if self.filepath is not None:
return str(self.filepath.name)
else:
return str(self.model_fields["filepath"].default)

Check warning on line 97 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L97

Added line #L97 was not covered by tests


class Model(FileModel):
Expand All @@ -95,13 +111,13 @@
endtime : datetime.datetime
End time of the simulation.

update_timestep: float = 86400
update_timestep: datetime.timedelta = timedelta(seconds=86400)
The output time step of the simulation in seconds (default of 1 day)
relative_dir: str = "."
relative_dir: Path = Path(".")
The relative directory of the input files.
input_dir: str = "."
input_dir: Path = Path(".")
The directory of the input files.
results_dir: str = "."
results_dir: Path = Path(".")
The directory of the results files.

network: Network
Expand Down Expand Up @@ -147,10 +163,10 @@
starttime: datetime.datetime
endtime: datetime.datetime

update_timestep: float = 86400
relative_dir: str = "."
input_dir: str = "."
results_dir: str = "."
update_timestep: datetime.timedelta = datetime.timedelta(seconds=86400)
relative_dir: Path = Path(".")
input_dir: Path = Path(".")
results_dir: Path = Path("results")

network: Network = Field(default_factory=Network, alias="database")
results: Results = Results()
Expand All @@ -174,6 +190,21 @@
pid_control: PidControl = Field(default_factory=PidControl)
user: User = Field(default_factory=User)

@field_validator("update_timestep")
@classmethod
def timestep_in_seconds(cls, v: Any) -> datetime.timedelta:
if not isinstance(v, datetime.timedelta):
v = datetime.timedelta(seconds=v)

Check warning on line 197 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L197

Added line #L197 was not covered by tests
return v

@field_serializer("update_timestep")
def serialize_dt(self, td: datetime.timedelta) -> int:
return int(td.total_seconds())

Check warning on line 202 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L202

Added line #L202 was not covered by tests

@field_serializer("relative_dir", "input_dir", "results_dir")
def serialize_path(self, path: Path) -> str:
return str(path)

def __repr__(self) -> str:
first = []
second = []
Expand All @@ -196,15 +227,14 @@
content = self.model_dump(exclude_unset=True, exclude_none=True, by_alias=True)
# Filter empty dicts (default Nodes)
content = dict(filter(lambda x: x[1], content.items()))

fn = directory / "ribasim.toml"
with open(fn, "wb") as f:
tomli_w.dump(content, f)
return fn

def _save(self, directory: DirectoryPath):
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
for sub in self.nodes().values():
sub._save(directory)
sub._save(directory, input_dir)

def nodes(self):
return {
Expand Down Expand Up @@ -283,6 +313,11 @@
self.validate_model_node_field_ids()
self.validate_model_node_ids()

@classmethod
def read(cls, filepath: FilePath) -> "Model":
"""Read model from TOML file."""
return cls(filepath=filepath) # type: ignore

def write(self, directory: FilePath) -> Path:
"""
Write the contents of the model to a database and a TOML configuration file.
Expand All @@ -297,7 +332,7 @@
context_file_loading.set({})
directory = Path(directory)
directory.mkdir(parents=True, exist_ok=True)
self._save(directory)
self._save(directory, self.input_dir)
fn = self._write_toml(directory)

context_file_loading.set({})
Expand All @@ -311,8 +346,10 @@
with open(filepath, "rb") as f:
config = tomli.load(f)

# Convert relative path to absolute path
config["database"] = filepath.parent / config["database"]
context_file_loading.get()["directory"] = filepath.parent / config.get(
"input_dir", "."
)

return config
else:
return {}
Expand All @@ -323,23 +360,6 @@
context_file_loading.set({})
return self

@classmethod
def from_toml(cls, path: Path | str) -> "Model":
"""
Initialize a model from the TOML configuration file.

Parameters
----------
path : FilePath
Path to the configuration TOML file.

Returns
-------
model : Model
"""
kwargs = cls._load(Path(path))
return cls(**kwargs)

def plot_control_listen(self, ax):
x_start, x_end = [], []
y_start, y_end = [], []
Expand Down
Loading