Skip to content

Commit

Permalink
Add support for indexing into TableModel, except Edge (#1304)
Browse files Browse the repository at this point in the history
  • Loading branch information
Huite authored Mar 22, 2024
1 parent 1fe87c2 commit f43010b
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 3 deletions.
15 changes: 12 additions & 3 deletions python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numbers
from collections.abc import Sequence
from enum import Enum
from typing import Any
Expand Down Expand Up @@ -144,10 +145,18 @@ def add(self, node: Node, tables: Sequence[TableModel[Any]] | None = None) -> No
else pd.concat([self.node.df, node_table])
)

def __getitem__(self, index):
row = self.node.df[self.node.df["node_id"] == index].iloc[0]
def __getitem__(self, index: int) -> NodeData:
# Unlike TableModel, support only indexing single rows.
if not isinstance(index, numbers.Integral):
node_model_name = type(self).__name__
indextype = type(index).__name__
raise TypeError(
f"{node_model_name} index must be an integer, not {indextype}"
)

row = self.node[index].iloc[0]
return NodeData(
node_id=index, node_type=row["node_type"], geometry=row["geometry"]
node_id=int(index), node_type=row["node_type"], geometry=row["geometry"]
)


Expand Down
3 changes: 3 additions & 0 deletions python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,6 @@ def plot(self, **kwargs) -> Axes:
)

return ax

def __getitem__(self, _):
raise NotImplementedError
15 changes: 15 additions & 0 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

import geopandas as gpd
import numpy as np
import pandas as pd
import pandera as pa
from pandera.typing import DataFrame
Expand Down Expand Up @@ -337,6 +338,20 @@ def _repr_html_(self):
else:
return f"<div>{self.tablename()}</div>" + self.df._repr_html_()

def __getitem__(self, index) -> pd.DataFrame | gpd.GeoDataFrame:
tablename = self.tablename()
if self.df is None:
raise ValueError(f"Cannot index into {tablename}: it contains no data.")

# Allow for indexing with multiple values.
np_index = np.atleast_1d(index)
missing = np.setdiff1d(np_index, self.df["node_id"].unique())
if missing.size > 0:
raise IndexError(f"{tablename} does not contain node_id: {missing}")

# Index with .loc[..., :] to always return a DataFrame.
return self.df.loc[self.df["node_id"].isin(np_index), :]


class SpatialTableModel(TableModel[TableT], Generic[TableT]):
df: GeoDataFrame[TableT] | None = Field(default=None, exclude=True, repr=False)
Expand Down
5 changes: 5 additions & 0 deletions python/ribasim/tests/test_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ def test_validation(edge):

def test_edge_plot(edge):
edge.plot()


def test_edge_indexing(edge):
with pytest.raises(NotImplementedError):
edge[1]
32 changes: 32 additions & 0 deletions python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from pydantic import ValidationError
from ribasim.config import Solver
from ribasim.geometry.edge import NodeData
from ribasim.input_base import esc_id
from ribasim.model import Model
from shapely import Point
Expand Down Expand Up @@ -182,3 +183,34 @@ def test_node_table(basic):
assert df.geometry.is_unique
assert df.node_type.iloc[0] == "Basin"
assert df.node_type.iloc[-1] == "Terminal"


def test_indexing(basic):
model = basic

result = model.basin[1]
assert isinstance(result, NodeData)

# Also test with a numpy type
result = model.basin[np.int32(1)]
assert isinstance(result, NodeData)

with pytest.raises(TypeError, match="Basin index must be an integer, not list"):
model.basin[[1, 3, 6]]

result = model.basin.static[1]
assert isinstance(result, pd.DataFrame)

result = model.basin.static[[1, 3, 6]]
assert isinstance(result, pd.DataFrame)

with pytest.raises(
IndexError, match=re.escape("Basin / static does not contain node_id: [2]")
):
model.basin.static[2]

with pytest.raises(
ValueError,
match=re.escape("Cannot index into Basin / time: it contains no data."),
):
model.basin.time[1]

0 comments on commit f43010b

Please sign in to comment.