diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index c6613c908..48e0e03fd 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -217,7 +217,6 @@ def _check_dataframe(cls, value: Any) -> Any: # Enable initialization with a DataFrame. if isinstance(value, pd.DataFrame | gpd.GeoDataFrame): - value.index.rename("fid", inplace=True) value = {"df": value} return value @@ -386,7 +385,6 @@ def _from_db(cls, path: Path, table: str): # tell pyarrow to map to pd.ArrowDtype rather than NumPy arrow_to_pandas_kwargs={"types_mapper": pd.ArrowDtype}, ) - df.index.rename(cls.tableschema()._index_name(), inplace=True) else: df = None diff --git a/python/ribasim/ribasim/schemas.py b/python/ribasim/ribasim/schemas.py index 2bcd22222..ac0292ccb 100644 --- a/python/ribasim/ribasim/schemas.py +++ b/python/ribasim/ribasim/schemas.py @@ -20,6 +20,11 @@ class Config: def _index_name(self) -> str: return "fid" + @pa.dataframe_parser + def _name_index(cls, df): + df.index.name = cls._index_name() + return df + @classmethod def migrate(cls, df: Any, schema_version: int) -> Any: f: Callable[[Any, Any], Any] = getattr( diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index 0320dca87..7c2dcf7b7 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -97,6 +97,11 @@ def test_extra_columns(): def test_index_tables(): p = pump.Static(flow_rate=[1.2]) assert p.df.index.name == "fid" + # Index name is applied by _name_index + df = p.df.reset_index(drop=True) + assert df.index.name is None + p.df = df + assert p.df.index.name == "fid" def test_extra_spatial_columns(): diff --git a/python/ribasim/tests/test_model.py b/python/ribasim/tests/test_model.py index f5e341561..9336b5275 100644 --- a/python/ribasim/tests/test_model.py +++ b/python/ribasim/tests/test_model.py @@ -112,6 +112,12 @@ def test_write_adds_fid_in_tables(basic, tmp_path): assert model_orig.edge.df.index.name == "edge_id" assert model_orig.edge.df.index.equals(pd.RangeIndex(1, nrow + 1)) + # Index name is applied by _name_index + df = model_orig.edge.df.copy() + df.index.name = "other" + model_orig.edge.df = df + assert model_orig.edge.df.index.name == "edge_id" + model_orig.write(tmp_path / "basic/ribasim.toml") with connect(tmp_path / "basic/database.gpkg") as connection: query = f"select * from {esc_id('Basin / profile')}" diff --git a/utils/templates/schemas.py.jinja b/utils/templates/schemas.py.jinja index 19b53c7ce..d986e4dc0 100644 --- a/utils/templates/schemas.py.jinja +++ b/utils/templates/schemas.py.jinja @@ -19,6 +19,11 @@ class _BaseSchema(pa.DataFrameModel): def _index_name(self) -> str: return "fid" + @pa.dataframe_parser + def _name_index(cls, df): + df.index.name = cls._index_name() + return df + @classmethod def migrate(cls, df: Any, schema_version: int) -> Any: f: Callable[[Any, Any], Any] = getattr(