Skip to content

Commit

Permalink
Cleanup model.py (#1336)
Browse files Browse the repository at this point in the history
Co-authored-by: Martijn Visser <[email protected]>
  • Loading branch information
Hofer-Julian and visr authored Mar 28, 2024
1 parent ae2fdf9 commit 261643f
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class Model(FileModel):
starttime: datetime.datetime
endtime: datetime.datetime

input_dir: Path = Field(default_factory=lambda: Path("."))
results_dir: Path = Field(default_factory=lambda: Path("results"))
input_dir: Path = Field(default=Path("."))
results_dir: Path = Field(default=Path("results"))

logging: Logging = Field(default_factory=Logging)
solver: Solver = Field(default_factory=Solver)
Expand Down Expand Up @@ -102,7 +102,10 @@ def serialize_path(self, path: Path) -> str:
return str(path)

def model_post_init(self, __context: Any) -> None:
# Always write dir fields
# When serializing we exclude fields that are set to their default values
# However, we always want to write `input_dir` and `results_dir`
# By overriding `BaseModel.model_post_init` we can set them explicitly,
# and enforce that they are always written.
self.model_fields_set.update({"input_dir", "results_dir"})

def __repr__(self) -> str:
Expand All @@ -125,7 +128,20 @@ def __repr__(self) -> str:
content.append(")")
return "\n".join(content)

def _write_toml(self, fn: Path):
def _write_toml(self, fn: Path) -> Path:
"""
Write the model data to a TOML file.
Parameters
----------
fn : FilePath
The file path where the TOML file will be written.
Returns
-------
Path
The file path of the written TOML file.
"""
content = self.model_dump(exclude_unset=True, exclude_none=True, by_alias=True)
# Filter empty dicts (default Nodes)
content = dict(filter(lambda x: x[1], content.items()))
Expand Down Expand Up @@ -169,7 +185,7 @@ def _nodes(self) -> Generator[MultiNodeModel, Any, None]:
if (
isinstance(attr, MultiNodeModel)
and attr.node.df is not None
# Model.read creates empty node tables (#1278)
# TODO: Model.read creates empty node tables (#1278)
and not attr.node.df.empty
):
yield attr
Expand Down Expand Up @@ -348,7 +364,11 @@ def plot(self, ax=None, indicate_subnetworks: bool = True) -> Any:
return ax

def to_xugrid(self):
"""Convert the network to a xugrid.UgridDataset."""
"""
Convert the network to a `xugrid.UgridDataset`.
This method will throw `ImportError`,
if the optional dependency `xugrid` isn't installed.
"""
node_df = self.node_table().df

# This will need to be adopted for locally unique node IDs,
Expand Down

0 comments on commit 261643f

Please sign in to comment.