Skip to content

Commit

Permalink
Refactor sort_keys implementation (#970)
Browse files Browse the repository at this point in the history
Fixes #956

While ideally we set the sort_keys on the Schema of a TableModel, or on
the TableModel themselves, this is not directly possible AFAIK:
- Schemas are autogenerated, and sort_keys is not something we can
serialize or parse in a JSONSchema
- The TableModels are generic (TableModel[T]), so specific sort_keys
can't be hardcoded to a class.

The previous approach was to hold the sort keys in a `Dict[table,
sort_keys]` on a NodeModel (which holds multiple tables), and passing
them on when save/sort was called from the Node. This PR improves on
that approach by introducing a private _sort_keys on the TableModel,
which is always set (from an empty list by default) on initialization
from the NodeModel.

Instead of keeping a separate Dict with sort_keys on a NodeModel, we now
set them in the `Field` attribute `json_schema_extra` (the only
customizable attribute of a Field) for each field. Each field of a
NodeModel is now validated by default (also on assignment), and if a
sort_keys is present for a fieldtype, it is set on the TableModel.

Based on the initial exploration by @deltamarnix in
https://github.com/Deltares/Ribasim/tree/feat/sort-keys-in-tablemodel
  • Loading branch information
evetion authored Jan 23, 2024
1 parent 839c0e6 commit eb7f512
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 84 deletions.
122 changes: 50 additions & 72 deletions python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,154 +79,132 @@ class Logging(ChildModel):

class Terminal(NodeModel):
static: TableModel[TerminalStaticSchema] = Field(
default_factory=TableModel[TerminalStaticSchema]
default_factory=TableModel[TerminalStaticSchema],
json_schema_extra={"sort_keys": ["node_id"]},
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id"]}


class PidControl(NodeModel):
static: TableModel[PidControlStaticSchema] = Field(
default_factory=TableModel[PidControlStaticSchema]
default_factory=TableModel[PidControlStaticSchema],
json_schema_extra={"sort_keys": ["node_id", "control_state"]},
)
time: TableModel[PidControlTimeSchema] = Field(
default_factory=TableModel[PidControlTimeSchema]
default_factory=TableModel[PidControlTimeSchema],
json_schema_extra={"sort_keys": ["node_id", "time"]},
)

_sort_keys: dict[str, list[str]] = {
"static": ["node_id", "control_state"],
"time": ["node_id", "time"],
}


class LevelBoundary(NodeModel):
static: TableModel[LevelBoundaryStaticSchema] = Field(
default_factory=TableModel[LevelBoundaryStaticSchema]
default_factory=TableModel[LevelBoundaryStaticSchema],
json_schema_extra={"sort_keys": ["node_id"]},
)
time: TableModel[LevelBoundaryTimeSchema] = Field(
default_factory=TableModel[LevelBoundaryTimeSchema]
default_factory=TableModel[LevelBoundaryTimeSchema],
json_schema_extra={"sort_keys": ["node_id", "time"]},
)

_sort_keys: dict[str, list[str]] = {
"static": ["node_id"],
"time": ["node_id", "time"],
}


class Pump(NodeModel):
static: TableModel[PumpStaticSchema] = Field(
default_factory=TableModel[PumpStaticSchema]
default_factory=TableModel[PumpStaticSchema],
json_schema_extra={"sort_keys": ["node_id", "control_state"]},
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}


class TabulatedRatingCurve(NodeModel):
static: TableModel[TabulatedRatingCurveStaticSchema] = Field(
default_factory=TableModel[TabulatedRatingCurveStaticSchema]
default_factory=TableModel[TabulatedRatingCurveStaticSchema],
json_schema_extra={"sort_keys": ["node_id", "control_state", "level"]},
)
time: TableModel[TabulatedRatingCurveTimeSchema] = Field(
default_factory=TableModel[TabulatedRatingCurveTimeSchema]
default_factory=TableModel[TabulatedRatingCurveTimeSchema],
json_schema_extra={"sort_keys": ["node_id", "time", "level"]},
)
_sort_keys: dict[str, list[str]] = {
"static": ["node_id", "control_state", "level"],
"time": ["node_id", "time", "level"],
}


class User(NodeModel):
static: TableModel[UserStaticSchema] = Field(
default_factory=TableModel[UserStaticSchema]
default_factory=TableModel[UserStaticSchema],
json_schema_extra={"sort_keys": ["node_id", "priority"]},
)
time: TableModel[UserTimeSchema] = Field(
default_factory=TableModel[UserTimeSchema],
json_schema_extra={"sort_keys": ["node_id", "priority", "time"]},
)
time: TableModel[UserTimeSchema] = Field(default_factory=TableModel[UserTimeSchema])

_sort_keys: dict[str, list[str]] = {
"static": ["node_id", "priority"],
"time": ["node_id", "priority", "time"],
}


class FlowBoundary(NodeModel):
static: TableModel[FlowBoundaryStaticSchema] = Field(
default_factory=TableModel[FlowBoundaryStaticSchema]
default_factory=TableModel[FlowBoundaryStaticSchema],
json_schema_extra={"sort_keys": ["node_id"]},
)
time: TableModel[FlowBoundaryTimeSchema] = Field(
default_factory=TableModel[FlowBoundaryTimeSchema]
default_factory=TableModel[FlowBoundaryTimeSchema],
json_schema_extra={"sort_keys": ["node_id", "time"]},
)

_sort_keys: dict[str, list[str]] = {
"static": ["node_id"],
"time": ["node_id", "time"],
}


class Basin(NodeModel):
profile: TableModel[BasinProfileSchema] = Field(
default_factory=TableModel[BasinProfileSchema]
default_factory=TableModel[BasinProfileSchema],
json_schema_extra={"sort_keys": ["node_id", "level"]},
)
state: TableModel[BasinStateSchema] = Field(
default_factory=TableModel[BasinStateSchema]
default_factory=TableModel[BasinStateSchema],
json_schema_extra={"sort_keys": ["node_id"]},
)
static: TableModel[BasinStaticSchema] = Field(
default_factory=TableModel[BasinStaticSchema]
default_factory=TableModel[BasinStaticSchema],
json_schema_extra={"sort_keys": ["node_id"]},
)
time: TableModel[BasinTimeSchema] = Field(
default_factory=TableModel[BasinTimeSchema]
default_factory=TableModel[BasinTimeSchema],
json_schema_extra={"sort_keys": ["node_id", "time"]},
)
subgrid: TableModel[BasinSubgridSchema] = Field(
default_factory=TableModel[BasinSubgridSchema]
default_factory=TableModel[BasinSubgridSchema],
json_schema_extra={"sort_keys": ["subgrid_id", "basin_level"]},
)

_sort_keys: dict[str, list[str]] = {
"static": ["node_id"],
"state": ["node_id"],
"profile": ["node_id", "level"],
"time": ["node_id", "time"],
"subgrid": ["subgrid_id", "basin_level"],
}


class ManningResistance(NodeModel):
static: TableModel[ManningResistanceStaticSchema] = Field(
default_factory=TableModel[ManningResistanceStaticSchema]
default_factory=TableModel[ManningResistanceStaticSchema],
json_schema_extra={"sort_keys": ["node_id", "control_state"]},
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}


class DiscreteControl(NodeModel):
condition: TableModel[DiscreteControlConditionSchema] = Field(
default_factory=TableModel[DiscreteControlConditionSchema]
default_factory=TableModel[DiscreteControlConditionSchema],
json_schema_extra={
"sort_keys": ["node_id", "listen_feature_id", "variable", "greater_than"]
},
)
logic: TableModel[DiscreteControlLogicSchema] = Field(
default_factory=TableModel[DiscreteControlLogicSchema]
default_factory=TableModel[DiscreteControlLogicSchema],
json_schema_extra={"sort_keys": ["node_id", "truth_state"]},
)

_sort_keys: dict[str, list[str]] = {
"condition": ["node_id", "listen_feature_id", "variable", "greater_than"],
"logic": ["node_id", "truth_state"],
}


class Outlet(NodeModel):
static: TableModel[OutletStaticSchema] = Field(
default_factory=TableModel[OutletStaticSchema]
default_factory=TableModel[OutletStaticSchema],
json_schema_extra={"sort_keys": ["node_id", "control_state"]},
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}


class LinearResistance(NodeModel):
static: TableModel[LinearResistanceStaticSchema] = Field(
default_factory=TableModel[LinearResistanceStaticSchema]
default_factory=TableModel[LinearResistanceStaticSchema],
json_schema_extra={"sort_keys": ["node_id", "control_state"]},
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}


class FractionalFlow(NodeModel):
static: TableModel[FractionalFlowStaticSchema] = Field(
default_factory=TableModel[FractionalFlowStaticSchema]
default_factory=TableModel[FractionalFlowStaticSchema],
json_schema_extra={"sort_keys": ["node_id", "control_state"]},
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}
31 changes: 22 additions & 9 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Any,
Generic,
TypeVar,
cast,
)

import geopandas as gpd
Expand All @@ -20,6 +21,8 @@
ConfigDict,
DirectoryPath,
Field,
PrivateAttr,
ValidationInfo,
field_validator,
model_serializer,
model_validator,
Expand Down Expand Up @@ -158,6 +161,7 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:

class TableModel(FileModel, Generic[TableT]):
df: DataFrame[TableT] | None = Field(default=None, exclude=True, repr=False)
_sort_keys: list[str] = PrivateAttr(default=[])

@field_validator("df")
@classmethod
Expand Down Expand Up @@ -219,15 +223,14 @@ def _save(
self,
directory: DirectoryPath,
input_dir: DirectoryPath,
sort_keys: list[str] = ["node_id"],
) -> None:
# TODO directory could be used to save an arrow file
db_path = context_file_loading.get().get("database")
if self.df is not None and self.filepath is not None:
self.sort(sort_keys)
self.sort()
self._write_arrow(self.filepath, directory, input_dir)
elif self.df is not None and db_path is not None:
self.sort(sort_keys)
self.sort()
self._write_table(db_path)

def _write_table(self, temp_path: Path) -> None:
Expand Down Expand Up @@ -289,13 +292,13 @@ def _from_arrow(cls, path: FilePath) -> pd.DataFrame:
directory = context_file_loading.get().get("directory", Path("."))
return pd.read_feather(directory / path)

def sort(self, sort_keys: list[str]):
def sort(self):
"""Sort the table as required.
Sorting is done automatically before writing the table.
"""
if self.df is not None:
self.df.sort_values(sort_keys, ignore_index=True, inplace=True)
self.df.sort_values(self._sort_keys, ignore_index=True, inplace=True)

@classmethod
def tableschema(cls) -> TableT:
Expand Down Expand Up @@ -374,7 +377,7 @@ def _write_table(self, path: FilePath) -> None:

gdf.to_file(path, layer=self.tablename(), driver="GPKG")

def sort(self, sort_keys: list[str]):
def sort(self):
self.df.sort_index(inplace=True)


Expand All @@ -392,15 +395,26 @@ def check_parent(self) -> "ChildModel":
class NodeModel(ChildModel):
"""Base class to handle combining the tables for a single node type."""

_sort_keys: dict[str, list[str]] = {}

@model_serializer(mode="wrap")
def set_modeld(
self, serializer: Callable[[type["NodeModel"]], dict[str, Any]]
) -> dict[str, Any]:
content = serializer(self)
return dict(filter(lambda x: x[1], content.items()))

@field_validator("*")
@classmethod
def set_sort_keys(cls, v: Any, info: ValidationInfo) -> Any:
"""Set sort keys for all TableModels if present in FieldInfo."""
if isinstance(v, (TableModel,)):
field = cls.model_fields[getattr(info, "field_name")]
extra = field.json_schema_extra
if extra is not None and isinstance(extra, dict):
# We set sort_keys ourselves as list[str] in json_schema_extra
# but mypy doesn't know.
v._sort_keys = cast(list[str], extra.get("sort_keys", []))
return v

@classmethod
def get_input_type(cls):
return cls.__name__
Expand Down Expand Up @@ -434,7 +448,6 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs):
getattr(self, field)._save(
directory,
input_dir,
sort_keys=self._sort_keys[field],
)

def _repr_content(self) -> str:
Expand Down
10 changes: 7 additions & 3 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,13 @@ def test_sort(level_setpoint_with_minmax, tmp_path):
# apply a wrong sort, then call the sort method to restore order
table.df.sort_values("greater_than", ascending=False, inplace=True)
assert table.df.iloc[0]["greater_than"] == 15.0
sort_keys = model.discrete_control._sort_keys["condition"]
assert sort_keys == ["node_id", "listen_feature_id", "variable", "greater_than"]
table.sort(sort_keys)
assert table._sort_keys == [
"node_id",
"listen_feature_id",
"variable",
"greater_than",
]
table.sort()
assert table.df.iloc[0]["greater_than"] == 5.0

# re-apply wrong sort, then check if it gets sorted on write
Expand Down

0 comments on commit eb7f512

Please sign in to comment.