diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py index d68d3f710..4f4d45fa4 100644 --- a/python/ribasim/ribasim/config.py +++ b/python/ribasim/ribasim/config.py @@ -1,3 +1,4 @@ +import numbers from collections.abc import Sequence from enum import Enum from typing import Any @@ -144,10 +145,18 @@ def add(self, node: Node, tables: Sequence[TableModel[Any]] | None = None) -> No else pd.concat([self.node.df, node_table]) ) - def __getitem__(self, index): - row = self.node.df[self.node.df["node_id"] == index].iloc[0] + def __getitem__(self, index: int) -> NodeData: + # Unlike TableModel, support only indexing single rows. + if not isinstance(index, numbers.Integral): + node_model_name = type(self).__name__ + indextype = type(index).__name__ + raise TypeError( + f"{node_model_name} index must be an integer, not {indextype}" + ) + + row = self.node[index].iloc[0] return NodeData( - node_id=index, node_type=row["node_type"], geometry=row["geometry"] + node_id=int(index), node_type=row["node_type"], geometry=row["geometry"] ) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index c71f47aa5..ba92a824a 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -159,3 +159,6 @@ def plot(self, **kwargs) -> Axes: ) return ax + + def __getitem__(self, _): + raise NotImplementedError diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index cf5373fc9..3b6ec99b6 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -14,6 +14,7 @@ ) import geopandas as gpd +import numpy as np import pandas as pd import pandera as pa from pandera.typing import DataFrame @@ -337,6 +338,20 @@ def _repr_html_(self): else: return f"
{self.tablename()}
" + self.df._repr_html_() + def __getitem__(self, index) -> pd.DataFrame | gpd.GeoDataFrame: + tablename = self.tablename() + if self.df is None: + raise ValueError(f"Cannot index into {tablename}: it contains no data.") + + # Allow for indexing with multiple values. + np_index = np.atleast_1d(index) + missing = np.setdiff1d(np_index, self.df["node_id"].unique()) + if missing.size > 0: + raise IndexError(f"{tablename} does not contain node_id: {missing}") + + # Index with .loc[..., :] to always return a DataFrame. + return self.df.loc[self.df["node_id"].isin(np_index), :] + class SpatialTableModel(TableModel[TableT], Generic[TableT]): df: GeoDataFrame[TableT] | None = Field(default=None, exclude=True, repr=False) diff --git a/python/ribasim/tests/test_edge.py b/python/ribasim/tests/test_edge.py index 045f11645..ce1acd501 100644 --- a/python/ribasim/tests/test_edge.py +++ b/python/ribasim/tests/test_edge.py @@ -32,3 +32,8 @@ def test_validation(edge): def test_edge_plot(edge): edge.plot() + + +def test_edge_indexing(edge): + with pytest.raises(NotImplementedError): + edge[1] diff --git a/python/ribasim/tests/test_model.py b/python/ribasim/tests/test_model.py index eefc87e9c..d5f3b4207 100644 --- a/python/ribasim/tests/test_model.py +++ b/python/ribasim/tests/test_model.py @@ -6,6 +6,7 @@ import pytest from pydantic import ValidationError from ribasim.config import Solver +from ribasim.geometry.edge import NodeData from ribasim.input_base import esc_id from ribasim.model import Model from shapely import Point @@ -182,3 +183,34 @@ def test_node_table(basic): assert df.geometry.is_unique assert df.node_type.iloc[0] == "Basin" assert df.node_type.iloc[-1] == "Terminal" + + +def test_indexing(basic): + model = basic + + result = model.basin[1] + assert isinstance(result, NodeData) + + # Also test with a numpy type + result = model.basin[np.int32(1)] + assert isinstance(result, NodeData) + + with pytest.raises(TypeError, match="Basin index must be an integer, not list"): + model.basin[[1, 3, 6]] + + result = model.basin.static[1] + assert isinstance(result, pd.DataFrame) + + result = model.basin.static[[1, 3, 6]] + assert isinstance(result, pd.DataFrame) + + with pytest.raises( + IndexError, match=re.escape("Basin / static does not contain node_id: [2]") + ): + model.basin.static[2] + + with pytest.raises( + ValueError, + match=re.escape("Cannot index into Basin / time: it contains no data."), + ): + model.basin.time[1]