From 24d3864b724798e53fd44b7e69f5856db0722f13 Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Fri, 19 Jan 2024 10:09:44 +0100 Subject: [PATCH 1/4] Refactor sort_keys. --- python/ribasim/ribasim/config.py | 122 +++++++++++---------------- python/ribasim/ribasim/input_base.py | 26 ++++-- python/ribasim/tests/test_io.py | 10 ++- 3 files changed, 74 insertions(+), 84 deletions(-) diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py index e992e811b..167f8fb64 100644 --- a/python/ribasim/ribasim/config.py +++ b/python/ribasim/ribasim/config.py @@ -79,154 +79,132 @@ class Logging(ChildModel): class Terminal(NodeModel): static: TableModel[TerminalStaticSchema] = Field( - default_factory=TableModel[TerminalStaticSchema] + default_factory=TableModel[TerminalStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id"]} - class PidControl(NodeModel): static: TableModel[PidControlStaticSchema] = Field( - default_factory=TableModel[PidControlStaticSchema] + default_factory=TableModel[PidControlStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) time: TableModel[PidControlTimeSchema] = Field( - default_factory=TableModel[PidControlTimeSchema] + default_factory=TableModel[PidControlTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id", "control_state"], - "time": ["node_id", "time"], - } - class LevelBoundary(NodeModel): static: TableModel[LevelBoundaryStaticSchema] = Field( - default_factory=TableModel[LevelBoundaryStaticSchema] + default_factory=TableModel[LevelBoundaryStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) time: TableModel[LevelBoundaryTimeSchema] = Field( - default_factory=TableModel[LevelBoundaryTimeSchema] + default_factory=TableModel[LevelBoundaryTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id"], - "time": ["node_id", "time"], - } - class Pump(NodeModel): static: TableModel[PumpStaticSchema] = Field( - default_factory=TableModel[PumpStaticSchema] + default_factory=TableModel[PumpStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class TabulatedRatingCurve(NodeModel): static: TableModel[TabulatedRatingCurveStaticSchema] = Field( - default_factory=TableModel[TabulatedRatingCurveStaticSchema] + default_factory=TableModel[TabulatedRatingCurveStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state", "level"]}, ) time: TableModel[TabulatedRatingCurveTimeSchema] = Field( - default_factory=TableModel[TabulatedRatingCurveTimeSchema] + default_factory=TableModel[TabulatedRatingCurveTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time", "level"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id", "control_state", "level"], - "time": ["node_id", "time", "level"], - } class User(NodeModel): static: TableModel[UserStaticSchema] = Field( - default_factory=TableModel[UserStaticSchema] + default_factory=TableModel[UserStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "priority"]}, + ) + time: TableModel[UserTimeSchema] = Field( + default_factory=TableModel[UserTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "priority", "time"]}, ) - time: TableModel[UserTimeSchema] = Field(default_factory=TableModel[UserTimeSchema]) - - _sort_keys: dict[str, list[str]] = { - "static": ["node_id", "priority"], - "time": ["node_id", "priority", "time"], - } class FlowBoundary(NodeModel): static: TableModel[FlowBoundaryStaticSchema] = Field( - default_factory=TableModel[FlowBoundaryStaticSchema] + default_factory=TableModel[FlowBoundaryStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) time: TableModel[FlowBoundaryTimeSchema] = Field( - default_factory=TableModel[FlowBoundaryTimeSchema] + default_factory=TableModel[FlowBoundaryTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id"], - "time": ["node_id", "time"], - } - class Basin(NodeModel): profile: TableModel[BasinProfileSchema] = Field( - default_factory=TableModel[BasinProfileSchema] + default_factory=TableModel[BasinProfileSchema], + json_schema_extra={"sort_keys": ["node_id", "level"]}, ) state: TableModel[BasinStateSchema] = Field( - default_factory=TableModel[BasinStateSchema] + default_factory=TableModel[BasinStateSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) static: TableModel[BasinStaticSchema] = Field( - default_factory=TableModel[BasinStaticSchema] + default_factory=TableModel[BasinStaticSchema], + json_schema_extra={"sort_keys": ["node_id"]}, ) time: TableModel[BasinTimeSchema] = Field( - default_factory=TableModel[BasinTimeSchema] + default_factory=TableModel[BasinTimeSchema], + json_schema_extra={"sort_keys": ["node_id", "time"]}, ) subgrid: TableModel[BasinSubgridSchema] = Field( - default_factory=TableModel[BasinSubgridSchema] + default_factory=TableModel[BasinSubgridSchema], + json_schema_extra={"sort_keys": ["subgrid_id", "basin_level"]}, ) - _sort_keys: dict[str, list[str]] = { - "static": ["node_id"], - "state": ["node_id"], - "profile": ["node_id", "level"], - "time": ["node_id", "time"], - "subgrid": ["subgrid_id", "basin_level"], - } - class ManningResistance(NodeModel): static: TableModel[ManningResistanceStaticSchema] = Field( - default_factory=TableModel[ManningResistanceStaticSchema] + default_factory=TableModel[ManningResistanceStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class DiscreteControl(NodeModel): condition: TableModel[DiscreteControlConditionSchema] = Field( - default_factory=TableModel[DiscreteControlConditionSchema] + default_factory=TableModel[DiscreteControlConditionSchema], + json_schema_extra={ + "sort_keys": ["node_id", "listen_feature_id", "variable", "greater_than"] + }, ) logic: TableModel[DiscreteControlLogicSchema] = Field( - default_factory=TableModel[DiscreteControlLogicSchema] + default_factory=TableModel[DiscreteControlLogicSchema], + json_schema_extra={"sort_keys": ["node_id", "truth_state"]}, ) - _sort_keys: dict[str, list[str]] = { - "condition": ["node_id", "listen_feature_id", "variable", "greater_than"], - "logic": ["node_id", "truth_state"], - } - class Outlet(NodeModel): static: TableModel[OutletStaticSchema] = Field( - default_factory=TableModel[OutletStaticSchema] + default_factory=TableModel[OutletStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class LinearResistance(NodeModel): static: TableModel[LinearResistanceStaticSchema] = Field( - default_factory=TableModel[LinearResistanceStaticSchema] + default_factory=TableModel[LinearResistanceStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} - class FractionalFlow(NodeModel): static: TableModel[FractionalFlowStaticSchema] = Field( - default_factory=TableModel[FractionalFlowStaticSchema] + default_factory=TableModel[FractionalFlowStaticSchema], + json_schema_extra={"sort_keys": ["node_id", "control_state"]}, ) - - _sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]} diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index 1538bfafa..82359ed61 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -20,6 +20,7 @@ ConfigDict, DirectoryPath, Field, + ValidationInfo, field_validator, model_serializer, model_validator, @@ -158,6 +159,7 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: class TableModel(FileModel, Generic[TableT]): df: DataFrame[TableT] | None = Field(default=None, exclude=True, repr=False) + sort_keys: list[str] = Field(default=[], exclude=True, repr=False) @field_validator("df") @classmethod @@ -219,15 +221,14 @@ def _save( self, directory: DirectoryPath, input_dir: DirectoryPath, - sort_keys: list[str] = ["node_id"], ) -> None: # TODO directory could be used to save an arrow file db_path = context_file_loading.get().get("database") if self.df is not None and self.filepath is not None: - self.sort(sort_keys) + self.sort() self._write_arrow(self.filepath, directory, input_dir) elif self.df is not None and db_path is not None: - self.sort(sort_keys) + self.sort() self._write_table(db_path) def _write_table(self, temp_path: Path) -> None: @@ -289,13 +290,13 @@ def _from_arrow(cls, path: FilePath) -> pd.DataFrame: directory = context_file_loading.get().get("directory", Path(".")) return pd.read_feather(directory / path) - def sort(self, sort_keys: list[str]): + def sort(self): """Sort the table as required. Sorting is done automatically before writing the table. """ if self.df is not None: - self.df.sort_values(sort_keys, ignore_index=True, inplace=True) + self.df.sort_values(self.sort_keys, ignore_index=True, inplace=True) @classmethod def tableschema(cls) -> TableT: @@ -374,7 +375,7 @@ def _write_table(self, path: FilePath) -> None: gdf.to_file(path, layer=self.tablename(), driver="GPKG") - def sort(self, sort_keys: list[str]): + def sort(self): self.df.sort_index(inplace=True) @@ -392,8 +393,6 @@ def check_parent(self) -> "ChildModel": class NodeModel(ChildModel): """Base class to handle combining the tables for a single node type.""" - _sort_keys: dict[str, list[str]] = {} - @model_serializer(mode="wrap") def set_modeld( self, serializer: Callable[[type["NodeModel"]], dict[str, Any]] @@ -401,6 +400,16 @@ def set_modeld( content = serializer(self) return dict(filter(lambda x: x[1], content.items())) + @field_validator("*") + @classmethod + def set_sort_keys(cls, v: Any, info: ValidationInfo) -> Any: + """Set sort keys for all TableModels if present in FieldInfo.""" + if isinstance(v, (TableModel,)): + field = cls.model_fields[info.field_name] + if getattr(field, "json_schema_extra", None) is not None: + v.sort_keys = field.json_schema_extra.get("sort_keys", []) + return v + @classmethod def get_input_type(cls): return cls.__name__ @@ -434,7 +443,6 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs): getattr(self, field)._save( directory, input_dir, - sort_keys=self._sort_keys[field], ) def _repr_content(self) -> str: diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index b1ac52056..f6cf6c5da 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -111,9 +111,13 @@ def test_sort(level_setpoint_with_minmax, tmp_path): # apply a wrong sort, then call the sort method to restore order table.df.sort_values("greater_than", ascending=False, inplace=True) assert table.df.iloc[0]["greater_than"] == 15.0 - sort_keys = model.discrete_control._sort_keys["condition"] - assert sort_keys == ["node_id", "listen_feature_id", "variable", "greater_than"] - table.sort(sort_keys) + assert table.sort_keys == [ + "node_id", + "listen_feature_id", + "variable", + "greater_than", + ] + table.sort() assert table.df.iloc[0]["greater_than"] == 5.0 # re-apply wrong sort, then check if it gets sorted on write From c24238c898995b0f1730029bf7cb0110df7a9c5a Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Fri, 19 Jan 2024 12:10:15 +0100 Subject: [PATCH 2/4] Mypy happy. --- python/ribasim/ribasim/input_base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index 82359ed61..ef3389b87 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -405,9 +405,10 @@ def set_modeld( def set_sort_keys(cls, v: Any, info: ValidationInfo) -> Any: """Set sort keys for all TableModels if present in FieldInfo.""" if isinstance(v, (TableModel,)): - field = cls.model_fields[info.field_name] - if getattr(field, "json_schema_extra", None) is not None: - v.sort_keys = field.json_schema_extra.get("sort_keys", []) + field = cls.model_fields[getattr(info, "field_name")] + extra = field.json_schema_extra + if extra is not None and isinstance(extra, dict): + v.sort_keys = extra.get("sort_keys", []) # type: ignore return v @classmethod From bdd638e82bc74c1b8ac8e6f24df0efe2ac5de171 Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Fri, 19 Jan 2024 12:20:01 +0100 Subject: [PATCH 3/4] Move to PrivateAttr. --- python/ribasim/ribasim/input_base.py | 7 ++++--- python/ribasim/tests/test_io.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index ef3389b87..97d020f1e 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -20,6 +20,7 @@ ConfigDict, DirectoryPath, Field, + PrivateAttr, ValidationInfo, field_validator, model_serializer, @@ -159,7 +160,7 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: class TableModel(FileModel, Generic[TableT]): df: DataFrame[TableT] | None = Field(default=None, exclude=True, repr=False) - sort_keys: list[str] = Field(default=[], exclude=True, repr=False) + _sort_keys: list[str] = PrivateAttr(default=[]) @field_validator("df") @classmethod @@ -296,7 +297,7 @@ def sort(self): Sorting is done automatically before writing the table. """ if self.df is not None: - self.df.sort_values(self.sort_keys, ignore_index=True, inplace=True) + self.df.sort_values(self._sort_keys, ignore_index=True, inplace=True) @classmethod def tableschema(cls) -> TableT: @@ -408,7 +409,7 @@ def set_sort_keys(cls, v: Any, info: ValidationInfo) -> Any: field = cls.model_fields[getattr(info, "field_name")] extra = field.json_schema_extra if extra is not None and isinstance(extra, dict): - v.sort_keys = extra.get("sort_keys", []) # type: ignore + v._sort_keys = extra.get("sort_keys", []) # type: ignore return v @classmethod diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index f6cf6c5da..e86b9456b 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -111,7 +111,7 @@ def test_sort(level_setpoint_with_minmax, tmp_path): # apply a wrong sort, then call the sort method to restore order table.df.sort_values("greater_than", ascending=False, inplace=True) assert table.df.iloc[0]["greater_than"] == 15.0 - assert table.sort_keys == [ + assert table._sort_keys == [ "node_id", "listen_feature_id", "variable", From fc17d7d79396e2b1af319fe99e38afa4481e4050 Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Fri, 19 Jan 2024 14:44:00 +0100 Subject: [PATCH 4/4] Remove type: ignore. --- python/ribasim/ribasim/input_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index 97d020f1e..17b066d01 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -9,6 +9,7 @@ Any, Generic, TypeVar, + cast, ) import geopandas as gpd @@ -409,7 +410,9 @@ def set_sort_keys(cls, v: Any, info: ValidationInfo) -> Any: field = cls.model_fields[getattr(info, "field_name")] extra = field.json_schema_extra if extra is not None and isinstance(extra, dict): - v._sort_keys = extra.get("sort_keys", []) # type: ignore + # We set sort_keys ourselves as list[str] in json_schema_extra + # but mypy doesn't know. + v._sort_keys = cast(list[str], extra.get("sort_keys", [])) return v @classmethod