Skip to content

Commit

Permalink
Add results to xugrid
Browse files Browse the repository at this point in the history
  • Loading branch information
visr committed Apr 10, 2024
1 parent 35dfc2c commit 96bbb5e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 6 deletions.
70 changes: 65 additions & 5 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def validate_model(self):
@classmethod
def read(cls, filepath: str | PathLike[str]) -> "Model":
"""Read model from TOML file."""
return cls(filepath=filepath) # type: ignore
return cls(filepath=Path(filepath)) # type: ignore

def write(self, filepath: str | PathLike[str]) -> Path:
"""
Expand All @@ -264,10 +264,10 @@ def write(self, filepath: str | PathLike[str]) -> Path:
# TODO
# self.validate_model()
filepath = Path(filepath)
self.filepath = filepath
if not filepath.suffix == ".toml":
raise ValueError(f"Filepath '{filepath}' is not a .toml file.")
context_file_loading.set({})
filepath = Path(filepath)
directory = filepath.parent
directory.mkdir(parents=True, exist_ok=True)
self._save(directory, self.input_dir)
Expand All @@ -280,7 +280,7 @@ def write(self, filepath: str | PathLike[str]) -> Path:
def _load(cls, filepath: Path | None) -> dict[str, Any]:
context_file_loading.set({})

if filepath is not None:
if filepath is not None and filepath.is_file():
with open(filepath, "rb") as f:
config = tomli.load(f)

Expand Down Expand Up @@ -395,9 +395,10 @@ def plot(self, ax=None, indicate_subnetworks: bool = True) -> Any:

return ax

def to_xugrid(self):
def to_xugrid(self, add_results: bool = True):
"""
Convert the network to a `xugrid.UgridDataset`.
Convert the network and results to a `xugrid.UgridDataset`.
To get the network only, set `add_results=False`.
This method will throw `ImportError`,
if the optional dependency `xugrid` isn't installed.
"""
Expand Down Expand Up @@ -449,4 +450,63 @@ def to_xugrid(self):
uds = uds.assign_coords(from_node_id=(edge_dim, from_node_id))
uds = uds.assign_coords(to_node_id=(edge_dim, to_node_id))

if add_results:
uds = self._add_results(uds)

return uds

def _add_results(self, uds):
toml_path = self.filepath
if toml_path is None:
raise FileNotFoundError("Model must be written to disk to add results.")

results_path = toml_path.parent / self.results_dir
basin_path = results_path / "basin.arrow"
flow_path = results_path / "flow.arrow"

if not basin_path.is_file() or not flow_path.is_file():
raise FileNotFoundError(
f"Cannot find results in '{results_path}', "
"perhaps the model needs to be run first."
)

basin_df = pd.read_feather(basin_path)
flow_df = pd.read_feather(flow_path)

Check warning on line 474 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L473-L474

Added lines #L473 - L474 were not covered by tests

edge_dim = uds.grid.edge_dimension
node_dim = uds.grid.node_dimension

Check warning on line 477 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L476-L477

Added lines #L476 - L477 were not covered by tests

# from node_id to the node_dim index
node_lookup = pd.Series(

Check warning on line 480 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L480

Added line #L480 was not covered by tests
index=uds["node_id"],
data=uds[edge_dim],
name="node_index",
)
# from edge_id to the edge_dim index
edge_lookup = pd.Series(

Check warning on line 486 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L486

Added line #L486 was not covered by tests
index=uds["edge_id"],
data=uds[edge_dim],
name="edge_index",
)

basin_df = pd.read_feather(basin_path)
flow_df = pd.read_feather(flow_path)

Check warning on line 493 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L492-L493

Added lines #L492 - L493 were not covered by tests

# datetime64[ms] gives trouble; https://github.com/pydata/xarray/issues/6318
flow_df["time"] = flow_df["time"].astype("datetime64[ns]")
basin_df["time"] = basin_df["time"].astype("datetime64[ns]")

Check warning on line 497 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L496-L497

Added lines #L496 - L497 were not covered by tests

# add flow results to the UgridDataset
flow_df[edge_dim] = edge_lookup[flow_df["edge_id"]].to_numpy()
flow_da = flow_df.set_index(["time", edge_dim])["flow_rate"].to_xarray()
uds[flow_da.name] = flow_da

Check warning on line 502 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L500-L502

Added lines #L500 - L502 were not covered by tests

# add basin results to the UgridDataset
basin_df[node_dim] = node_lookup[basin_df["node_id"]].to_numpy()
basin_df.drop(columns=["node_id"], inplace=True)
basin_ds = basin_df.set_index(["time", node_dim]).to_xarray()

Check warning on line 507 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L505-L507

Added lines #L505 - L507 were not covered by tests

for var_name, da in basin_ds.data_vars.items():
uds[var_name] = da

Check warning on line 510 in python/ribasim/ribasim/model.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/model.py#L509-L510

Added lines #L509 - L510 were not covered by tests

return uds
3 changes: 3 additions & 0 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ def __assert_equal(a: DataFrame, b: DataFrame) -> None:
def test_basic(basic, tmp_path):
model_orig = basic
toml_path = tmp_path / "basic/ribasim.toml"
assert model_orig.filepath is None
model_orig.write(toml_path)
assert model_orig.filepath == toml_path
model_loaded = Model.read(toml_path)
assert model_loaded.filepath == toml_path

with open(toml_path, "rb") as f:
toml_dict = tomli.load(f)
Expand Down
9 changes: 8 additions & 1 deletion python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_indexing(basic):


def test_xugrid(basic, tmp_path):
uds = basic.to_xugrid()
uds = basic.to_xugrid(add_results=False)
assert isinstance(uds, xugrid.UgridDataset)
assert uds.grid.edge_dimension == "ribasim_nEdges"
assert uds.grid.node_dimension == "ribasim_nNodes"
Expand All @@ -243,6 +243,13 @@ def test_xugrid(basic, tmp_path):
uds = xugrid.open_dataset(tmp_path / "ribasim.nc")
assert uds.attrs["Conventions"] == "CF-1.9 UGRID-1.0"

with pytest.raises(FileNotFoundError, match="Model must be written to disk"):
basic.to_xugrid(add_results=True)

basic.write(tmp_path / "ribasim.toml")
with pytest.raises(FileNotFoundError, match="Cannot find results"):
basic.to_xugrid(add_results=True)


def test_to_crs(bucket: Model):
model = bucket
Expand Down

0 comments on commit 96bbb5e

Please sign in to comment.