Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and change table sorting in Python #903

Merged
merged 5 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions docs/core/usage.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ Lets a fraction (in [0,1]) of the incoming flow trough.
column | type | unit | restriction
------------- | ------- | ------------ | -----------
node_id | Int | - | sorted
control_state | String | - | (optional) sorted per node_id
fraction | Float64 | - | in the interval [0,1]
control_state | String | - | (optional)

# TabulatedRatingCurve

Expand All @@ -359,10 +359,10 @@ relation between the storage of a connected Basin (via the outlet level) and its
column | type | unit | restriction
------------- | ------- | ------------ | -----------
node_id | Int | - | sorted
control_state | String | - | (optional) sorted per node_id
active | Bool | - | (optional, default true)
level | Float64 | $m$ | sorted per control_state
discharge | Float64 | $m^3 s^{-1}$ | non-negative
control_state | String | - | (optional) sorted per node_id

node_id | discharge | level
------- |----------- |-------
Expand All @@ -382,8 +382,8 @@ Note that a `node_id` can be either in this table or in the static one, but not

column | type | unit | restriction
--------- | ------- | ------------ | -----------
time | DateTime | - | sorted
node_id | Int | - | sorted per time
node_id | Int | - | sorted
time | DateTime | - | sorted per node_id
level | Float64 | $m$ | -
discharge | Float64 | $m^3 s^{-1}$ | non-negative

Expand All @@ -398,11 +398,11 @@ When PID controlled, the pump must point away from the controlled basin in terms
column | type | unit | restriction
--------- | ------- | ------------ | -----------
node_id | Int | - | sorted
control_state | String | - | (optional) sorted per node_id
active | Bool | - | (optional, default true)
flow_rate | Float64 | $m^3 s^{-1}$ | non-negative
min_flow_rate | Float64 | $m^3 s^{-1}$ | (optional, default 0.0)
max_flow_rate | Float64 | $m^3 s^{-1}$ | (optional)
control_state | String | - | (optional)

# Outlet

Expand All @@ -412,12 +412,12 @@ When PID controlled, the outlet must point towards the controlled basin in terms
column | type | unit | restriction
--------- | ------- | ------------ | -----------
node_id | Int | - | sorted
control_state | String | - | (optional) sorted per node_id
active | Bool | - | (optional, default true)
flow_rate | Float64 | $m^3 s^{-1}$ | non-negative
min_flow_rate | Float64 | $m^3 s^{-1}$ | (optional, default 0.0)
max_flow_rate | Float64 | $m^3 s^{-1}$ | (optional)
min_crest_level | Float64 | $m$ | (optional)
control_state | String | - | (optional)

# User

Expand Down Expand Up @@ -454,11 +454,11 @@ Note that a `node_id` can be either in this table or in the static one, but not
column | type | unit | restriction
------------- | -------- | ------------ | -----------
node_id | Int | - | sorted
priority | Int | - | sorted per node id
time | DateTime | - | sorted per priority per node id
demand | Float64 | $m^3 s^{-1}$ | -
return_factor | Float64 | - | between [0 - 1]
min_level | Float64 | $m$ | (optional)
priority | Int | - | sorted per node id

## Allocation results

Expand Down Expand Up @@ -503,8 +503,8 @@ Note that a `node_id` can be either in this table or in the static one, but not

column | type | unit | restriction
--------- | ------- | ------------ | -----------
time | DateTime | - | sorted
node_id | Int | - | sorted per time
node_id | Int | - | sorted
time | DateTime | - | sorted per node_id
level | Float64 | $m$ | -

# FlowBoundary
Expand Down Expand Up @@ -534,8 +534,8 @@ Note that a `node_id` can be either in this table or in the static one, but not

column | type | unit | restriction
--------- | ------- | ------------ | -----------
time | DateTime | - | sorted
node_id | Int | - | sorted per time
node_id | Int | - | sorted
time | DateTime | - | sorted per node_id
flow_rate | Float64 | $m^3 s^{-1}$ | non-negative

# LinearResistance
Expand All @@ -545,9 +545,9 @@ Flow proportional to the level difference between the connected basins.
column | type | unit | restriction
------------- | ------- | ------------ | -----------
node_id | Int | - | sorted
control_state | String | - | (optional) sorted per node_id
active | Bool | - | (optional, default true)
resistance | Float64 | $sm^{-2}$ | -
control_state | String | - | (optional)

# ManningResistance

Expand All @@ -556,12 +556,12 @@ Flow through this connection is estimated by conservation of energy and the Mann
column | type | unit | restriction
------------- | ------- | ------------ | -----------
node_id | Int | - | sorted
control_state | String | - | (optional) sorted per node_id
active | Bool | - | (optional, default true)
length | Float64 | $m$ | positive
manning_n | Float64 | $s m^{-\frac{1}{3}}$ | positive
profile_width | Float64 | $m$ | positive
profile_slope | Float64 | - | -
control_state | String | - | (optional)

# Terminal

Expand All @@ -585,9 +585,9 @@ The condition schema defines conditions of the form 'the discrete_control node w
column | type | unit | restriction
----------------- | -------- | ------- | -----------
node_id | Int | - | sorted
listen_feature_id | Int | - | -
variable | String | - | must be "level" or "flow_rate"
greater_than | Float64 | various | -
listen_feature_id | Int | - | sorted per node_id
variable | String | - | must be "level" or "flow_rate", sorted per listen_feature_id
greater_than | Float64 | various | sorted per variable
look_ahead | Float64 | $s$ | Only on transient boundary conditions, non-negative (optional, default 0)

## DiscreteControl / logic
Expand All @@ -596,8 +596,8 @@ The logic schema defines which control states are triggered based on the truth o
DiscreteControl is applied in the Julia core as follows:

- During the simulation it is checked whether the truth of any of the conditions changes.
- When a condition changes, the corresponding discrrete_control node id is retrieved (node_id in the condition schema above).
- The truth value of all the conditions this discrete_control node lisens to are retrieved, in the order as they are specified in the condition schema. This is then converted into a string of "T" for true and "F" for false. This string we call the truth state.*
- When a condition changes, the corresponding discrete_control node id is retrieved (node_id in the condition schema above).
- The truth value of all the conditions this discrete_control node listens to are retrieved, in the order as they are specified in the condition schema. This is then converted into a string of "T" for true and "F" for false. This string we call the truth state.*
- The table below determines for the given discrete_control node ID and truth state what the corresponding control state is.
- For all the nodes this discrete_control node affects (as given by the "control" edges in [Edges / static](usage.qmd#edge)), their parameters are set to those parameters in `NodeType / static` corresponding to the determined control state.

Expand All @@ -606,8 +606,8 @@ DiscreteControl is applied in the Julia core as follows:
column | type | unit | restriction
-------------- | -------- | ---- | -----------
node_id | Int | - | sorted
truth_state | String | - | Consists of the characters "T" (true), "F" (false), "U" (upcrossing), "D" (downcrossing) and "*" (any)
control_state | String | - |
control_state | String | - | -
truth_state | String | - | Consists of the characters "T" (true), "F" (false), "U" (upcrossing), "D" (downcrossing) and "*" (any), sorted per node_id

## DiscreteControl results

Expand All @@ -631,13 +631,13 @@ In the future controlling the flow on a particular edge could be supported.
column | type | unit | restriction
-------------- | -------- | -------- | -----------
node_id | Int | - | sorted
control_state | String | - | (optional) sorted per node_id
active | Bool | - | (optional, default true)
listen_node_id | Int | - | -
target | Float64 | $m$ | -
proportional | Float64 | $s^{-1}$ | -
integral | Float64 | $s^{-2}$ | -
derivative | Float64 | - | -
control_state | String | - | -

## PidControl / time

Expand All @@ -651,8 +651,8 @@ Note that a `node_id` can be either in this table or in the static one, but not

column | type | unit | restriction
-------------- | -------- | -------- | -----------
node_id | Int | - | sorted per time
time | DateTime | - | sorted
node_id | Int | - | sorted
time | DateTime | - | sorted per node_id
listen_node_id | Int | - | -
target | Float64 | $m$ | -
proportional | Float64 | $s^{-1}$ | -
Expand Down
40 changes: 34 additions & 6 deletions python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class Terminal(NodeModel):
default_factory=TableModel[TerminalStaticSchema]
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id"]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree with adding this protected variable in multiple places. There are other ways to do this. For example, add an __init__ function everywhere to add the required keys to the dict.

But I am trying to look a bit further. It would be better if we receive this information from the julia code. For example, by adding it in the schemas and reading the schemas in python. Would that be something that is possible with pydantic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_sort_keys was already used, falling back to a wrong default if it wasn't present. This PR removes the wrong default and requires explicitly stating the correct sorting order.

If possible I agree it would be nice to share this or run this through pydantic. I'd like to keep this PR focused on a bugfix and not add a refactor on top.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@visr @evetion take a look at feat/sort-keys-in-tablemodel. I don't have it working yet. But to me it seems like we need to put the sort keys inside of the table model as some class var which depends per initialization. Or maybe in another way. Just a Field?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been playing a bit more with this, and it seems to me that the best option would be to make the sort_keys an instance Field

sort_keys: list[str] = Field(init_var=True, exclude=True, repr=False)

I set it to init_var, so it can never be changed afterwards, and I set it to exclude, because the data is always set in python.

The only problem that I still had was the moment that you set a pandas DataFrame, because then it tells me that:

pydantic_core._pydantic_core.ValidationError: 1 validation error for Terminal
static.sort_keys
Input should be a valid list [type=list_type, input_value=None, input_type=NoneType]
For further information visit https://errors.pydantic.dev/2.5/v/list_type

I believe we could solve this by using a different scheme for custom data types, instead of the @model_validator as described here: https://docs.pydantic.dev/latest/concepts/types/#custom-types

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is init_var=True needed? Because if you set a DataFrame, the TableModel already exists, including its specific sort_keys. Ideally we get that from somewhere.

But I am trying to look a bit further. It would be better if we receive this information from the julia code. For example, by adding it in the schemas and reading the schemas in python. Would that be something that is possible with pydantic?

Fully agree, but I couldn't find a way of doing so. It requires some custom info passed to the JSON Schema, the https://github.com/koxudaxi/datamodel-code-generator/, to read it and pass it into some class as a private field, or Config option, that we here can pick up on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created #956 for the first effort on getting the sort_keys in pydantic. You could decide to enhance the ticket, or otherwise just make a new ticket to get the info from Julia.



class PidControl(NodeModel):
static: TableModel[PidControlStaticSchema] = Field(
Expand All @@ -91,7 +93,10 @@ class PidControl(NodeModel):
default_factory=TableModel[PidControlTimeSchema]
)

_sort_keys: dict[str, list[str]] = {"time": ["time", "node_id"]}
_sort_keys: dict[str, list[str]] = {
"static": ["node_id", "control_state"],
"time": ["node_id", "time"],
}


class LevelBoundary(NodeModel):
Expand All @@ -102,14 +107,19 @@ class LevelBoundary(NodeModel):
default_factory=TableModel[LevelBoundaryTimeSchema]
)

_sort_keys: dict[str, list[str]] = {"time": ["time", "node_id"]}
_sort_keys: dict[str, list[str]] = {
"static": ["node_id"],
"time": ["node_id", "time"],
}


class Pump(NodeModel):
static: TableModel[PumpStaticSchema] = Field(
default_factory=TableModel[PumpStaticSchema]
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}


class TabulatedRatingCurve(NodeModel):
static: TableModel[TabulatedRatingCurveStaticSchema] = Field(
Expand All @@ -119,8 +129,8 @@ class TabulatedRatingCurve(NodeModel):
default_factory=TableModel[TabulatedRatingCurveTimeSchema]
)
_sort_keys: dict[str, list[str]] = {
"static": ["node_id", "level"],
"time": ["time", "node_id", "level"],
"static": ["node_id", "control_state", "level"],
"time": ["node_id", "time", "level"],
}


Expand All @@ -144,7 +154,10 @@ class FlowBoundary(NodeModel):
default_factory=TableModel[FlowBoundaryTimeSchema]
)

_sort_keys: dict[str, list[str]] = {"time": ["time", "node_id"]}
_sort_keys: dict[str, list[str]] = {
"static": ["node_id"],
"time": ["node_id", "time"],
}


class Basin(NodeModel):
Expand All @@ -165,8 +178,10 @@ class Basin(NodeModel):
)

_sort_keys: dict[str, list[str]] = {
"static": ["node_id"],
"state": ["node_id"],
"profile": ["node_id", "level"],
"time": ["time", "node_id"],
"time": ["node_id", "time"],
"subgrid": ["subgrid_id", "basin_level"],
}

Expand All @@ -176,6 +191,8 @@ class ManningResistance(NodeModel):
default_factory=TableModel[ManningResistanceStaticSchema]
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}


class DiscreteControl(NodeModel):
condition: TableModel[DiscreteControlConditionSchema] = Field(
Expand All @@ -185,20 +202,31 @@ class DiscreteControl(NodeModel):
default_factory=TableModel[DiscreteControlLogicSchema]
)

_sort_keys: dict[str, list[str]] = {
"condition": ["node_id", "listen_feature_id", "variable", "greater_than"],
"logic": ["node_id", "truth_state"],
}


class Outlet(NodeModel):
static: TableModel[OutletStaticSchema] = Field(
default_factory=TableModel[OutletStaticSchema]
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}


class LinearResistance(NodeModel):
static: TableModel[LinearResistanceStaticSchema] = Field(
default_factory=TableModel[LinearResistanceStaticSchema]
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}


class FractionalFlow(NodeModel):
static: TableModel[FractionalFlowStaticSchema] = Field(
default_factory=TableModel[FractionalFlowStaticSchema]
)

_sort_keys: dict[str, list[str]] = {"static": ["node_id", "control_state"]}
9 changes: 4 additions & 5 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,9 @@ def _from_arrow(cls, path: FilePath) -> pd.DataFrame:
directory = context_file_loading.get().get("directory", Path("."))
return pd.read_feather(directory / path)

def sort(self, sort_keys: list[str] = ["node_id"]):
"""Sort all input tables as required.
def sort(self, sort_keys: list[str]):
"""Sort the table as required.

Tables are sorted by "node_id", unless otherwise specified.
Sorting is done automatically before writing the table.
"""
if self.df is not None:
Expand Down Expand Up @@ -375,7 +374,7 @@ def _write_table(self, path: FilePath) -> None:

gdf.to_file(path, layer=self.tablename(), driver="GPKG")

def sort(self, sort_keys: list[str] = ["node_id"]):
def sort(self, sort_keys: list[str]):
self.df.sort_index(inplace=True)


Expand Down Expand Up @@ -435,7 +434,7 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs):
getattr(self, field)._save(
directory,
input_dir,
sort_keys=self._sort_keys.get("field", ["node_id"]),
deltamarnix marked this conversation as resolved.
Show resolved Hide resolved
sort_keys=self._sort_keys[field],
)

def _repr_content(self) -> str:
Expand Down
5 changes: 5 additions & 0 deletions python/ribasim/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ def backwater() -> ribasim.Model:
@pytest.fixture()
def discrete_control_of_pid_control() -> ribasim.Model:
return ribasim_testmodels.discrete_control_of_pid_control_model()


@pytest.fixture()
def level_setpoint_with_minmax() -> ribasim.Model:
return ribasim_testmodels.level_setpoint_with_minmax_model()
23 changes: 23 additions & 0 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,26 @@ def test_extra_columns(basic_transient):
node = Node(df=df)

assert "meta_node_id" in node.df.columns


def test_sort(level_setpoint_with_minmax, tmp_path):
model = level_setpoint_with_minmax
table = model.discrete_control.condition

# apply a wrong sort, then call the sort method to restore order
table.df.sort_values("greater_than", ascending=False, inplace=True)
assert table.df.iloc[0]["greater_than"] == 15.0
sort_keys = model.discrete_control._sort_keys["condition"]
assert sort_keys == ["node_id", "listen_feature_id", "variable", "greater_than"]
table.sort(sort_keys)
assert table.df.iloc[0]["greater_than"] == 5.0

# re-apply wrong sort, then check if it gets sorted on write
table.df.sort_values("greater_than", ascending=False, inplace=True)
model.write(tmp_path / "basic/ribasim.toml")
# write sorts the model in place
assert table.df.iloc[0]["greater_than"] == 5.0
model_loaded = ribasim.Model(filepath=tmp_path / "basic/ribasim.toml")
table_loaded = model_loaded.discrete_control.condition
assert table_loaded.df.iloc[0]["greater_than"] == 5.0
__assert_equal(table.df, table_loaded.df)