Skip to content

Commit

Permalink
Support arrow input. (#798)
Browse files Browse the repository at this point in the history
Fixes #170
  • Loading branch information
evetion authored Nov 17, 2023
1 parent 0b7b72d commit 50c8675
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 67 deletions.
12 changes: 7 additions & 5 deletions core/test/docs.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ endtime = 2021-01-01 # required

# input files
database = "database.gpkg" # required
input_dir = "." # optional, default "."
results_dir = "results" # optional, default "."

# Specific tables can also go into Arrow files rather than the database.
# For large tables this can benefit from better compressed file sizes.
Expand Down Expand Up @@ -37,9 +39,9 @@ timing = false # optional, whether to log debug timing statements

[results]
# These results files are always written
flow = "results/flow.arrow" # optional, default "results/flow.arrow"
basin = "results/basin.arrow" # optional, default "results/basin.arrow"
control = "results/control.arrow" # optional, default "results/control.arrow"
flow = "results/flow.arrow" # optional, default "results/flow.arrow"
basin = "results/basin.arrow" # optional, default "results/basin.arrow"
control = "results/control.arrow" # optional, default "results/control.arrow"
allocation = "results/allocation.arrow" # optional, default "results/allocation.arrow"
compression = "zstd" # optional, default "zstd", also supports "lz4"
compression_level = 6 # optional, default 6
compression = "zstd" # optional, default "zstd", also supports "lz4"
compression_level = 6 # optional, default 6
8 changes: 8 additions & 0 deletions core/test/run_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ end
end
end

@testset "basic arrow model" begin
toml_path = normpath(@__DIR__, "../../generated_testmodels/basic_arrow/ribasim.toml")
@test ispath(toml_path)
model = Ribasim.run(toml_path)
@test model isa Ribasim.Model
@test successful_retcode(model)
end

@testset "basic transient model" begin
toml_path =
normpath(@__DIR__, "../../generated_testmodels/basic_transient/ribasim.toml")
Expand Down
2 changes: 1 addition & 1 deletion docs/python/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = ribasim.Model.from_toml(datadir / \"basic/ribasim.toml\")"
"model = ribasim.Model(filepath=datadir / \"basic/ribasim.toml\")"
]
},
{
Expand Down
75 changes: 53 additions & 22 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
field_validator,
model_serializer,
model_validator,
validate_call,
)

from ribasim.types import FilePath
Expand Down Expand Up @@ -104,8 +105,8 @@ def check_filepath(cls, value: Any) -> Any:
if filepath is not None:
filepath = Path(filepath)
data = cls._load(filepath)
value.update(data)
return value
data.update(value)
return data
elif isinstance(value, Path | str):
# Pydantic Model init requires a dict
data = cls._load(Path(value))
Expand All @@ -114,15 +115,25 @@ def check_filepath(cls, value: Any) -> Any:
else:
return value

@validate_call
def set_filepath(self, filepath: Path) -> None:
"""Set the filepath of this instance.
Args:
filepath (Path): The filepath to set.
"""
# Disable assignment validation, which would
# otherwise trigger check_filepath() and _load() again.
self.model_config["validate_assignment"] = False
self.filepath = filepath
self.model_config["validate_assignment"] = True

@abstractmethod
def _save(self, directory: DirectoryPath) -> None:
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath) -> None:
"""Save this instance to disk.
This method needs to be implemented by any class deriving from
FileModel.
Args:
save_settings (ModelSaveSettings): The model save settings.
"""
raise NotImplementedError()

Expand Down Expand Up @@ -159,8 +170,8 @@ def prefix_extra_columns(cls, v: DataFrame[TableT]):
return v

@model_serializer
def set_model(self) -> Path | None:
return self.filepath
def set_model(self) -> str | None:
return str(self.filepath.name) if self.filepath is not None else None

@classmethod
def tablename(cls) -> str:
Expand Down Expand Up @@ -205,11 +216,17 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:
return {}

def _save(
self, directory: DirectoryPath, sort_keys: list[str] = ["node_id"]
self,
directory: DirectoryPath,
input_dir: DirectoryPath,
sort_keys: list[str] = ["node_id"],
) -> None:
# TODO directory could be used to save an arrow file
db_path = context_file_loading.get().get("database")
if self.df is not None and db_path is not None:
if self.df is not None and self.filepath is not None:
self.sort(sort_keys)
self._write_arrow(self.filepath, directory, input_dir)
elif self.df is not None and db_path is not None:
self.sort(sort_keys)
self._write_table(db_path)

Expand All @@ -223,15 +240,26 @@ def _write_table(self, temp_path: Path) -> None:
SQLite connection to the database.
"""
table = self.tablename()
if self.df is not None: # double check to make mypy happy
with closing(connect(temp_path)) as connection:
self.df.to_sql(table, connection, index=False, if_exists="replace")

# Set geopackage attribute table
with closing(connection.cursor()) as cursor:
sql = "INSERT INTO gpkg_contents (table_name, data_type, identifier) VALUES (?, ?, ?)"
cursor.execute(sql, (table, "attributes", table))
connection.commit()
assert self.df is not None
with closing(connect(temp_path)) as connection:
self.df.to_sql(table, connection, index=False, if_exists="replace")

# Set geopackage attribute table
with closing(connection.cursor()) as cursor:
sql = "INSERT INTO gpkg_contents (table_name, data_type, identifier) VALUES (?, ?, ?)"
cursor.execute(sql, (table, "attributes", table))
connection.commit()

def _write_arrow(self, filepath: Path, directory: Path, input_dir: Path) -> None:
"""Write the contents of the input to a an arrow file."""
assert self.df is not None
path = directory / input_dir / filepath
path.parent.mkdir(parents=True, exist_ok=True)
self.df.to_feather(
path,
compression="zstd",
compression_level=6,
)

@classmethod
def _from_db(cls, path: FilePath, table: str) -> pd.DataFrame | None:
Expand All @@ -246,7 +274,8 @@ def _from_db(cls, path: FilePath, table: str) -> pd.DataFrame | None:

@classmethod
def _from_arrow(cls, path: FilePath) -> pd.DataFrame:
return pd.read_feather(path)
directory = context_file_loading.get().get("directory", Path("."))
return pd.read_feather(directory / path)

def sort(self, sort_keys: list[str] = ["node_id"]):
"""Sort all input tables as required.
Expand Down Expand Up @@ -358,8 +387,10 @@ def node_ids_and_types(self) -> tuple[list[int], list[str]]:
ids = self.node_ids()
return list(ids), len(ids) * [self.get_input_type()]

def _save(self, directory: DirectoryPath):
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs):
for field in self.fields():
getattr(self, field)._save(
directory, sort_keys=self._sort_keys.get("field", ["node_id"])
directory,
input_dir,
sort_keys=self._sort_keys.get("field", ["node_id"]),
)
96 changes: 58 additions & 38 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
import pandas as pd
import tomli
import tomli_w
from pydantic import DirectoryPath, Field, model_serializer, model_validator
from pydantic import (
DirectoryPath,
Field,
field_serializer,
field_validator,
model_serializer,
model_validator,
)

from ribasim.config import (
Allocation,
Expand Down Expand Up @@ -36,6 +43,10 @@


class Network(FileModel, NodeModel):
filepath: Path | None = Field(
default=Path("database.gpkg"), exclude=True, repr=False
)

node: Node = Field(default_factory=Node)
edge: Edge = Field(default_factory=Edge)

Expand All @@ -50,35 +61,40 @@ def n_nodes(self):
@classmethod
def _load(cls, filepath: Path | None) -> dict[str, Any]:
if filepath is not None:
context_file_loading.get()["database"] = filepath
directory = context_file_loading.get().get("directory", Path("."))
context_file_loading.get()["database"] = directory / filepath
return {}

@classmethod
def _layername(cls, field: str) -> str:
return field.capitalize()

def _save(self, directory):
def _save(self, directory, input_dir=Path(".")):
# We write all tables to a temporary database with a dot prefix,
# and at the end move this over the target file.
# This does not throw a PermissionError if the file is open in QGIS.
directory = Path(directory)
db_path = directory / "database.gpkg"
db_path = directory / input_dir / "database.gpkg"
db_path = db_path.resolve()
db_path.parent.mkdir(parents=True, exist_ok=True)
temp_path = db_path.with_stem(".database")

# avoid adding tables to existing model
temp_path.unlink(missing_ok=True)
context_file_loading.get()["database"] = temp_path

self.node._save(directory)
self.edge._save(directory)
self.node._save(directory, input_dir)
self.edge._save(directory, input_dir)

shutil.move(temp_path, db_path)
context_file_loading.get()["database"] = db_path

@model_serializer
def set_modelname(self) -> str:
return "database.gpkg"
if self.filepath is not None:
return str(self.filepath.name)
else:
return str(self.model_fields["filepath"].default)


class Model(FileModel):
Expand All @@ -95,13 +111,13 @@ class Model(FileModel):
endtime : datetime.datetime
End time of the simulation.
update_timestep: float = 86400
update_timestep: datetime.timedelta = timedelta(seconds=86400)
The output time step of the simulation in seconds (default of 1 day)
relative_dir: str = "."
relative_dir: Path = Path(".")
The relative directory of the input files.
input_dir: str = "."
input_dir: Path = Path(".")
The directory of the input files.
results_dir: str = "."
results_dir: Path = Path(".")
The directory of the results files.
network: Network
Expand Down Expand Up @@ -147,10 +163,10 @@ class Model(FileModel):
starttime: datetime.datetime
endtime: datetime.datetime

update_timestep: float = 86400
relative_dir: str = "."
input_dir: str = "."
results_dir: str = "."
update_timestep: datetime.timedelta = datetime.timedelta(seconds=86400)
relative_dir: Path = Path(".")
input_dir: Path = Path(".")
results_dir: Path = Path("results")

network: Network = Field(default_factory=Network, alias="database")
results: Results = Results()
Expand All @@ -174,6 +190,21 @@ class Model(FileModel):
pid_control: PidControl = Field(default_factory=PidControl)
user: User = Field(default_factory=User)

@field_validator("update_timestep")
@classmethod
def timestep_in_seconds(cls, v: Any) -> datetime.timedelta:
if not isinstance(v, datetime.timedelta):
v = datetime.timedelta(seconds=v)
return v

@field_serializer("update_timestep")
def serialize_dt(self, td: datetime.timedelta) -> int:
return int(td.total_seconds())

@field_serializer("relative_dir", "input_dir", "results_dir")
def serialize_path(self, path: Path) -> str:
return str(path)

def __repr__(self) -> str:
first = []
second = []
Expand All @@ -196,15 +227,14 @@ def _write_toml(self, directory: FilePath):
content = self.model_dump(exclude_unset=True, exclude_none=True, by_alias=True)
# Filter empty dicts (default Nodes)
content = dict(filter(lambda x: x[1], content.items()))

fn = directory / "ribasim.toml"
with open(fn, "wb") as f:
tomli_w.dump(content, f)
return fn

def _save(self, directory: DirectoryPath):
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
for sub in self.nodes().values():
sub._save(directory)
sub._save(directory, input_dir)

def nodes(self):
return {
Expand Down Expand Up @@ -283,6 +313,11 @@ def validate_model(self):
self.validate_model_node_field_ids()
self.validate_model_node_ids()

@classmethod
def read(cls, filepath: FilePath) -> "Model":
"""Read model from TOML file."""
return cls(filepath=filepath) # type: ignore

def write(self, directory: FilePath) -> Path:
"""
Write the contents of the model to a database and a TOML configuration file.
Expand All @@ -297,7 +332,7 @@ def write(self, directory: FilePath) -> Path:
context_file_loading.set({})
directory = Path(directory)
directory.mkdir(parents=True, exist_ok=True)
self._save(directory)
self._save(directory, self.input_dir)
fn = self._write_toml(directory)

context_file_loading.set({})
Expand All @@ -311,8 +346,10 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:
with open(filepath, "rb") as f:
config = tomli.load(f)

# Convert relative path to absolute path
config["database"] = filepath.parent / config["database"]
context_file_loading.get()["directory"] = filepath.parent / config.get(
"input_dir", "."
)

return config
else:
return {}
Expand All @@ -323,23 +360,6 @@ def reset_contextvar(self) -> "Model":
context_file_loading.set({})
return self

@classmethod
def from_toml(cls, path: Path | str) -> "Model":
"""
Initialize a model from the TOML configuration file.
Parameters
----------
path : FilePath
Path to the configuration TOML file.
Returns
-------
model : Model
"""
kwargs = cls._load(Path(path))
return cls(**kwargs)

def plot_control_listen(self, ax):
x_start, x_end = [], []
y_start, y_end = [], []
Expand Down
Loading

0 comments on commit 50c8675

Please sign in to comment.