Skip to content

Commit

Permalink
Make edge_id use configurable.
Browse files Browse the repository at this point in the history
  • Loading branch information
evetion committed Aug 19, 2024
1 parent e8e235a commit 552d7c5
Show file tree
Hide file tree
Showing 28 changed files with 231 additions and 222 deletions.
29 changes: 18 additions & 11 deletions python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(

def into_geodataframe(self, node_type: str, node_id: int) -> GeoDataFrame:
extra = self.model_extra if self.model_extra is not None else {}
return GeoDataFrame(
gdf = GeoDataFrame(
data={
"node_id": pd.Series([node_id], dtype=np.int32),
"node_type": pd.Series([node_type], dtype=str),
Expand All @@ -192,6 +192,8 @@ def into_geodataframe(self, node_type: str, node_id: int) -> GeoDataFrame:
},
geometry=[self.geometry],
)
gdf.set_index("node_id", inplace=True)
return gdf


class MultiNodeModel(NodeModel):
Expand Down Expand Up @@ -229,8 +231,8 @@ def add(
)

if node_id is None:
node_id = self._parent.used_node_ids.new_id()
elif node_id in self._parent.used_node_ids:
node_id = self._parent._used_node_ids.new_id()
elif node_id in self._parent._used_node_ids:
raise ValueError(
f"Node IDs have to be unique, but {node_id} already exists."
)
Expand All @@ -243,17 +245,22 @@ def add(
)
assert table.df is not None
table_to_append = table.df.assign(node_id=node_id)
setattr(self, member_name, pd.concat([existing_table, table_to_append]))
setattr(
self,
member_name,
pd.concat([existing_table, table_to_append], ignore_index=True),
)

node_table = node.into_geodataframe(
node_type=self.__class__.__name__, node_id=node_id
)
self.node.df = (
node_table
if self.node.df is None
else pd.concat([self.node.df, node_table])
)
self._parent.used_node_ids.add(node_id)
if self.node.df is None:
self.node.df = node_table
else:
df = pd.concat([self.node.df, node_table])
self.node.df = df

self._parent._used_node_ids.add(node_id)
return self[node_id]

def __getitem__(self, index: int) -> NodeData:
Expand All @@ -265,7 +272,7 @@ def __getitem__(self, index: int) -> NodeData:
f"{node_model_name} index must be an integer, not {indextype}"
)

row = self.node[index].iloc[0]
row = self.node.df.loc[index]
return NodeData(
node_id=int(index), node_type=row["node_type"], geometry=row["geometry"]
)
Expand Down
3 changes: 2 additions & 1 deletion python/ribasim/ribasim/geometry/area.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import pandera as pa
from pandera.dtypes import Int32
from pandera.typing import Series
from pandera.typing import Index, Series
from pandera.typing.geopandas import GeoSeries

from ribasim.schemas import _BaseSchema


class BasinAreaSchema(_BaseSchema):
fid: Index[Int32] = pa.Field(default=0, check_name=True)
node_id: Series[Int32] = pa.Field(nullable=False, default=0)
geometry: GeoSeries[Any] = pa.Field(default=None, nullable=True)
38 changes: 24 additions & 14 deletions python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, NamedTuple
from typing import Any, NamedTuple, Optional

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -8,11 +8,13 @@
from matplotlib.axes import Axes
from numpy.typing import NDArray
from pandera.dtypes import Int32
from pandera.typing import Series
from pandera.typing import Index, Series
from pandera.typing.geopandas import GeoDataFrame, GeoSeries
from pydantic import NonNegativeInt, PrivateAttr
from shapely.geometry import LineString, MultiLineString, Point

from ribasim.input_base import SpatialTableModel
from ribasim.utils import UsedIDs

__all__ = ("EdgeTable",)

Expand All @@ -32,6 +34,7 @@ class NodeData(NamedTuple):


class EdgeSchema(pa.DataFrameModel):
edge_id: Index[Int32] = pa.Field(default=0, ge=0, check_name=True, coerce=True)
name: Series[str] = pa.Field(default="")
from_node_id: Series[Int32] = pa.Field(default=0, coerce=True)
to_node_id: Series[Int32] = pa.Field(default=0, coerce=True)
Expand All @@ -44,17 +47,24 @@ class EdgeSchema(pa.DataFrameModel):
class Config:
add_missing_columns = True

@classmethod
def _index_name(self) -> str:
return "edge_id"


class EdgeTable(SpatialTableModel[EdgeSchema]):
"""Defines the connections between nodes."""

_used_edge_ids: UsedIDs = PrivateAttr(default_factory=UsedIDs)

def add(
self,
from_node: NodeData,
to_node: NodeData,
geometry: LineString | MultiLineString | None = None,
name: str = "",
subnetwork_id: int | None = None,
edge_id: Optional[NonNegativeInt] = None,
**kwargs,
):
"""Add an edge between nodes. The type of the edge (flow or control)
Expand Down Expand Up @@ -84,9 +94,16 @@ def add(
"control" if from_node.node_type in SPATIALCONTROLNODETYPES else "flow"
)
assert self.df is not None
if edge_id is None:
edge_id = self._used_edge_ids.new_id()
elif edge_id in self._used_edge_ids:
raise ValueError(

Check warning on line 100 in python/ribasim/ribasim/geometry/edge.py

View check run for this annotation

Codecov / codecov/patch

python/ribasim/ribasim/geometry/edge.py#L100

Added line #L100 was not covered by tests
f"Edge IDs have to be unique, but {edge_id} already exists."
)

table_to_append = GeoDataFrame[EdgeSchema](
table_to_append = GeoDataFrame(
data={
"edge_id": pd.Series([edge_id], dtype=np.int32),
"from_node_id": pd.Series([from_node.node_id], dtype=np.int32),
"to_node_id": pd.Series([to_node.node_id], dtype=np.int32),
"edge_type": pd.Series([edge_type], dtype=str),
Expand All @@ -97,26 +114,19 @@ def add(
geometry=geometry_to_append,
crs=self.df.crs,
)
table_to_append.set_index("edge_id", inplace=True)

self.df = GeoDataFrame[EdgeSchema](
pd.concat([self.df, table_to_append], ignore_index=True)
)
self.df = GeoDataFrame[EdgeSchema](pd.concat([self.df, table_to_append]))
if self.df.duplicated(subset=["from_node_id", "to_node_id"]).any():
raise ValueError(
f"Edges have to be unique, but edge ({from_node.node_id}, {to_node.node_id}) already exists."
f"Edges have to be unique, but edge with from_node_id {from_node.node_id} to_node_id {to_node.node_id} already exists."
)
self.df.index.name = "fid"
self._used_edge_ids.add(edge_id)

def _get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]:
assert self.df is not None
return (self.df.edge_type == edge_type).to_numpy()

def sort(self):
# Only sort the index (fid / edge_id) since this needs to be sorted in a GeoPackage.
# Under most circumstances, this retains the input order,
# making the edge_id as stable as possible; useful for post-processing.
self.df.sort_index(inplace=True)

def plot(self, **kwargs) -> Axes:
"""Plot the edges of the model.
Expand Down
18 changes: 7 additions & 11 deletions python/ribasim/ribasim/geometry/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandera as pa
from matplotlib.patches import Patch
from pandera.dtypes import Int32
from pandera.typing import Series
from pandera.typing import Index, Series
from pandera.typing.geopandas import GeoSeries

from ribasim.input_base import SpatialTableModel
Expand All @@ -16,7 +16,7 @@


class NodeSchema(pa.DataFrameModel):
node_id: Series[Int32] = pa.Field(ge=0)
node_id: Index[Int32] = pa.Field(default=0, check_name=True)
name: Series[str] = pa.Field(default="")
node_type: Series[str] = pa.Field(default="")
subnetwork_id: Series[pd.Int32Dtype] = pa.Field(
Expand All @@ -28,6 +28,10 @@ class Config:
add_missing_columns = True
coerce = True

@classmethod
def _index_name(self) -> str:
return "node_id"


class NodeTable(SpatialTableModel[NodeSchema]):
"""The Ribasim nodes as Point geometries."""
Expand All @@ -37,12 +41,6 @@ def filter(self, nodetype: str):
if self.df is not None:
mask = self.df[self.df["node_type"] != nodetype].index
self.df.drop(mask, inplace=True)
self.df.reset_index(inplace=True, drop=True)

def sort(self):
assert self.df is not None
sort_keys = ["node_type", "node_id"]
self.df.sort_values(sort_keys, ignore_index=True, inplace=True)

def plot_allocation_networks(self, ax=None, zorder=None) -> Any:
if ax is None:
Expand Down Expand Up @@ -156,9 +154,7 @@ def plot(self, ax=None, zorder=None) -> Any:

assert self.df is not None
geometry = self.df["geometry"]
for text, xy in zip(
self.df["node_id"], np.column_stack((geometry.x, geometry.y))
):
for text, xy in zip(self.df.index, np.column_stack((geometry.x, geometry.y))):
ax.annotate(text=text, xy=xy, xytext=(2.0, 2.0), textcoords="offset points")

return ax
54 changes: 34 additions & 20 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
context_file_loading: ContextVar[dict[str, Any]] = ContextVar(
"file_loading", default={}
)
context_file_writing: ContextVar[dict[str, Any]] = ContextVar(
"file_writing", default={}
)

TableT = TypeVar("TableT", bound=pa.DataFrameModel)

Expand Down Expand Up @@ -179,10 +182,7 @@ def _check_extra_columns(cls, v: DataFrame[TableT]):
"""Allow only extra columns with `meta_` prefix."""
if isinstance(v, (pd.DataFrame, gpd.GeoDataFrame)):
for colname in v.columns:
if colname == "fid":
# Autogenerated on writing, don't carry them
v = v.drop(columns=["fid"])
elif colname not in cls.columns() and not colname.startswith("meta_"):
if colname not in cls.columns() and not colname.startswith("meta_"):
raise ValueError(
f"Unrecognized column '{colname}'. Extra columns need a 'meta_' prefix."
)
Expand Down Expand Up @@ -216,8 +216,13 @@ def tablename(cls) -> str:
@model_validator(mode="before")
@classmethod
def _check_dataframe(cls, value: Any) -> Any:
# Enable initialization with a Dict.
if isinstance(value, dict) and len(value) > 0 and "df" not in value:
value = DataFrame(dict(**value))

# Enable initialization with a DataFrame.
if isinstance(value, pd.DataFrame | gpd.GeoDataFrame):
value.index.rename("fid", inplace=True)
value = {"df": value}

return value
Expand All @@ -232,7 +237,7 @@ def _node_ids(self) -> set[int]:
@classmethod
def _load(cls, filepath: Path | None) -> dict[str, Any]:
db = context_file_loading.get().get("database")
if filepath is not None:
if filepath is not None and db is not None:
adf = cls._from_arrow(filepath)
# TODO Store filepath?
return {"df": adf}
Expand All @@ -244,12 +249,11 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:

def _save(self, directory: DirectoryPath, input_dir: DirectoryPath) -> None:
# TODO directory could be used to save an arrow file
db_path = context_file_loading.get().get("database")
db_path = context_file_writing.get().get("database")
self.sort()
if self.filepath is not None:
self.sort()
self._write_arrow(self.filepath, directory, input_dir)
elif db_path is not None:
self.sort()
self._write_geopackage(db_path)

def _write_geopackage(self, temp_path: Path) -> None:
Expand All @@ -264,16 +268,11 @@ def _write_geopackage(self, temp_path: Path) -> None:
assert self.df is not None
table = self.tablename()

# Add `fid` to all tables as primary key
# Enables editing values manually in QGIS
df = self.df.copy()
df["fid"] = range(1, len(df) + 1)

with closing(connect(temp_path)) as connection:
df.to_sql(
self.df.to_sql(
table,
connection,
index=False,
index=True,
if_exists="replace",
dtype={"fid": "INTEGER PRIMARY KEY AUTOINCREMENT"},
)
Expand Down Expand Up @@ -303,6 +302,7 @@ def _from_db(cls, path: Path, table: str) -> pd.DataFrame | None:
df = pd.read_sql_query(
query, connection, parse_dates={"time": {"format": "ISO8601"}}
)
df.set_index("fid", inplace=True)
else:
df = None

Expand All @@ -319,7 +319,9 @@ def sort(self):
Sorting is done automatically before writing the table.
"""
if self.df is not None:
self.df.sort_values(self._sort_keys, ignore_index=True, inplace=True)
df = self.df.sort_values(self._sort_keys, ignore_index=True)
df.index.rename("fid", inplace=True)
self.df = df # trigger validation and thus index coercion to int32

@classmethod
def tableschema(cls) -> TableT:
Expand All @@ -336,7 +338,7 @@ def tableschema(cls) -> TableT:
def columns(cls) -> list[str]:
"""Retrieve column names."""
T = cls.tableschema()
return list(T.to_schema().columns.keys())
return list(T.to_schema().columns.keys()) + [T.to_schema().index.name]

def __repr__(self) -> str:
# Make sure not to return just "None", because it gets extremely confusing
Expand Down Expand Up @@ -367,11 +369,19 @@ def __getitem__(self, index) -> pd.DataFrame | gpd.GeoDataFrame:
class SpatialTableModel(TableModel[TableT], Generic[TableT]):
df: GeoDataFrame[TableT] | None = Field(default=None, exclude=True, repr=False)

def sort(self):
# Only sort the index (node_id / edge_id) since this needs to be sorted in a GeoPackage.
# Under most circumstances, this retains the input order,
# making the edge_id as stable as possible; useful for post-processing.
self.df.sort_index(inplace=True)

@classmethod
def _from_db(cls, path: Path, table: str):
with connect(path) as connection:
if exists(connection, table):
# pyogrio hardcodes fid name on reading
df = gpd.read_file(path, layer=table, fid_as_index=True)
df.index.rename(cls.tableschema()._index_name(), inplace=True)
else:
df = None

Expand All @@ -386,9 +396,13 @@ def _write_geopackage(self, path: Path) -> None:
path : Path
"""
assert self.df is not None
# the index name must be fid otherwise it will generate a separate fid column
self.df.index.name = "fid"
self.df.to_file(path, layer=self.tablename(), index=True, driver="GPKG")
self.df.to_file(
path,
layer=self.tablename(),
driver="GPKG",
index=True,
fid=self.df.index.name,
)
_add_styles_to_geopackage(path, self.tablename())


Expand Down
Loading

0 comments on commit 552d7c5

Please sign in to comment.