From eb7f512d003bdc4810c9503bce18423d151aab65 Mon Sep 17 00:00:00 2001 From: Maarten Pronk <8655030+evetion@users.noreply.github.com> Date: Tue, 23 Jan 2024 10:07:52 +0100 Subject: [PATCH] Refactor `sort_keys` implementation (#970) Fixes #956 While ideally we set the sort_keys on the Schema of a TableModel, or on the TableModel themselves, this is not directly possible AFAIK: - Schemas are autogenerated, and sort_keys is not something we can serialize or parse in a JSONSchema - The TableModels are generic (TableModel[T]), so specific sort_keys can't be hardcoded to a class. The previous approach was to hold the sort keys in a `Dict[table, sort_keys]` on a NodeModel (which holds multiple tables), and passing them on when save/sort was called from the Node. This PR improves on that approach by introducing a private _sort_keys on the TableModel, which is always set (from an empty list by default) on initialization from the NodeModel. Instead of keeping a separate Dict with sort_keys on a NodeModel, we now set them in the `Field` attribute `json_schema_extra` (the only customizable attribute of a Field) for each field. Each field of a NodeModel is now validated by default (also on assignment), and if a sort_keys is present for a fieldtype, it is set on the TableModel. Based on the initial exploration by @deltamarnix in https://github.com/Deltares/Ribasim/tree/feat/sort-keys-in-tablemodel --- python/ribasim/ribasim/config.py | 122 +++++++++++---------------- python/ribasim/ribasim/input_base.py | 31 +++++-- python/ribasim/tests/test_io.py | 10 ++- 3 files changed, 79 insertions(+), 84 deletions(-) diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py index e992e811b..167f8fb64 100644 --- a/python/ribasim/ribasim/config.py +++ b/python/ribasim/ribasim/config.py @@ -79,154 +79,132 @@ class Logging(ChildModel): class Terminal(NodeModel): static: TableModel[TerminalStaticSchema] = Field( - default_factory=TableModel[TerminalStaticSchema] + default_factory=TableModel[TerminalStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id"]} - class PidControl(NodeModel): static: TableModel[PidControlStaticSchema] = Field( - default_factory=TableModel[PidControlStaticSchema] + default_factory=TableModel[PidControlStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) time: TableModel[PidControlTimeSchema] = Field( - default_factory=TableModel[PidControlTimeSchema] + default_factory=TableModel[PidControlTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id", "control_state"], - "time": ["node_id", "time"], - } - class LevelBoundary(NodeModel): static: TableModel[LevelBoundaryStaticSchema] = Field( - default_factory=TableModel[LevelBoundaryStaticSchema] + default_factory=TableModel[LevelBoundaryStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) time: TableModel[LevelBoundaryTimeSchema] = Field( - default_factory=TableModel[LevelBoundaryTimeSchema] + default_factory=TableModel[LevelBoundaryTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id"], - "time": ["node_id", "time"], - } - class Pump(NodeModel): static: TableModel[PumpStaticSchema] = Field( - default_factory=TableModel[PumpStaticSchema] + default_factory=TableModel[PumpStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class TabulatedRatingCurve(NodeModel): static: TableModel[TabulatedRatingCurveStaticSchema] = Field( - default_factory=TableModel[TabulatedRatingCurveStaticSchema] + default_factory=TableModel[TabulatedRatingCurveStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state", "level"]}, ) time: TableModel[TabulatedRatingCurveTimeSchema] = Field( - default_factory=TableModel[TabulatedRatingCurveTimeSchema] + default_factory=TableModel[TabulatedRatingCurveTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time", "level"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id", "control_state", "level"], - "time": ["node_id", "time", "level"], - } class User(NodeModel): static: TableModel[UserStaticSchema] = Field( - default_factory=TableModel[UserStaticSchema] + default_factory=TableModel[UserStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "priority"]}, + ) + time: TableModel[UserTimeSchema] = Field( + default_factory=TableModel[UserTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "priority", "time"]}, ) - time: TableModel[UserTimeSchema] = Field(default_factory=TableModel[UserTimeSchema]) - - _sort_keys: dict[str, list[str]] = { - "static": ["node_id", "priority"], - "time": ["node_id", "priority", "time"], - } class FlowBoundary(NodeModel): static: TableModel[FlowBoundaryStaticSchema] = Field( - default_factory=TableModel[FlowBoundaryStaticSchema] + default_factory=TableModel[FlowBoundaryStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) time: TableModel[FlowBoundaryTimeSchema] = Field( - default_factory=TableModel[FlowBoundaryTimeSchema] + default_factory=TableModel[FlowBoundaryTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id"], - "time": ["node_id", "time"], - } - class Basin(NodeModel): profile: TableModel[BasinProfileSchema] = Field( - default_factory=TableModel[BasinProfileSchema] + default_factory=TableModel[BasinProfileSchema], + json_schema_extra={"sort_keys": ["node_id", "level"]}, ) state: TableModel[BasinStateSchema] = Field( - default_factory=TableModel[BasinStateSchema] + default_factory=TableModel[BasinStateSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) static: TableModel[BasinStaticSchema] = Field( - default_factory=TableModel[BasinStaticSchema] + default_factory=TableModel[BasinStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) time: TableModel[BasinTimeSchema] = Field( - default_factory=TableModel[BasinTimeSchema] + default_factory=TableModel[BasinTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) subgrid: TableModel[BasinSubgridSchema] = Field( - default_factory=TableModel[BasinSubgridSchema] + default_factory=TableModel[BasinSubgridSchema], + json_schema_extra={"sort_keys": ["subgrid_id", "basin_level"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id"], - "state": ["node_id"], - "profile": ["node_id", "level"], - "time": ["node_id", "time"], - "subgrid": ["subgrid_id", "basin_level"], - } - class ManningResistance(NodeModel): static: TableModel[ManningResistanceStaticSchema] = Field( - default_factory=TableModel[ManningResistanceStaticSchema] + default_factory=TableModel[ManningResistanceStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class DiscreteControl(NodeModel): condition: TableModel[DiscreteControlConditionSchema] = Field( - default_factory=TableModel[DiscreteControlConditionSchema] + default_factory=TableModel[DiscreteControlConditionSchema], + json_schema_extra={ + "sort_keys": ["node_id", "listen_feature_id", "variable", "greater_than"] + }, ) logic: TableModel[DiscreteControlLogicSchema] = Field( - default_factory=TableModel[DiscreteControlLogicSchema] + default_factory=TableModel[DiscreteControlLogicSchema], + json_schema_extra={"sort_keys": ["node_id", "truth_state"]}, ) - _sort_keys: dict[str, list[str]] = { - "condition": ["node_id", "listen_feature_id", "variable", "greater_than"], - "logic": ["node_id", "truth_state"], - } - class Outlet(NodeModel): static: TableModel[OutletStaticSchema] = Field( - default_factory=TableModel[OutletStaticSchema] + default_factory=TableModel[OutletStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class LinearResistance(NodeModel): static: TableModel[LinearResistanceStaticSchema] = Field( - default_factory=TableModel[LinearResistanceStaticSchema] + default_factory=TableModel[LinearResistanceStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class FractionalFlow(NodeModel): static: TableModel[FractionalFlowStaticSchema] = Field( - default_factory=TableModel[FractionalFlowStaticSchema] + default_factory=TableModel[FractionalFlowStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index 1538bfafa..17b066d01 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -9,6 +9,7 @@ Any, Generic, TypeVar, + cast, ) import geopandas as gpd @@ -20,6 +21,8 @@ ConfigDict, DirectoryPath, Field, + PrivateAttr, + ValidationInfo, field_validator, model_serializer, model_validator, @@ -158,6 +161,7 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: class TableModel(FileModel, Generic[TableT]): df: DataFrame[TableT] | None = Field(default=None, exclude=True, repr=False) + _sort_keys: list[str] = PrivateAttr(default=[]) @field_validator("df") @classmethod @@ -219,15 +223,14 @@ def _save( self, directory: DirectoryPath, input_dir: DirectoryPath, - sort_keys: list[str] = ["node_id"], ) -> None: # TODO directory could be used to save an arrow file db_path = context_file_loading.get().get("database") if self.df is not None and self.filepath is not None: - self.sort(sort_keys) + self.sort() self._write_arrow(self.filepath, directory, input_dir) elif self.df is not None and db_path is not None: - self.sort(sort_keys) + self.sort() self._write_table(db_path) def _write_table(self, temp_path: Path) -> None: @@ -289,13 +292,13 @@ def _from_arrow(cls, path: FilePath) -> pd.DataFrame: directory = context_file_loading.get().get("directory", Path(".")) return pd.read_feather(directory / path) - def sort(self, sort_keys: list[str]): + def sort(self): """Sort the table as required. Sorting is done automatically before writing the table. """ if self.df is not None: - self.df.sort_values(sort_keys, ignore_index=True, inplace=True) + self.df.sort_values(self._sort_keys, ignore_index=True, inplace=True) @classmethod def tableschema(cls) -> TableT: @@ -374,7 +377,7 @@ def _write_table(self, path: FilePath) -> None: gdf.to_file(path, layer=self.tablename(), driver="GPKG") - def sort(self, sort_keys: list[str]): + def sort(self): self.df.sort_index(inplace=True) @@ -392,8 +395,6 @@ def check_parent(self) -> "ChildModel": class NodeModel(ChildModel): """Base class to handle combining the tables for a single node type.""" - _sort_keys: dict[str, list[str]] = {} - @model_serializer(mode="wrap") def set_modeld( self, serializer: Callable[[type["NodeModel"]], dict[str, Any]] @@ -401,6 +402,19 @@ def set_modeld( content = serializer(self) return dict(filter(lambda x: x[1], content.items())) + @field_validator("*") + @classmethod + def set_sort_keys(cls, v: Any, info: ValidationInfo) -> Any: + """Set sort keys for all TableModels if present in FieldInfo.""" + if isinstance(v, (TableModel,)): + field = cls.model_fields[getattr(info, "field_name")] + extra = field.json_schema_extra + if extra is not None and isinstance(extra, dict): + # We set sort_keys ourselves as list[str] in json_schema_extra + # but mypy doesn't know. + v._sort_keys = cast(list[str], extra.get("sort_keys", [])) + return v + @classmethod def get_input_type(cls): return cls.__name__ @@ -434,7 +448,6 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs): getattr(self, field)._save( directory, input_dir, - sort_keys=self._sort_keys[field], ) def _repr_content(self) -> str: diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index b1ac52056..e86b9456b 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -111,9 +111,13 @@ def test_sort(level_setpoint_with_minmax, tmp_path): # apply a wrong sort, then call the sort method to restore order table.df.sort_values("greater_than", ascending=False, inplace=True) assert table.df.iloc[0]["greater_than"] == 15.0 - sort_keys = model.discrete_control._sort_keys["condition"] - assert sort_keys == ["node_id", "listen_feature_id", "variable", "greater_than"] - table.sort(sort_keys) + assert table._sort_keys == [ + "node_id", + "listen_feature_id", + "variable", + "greater_than", + ] + table.sort() assert table.df.iloc[0]["greater_than"] == 5.0 # re-apply wrong sort, then check if it gets sorted on write