Skip to content

Commit

Permalink
Automatically name index (#1974)
Browse files Browse the repository at this point in the history
Fixes #1968

We let pandera check the index name with `check_name = True`. This adds
a pandera parser called `_name_index` to set the index to the desired
name before validating. This still keeps the index names, but offers
more user convenience, since the index name is lost with many pandas
operations.
  • Loading branch information
visr authored Dec 16, 2024
1 parent c963001 commit fc02cfb
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 2 deletions.
2 changes: 0 additions & 2 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def _check_dataframe(cls, value: Any) -> Any:

# Enable initialization with a DataFrame.
if isinstance(value, pd.DataFrame | gpd.GeoDataFrame):
value.index.rename("fid", inplace=True)
value = {"df": value}

return value
Expand Down Expand Up @@ -386,7 +385,6 @@ def _from_db(cls, path: Path, table: str):
# tell pyarrow to map to pd.ArrowDtype rather than NumPy
arrow_to_pandas_kwargs={"types_mapper": pd.ArrowDtype},
)
df.index.rename(cls.tableschema()._index_name(), inplace=True)
else:
df = None

Expand Down
5 changes: 5 additions & 0 deletions python/ribasim/ribasim/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class Config:
def _index_name(self) -> str:
return "fid"

@pa.dataframe_parser
def _name_index(cls, df):
df.index.name = cls._index_name()
return df

@classmethod
def migrate(cls, df: Any, schema_version: int) -> Any:
f: Callable[[Any, Any], Any] = getattr(
Expand Down
5 changes: 5 additions & 0 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def test_extra_columns():
def test_index_tables():
p = pump.Static(flow_rate=[1.2])
assert p.df.index.name == "fid"
# Index name is applied by _name_index
df = p.df.reset_index(drop=True)
assert df.index.name is None
p.df = df
assert p.df.index.name == "fid"


def test_extra_spatial_columns():
Expand Down
6 changes: 6 additions & 0 deletions python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def test_write_adds_fid_in_tables(basic, tmp_path):
assert model_orig.edge.df.index.name == "edge_id"
assert model_orig.edge.df.index.equals(pd.RangeIndex(1, nrow + 1))

# Index name is applied by _name_index
df = model_orig.edge.df.copy()
df.index.name = "other"
model_orig.edge.df = df
assert model_orig.edge.df.index.name == "edge_id"

model_orig.write(tmp_path / "basic/ribasim.toml")
with connect(tmp_path / "basic/database.gpkg") as connection:
query = f"select * from {esc_id('Basin / profile')}"
Expand Down
5 changes: 5 additions & 0 deletions utils/templates/schemas.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class _BaseSchema(pa.DataFrameModel):
def _index_name(self) -> str:
return "fid"

@pa.dataframe_parser
def _name_index(cls, df):
df.index.name = cls._index_name()
return df

@classmethod
def migrate(cls, df: Any, schema_version: int) -> Any:
f: Callable[[Any, Any], Any] = getattr(
Expand Down

0 comments on commit fc02cfb

Please sign in to comment.