diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py
index d68d3f710..4f4d45fa4 100644
--- a/python/ribasim/ribasim/config.py
+++ b/python/ribasim/ribasim/config.py
@@ -1,3 +1,4 @@
+import numbers
from collections.abc import Sequence
from enum import Enum
from typing import Any
@@ -144,10 +145,18 @@ def add(self, node: Node, tables: Sequence[TableModel[Any]] | None = None) -> No
else pd.concat([self.node.df, node_table])
)
- def __getitem__(self, index):
- row = self.node.df[self.node.df["node_id"] == index].iloc[0]
+ def __getitem__(self, index: int) -> NodeData:
+ # Unlike TableModel, support only indexing single rows.
+ if not isinstance(index, numbers.Integral):
+ node_model_name = type(self).__name__
+ indextype = type(index).__name__
+ raise TypeError(
+ f"{node_model_name} index must be an integer, not {indextype}"
+ )
+
+ row = self.node[index].iloc[0]
return NodeData(
- node_id=index, node_type=row["node_type"], geometry=row["geometry"]
+ node_id=int(index), node_type=row["node_type"], geometry=row["geometry"]
)
diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py
index c71f47aa5..ba92a824a 100644
--- a/python/ribasim/ribasim/geometry/edge.py
+++ b/python/ribasim/ribasim/geometry/edge.py
@@ -159,3 +159,6 @@ def plot(self, **kwargs) -> Axes:
)
return ax
+
+ def __getitem__(self, _):
+ raise NotImplementedError
diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py
index cf5373fc9..3b6ec99b6 100644
--- a/python/ribasim/ribasim/input_base.py
+++ b/python/ribasim/ribasim/input_base.py
@@ -14,6 +14,7 @@
)
import geopandas as gpd
+import numpy as np
import pandas as pd
import pandera as pa
from pandera.typing import DataFrame
@@ -337,6 +338,20 @@ def _repr_html_(self):
else:
return f"
{self.tablename()}
" + self.df._repr_html_()
+ def __getitem__(self, index) -> pd.DataFrame | gpd.GeoDataFrame:
+ tablename = self.tablename()
+ if self.df is None:
+ raise ValueError(f"Cannot index into {tablename}: it contains no data.")
+
+ # Allow for indexing with multiple values.
+ np_index = np.atleast_1d(index)
+ missing = np.setdiff1d(np_index, self.df["node_id"].unique())
+ if missing.size > 0:
+ raise IndexError(f"{tablename} does not contain node_id: {missing}")
+
+ # Index with .loc[..., :] to always return a DataFrame.
+ return self.df.loc[self.df["node_id"].isin(np_index), :]
+
class SpatialTableModel(TableModel[TableT], Generic[TableT]):
df: GeoDataFrame[TableT] | None = Field(default=None, exclude=True, repr=False)
diff --git a/python/ribasim/tests/test_edge.py b/python/ribasim/tests/test_edge.py
index 045f11645..ce1acd501 100644
--- a/python/ribasim/tests/test_edge.py
+++ b/python/ribasim/tests/test_edge.py
@@ -32,3 +32,8 @@ def test_validation(edge):
def test_edge_plot(edge):
edge.plot()
+
+
+def test_edge_indexing(edge):
+ with pytest.raises(NotImplementedError):
+ edge[1]
diff --git a/python/ribasim/tests/test_model.py b/python/ribasim/tests/test_model.py
index eefc87e9c..d5f3b4207 100644
--- a/python/ribasim/tests/test_model.py
+++ b/python/ribasim/tests/test_model.py
@@ -6,6 +6,7 @@
import pytest
from pydantic import ValidationError
from ribasim.config import Solver
+from ribasim.geometry.edge import NodeData
from ribasim.input_base import esc_id
from ribasim.model import Model
from shapely import Point
@@ -182,3 +183,34 @@ def test_node_table(basic):
assert df.geometry.is_unique
assert df.node_type.iloc[0] == "Basin"
assert df.node_type.iloc[-1] == "Terminal"
+
+
+def test_indexing(basic):
+ model = basic
+
+ result = model.basin[1]
+ assert isinstance(result, NodeData)
+
+ # Also test with a numpy type
+ result = model.basin[np.int32(1)]
+ assert isinstance(result, NodeData)
+
+ with pytest.raises(TypeError, match="Basin index must be an integer, not list"):
+ model.basin[[1, 3, 6]]
+
+ result = model.basin.static[1]
+ assert isinstance(result, pd.DataFrame)
+
+ result = model.basin.static[[1, 3, 6]]
+ assert isinstance(result, pd.DataFrame)
+
+ with pytest.raises(
+ IndexError, match=re.escape("Basin / static does not contain node_id: [2]")
+ ):
+ model.basin.static[2]
+
+ with pytest.raises(
+ ValueError,
+ match=re.escape("Cannot index into Basin / time: it contains no data."),
+ ):
+ model.basin.time[1]