From 24d3864b724798e53fd44b7e69f5856db0722f13 Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Fri, 19 Jan 2024 10:09:44 +0100 Subject: [PATCH] Refactor sort_keys. --- python/ribasim/ribasim/config.py | 122 +++++++++++---------------- python/ribasim/ribasim/input_base.py | 26 ++++-- python/ribasim/tests/test_io.py | 10 ++- 3 files changed, 74 insertions(+), 84 deletions(-) diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py index e992e811b..167f8fb64 100644 --- a/python/ribasim/ribasim/config.py +++ b/python/ribasim/ribasim/config.py @@ -79,154 +79,132 @@ class Logging(ChildModel): class Terminal(NodeModel): static: TableModel[TerminalStaticSchema] = Field( - default_factory=TableModel[TerminalStaticSchema] + default_factory=TableModel[TerminalStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id"]} - class PidControl(NodeModel): static: TableModel[PidControlStaticSchema] = Field( - default_factory=TableModel[PidControlStaticSchema] + default_factory=TableModel[PidControlStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) time: TableModel[PidControlTimeSchema] = Field( - default_factory=TableModel[PidControlTimeSchema] + default_factory=TableModel[PidControlTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id", "control_state"], - "time": ["node_id", "time"], - } - class LevelBoundary(NodeModel): static: TableModel[LevelBoundaryStaticSchema] = Field( - default_factory=TableModel[LevelBoundaryStaticSchema] + default_factory=TableModel[LevelBoundaryStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) time: TableModel[LevelBoundaryTimeSchema] = Field( - default_factory=TableModel[LevelBoundaryTimeSchema] + default_factory=TableModel[LevelBoundaryTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id"], - "time": ["node_id", "time"], - } - class Pump(NodeModel): static: TableModel[PumpStaticSchema] = Field( - default_factory=TableModel[PumpStaticSchema] + default_factory=TableModel[PumpStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class TabulatedRatingCurve(NodeModel): static: TableModel[TabulatedRatingCurveStaticSchema] = Field( - default_factory=TableModel[TabulatedRatingCurveStaticSchema] + default_factory=TableModel[TabulatedRatingCurveStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state", "level"]}, ) time: TableModel[TabulatedRatingCurveTimeSchema] = Field( - default_factory=TableModel[TabulatedRatingCurveTimeSchema] + default_factory=TableModel[TabulatedRatingCurveTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time", "level"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id", "control_state", "level"], - "time": ["node_id", "time", "level"], - } class User(NodeModel): static: TableModel[UserStaticSchema] = Field( - default_factory=TableModel[UserStaticSchema] + default_factory=TableModel[UserStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "priority"]}, + ) + time: TableModel[UserTimeSchema] = Field( + default_factory=TableModel[UserTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "priority", "time"]}, ) - time: TableModel[UserTimeSchema] = Field(default_factory=TableModel[UserTimeSchema]) - - _sort_keys: dict[str, list[str]] = { - "static": ["node_id", "priority"], - "time": ["node_id", "priority", "time"], - } class FlowBoundary(NodeModel): static: TableModel[FlowBoundaryStaticSchema] = Field( - default_factory=TableModel[FlowBoundaryStaticSchema] + default_factory=TableModel[FlowBoundaryStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) time: TableModel[FlowBoundaryTimeSchema] = Field( - default_factory=TableModel[FlowBoundaryTimeSchema] + default_factory=TableModel[FlowBoundaryTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id"], - "time": ["node_id", "time"], - } - class Basin(NodeModel): profile: TableModel[BasinProfileSchema] = Field( - default_factory=TableModel[BasinProfileSchema] + default_factory=TableModel[BasinProfileSchema], + json_schema_extra={"sort_keys": ["node_id", "level"]}, ) state: TableModel[BasinStateSchema] = Field( - default_factory=TableModel[BasinStateSchema] + default_factory=TableModel[BasinStateSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) static: TableModel[BasinStaticSchema] = Field( - default_factory=TableModel[BasinStaticSchema] + default_factory=TableModel[BasinStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) time: TableModel[BasinTimeSchema] = Field( - default_factory=TableModel[BasinTimeSchema] + default_factory=TableModel[BasinTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) subgrid: TableModel[BasinSubgridSchema] = Field( - default_factory=TableModel[BasinSubgridSchema] + default_factory=TableModel[BasinSubgridSchema], + json_schema_extra={"sort_keys": ["subgrid_id", "basin_level"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id"], - "state": ["node_id"], - "profile": ["node_id", "level"], - "time": ["node_id", "time"], - "subgrid": ["subgrid_id", "basin_level"], - } - class ManningResistance(NodeModel): static: TableModel[ManningResistanceStaticSchema] = Field( - default_factory=TableModel[ManningResistanceStaticSchema] + default_factory=TableModel[ManningResistanceStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class DiscreteControl(NodeModel): condition: TableModel[DiscreteControlConditionSchema] = Field( - default_factory=TableModel[DiscreteControlConditionSchema] + default_factory=TableModel[DiscreteControlConditionSchema], + json_schema_extra={ + "sort_keys": ["node_id", "listen_feature_id", "variable", "greater_than"] + }, ) logic: TableModel[DiscreteControlLogicSchema] = Field( - default_factory=TableModel[DiscreteControlLogicSchema] + default_factory=TableModel[DiscreteControlLogicSchema], + json_schema_extra={"sort_keys": ["node_id", "truth_state"]}, ) - _sort_keys: dict[str, list[str]] = { - "condition": ["node_id", "listen_feature_id", "variable", "greater_than"], - "logic": ["node_id", "truth_state"], - } - class Outlet(NodeModel): static: TableModel[OutletStaticSchema] = Field( - default_factory=TableModel[OutletStaticSchema] + default_factory=TableModel[OutletStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class LinearResistance(NodeModel): static: TableModel[LinearResistanceStaticSchema] = Field( - default_factory=TableModel[LinearResistanceStaticSchema] + default_factory=TableModel[LinearResistanceStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class FractionalFlow(NodeModel): static: TableModel[FractionalFlowStaticSchema] = Field( - default_factory=TableModel[FractionalFlowStaticSchema] + default_factory=TableModel[FractionalFlowStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index 1538bfafa..82359ed61 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -20,6 +20,7 @@ ConfigDict, DirectoryPath, Field, + ValidationInfo, field_validator, model_serializer, model_validator, @@ -158,6 +159,7 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: class TableModel(FileModel, Generic[TableT]): df: DataFrame[TableT] | None = Field(default=None, exclude=True, repr=False) + sort_keys: list[str] = Field(default=[], exclude=True, repr=False) @field_validator("df") @classmethod @@ -219,15 +221,14 @@ def _save( self, directory: DirectoryPath, input_dir: DirectoryPath, - sort_keys: list[str] = ["node_id"], ) -> None: # TODO directory could be used to save an arrow file db_path = context_file_loading.get().get("database") if self.df is not None and self.filepath is not None: - self.sort(sort_keys) + self.sort() self._write_arrow(self.filepath, directory, input_dir) elif self.df is not None and db_path is not None: - self.sort(sort_keys) + self.sort() self._write_table(db_path) def _write_table(self, temp_path: Path) -> None: @@ -289,13 +290,13 @@ def _from_arrow(cls, path: FilePath) -> pd.DataFrame: directory = context_file_loading.get().get("directory", Path(".")) return pd.read_feather(directory / path) - def sort(self, sort_keys: list[str]): + def sort(self): """Sort the table as required. Sorting is done automatically before writing the table. """ if self.df is not None: - self.df.sort_values(sort_keys, ignore_index=True, inplace=True) + self.df.sort_values(self.sort_keys, ignore_index=True, inplace=True) @classmethod def tableschema(cls) -> TableT: @@ -374,7 +375,7 @@ def _write_table(self, path: FilePath) -> None: gdf.to_file(path, layer=self.tablename(), driver="GPKG") - def sort(self, sort_keys: list[str]): + def sort(self): self.df.sort_index(inplace=True) @@ -392,8 +393,6 @@ def check_parent(self) -> "ChildModel": class NodeModel(ChildModel): """Base class to handle combining the tables for a single node type.""" - _sort_keys: dict[str, list[str]] = {} - @model_serializer(mode="wrap") def set_modeld( self, serializer: Callable[[type["NodeModel"]], dict[str, Any]] @@ -401,6 +400,16 @@ def set_modeld( content = serializer(self) return dict(filter(lambda x: x[1], content.items())) + @field_validator("*") + @classmethod + def set_sort_keys(cls, v: Any, info: ValidationInfo) -> Any: + """Set sort keys for all TableModels if present in FieldInfo.""" + if isinstance(v, (TableModel,)): + field = cls.model_fields[info.field_name] + if getattr(field, "json_schema_extra", None) is not None: + v.sort_keys = field.json_schema_extra.get("sort_keys", []) + return v + @classmethod def get_input_type(cls): return cls.__name__ @@ -434,7 +443,6 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs): getattr(self, field)._save( directory, input_dir, - sort_keys=self._sort_keys[field], ) def _repr_content(self) -> str: diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index b1ac52056..f6cf6c5da 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -111,9 +111,13 @@ def test_sort(level_setpoint_with_minmax, tmp_path): # apply a wrong sort, then call the sort method to restore order table.df.sort_values("greater_than", ascending=False, inplace=True) assert table.df.iloc[0]["greater_than"] == 15.0 - sort_keys = model.discrete_control._sort_keys["condition"] - assert sort_keys == ["node_id", "listen_feature_id", "variable", "greater_than"] - table.sort(sort_keys) + assert table.sort_keys == [ + "node_id", + "listen_feature_id", + "variable", + "greater_than", + ] + table.sort() assert table.df.iloc[0]["greater_than"] == 5.0 # re-apply wrong sort, then check if it gets sorted on write