diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py index 87e2b7d8a..1557e921e 100644 --- a/python/ribasim/ribasim/config.py +++ b/python/ribasim/ribasim/config.py @@ -182,7 +182,7 @@ def __init__( def into_geodataframe(self, node_type: str, node_id: int) -> GeoDataFrame: extra = self.model_extra if self.model_extra is not None else {} - return GeoDataFrame( + gdf = GeoDataFrame( data={ "node_id": pd.Series([node_id], dtype=np.int32), "node_type": pd.Series([node_type], dtype=str), @@ -192,6 +192,8 @@ def into_geodataframe(self, node_type: str, node_id: int) -> GeoDataFrame: }, geometry=[self.geometry], ) + gdf.set_index("node_id", inplace=True) + return gdf class MultiNodeModel(NodeModel): @@ -229,8 +231,8 @@ def add( ) if node_id is None: - node_id = self._parent.used_node_ids.new_id() - elif node_id in self._parent.used_node_ids: + node_id = self._parent._used_node_ids.new_id() + elif node_id in self._parent._used_node_ids: raise ValueError( f"Node IDs have to be unique, but {node_id} already exists." ) @@ -243,17 +245,22 @@ def add( ) assert table.df is not None table_to_append = table.df.assign(node_id=node_id) - setattr(self, member_name, pd.concat([existing_table, table_to_append])) + setattr( + self, + member_name, + pd.concat([existing_table, table_to_append], ignore_index=True), + ) node_table = node.into_geodataframe( node_type=self.__class__.__name__, node_id=node_id ) - self.node.df = ( - node_table - if self.node.df is None - else pd.concat([self.node.df, node_table]) - ) - self._parent.used_node_ids.add(node_id) + if self.node.df is None: + self.node.df = node_table + else: + df = pd.concat([self.node.df, node_table]) + self.node.df = df + + self._parent._used_node_ids.add(node_id) return self[node_id] def __getitem__(self, index: int) -> NodeData: @@ -265,7 +272,7 @@ def __getitem__(self, index: int) -> NodeData: f"{node_model_name} index must be an integer, not {indextype}" ) - row = self.node[index].iloc[0] + row = self.node.df.loc[index] return NodeData( node_id=int(index), node_type=row["node_type"], geometry=row["geometry"] ) diff --git a/python/ribasim/ribasim/geometry/area.py b/python/ribasim/ribasim/geometry/area.py index d6aa448c1..a3c0b846c 100644 --- a/python/ribasim/ribasim/geometry/area.py +++ b/python/ribasim/ribasim/geometry/area.py @@ -2,12 +2,13 @@ import pandera as pa from pandera.dtypes import Int32 -from pandera.typing import Series +from pandera.typing import Index, Series from pandera.typing.geopandas import GeoSeries from ribasim.schemas import _BaseSchema class BasinAreaSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=0, check_name=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) geometry: GeoSeries[Any] = pa.Field(default=None, nullable=True) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index f0c42b6fd..c20e9e6a6 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -1,4 +1,4 @@ -from typing import Any, NamedTuple +from typing import Any, NamedTuple, Optional import matplotlib.pyplot as plt import numpy as np @@ -8,11 +8,13 @@ from matplotlib.axes import Axes from numpy.typing import NDArray from pandera.dtypes import Int32 -from pandera.typing import Series +from pandera.typing import Index, Series from pandera.typing.geopandas import GeoDataFrame, GeoSeries +from pydantic import NonNegativeInt, PrivateAttr from shapely.geometry import LineString, MultiLineString, Point from ribasim.input_base import SpatialTableModel +from ribasim.utils import UsedIDs __all__ = ("EdgeTable",) @@ -32,6 +34,7 @@ class NodeData(NamedTuple): class EdgeSchema(pa.DataFrameModel): + edge_id: Index[Int32] = pa.Field(default=0, ge=0, check_name=True, coerce=True) name: Series[str] = pa.Field(default="") from_node_id: Series[Int32] = pa.Field(default=0, coerce=True) to_node_id: Series[Int32] = pa.Field(default=0, coerce=True) @@ -44,10 +47,16 @@ class EdgeSchema(pa.DataFrameModel): class Config: add_missing_columns = True + @classmethod + def _index_name(self) -> str: + return "edge_id" + class EdgeTable(SpatialTableModel[EdgeSchema]): """Defines the connections between nodes.""" + _used_edge_ids: UsedIDs = PrivateAttr(default_factory=UsedIDs) + def add( self, from_node: NodeData, @@ -55,6 +64,7 @@ def add( geometry: LineString | MultiLineString | None = None, name: str = "", subnetwork_id: int | None = None, + edge_id: Optional[NonNegativeInt] = None, **kwargs, ): """Add an edge between nodes. The type of the edge (flow or control) @@ -84,9 +94,16 @@ def add( "control" if from_node.node_type in SPATIALCONTROLNODETYPES else "flow" ) assert self.df is not None + if edge_id is None: + edge_id = self._used_edge_ids.new_id() + elif edge_id in self._used_edge_ids: + raise ValueError( + f"Edge IDs have to be unique, but {edge_id} already exists." + ) - table_to_append = GeoDataFrame[EdgeSchema]( + table_to_append = GeoDataFrame( data={ + "edge_id": pd.Series([edge_id], dtype=np.int32), "from_node_id": pd.Series([from_node.node_id], dtype=np.int32), "to_node_id": pd.Series([to_node.node_id], dtype=np.int32), "edge_type": pd.Series([edge_type], dtype=str), @@ -97,26 +114,19 @@ def add( geometry=geometry_to_append, crs=self.df.crs, ) + table_to_append.set_index("edge_id", inplace=True) - self.df = GeoDataFrame[EdgeSchema]( - pd.concat([self.df, table_to_append], ignore_index=True) - ) + self.df = GeoDataFrame[EdgeSchema](pd.concat([self.df, table_to_append])) if self.df.duplicated(subset=["from_node_id", "to_node_id"]).any(): raise ValueError( - f"Edges have to be unique, but edge ({from_node.node_id}, {to_node.node_id}) already exists." + f"Edges have to be unique, but edge with from_node_id {from_node.node_id} to_node_id {to_node.node_id} already exists." ) - self.df.index.name = "fid" + self._used_edge_ids.add(edge_id) def _get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]: assert self.df is not None return (self.df.edge_type == edge_type).to_numpy() - def sort(self): - # Only sort the index (fid / edge_id) since this needs to be sorted in a GeoPackage. - # Under most circumstances, this retains the input order, - # making the edge_id as stable as possible; useful for post-processing. - self.df.sort_index(inplace=True) - def plot(self, **kwargs) -> Axes: """Plot the edges of the model. diff --git a/python/ribasim/ribasim/geometry/node.py b/python/ribasim/ribasim/geometry/node.py index 5cbc94b50..faf02d749 100644 --- a/python/ribasim/ribasim/geometry/node.py +++ b/python/ribasim/ribasim/geometry/node.py @@ -7,7 +7,7 @@ import pandera as pa from matplotlib.patches import Patch from pandera.dtypes import Int32 -from pandera.typing import Series +from pandera.typing import Index, Series from pandera.typing.geopandas import GeoSeries from ribasim.input_base import SpatialTableModel @@ -16,7 +16,7 @@ class NodeSchema(pa.DataFrameModel): - node_id: Series[Int32] = pa.Field(ge=0) + node_id: Index[Int32] = pa.Field(default=0, check_name=True) name: Series[str] = pa.Field(default="") node_type: Series[str] = pa.Field(default="") subnetwork_id: Series[pd.Int32Dtype] = pa.Field( @@ -28,6 +28,10 @@ class Config: add_missing_columns = True coerce = True + @classmethod + def _index_name(self) -> str: + return "node_id" + class NodeTable(SpatialTableModel[NodeSchema]): """The Ribasim nodes as Point geometries.""" @@ -37,12 +41,6 @@ def filter(self, nodetype: str): if self.df is not None: mask = self.df[self.df["node_type"] != nodetype].index self.df.drop(mask, inplace=True) - self.df.reset_index(inplace=True, drop=True) - - def sort(self): - assert self.df is not None - sort_keys = ["node_type", "node_id"] - self.df.sort_values(sort_keys, ignore_index=True, inplace=True) def plot_allocation_networks(self, ax=None, zorder=None) -> Any: if ax is None: @@ -156,9 +154,7 @@ def plot(self, ax=None, zorder=None) -> Any: assert self.df is not None geometry = self.df["geometry"] - for text, xy in zip( - self.df["node_id"], np.column_stack((geometry.x, geometry.y)) - ): + for text, xy in zip(self.df.index, np.column_stack((geometry.x, geometry.y))): ax.annotate(text=text, xy=xy, xytext=(2.0, 2.0), textcoords="offset points") return ax diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index 4cf1bf99b..0cb8df99a 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -60,6 +60,9 @@ context_file_loading: ContextVar[dict[str, Any]] = ContextVar( "file_loading", default={} ) +context_file_writing: ContextVar[dict[str, Any]] = ContextVar( + "file_writing", default={} +) TableT = TypeVar("TableT", bound=pa.DataFrameModel) @@ -179,10 +182,7 @@ def _check_extra_columns(cls, v: DataFrame[TableT]): """Allow only extra columns with `meta_` prefix.""" if isinstance(v, (pd.DataFrame, gpd.GeoDataFrame)): for colname in v.columns: - if colname == "fid": - # Autogenerated on writing, don't carry them - v = v.drop(columns=["fid"]) - elif colname not in cls.columns() and not colname.startswith("meta_"): + if colname not in cls.columns() and not colname.startswith("meta_"): raise ValueError( f"Unrecognized column '{colname}'. Extra columns need a 'meta_' prefix." ) @@ -216,8 +216,13 @@ def tablename(cls) -> str: @model_validator(mode="before") @classmethod def _check_dataframe(cls, value: Any) -> Any: + # Enable initialization with a Dict. + if isinstance(value, dict) and len(value) > 0 and "df" not in value: + value = DataFrame(dict(**value)) + # Enable initialization with a DataFrame. if isinstance(value, pd.DataFrame | gpd.GeoDataFrame): + value.index.rename("fid", inplace=True) value = {"df": value} return value @@ -232,7 +237,7 @@ def _node_ids(self) -> set[int]: @classmethod def _load(cls, filepath: Path | None) -> dict[str, Any]: db = context_file_loading.get().get("database") - if filepath is not None: + if filepath is not None and db is not None: adf = cls._from_arrow(filepath) # TODO Store filepath? return {"df": adf} @@ -244,12 +249,11 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: def _save(self, directory: DirectoryPath, input_dir: DirectoryPath) -> None: # TODO directory could be used to save an arrow file - db_path = context_file_loading.get().get("database") + db_path = context_file_writing.get().get("database") + self.sort() if self.filepath is not None: - self.sort() self._write_arrow(self.filepath, directory, input_dir) elif db_path is not None: - self.sort() self._write_geopackage(db_path) def _write_geopackage(self, temp_path: Path) -> None: @@ -264,16 +268,11 @@ def _write_geopackage(self, temp_path: Path) -> None: assert self.df is not None table = self.tablename() - # Add `fid` to all tables as primary key - # Enables editing values manually in QGIS - df = self.df.copy() - df["fid"] = range(1, len(df) + 1) - with closing(connect(temp_path)) as connection: - df.to_sql( + self.df.to_sql( table, connection, - index=False, + index=True, if_exists="replace", dtype={"fid": "INTEGER PRIMARY KEY AUTOINCREMENT"}, ) @@ -303,6 +302,7 @@ def _from_db(cls, path: Path, table: str) -> pd.DataFrame | None: df = pd.read_sql_query( query, connection, parse_dates={"time": {"format": "ISO8601"}} ) + df.set_index("fid", inplace=True) else: df = None @@ -319,7 +319,9 @@ def sort(self): Sorting is done automatically before writing the table. """ if self.df is not None: - self.df.sort_values(self._sort_keys, ignore_index=True, inplace=True) + df = self.df.sort_values(self._sort_keys, ignore_index=True) + df.index.rename("fid", inplace=True) + self.df = df # trigger validation and thus index coercion to int32 @classmethod def tableschema(cls) -> TableT: @@ -336,7 +338,7 @@ def tableschema(cls) -> TableT: def columns(cls) -> list[str]: """Retrieve column names.""" T = cls.tableschema() - return list(T.to_schema().columns.keys()) + return list(T.to_schema().columns.keys()) + [T.to_schema().index.name] def __repr__(self) -> str: # Make sure not to return just "None", because it gets extremely confusing @@ -367,11 +369,19 @@ def __getitem__(self, index) -> pd.DataFrame | gpd.GeoDataFrame: class SpatialTableModel(TableModel[TableT], Generic[TableT]): df: GeoDataFrame[TableT] | None = Field(default=None, exclude=True, repr=False) + def sort(self): + # Only sort the index (node_id / edge_id) since this needs to be sorted in a GeoPackage. + # Under most circumstances, this retains the input order, + # making the edge_id as stable as possible; useful for post-processing. + self.df.sort_index(inplace=True) + @classmethod def _from_db(cls, path: Path, table: str): with connect(path) as connection: if exists(connection, table): + # pyogrio hardcodes fid name on reading df = gpd.read_file(path, layer=table, fid_as_index=True) + df.index.rename(cls.tableschema()._index_name(), inplace=True) else: df = None @@ -386,9 +396,13 @@ def _write_geopackage(self, path: Path) -> None: path : Path """ assert self.df is not None - # the index name must be fid otherwise it will generate a separate fid column - self.df.index.name = "fid" - self.df.to_file(path, layer=self.tablename(), index=True, driver="GPKG") + self.df.to_file( + path, + layer=self.tablename(), + driver="GPKG", + index=True, + fid=self.df.index.name, + ) _add_styles_to_geopackage(path, self.tablename()) diff --git a/python/ribasim/ribasim/model.py b/python/ribasim/ribasim/model.py index fc2f3e61a..11c50da96 100644 --- a/python/ribasim/ribasim/model.py +++ b/python/ribasim/ribasim/model.py @@ -13,7 +13,7 @@ from pydantic import ( DirectoryPath, Field, - NonNegativeInt, + PrivateAttr, field_serializer, model_validator, ) @@ -44,14 +44,15 @@ from ribasim.geometry.edge import EdgeSchema, EdgeTable from ribasim.geometry.node import NodeTable from ribasim.input_base import ( - BaseModel, ChildModel, FileModel, SpatialTableModel, context_file_loading, + context_file_writing, ) from ribasim.utils import ( MissingOptionalModule, + UsedIDs, _edge_lookup, _node_lookup, _node_lookup_numpy, @@ -64,28 +65,6 @@ xugrid = MissingOptionalModule("xugrid") -class UsedNodeIDs(BaseModel): - """A helper class to manage global unique node IDs. - - We keep track of all node IDs in the model, - and keep track of the maximum to provide new IDs. - MultiNodeModels will check this instance on `add`. - """ - - node_ids: set[int] = set() - max_node_id: NonNegativeInt = 0 - - def add(self, node_id: int) -> None: - self.node_ids.add(node_id) - self.max_node_id = max(self.max_node_id, node_id) - - def __contains__(self, value: int) -> bool: - return self.node_ids.__contains__(value) - - def new_id(self) -> int: - return self.max_node_id + 1 - - class Model(FileModel): """A model of inland water resources systems.""" @@ -93,8 +72,6 @@ class Model(FileModel): endtime: datetime.datetime crs: str - used_node_ids: UsedNodeIDs = Field(default_factory=UsedNodeIDs) - input_dir: Path = Field(default=Path(".")) results_dir: Path = Field(default=Path("results")) @@ -124,6 +101,8 @@ class Model(FileModel): edge: EdgeTable = Field(default_factory=EdgeTable) + _used_node_ids: UsedIDs = PrivateAttr(default_factory=UsedIDs) + @model_validator(mode="after") def _set_node_parent(self) -> "Model": for ( @@ -137,7 +116,7 @@ def _set_node_parent(self) -> "Model": @model_validator(mode="after") def _ensure_edge_table_is_present(self) -> "Model": if self.edge.df is None: - self.edge.df = GeoDataFrame[EdgeSchema]() + self.edge.df = GeoDataFrame[EdgeSchema](index=pd.Index([], name="edge_id")) self.edge.df.set_geometry("geometry", inplace=True, crs=self.crs) return self @@ -200,16 +179,11 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath): db_path = directory / input_dir / "database.gpkg" db_path.parent.mkdir(parents=True, exist_ok=True) db_path.unlink(missing_ok=True) - context_file_loading.get()["database"] = db_path + context_file_writing.get()["database"] = db_path self.edge._save(directory, input_dir) node = self.node_table() assert node.df is not None - if not node.df["node_id"].is_unique: - raise ValueError("node_id must be unique") - node.df.set_index("node_id", drop=False, inplace=True) - node.df.index.name = "fid" - node.df.sort_index(inplace=True) node._save(directory, input_dir) for sub in self._nodes(): @@ -232,23 +206,23 @@ def to_crs(self, crs: str) -> None: def _apply_crs_function(self, function_name: str, crs: str) -> None: """Apply `function_name`, with `crs` as the first and only argument to all spatial tables.""" - self.edge.df = getattr(self.edge.df, function_name)(crs) + getattr(self.edge.df, function_name)(crs, inplace=True) for sub in self._nodes(): if sub.node.df is not None: - sub.node.df = getattr(sub.node.df, function_name)(crs) + getattr(sub.node.df, function_name)(crs, inplace=True) for table in sub._tables(): if isinstance(table, SpatialTableModel) and table.df is not None: - table.df = getattr(table.df, function_name)(crs) + getattr(table.df, function_name)(crs, inplace=True) self.crs = crs def node_table(self) -> NodeTable: """Compute the full sorted NodeTable from all node types.""" df_chunks = [node.node.df.set_crs(self.crs) for node in self._nodes()] # type:ignore - df = pd.concat(df_chunks, ignore_index=True) + df = pd.concat(df_chunks) node_table = NodeTable(df=df) node_table.sort() assert node_table.df is not None - node_table.df.index.name = "fid" + assert node_table.df.index.is_unique, "node_id must be unique" return node_table def _nodes(self) -> Generator[MultiNodeModel, Any, None]: @@ -297,13 +271,13 @@ def write(self, filepath: str | PathLike[str]) -> Path: self.filepath = filepath if not filepath.suffix == ".toml": raise ValueError(f"Filepath '{filepath}' is not a .toml file.") - context_file_loading.set({}) + context_file_writing.set({}) directory = filepath.parent directory.mkdir(parents=True, exist_ok=True) self._save(directory, self.input_dir) fn = self._write_toml(filepath) - context_file_loading.set({}) + context_file_writing.set({}) return fn @classmethod @@ -472,15 +446,12 @@ def to_xugrid(self, add_flow: bool = False, add_allocation: bool = False): node_df = self.node_table().df assert node_df is not None - if not node_df.node_id.is_unique: - raise ValueError("node_id must be unique") - assert self.edge.df is not None edge_df = self.edge.df.copy() # We assume only the flow network is of interest. edge_df = edge_df[edge_df.edge_type == "flow"] - node_id = node_df.node_id.to_numpy() + node_id = node_df.index.to_numpy() edge_id = edge_df.index.to_numpy() from_node_id = edge_df.from_node_id.to_numpy() to_node_id = edge_df.to_node_id.to_numpy() diff --git a/python/ribasim/ribasim/nodes/basin.py b/python/ribasim/ribasim/nodes/basin.py index 7aa1a384b..1beb21336 100644 --- a/python/ribasim/ribasim/nodes/basin.py +++ b/python/ribasim/ribasim/nodes/basin.py @@ -1,8 +1,5 @@ -from geopandas import GeoDataFrame -from pandas import DataFrame - from ribasim.geometry.area import BasinAreaSchema -from ribasim.input_base import TableModel +from ribasim.input_base import SpatialTableModel, TableModel from ribasim.schemas import ( BasinConcentrationExternalSchema, BasinConcentrationSchema, @@ -26,45 +23,36 @@ class Static(TableModel[BasinStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Time(TableModel[BasinTimeSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class State(TableModel[BasinStateSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Profile(TableModel[BasinProfileSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Subgrid(TableModel[BasinSubgridSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass -class Area(TableModel[BasinAreaSchema]): - def __init__(self, **kwargs): - super().__init__(df=GeoDataFrame(dict(**kwargs))) +class Area(SpatialTableModel[BasinAreaSchema]): + pass class Concentration(TableModel[BasinConcentrationSchema]): - def __init__(self, **kwargs): - super().__init__(df=GeoDataFrame(dict(**kwargs))) + pass class ConcentrationExternal(TableModel[BasinConcentrationExternalSchema]): - def __init__(self, **kwargs): - super().__init__(df=GeoDataFrame(dict(**kwargs))) + pass class ConcentrationState(TableModel[BasinConcentrationStateSchema]): - def __init__(self, **kwargs): - super().__init__(df=GeoDataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/continuous_control.py b/python/ribasim/ribasim/nodes/continuous_control.py index 3b130ab90..a22cb9014 100644 --- a/python/ribasim/ribasim/nodes/continuous_control.py +++ b/python/ribasim/ribasim/nodes/continuous_control.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( ContinuousControlFunctionSchema, @@ -10,10 +8,8 @@ class Variable(TableModel[ContinuousControlVariableSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Function(TableModel[ContinuousControlFunctionSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/discrete_control.py b/python/ribasim/ribasim/nodes/discrete_control.py index df7b9f42b..1f61b3bbf 100644 --- a/python/ribasim/ribasim/nodes/discrete_control.py +++ b/python/ribasim/ribasim/nodes/discrete_control.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( DiscreteControlConditionSchema, @@ -11,15 +9,12 @@ class Variable(TableModel[DiscreteControlVariableSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Condition(TableModel[DiscreteControlConditionSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Logic(TableModel[DiscreteControlLogicSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/flow_boundary.py b/python/ribasim/ribasim/nodes/flow_boundary.py index 096652f6f..e32f5aa73 100644 --- a/python/ribasim/ribasim/nodes/flow_boundary.py +++ b/python/ribasim/ribasim/nodes/flow_boundary.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( FlowBoundaryConcentrationSchema, @@ -11,15 +9,12 @@ class Static(TableModel[FlowBoundaryStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Time(TableModel[FlowBoundaryTimeSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Concentration(TableModel[FlowBoundaryConcentrationSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/flow_demand.py b/python/ribasim/ribasim/nodes/flow_demand.py index 56ce7019a..425f4410c 100644 --- a/python/ribasim/ribasim/nodes/flow_demand.py +++ b/python/ribasim/ribasim/nodes/flow_demand.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( FlowDemandStaticSchema, @@ -10,10 +8,8 @@ class Static(TableModel[FlowDemandStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Time(TableModel[FlowDemandTimeSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/level_boundary.py b/python/ribasim/ribasim/nodes/level_boundary.py index 1fb9d5d81..f6278edba 100644 --- a/python/ribasim/ribasim/nodes/level_boundary.py +++ b/python/ribasim/ribasim/nodes/level_boundary.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( LevelBoundaryConcentrationSchema, @@ -11,15 +9,12 @@ class Static(TableModel[LevelBoundaryStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Time(TableModel[LevelBoundaryTimeSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Concentration(TableModel[LevelBoundaryConcentrationSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/level_demand.py b/python/ribasim/ribasim/nodes/level_demand.py index e65bd96d0..5553fe169 100644 --- a/python/ribasim/ribasim/nodes/level_demand.py +++ b/python/ribasim/ribasim/nodes/level_demand.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( LevelDemandStaticSchema, @@ -10,10 +8,8 @@ class Static(TableModel[LevelDemandStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Time(TableModel[LevelDemandTimeSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/linear_resistance.py b/python/ribasim/ribasim/nodes/linear_resistance.py index 5adb72af6..013c13e7f 100644 --- a/python/ribasim/ribasim/nodes/linear_resistance.py +++ b/python/ribasim/ribasim/nodes/linear_resistance.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( LinearResistanceStaticSchema, @@ -9,5 +7,4 @@ class Static(TableModel[LinearResistanceStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/manning_resistance.py b/python/ribasim/ribasim/nodes/manning_resistance.py index 46feea8a0..62306125a 100644 --- a/python/ribasim/ribasim/nodes/manning_resistance.py +++ b/python/ribasim/ribasim/nodes/manning_resistance.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( ManningResistanceStaticSchema, @@ -9,5 +7,4 @@ class Static(TableModel[ManningResistanceStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/outlet.py b/python/ribasim/ribasim/nodes/outlet.py index 8c6623cd9..9f38440c9 100644 --- a/python/ribasim/ribasim/nodes/outlet.py +++ b/python/ribasim/ribasim/nodes/outlet.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( OutletStaticSchema, @@ -9,5 +7,4 @@ class Static(TableModel[OutletStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/pid_control.py b/python/ribasim/ribasim/nodes/pid_control.py index e359e849e..a90353b63 100644 --- a/python/ribasim/ribasim/nodes/pid_control.py +++ b/python/ribasim/ribasim/nodes/pid_control.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import PidControlStaticSchema, PidControlTimeSchema @@ -7,10 +5,8 @@ class Static(TableModel[PidControlStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Time(TableModel[PidControlTimeSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/pump.py b/python/ribasim/ribasim/nodes/pump.py index c2a9a2250..afd8ab355 100644 --- a/python/ribasim/ribasim/nodes/pump.py +++ b/python/ribasim/ribasim/nodes/pump.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( PumpStaticSchema, @@ -9,5 +7,4 @@ class Static(TableModel[PumpStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/tabulated_rating_curve.py b/python/ribasim/ribasim/nodes/tabulated_rating_curve.py index 6afb46c5f..9397eb61a 100644 --- a/python/ribasim/ribasim/nodes/tabulated_rating_curve.py +++ b/python/ribasim/ribasim/nodes/tabulated_rating_curve.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( TabulatedRatingCurveStaticSchema, @@ -10,10 +8,8 @@ class Static(TableModel[TabulatedRatingCurveStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Time(TableModel[TabulatedRatingCurveTimeSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/nodes/user_demand.py b/python/ribasim/ribasim/nodes/user_demand.py index d33192dc0..362848e5f 100644 --- a/python/ribasim/ribasim/nodes/user_demand.py +++ b/python/ribasim/ribasim/nodes/user_demand.py @@ -1,5 +1,3 @@ -from pandas import DataFrame - from ribasim.input_base import TableModel from ribasim.schemas import ( UserDemandStaticSchema, @@ -10,10 +8,8 @@ class Static(TableModel[UserDemandStaticSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass class Time(TableModel[UserDemandTimeSchema]): - def __init__(self, **kwargs): - super().__init__(df=DataFrame(dict(**kwargs))) + pass diff --git a/python/ribasim/ribasim/schemas.py b/python/ribasim/ribasim/schemas.py index f5bccfb4d..c2b11c298 100644 --- a/python/ribasim/ribasim/schemas.py +++ b/python/ribasim/ribasim/schemas.py @@ -2,7 +2,7 @@ import pandera as pa from pandera.dtypes import Int32, Timestamp -from pandera.typing import Series +from pandera.typing import Index, Series class _BaseSchema(pa.DataFrameModel): @@ -10,8 +10,13 @@ class Config: add_missing_columns = True coerce = True + @classmethod + def _index_name(self) -> str: + return "fid" + class BasinConcentrationExternalSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) substance: Series[str] = pa.Field(nullable=False) @@ -19,12 +24,14 @@ class BasinConcentrationExternalSchema(_BaseSchema): class BasinConcentrationStateSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) substance: Series[str] = pa.Field(nullable=False) concentration: Series[float] = pa.Field(nullable=True) class BasinConcentrationSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) substance: Series[str] = pa.Field(nullable=False) @@ -33,17 +40,20 @@ class BasinConcentrationSchema(_BaseSchema): class BasinProfileSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) area: Series[float] = pa.Field(nullable=False) level: Series[float] = pa.Field(nullable=False) class BasinStateSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) level: Series[float] = pa.Field(nullable=False) class BasinStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) drainage: Series[float] = pa.Field(nullable=True) potential_evaporation: Series[float] = pa.Field(nullable=True) @@ -52,6 +62,7 @@ class BasinStaticSchema(_BaseSchema): class BasinSubgridSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) subgrid_id: Series[Int32] = pa.Field(nullable=False, default=0) node_id: Series[Int32] = pa.Field(nullable=False, default=0) basin_level: Series[float] = pa.Field(nullable=False) @@ -59,6 +70,7 @@ class BasinSubgridSchema(_BaseSchema): class BasinTimeSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) drainage: Series[float] = pa.Field(nullable=True) @@ -68,6 +80,7 @@ class BasinTimeSchema(_BaseSchema): class ContinuousControlFunctionSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) input: Series[float] = pa.Field(nullable=False) output: Series[float] = pa.Field(nullable=False) @@ -75,6 +88,7 @@ class ContinuousControlFunctionSchema(_BaseSchema): class ContinuousControlVariableSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) listen_node_id: Series[Int32] = pa.Field(nullable=False, default=0) variable: Series[str] = pa.Field(nullable=False) @@ -83,18 +97,21 @@ class ContinuousControlVariableSchema(_BaseSchema): class DiscreteControlConditionSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) compound_variable_id: Series[Int32] = pa.Field(nullable=False, default=0) greater_than: Series[float] = pa.Field(nullable=False) class DiscreteControlLogicSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) truth_state: Series[str] = pa.Field(nullable=False) control_state: Series[str] = pa.Field(nullable=False) class DiscreteControlVariableSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) compound_variable_id: Series[Int32] = pa.Field(nullable=False, default=0) listen_node_id: Series[Int32] = pa.Field(nullable=False, default=0) @@ -104,6 +121,7 @@ class DiscreteControlVariableSchema(_BaseSchema): class FlowBoundaryConcentrationSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) substance: Series[str] = pa.Field(nullable=False) @@ -111,24 +129,28 @@ class FlowBoundaryConcentrationSchema(_BaseSchema): class FlowBoundaryStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) active: Series[pa.BOOL] = pa.Field(nullable=True) flow_rate: Series[float] = pa.Field(nullable=False) class FlowBoundaryTimeSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) flow_rate: Series[float] = pa.Field(nullable=False) class FlowDemandStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) demand: Series[float] = pa.Field(nullable=False) priority: Series[Int32] = pa.Field(nullable=False, default=0) class FlowDemandTimeSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) demand: Series[float] = pa.Field(nullable=False) @@ -136,6 +158,7 @@ class FlowDemandTimeSchema(_BaseSchema): class LevelBoundaryConcentrationSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) substance: Series[str] = pa.Field(nullable=False) @@ -143,18 +166,21 @@ class LevelBoundaryConcentrationSchema(_BaseSchema): class LevelBoundaryStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) active: Series[pa.BOOL] = pa.Field(nullable=True) level: Series[float] = pa.Field(nullable=False) class LevelBoundaryTimeSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) level: Series[float] = pa.Field(nullable=False) class LevelDemandStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) min_level: Series[float] = pa.Field(nullable=True) max_level: Series[float] = pa.Field(nullable=True) @@ -162,6 +188,7 @@ class LevelDemandStaticSchema(_BaseSchema): class LevelDemandTimeSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) min_level: Series[float] = pa.Field(nullable=True) @@ -170,6 +197,7 @@ class LevelDemandTimeSchema(_BaseSchema): class LinearResistanceStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) active: Series[pa.BOOL] = pa.Field(nullable=True) resistance: Series[float] = pa.Field(nullable=False) @@ -178,6 +206,7 @@ class LinearResistanceStaticSchema(_BaseSchema): class ManningResistanceStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) active: Series[pa.BOOL] = pa.Field(nullable=True) length: Series[float] = pa.Field(nullable=False) @@ -188,6 +217,7 @@ class ManningResistanceStaticSchema(_BaseSchema): class OutletStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) active: Series[pa.BOOL] = pa.Field(nullable=True) flow_rate: Series[float] = pa.Field(nullable=False) @@ -198,6 +228,7 @@ class OutletStaticSchema(_BaseSchema): class PidControlStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) active: Series[pa.BOOL] = pa.Field(nullable=True) listen_node_id: Series[Int32] = pa.Field(nullable=False, default=0) @@ -209,6 +240,7 @@ class PidControlStaticSchema(_BaseSchema): class PidControlTimeSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) listen_node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) @@ -220,6 +252,7 @@ class PidControlTimeSchema(_BaseSchema): class PumpStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) active: Series[pa.BOOL] = pa.Field(nullable=True) flow_rate: Series[float] = pa.Field(nullable=False) @@ -229,6 +262,7 @@ class PumpStaticSchema(_BaseSchema): class TabulatedRatingCurveStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) active: Series[pa.BOOL] = pa.Field(nullable=True) level: Series[float] = pa.Field(nullable=False) @@ -237,6 +271,7 @@ class TabulatedRatingCurveStaticSchema(_BaseSchema): class TabulatedRatingCurveTimeSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) level: Series[float] = pa.Field(nullable=False) @@ -244,6 +279,7 @@ class TabulatedRatingCurveTimeSchema(_BaseSchema): class UserDemandStaticSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) active: Series[pa.BOOL] = pa.Field(nullable=True) demand: Series[float] = pa.Field(nullable=True) @@ -253,6 +289,7 @@ class UserDemandStaticSchema(_BaseSchema): class UserDemandTimeSchema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) node_id: Series[Int32] = pa.Field(nullable=False, default=0) time: Series[Timestamp] = pa.Field(nullable=False) demand: Series[float] = pa.Field(nullable=False) diff --git a/python/ribasim/ribasim/utils.py b/python/ribasim/ribasim/utils.py index a67186483..cea0bb31c 100644 --- a/python/ribasim/ribasim/utils.py +++ b/python/ribasim/ribasim/utils.py @@ -4,6 +4,7 @@ import pandas as pd from pandera.dtypes import Int32 from pandera.typing import Series +from pydantic import BaseModel, NonNegativeInt def _pascal_to_snake(pascal_str): @@ -67,3 +68,25 @@ def _time_in_ns(df) -> None: """Convert the time column to datetime64[ns] dtype.""" # datetime64[ms] gives trouble; https://github.com/pydata/xarray/issues/6318 df["time"] = df["time"].astype("datetime64[ns]") + + +class UsedIDs(BaseModel): + """A helper class to manage global unique (node) IDs. + + We keep track of all IDs in the model, + and keep track of the maximum to provide new IDs. + MultiNodeModels and Edge will check this instance on `add`. + """ + + node_ids: set[int] = set() + max_node_id: NonNegativeInt = 0 + + def add(self, node_id: int) -> None: + self.node_ids.add(node_id) + self.max_node_id = max(self.max_node_id, node_id) + + def __contains__(self, value: int) -> bool: + return self.node_ids.__contains__(value) + + def new_id(self) -> int: + return self.max_node_id + 1 diff --git a/python/ribasim/tests/test_edge.py b/python/ribasim/tests/test_edge.py index f0c0690b1..754f65b6f 100644 --- a/python/ribasim/tests/test_edge.py +++ b/python/ribasim/tests/test_edge.py @@ -13,8 +13,10 @@ def edge() -> EdgeTable: d = (1.0, 1.0) geometry = [sg.LineString([a, b, c]), sg.LineString([a, d])] df = gpd.GeoDataFrame( - data={"from_node_id": [1, 1], "to_node_id": [2, 3]}, geometry=geometry + data={"edge_id": [0, 1], "from_node_id": [1, 1], "to_node_id": [2, 3]}, + geometry=geometry, ) + df.set_index("edge_id", inplace=True) edge = EdgeTable(df=df) return edge @@ -25,11 +27,13 @@ def test_validation(edge): with pytest.raises(ValidationError): df = gpd.GeoDataFrame( data={ + "edge_id": [0, 1], "from_node_id": [1, 1], "to_node_id": ["foo", 3], }, # None is coerced to 0 without errors geometry=[None, None], ) + df.set_index("edge_id", inplace=True) EdgeTable(df=df) diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index 6da93ced9..1fbbd641e 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np +import pandas as pd import pytest import ribasim import tomli @@ -22,8 +23,6 @@ def __assert_equal(a: DataFrame, b: DataFrame) -> None: a = a.reset_index(drop=True) b = b.reset_index(drop=True) - a.drop(columns=["fid"], inplace=True, errors="ignore") - b.drop(columns=["fid"], inplace=True, errors="ignore") assert_frame_equal(a, b) @@ -94,6 +93,11 @@ def test_extra_columns(): pump.Static(extra=[-2], flow_rate=[1.2]) +def test_index_tables(): + p = pump.Static(flow_rate=[1.2]) + assert p.df.index.name == "fid" + + def test_extra_spatial_columns(): model = Model( starttime="2020-01-01", @@ -226,7 +230,7 @@ def test_sort(level_range, tmp_path): def test_roundtrip(trivial, tmp_path): model1 = trivial # set custom Edge index - model1.edge.df.index = [15, 12] + model1.edge.df.index = pd.Index([15, 12], name="edge_id") model1dir = tmp_path / "model1" model2dir = tmp_path / "model2" # read a model and then write it to a different path diff --git a/python/ribasim/tests/test_model.py b/python/ribasim/tests/test_model.py index db249f0a4..81e6800cd 100644 --- a/python/ribasim/tests/test_model.py +++ b/python/ribasim/tests/test_model.py @@ -98,12 +98,12 @@ def test_write_adds_fid_in_tables(basic, tmp_path): model_orig = basic # for node an explicit index was provided nrow = len(model_orig.basin.node.df) - assert model_orig.basin.node.df.index.name is None + assert model_orig.basin.node.df.index.name == "node_id" - # for edge no index was provided, but it still needs to write it to file + # for edge an explicit index was provided nrow = len(model_orig.edge.df) - assert model_orig.edge.df.index.name == "fid" - assert model_orig.edge.df.index.equals(pd.RangeIndex(nrow)) + assert model_orig.edge.df.index.name == "edge_id" + assert model_orig.edge.df.index.equals(pd.RangeIndex(1, nrow + 1)) model_orig.write(tmp_path / "basic/ribasim.toml") with connect(tmp_path / "basic/database.gpkg") as connection: @@ -111,13 +111,13 @@ def test_write_adds_fid_in_tables(basic, tmp_path): df = pd.read_sql_query(query, connection) assert "fid" in df.columns - query = "select fid from Node" + query = "select node_id from Node" df = pd.read_sql_query(query, connection) - assert "fid" in df.columns + assert "node_id" in df.columns - query = "select fid from Edge" + query = "select edge_id from Edge" df = pd.read_sql_query(query, connection) - assert "fid" in df.columns + assert "edge_id" in df.columns def test_node_table(basic): @@ -125,10 +125,10 @@ def test_node_table(basic): node = model.node_table() df = node.df assert df.geometry.is_unique - assert df.node_id.dtype == np.int32 + assert df.index.dtype == np.int32 assert df.subnetwork_id.dtype == pd.Int32Dtype() assert df.node_type.iloc[0] == "Basin" - assert df.node_type.iloc[-1] == "Terminal" + assert df.node_type.iloc[-1] == "LevelBoundary" assert df.crs == CRS.from_epsg(28992) @@ -145,7 +145,9 @@ def test_duplicate_edge(trivial): model = trivial with pytest.raises( ValueError, - match=re.escape("Edges have to be unique, but edge (6, 0) already exists."), + match=re.escape( + "Edges have to be unique, but edge with from_node_id 6 to_node_id 0 already exists." + ), ): model.edge.add( model.basin[6], diff --git a/python/ribasim_testmodels/ribasim_testmodels/trivial.py b/python/ribasim_testmodels/ribasim_testmodels/trivial.py index 0e84e1902..6e01911f1 100644 --- a/python/ribasim_testmodels/ribasim_testmodels/trivial.py +++ b/python/ribasim_testmodels/ribasim_testmodels/trivial.py @@ -48,10 +48,7 @@ def trivial_model() -> Model: [tabulated_rating_curve.Static(level=[0.0, 1.0], flow_rate=[0.0, 10 / 86400])], ) - model.edge.add( - basin6, - trc0, - ) + model.edge.add(basin6, trc0, edge_id=100) model.edge.add(trc0, term) return model diff --git a/ribasim_qgis/core/nodes.py b/ribasim_qgis/core/nodes.py index f5d9dd82e..f7be187f1 100644 --- a/ribasim_qgis/core/nodes.py +++ b/ribasim_qgis/core/nodes.py @@ -224,6 +224,7 @@ class Edge(Input): @classmethod def attributes(cls) -> list[QgsField]: return [ + QgsField("edge_id", QVariant.Int), QgsField("name", QVariant.String), QgsField("from_node_id", QVariant.Int), QgsField("to_node_id", QVariant.Int), @@ -260,7 +261,7 @@ def set_editor_widget(self) -> None: @property def labels(self) -> Any: pal_layer = QgsPalLayerSettings() - pal_layer.fieldName = """concat("name", ' #', "fid")""" + pal_layer.fieldName = """concat("name", ' #', "edge_id")""" pal_layer.isExpression = True pal_layer.placement = Qgis.LabelPlacement.Line pal_layer.dist = 1.0 diff --git a/utils/templates/schemas.py.jinja b/utils/templates/schemas.py.jinja index 1ff1a6a23..8a9cac921 100644 --- a/utils/templates/schemas.py.jinja +++ b/utils/templates/schemas.py.jinja @@ -2,7 +2,7 @@ import pandera as pa from pandera.dtypes import Int32, Timestamp -from pandera.typing import Series +from pandera.typing import Series, Index class _BaseSchema(pa.DataFrameModel): @@ -10,9 +10,14 @@ class _BaseSchema(pa.DataFrameModel): add_missing_columns = True coerce = True + @classmethod + def _index_name(self) -> str: + return "fid" + {% for m in models %} class {{m[:name]}}Schema(_BaseSchema): + fid: Index[Int32] = pa.Field(default=1, check_name=True, coerce=True) {% for f in m[:fields] %} {% if (f[2] == "Series[Int32]") %} {{ f[1] }}: {{ f[2] }} = pa.Field(nullable={{ f[3] }}, default=0)