Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support conditions on linear combinations of variables for DiscreteControl #1371

Merged
merged 33 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4812bca
Python side
SouthEndMusic Apr 10, 2024
44fcaae
Change name of new schema
SouthEndMusic Apr 10, 2024
1bc18e7
Read compound variable data in core
SouthEndMusic Apr 10, 2024
0cec537
Handle compound variables in discrete_control callback
SouthEndMusic Apr 10, 2024
78e8364
Fix part of tests
SouthEndMusic Apr 10, 2024
5032045
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 10, 2024
ff54bb3
Fix the last test in a hacky way
SouthEndMusic Apr 10, 2024
acca949
Add tests
SouthEndMusic Apr 11, 2024
f7cb40c
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 11, 2024
72e65b5
small plotting fix
SouthEndMusic Apr 11, 2024
44eab71
small qgis fix
SouthEndMusic Apr 11, 2024
25997fe
Update schemas in usage.qmd
SouthEndMusic Apr 11, 2024
3144ef1
Update docstrings
SouthEndMusic Apr 11, 2024
e842bdf
update sorting and schema explanations
SouthEndMusic Apr 11, 2024
c975b0c
Fix table sorting
SouthEndMusic Apr 11, 2024
bb123a3
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 11, 2024
6257d6f
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 11, 2024
0ba1d6f
Refactor variable input
SouthEndMusic Apr 15, 2024
2a27fb9
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 15, 2024
17d2a15
Add compound_variable_id
SouthEndMusic Apr 15, 2024
75b56a0
Fix bugs
SouthEndMusic Apr 16, 2024
bf117d1
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
a65f6f6
Small fixes
SouthEndMusic Apr 16, 2024
41ae110
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
9765dc2
Fix examples notebook
SouthEndMusic Apr 16, 2024
267d412
Update docstrings
SouthEndMusic Apr 16, 2024
e6d42e4
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
c662fe7
Update usage.qmd
SouthEndMusic Apr 16, 2024
ae2479c
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
939403e
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
a520e2f
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
3707f50
Fix examples notebook?
SouthEndMusic Apr 16, 2024
a140316
Comments adressed
SouthEndMusic Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,20 @@ function discrete_control_condition(out, u, t, integrator)
(; p) = integrator
(; discrete_control) = p

for (i, (listen_node_id, variable, greater_than, look_ahead)) in enumerate(
for (i, (listen_node_ids, variables, weights, greater_than, look_aheads)) in enumerate(
zip(
discrete_control.listen_node_id,
discrete_control.variable,
discrete_control.weight,
discrete_control.greater_than,
discrete_control.look_ahead,
),
)
value = get_value(p, listen_node_id, variable, look_ahead, u, t)
value = 0.0
for (listen_node_id, variable, weight, look_ahead) in
zip(listen_node_ids, variables, weights, look_aheads)
value += weight * get_value(p, listen_node_id, variable, look_ahead, u, t)
end
diff = value - greater_than
out[i] = diff
end
Expand Down Expand Up @@ -272,16 +277,20 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx)
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
is_basin = id_index(basin.node_id, discrete_control.listen_node_id[condition_idx])[1]
listen_node_ids = discrete_control.listen_node_id[condition_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

# NOTE: The above no longer works when listen feature ids can be something other than node ids
# I think the more durable option is to give all possible condition types a different variable string,
# e.g. basin.level and level_boundary.level
if variable[condition_idx] == "level" && control_state_change && is_basin
if variable[condition_idx][1] == "level" && control_state_change && is_basin
evetion marked this conversation as resolved.
Show resolved Hide resolved
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
_, condition_basin_idx = id_index(basin.node_id, listen_node_id[condition_idx])
_, condition_basin_idx = id_index(basin.node_id, listen_node_id[condition_idx][1])

if du[condition_basin_idx] < 0.0
condition_value[condition_idx] = false
Expand All @@ -308,13 +317,18 @@ function discrete_control_affect_downcrossing!(integrator, condition_idx)
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
if variable[condition_idx] == "level" && control_state_change
listen_node_ids = discrete_control.listen_node_id[condition_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

if variable[condition_idx][1] == "level" && control_state_change && is_basin
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
has_index, condition_basin_idx =
id_index(basin.node_id, listen_node_id[condition_idx])
id_index(basin.node_id, listen_node_id[condition_idx][1])

if has_index && du[condition_basin_idx] > 0.0
condition_value[condition_idx] = true
Expand Down Expand Up @@ -369,7 +383,7 @@ function discrete_control_affect!(
discrete_control.logic_mapping[(discrete_control_node_id, truth_state)]
else
error(
"Control state specified for neither $truth_state_crossing_specific nor $truth_state for DiscreteControl node $discrete_control_node_id.",
"Control state specified for neither $truth_state_crossing_specific nor $truth_state for $discrete_control_node_id.",
)
end

Expand Down
7 changes: 4 additions & 3 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,10 @@ record: Namedtuple with discrete control information for results
"""
struct DiscreteControl <: AbstractParameterNode
node_id::Vector{NodeID}
listen_node_id::Vector{NodeID}
variable::Vector{String}
look_ahead::Vector{Float64}
listen_node_id::Vector{Vector{NodeID}}
variable::Vector{Vector{String}}
weight::Vector{Vector{Float64}}
look_ahead::Vector{Vector{Float64}}
greater_than::Vector{Float64}
condition_value::Vector{Bool}
control_state::Dict{NodeID, Tuple{String, Float64}}
Expand Down
47 changes: 43 additions & 4 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,49 @@
)
end

function get_compound_variables(compound_variable, condition)
listen_node_id = Vector{NodeID}[]
variable = Vector{String}[]
weight = Vector{Float64}[]
look_ahead = Vector{Float64}[]

for cond in condition
if cond.listen_node_type == "compound"
compound_variable_data = filter(
row -> (row.node_id, row.name) == (cond.node_id, cond.variable),

Check warning on line 554 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L553-L554

Added lines #L553 - L554 were not covered by tests
compound_variable,
)
listen_node_id_data =

Check warning on line 557 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L557

Added line #L557 was not covered by tests
NodeID.(
compound_variable_data.listen_node_type,
compound_variable_data.listen_node_id,
)
@assert !isempty(listen_node_id_data) "No compound variable data found for name $(cond.variable)."
variable_data = compound_variable_data.variable
weight_data = compound_variable_data.weight
look_ahead_data = coalesce.(compound_variable_data.look_ahead, 0.0)

Check warning on line 565 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L562-L565

Added lines #L562 - L565 were not covered by tests
else
listen_node_id_data = [NodeID(cond.listen_node_type, cond.listen_node_id)]
variable_data = [cond.variable]
weight_data = [1.0]
look_ahead_data = [coalesce(cond.look_ahead, 0.0)]
end

push!(listen_node_id, listen_node_id_data)
push!(variable, variable_data)
push!(weight, weight_data)
push!(look_ahead, look_ahead_data)
end
return listen_node_id, variable, weight, look_ahead
end

function DiscreteControl(db::DB, config::Config)::DiscreteControl
compound_variable = load_structvector(db, config, DiscreteControlCompoundvariableV1)
condition = load_structvector(db, config, DiscreteControlConditionV1)

listen_node_id, variable, weight, look_ahead =
get_compound_variables(compound_variable, condition)

condition_value = fill(false, length(condition.node_id))
control_state::Dict{NodeID, Tuple{String, Float64}} = Dict()

Expand All @@ -557,7 +597,6 @@
end

logic = load_structvector(db, config, DiscreteControlLogicV1)

logic_mapping = Dict{Tuple{NodeID, String}, String}()

for (node_id, truth_state, control_state_) in
Expand All @@ -567,7 +606,6 @@
end

logic_mapping = expand_logic_mapping(logic_mapping)
look_ahead = coalesce.(condition.look_ahead, 0.0)

record = (
time = Float64[],
Expand All @@ -578,8 +616,9 @@

return DiscreteControl(
NodeID.(NodeType.DiscreteControl, condition.node_id), # Not unique
NodeID.(condition.listen_node_type, condition.listen_node_id),
condition.variable,
listen_node_id,
variable,
weight,
look_ahead,
condition.greater_than,
condition_value,
Expand Down
11 changes: 11 additions & 0 deletions core/src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# The identifier is parsed as ribasim.nodetype.kind, no capitals or underscores are allowed.
@schema "ribasim.discretecontrol.condition" DiscreteControlCondition
@schema "ribasim.discretecontrol.logic" DiscreteControlLogic
@schema "ribasim.discretecontrol.compoundvariable" DiscreteControlCompoundvariable
@schema "ribasim.basin.static" BasinStatic
@schema "ribasim.basin.time" BasinTime
@schema "ribasim.basin.profile" BasinProfile
Expand Down Expand Up @@ -183,6 +184,16 @@ end
node_id::Int32
end

@version DiscreteControlCompoundvariableV1 begin
node_id::Int32
name::String
listen_node_type::String
listen_node_id::Int
variable::String
weight::Float64
look_ahead::Union{Missing, Float64}
end

@version DiscreteControlConditionV1 begin
node_id::Int32
listen_node_type::String
Expand Down
43 changes: 23 additions & 20 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -552,29 +552,32 @@
end
end
end
for (Δt, var, node_id) in zip(look_ahead, variable, listen_node_id)
if !iszero(Δt)
node_type = node_id.type
# TODO: If more transient listen variables must be supported, this validation must be more specific
# (e.g. for some node some variables are transient, some not).
if node_type ∉ [NodeType.FlowBoundary, NodeType.LevelBoundary]
errors = true
@error "Look ahead supplied for non-timeseries listen variable '$var' from listen node $node_id."
else
if Δt < 0
for (look_aheads, variables, listen_node_ids) in
zip(look_ahead, variable, listen_node_id)
for (Δt, var, node_id) in zip(look_aheads, variables, listen_node_ids)
if !iszero(Δt)
node_type = node_id.type
# TODO: If more transient listen variables must be supported, this validation must be more specific
SouthEndMusic marked this conversation as resolved.
Show resolved Hide resolved
# (e.g. for some node some variables are transient, some not).
if node_type ∉ [NodeType.FlowBoundary, NodeType.LevelBoundary]
errors = true
@error "Negative look ahead supplied for listen variable '$var' from listen node $node_id."
@error "Look ahead supplied for non-timeseries listen variable '$var' from listen node $node_id."
else
node = getfield(p, graph[node_id].type)
idx = if node_type == NodeType.Basin
id_index(node.node_id, node_id)
else
searchsortedfirst(node.node_id, node_id)
end
interpolation = getfield(node, Symbol(var))[idx]
if t_end + Δt > interpolation.t[end]
if Δt < 0
errors = true
@error "Look ahead for listen variable '$var' from listen node $node_id goes past timeseries end during simulation."
@error "Negative look ahead supplied for listen variable '$var' from listen node $node_id."
else
node = getfield(p, graph[node_id].type)
idx = if node_type == NodeType.Basin
id_index(node.node_id, node_id)

Check warning on line 572 in core/src/validation.jl

View check run for this annotation

Codecov / codecov/patch

core/src/validation.jl#L572

Added line #L572 was not covered by tests
else
searchsortedfirst(node.node_id, node_id)
end
interpolation = getfield(node, Symbol(var))[idx]
if t_end + Δt > interpolation.t[end]
errors = true
@error "Look ahead for listen variable '$var' from listen node $node_id goes past timeseries end during simulation."
end
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions core/test/control_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ end
p = model.integrator.p
(; discrete_control, flow_boundary) = p

Δt = discrete_control.look_ahead[1]
Δt = discrete_control.look_ahead[1][1]
evetion marked this conversation as resolved.
Show resolved Hide resolved

t = Ribasim.tsaves(model)
t_control = discrete_control.record.time[2]
Expand All @@ -78,7 +78,7 @@ end
p = model.integrator.p
(; discrete_control, level_boundary) = p

Δt = discrete_control.look_ahead[1]
Δt = discrete_control.look_ahead[1][1]

t = Ribasim.tsaves(model)
t_control = discrete_control.record.time[2]
Expand Down
5 changes: 5 additions & 0 deletions python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BasinStaticSchema,
BasinSubgridSchema,
BasinTimeSchema,
DiscreteControlCompoundvariableSchema,
DiscreteControlConditionSchema,
DiscreteControlLogicSchema,
FlowBoundaryStaticSchema,
Expand Down Expand Up @@ -287,6 +288,10 @@ class ManningResistance(MultiNodeModel):


class DiscreteControl(MultiNodeModel):
compoundvariable: TableModel[DiscreteControlCompoundvariableSchema] = Field(
default_factory=TableModel[DiscreteControlCompoundvariableSchema],
json_schema_extra={"sort_keys": ["name"]},
)
condition: TableModel[DiscreteControlConditionSchema] = Field(
default_factory=TableModel[DiscreteControlConditionSchema],
json_schema_extra={
Expand Down
12 changes: 9 additions & 3 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,17 @@ def plot_control_listen(self, ax):
df_listen_edge = pd.concat([df_listen_edge, to_add])

# Listen edges from DiscreteControl
condition = self.discrete_control.condition.df
if condition is not None:
to_add = condition[
for table in (
self.discrete_control.condition.df,
self.discrete_control.compound_variable.df,
):
if table is None:
continue

to_add = table[
["node_id", "listen_node_id", "listen_node_type"]
].drop_duplicates()
to_add = to_add[to_add["listen_node_type"] != "compound"]
SouthEndMusic marked this conversation as resolved.
Show resolved Hide resolved
to_add.columns = ["control_node_id", "listen_node_id", "listen_node_type"]
to_add["control_node_type"] = "DiscreteControl"
df_listen_edge = pd.concat([df_listen_edge, to_add])
Expand Down
13 changes: 11 additions & 2 deletions python/ribasim/ribasim/nodes/discrete_control.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from pandas import DataFrame

from ribasim.input_base import TableModel
from ribasim.schemas import DiscreteControlConditionSchema, DiscreteControlLogicSchema
from ribasim.schemas import (
DiscreteControlCompoundvariableSchema,
DiscreteControlConditionSchema,
DiscreteControlLogicSchema,
)

__all__ = ["Condition", "Logic"]
__all__ = ["Condition", "Logic", "Compoundvariable"]


class Compoundvariable(TableModel[DiscreteControlCompoundvariableSchema]):
def __init__(self, **kwargs):
super().__init__(df=DataFrame(dict(**kwargs)))


class Condition(TableModel[DiscreteControlConditionSchema]):
Expand Down
10 changes: 10 additions & 0 deletions python/ribasim/ribasim/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ class BasinTimeSchema(_BaseSchema):
urban_runoff: Series[float] = pa.Field(nullable=True)


class DiscreteControlCompoundvariableSchema(_BaseSchema):
node_id: Series[Int32] = pa.Field(nullable=False, default=0)
name: Series[str] = pa.Field(nullable=False)
listen_node_type: Series[str] = pa.Field(nullable=False)
listen_node_id: Series[Int32] = pa.Field(nullable=False, default=0)
variable: Series[str] = pa.Field(nullable=False)
weight: Series[float] = pa.Field(nullable=False)
look_ahead: Series[float] = pa.Field(nullable=True)


class DiscreteControlConditionSchema(_BaseSchema):
node_id: Series[Int32] = pa.Field(nullable=False, default=0)
listen_node_type: Series[str] = pa.Field(nullable=False)
Expand Down
2 changes: 2 additions & 0 deletions python/ribasim_testmodels/ribasim_testmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from ribasim_testmodels.bucket import bucket_model, leaky_bucket_model
from ribasim_testmodels.discrete_control import (
compound_variable_condition_model,
flow_condition_model,
level_boundary_condition_model,
level_setpoint_with_minmax_model,
Expand Down Expand Up @@ -64,6 +65,7 @@
"basic_model",
"basic_transient_model",
"bucket_model",
"compound_variable_condition_model",
"discrete_control_of_pid_control_model",
"dutch_waterways_model",
"flow_boundary_time_model",
Expand Down
Loading
Loading