diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index 4f2d61946..f6285e1c5 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -49,10 +49,9 @@ class EdgeSchema(_GeoBaseSchema): edge_type: Series[str] = pa.Field(default="flow") geometry: GeoSeries[LineString] = pa.Field(default=None, nullable=True) - @pa.dataframe_parser - def _name_index(cls, df): - df.index.name = "edge_id" - return df + @classmethod + def _index_name(self) -> str: + return "edge_id" class EdgeTable(SpatialTableModel[EdgeSchema]): diff --git a/python/ribasim/ribasim/geometry/node.py b/python/ribasim/ribasim/geometry/node.py index b92354b63..691dc4d19 100644 --- a/python/ribasim/ribasim/geometry/node.py +++ b/python/ribasim/ribasim/geometry/node.py @@ -27,10 +27,9 @@ class NodeSchema(_GeoBaseSchema): ) geometry: GeoSeries[Point] = pa.Field(default=None, nullable=True) - @pa.dataframe_parser - def _name_index(cls, df): - df.index.name = "node_id" - return df + @classmethod + def _index_name(self) -> str: + return "node_id" class NodeTable(SpatialTableModel[NodeSchema]): diff --git a/python/ribasim/ribasim/schemas.py b/python/ribasim/ribasim/schemas.py index 8de01c774..ac0292ccb 100644 --- a/python/ribasim/ribasim/schemas.py +++ b/python/ribasim/ribasim/schemas.py @@ -16,11 +16,13 @@ class Config: add_missing_columns = True coerce = True + @classmethod + def _index_name(self) -> str: + return "fid" + @pa.dataframe_parser def _name_index(cls, df): - # Node and Edge have different index names, avoid running both parsers - if cls.__name__ not in ("NodeSchema", "EdgeSchema"): - df.index.name = "fid" + df.index.name = cls._index_name() return df @classmethod diff --git a/utils/templates/schemas.py.jinja b/utils/templates/schemas.py.jinja index 6e79f6953..d986e4dc0 100644 --- a/utils/templates/schemas.py.jinja +++ b/utils/templates/schemas.py.jinja @@ -15,11 +15,13 @@ class _BaseSchema(pa.DataFrameModel): add_missing_columns = True coerce = True + @classmethod + def _index_name(self) -> str: + return "fid" + @pa.dataframe_parser def _name_index(cls, df): - # Node and Edge have different index names, avoid running both parsers - if cls.__name__ not in ("NodeSchema", "EdgeSchema"): - df.index.name = "fid" + df.index.name = cls._index_name() return df @classmethod