Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support conditions on linear combinations of variables for DiscreteControl #1371

Merged
merged 33 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4812bca
Python side
SouthEndMusic Apr 10, 2024
44fcaae
Change name of new schema
SouthEndMusic Apr 10, 2024
1bc18e7
Read compound variable data in core
SouthEndMusic Apr 10, 2024
0cec537
Handle compound variables in discrete_control callback
SouthEndMusic Apr 10, 2024
78e8364
Fix part of tests
SouthEndMusic Apr 10, 2024
5032045
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 10, 2024
ff54bb3
Fix the last test in a hacky way
SouthEndMusic Apr 10, 2024
acca949
Add tests
SouthEndMusic Apr 11, 2024
f7cb40c
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 11, 2024
72e65b5
small plotting fix
SouthEndMusic Apr 11, 2024
44eab71
small qgis fix
SouthEndMusic Apr 11, 2024
25997fe
Update schemas in usage.qmd
SouthEndMusic Apr 11, 2024
3144ef1
Update docstrings
SouthEndMusic Apr 11, 2024
e842bdf
update sorting and schema explanations
SouthEndMusic Apr 11, 2024
c975b0c
Fix table sorting
SouthEndMusic Apr 11, 2024
bb123a3
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 11, 2024
6257d6f
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 11, 2024
0ba1d6f
Refactor variable input
SouthEndMusic Apr 15, 2024
2a27fb9
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 15, 2024
17d2a15
Add compound_variable_id
SouthEndMusic Apr 15, 2024
75b56a0
Fix bugs
SouthEndMusic Apr 16, 2024
bf117d1
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
a65f6f6
Small fixes
SouthEndMusic Apr 16, 2024
41ae110
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
9765dc2
Fix examples notebook
SouthEndMusic Apr 16, 2024
267d412
Update docstrings
SouthEndMusic Apr 16, 2024
e6d42e4
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
c662fe7
Update usage.qmd
SouthEndMusic Apr 16, 2024
ae2479c
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
939403e
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
a520e2f
Merge branch 'main' into compound_condition_variables
SouthEndMusic Apr 16, 2024
3707f50
Fix examples notebook?
SouthEndMusic Apr 16, 2024
a140316
Comments adressed
SouthEndMusic Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 78 additions & 30 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,19 @@ function set_initial_discrete_controlled_parameters!(
(; p) = integrator
(; discrete_control) = p

n_conditions = length(discrete_control.condition_value)
n_conditions = sum(length(vec) for vec in discrete_control.condition_value; init = 0)
condition_diffs = zeros(Float64, n_conditions)
discrete_control_condition(condition_diffs, storage0, integrator.t, integrator)
discrete_control.condition_value .= (condition_diffs .> 0.0)

# Set the discrete control value (bool) per compound variable
idx_start = 1
for (compound_variable_idx, vec) in enumerate(discrete_control.condition_value)
l = length(vec)
idx_end = idx_start + l - 1
discrete_control.condition_value[compound_variable_idx] .=
(condition_diffs[idx_start:idx_end] .> 0)
idx_start += l
end

# For every discrete_control node find a condition_idx it listens to
for discrete_control_node_id in unique(discrete_control.node_id)
Expand Down Expand Up @@ -78,7 +87,7 @@ function create_callbacks(

saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level)

n_conditions = length(discrete_control.node_id)
n_conditions = sum(length(vec) for vec in discrete_control.greater_than; init = 0)
if n_conditions > 0
discrete_control_cb = VectorContinuousCallback(
discrete_control_condition,
Expand Down Expand Up @@ -183,18 +192,27 @@ Listens for changes in condition truths.
function discrete_control_condition(out, u, t, integrator)
(; p) = integrator
(; discrete_control) = p

for (i, (listen_node_id, variable, greater_than, look_ahead)) in enumerate(
zip(
discrete_control.listen_node_id,
discrete_control.variable,
discrete_control.greater_than,
discrete_control.look_ahead,
),
condition_idx = 0

# Loop over compound variables
for (listen_node_ids, variables, weights, greater_thans, look_aheads) in zip(
discrete_control.listen_node_id,
discrete_control.variable,
discrete_control.weight,
discrete_control.greater_than,
discrete_control.look_ahead,
)
value = get_value(p, listen_node_id, variable, look_ahead, u, t)
diff = value - greater_than
out[i] = diff
value = 0.0
for (listen_node_id, variable, weight, look_ahead) in
zip(listen_node_ids, variables, weights, look_aheads)
value += weight * get_value(p, listen_node_id, variable, look_ahead, u, t)
end
# Loop over greater_than values for this compound_variable
for greater_than in greater_thans
condition_idx += 1
diff = value - greater_than
out[condition_idx] = diff
end
end
end

Expand Down Expand Up @@ -252,7 +270,9 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx)
(; discrete_control, basin) = p
(; variable, condition_value, listen_node_id) = discrete_control

condition_value[condition_idx] = true
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
condition_value[compound_variable_idx][greater_than_idx] = true

control_state_change = discrete_control_affect!(integrator, condition_idx, true)

Expand All @@ -262,19 +282,24 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx)
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
is_basin = id_index(basin.node_id, discrete_control.listen_node_id[condition_idx])[1]
listen_node_ids = discrete_control.listen_node_id[compound_variable_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

# NOTE: The above no longer works when listen feature ids can be something other than node ids
# I think the more durable option is to give all possible condition types a different variable string,
# e.g. basin.level and level_boundary.level
if variable[condition_idx] == "level" && control_state_change && is_basin
if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
_, condition_basin_idx = id_index(basin.node_id, listen_node_id[condition_idx])
_, condition_basin_idx =
id_index(basin.node_id, listen_node_id[compound_variable_idx][1])

if du[condition_basin_idx] < 0.0
condition_value[condition_idx] = false
condition_value[compound_variable_idx][greater_than_idx] = false
discrete_control_affect!(integrator, condition_idx, false)
end
end
Expand All @@ -288,7 +313,9 @@ function discrete_control_affect_downcrossing!(integrator, condition_idx)
(; discrete_control, basin) = p
(; variable, condition_value, listen_node_id) = discrete_control

condition_value[condition_idx] = false
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
condition_value[compound_variable_idx][greater_than_idx] = false

control_state_change = discrete_control_affect!(integrator, condition_idx, false)

Expand All @@ -298,16 +325,23 @@ function discrete_control_affect_downcrossing!(integrator, condition_idx)
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
if variable[condition_idx] == "level" && control_state_change
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
listen_node_ids = discrete_control.listen_node_id[compound_variable_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
has_index, condition_basin_idx =
id_index(basin.node_id, listen_node_id[condition_idx])
id_index(basin.node_id, listen_node_id[compound_variable_idx][1])

if has_index && du[condition_basin_idx] > 0.0
condition_value[condition_idx] = true
condition_value[compound_variable_idx][greater_than_idx] = true
discrete_control_affect!(integrator, condition_idx, true)
end
end
Expand All @@ -325,20 +359,34 @@ function discrete_control_affect!(
(; discrete_control, graph) = p

# Get the discrete_control node that listens to this condition
discrete_control_node_id = discrete_control.node_id[condition_idx]

compound_variable_idx, _ = get_discrete_control_indices(discrete_control, condition_idx)
discrete_control_node_id = discrete_control.node_id[compound_variable_idx]

# Get the indices of all conditions that this control node listens to
condition_ids = discrete_control.node_id .== discrete_control_node_id
where_node_id = searchsorted(discrete_control.node_id, discrete_control_node_id)

# Get the truth state for this discrete_control node
truth_values = [ifelse(b, "T", "F") for b in discrete_control.condition_value]
truth_state = join(truth_values[condition_ids], "")
truth_values = cat(
[
[ifelse(b, "T", "F") for b in discrete_control.condition_value[i]] for
i in where_node_id
]...;
dims = 1,
)
truth_state = join(truth_values, "")

# Get the truth specific about the latest crossing
if !ismissing(upcrossing)
truth_values[condition_idx] = upcrossing ? "U" : "D"
truth_value_idx =
condition_idx - sum(
length(vec) for
vec in discrete_control.condition_value[1:(where_node_id.start - 1)];
init = 0,
)
truth_values[truth_value_idx] = upcrossing ? "U" : "D"
end
truth_state_crossing_specific = join(truth_values[condition_ids], "")
truth_state_crossing_specific = join(truth_values, "")

# What the local control state should be
control_state_new =
Expand All @@ -359,7 +407,7 @@ function discrete_control_affect!(
discrete_control.logic_mapping[(discrete_control_node_id, truth_state)]
else
error(
"Control state specified for neither $truth_state_crossing_specific nor $truth_state for DiscreteControl node $discrete_control_node_id.",
"Control state specified for neither $truth_state_crossing_specific nor $truth_state for $discrete_control_node_id.",
)
end

Expand Down
27 changes: 16 additions & 11 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,23 +443,28 @@ struct Terminal <: AbstractParameterNode
end

"""
node_id: node ID of the DiscreteControl node; these are not unique but repeated
by the amount of conditions of this DiscreteControl node
listen_node_id: the ID of the node being condition on
variable: the name of the variable in the condition
greater_than: The threshold value in the condition
condition_value: The current value of each condition
node_id: node ID of the DiscreteControl node per compound variable (can contain repeats)
listen_node_id: the IDs of the nodes being condition on per compound variable
variable: the names of the variables in the condition per compound variable
weight: the weight of the variables in the condition per compound variable
look_ahead: the look ahead of variables in the condition in seconds per compound_variable
greater_than: The threshold values per compound variable
condition_value: The current truth value of each condition per compound_variable per greater_than
control_state: Dictionary: node ID => (control state, control state start)
logic_mapping: Dictionary: (control node ID, truth state) => control state
record: Namedtuple with discrete control information for results
"""
struct DiscreteControl <: AbstractParameterNode
node_id::Vector{NodeID}
listen_node_id::Vector{NodeID}
variable::Vector{String}
look_ahead::Vector{Float64}
greater_than::Vector{Float64}
condition_value::Vector{Bool}
# Definition of compound variables
listen_node_id::Vector{Vector{NodeID}}
variable::Vector{Vector{String}}
weight::Vector{Vector{Float64}}
look_ahead::Vector{Vector{Float64}}
# Definition of conditions (one or more greater_than per compound variable)
greater_than::Vector{Vector{Float64}}
condition_value::Vector{BitVector}
# Definition of logic
control_state::Dict{NodeID, Tuple{String, Float64}}
logic_mapping::Dict{Tuple{NodeID, String}, String}
record::@NamedTuple{
Expand Down
84 changes: 77 additions & 7 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,81 @@
)
end

function parse_variables_and_conditions(compound_variable, condition)
node_id = NodeID[]
listen_node_id = Vector{NodeID}[]
variable = Vector{String}[]
weight = Vector{Float64}[]
look_ahead = Vector{Float64}[]
greater_than = Vector{Float64}[]
condition_value = BitVector[]
errors = false

# Loop over unique discrete_control node IDs (on which at least one condition is defined)
for id in unique(condition.node_id)
condition_group_id = filter(row -> row.node_id == id, condition)
variable_group_id = filter(row -> row.node_id == id, compound_variable)
# Loop over compound variables for this node ID
for compound_variable_id in unique(condition_group_id.compound_variable_id)
condition_group_variable = filter(
row -> row.compound_variable_id == compound_variable_id,
condition_group_id,
)
variable_group_variable = filter(
row -> row.compound_variable_id == compound_variable_id,
variable_group_id,
)
discrete_control_id = NodeID(NodeType.DiscreteControl, id)
if isempty(variable_group_variable)
errors = true
@error "compound_variable_id $compound_variable_id for $discrete_control_id in condition table but not in variable table"

Check warning on line 572 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L571-L572

Added lines #L571 - L572 were not covered by tests
else
push!(node_id, discrete_control_id)
push!(
listen_node_id,
NodeID.(
variable_group_variable.listen_node_type,
variable_group_variable.listen_node_id,
),
)
push!(variable, variable_group_variable.variable)
push!(weight, coalesce.(variable_group_variable.weight, 1.0))
push!(look_ahead, coalesce.(variable_group_variable.look_ahead, 0.0))
push!(greater_than, condition_group_variable.greater_than)
push!(
condition_value,
BitVector(zeros(length(condition_group_variable.greater_than))),
)
end
end
end
return node_id,
listen_node_id,
variable,
weight,
look_ahead,
greater_than,
condition_value,
!errors
end

function DiscreteControl(db::DB, config::Config)::DiscreteControl
condition = load_structvector(db, config, DiscreteControlConditionV1)
compound_variable = load_structvector(db, config, DiscreteControlVariableV1)

node_id,
listen_node_id,
variable,
weight,
look_ahead,
greater_than,
condition_value,
valid = parse_variables_and_conditions(compound_variable, condition)

if !valid
error("Problems encountered when parsing DiscreteControl variables and conditions.")

Check warning on line 617 in core/src/read.jl

View check run for this annotation

Codecov / codecov/patch

core/src/read.jl#L617

Added line #L617 was not covered by tests
end

condition_value = fill(false, length(condition.node_id))
control_state::Dict{NodeID, Tuple{String, Float64}} = Dict()

rows = execute(db, "SELECT from_node_id, edge_type FROM Edge ORDER BY fid")
Expand All @@ -557,7 +628,6 @@
end

logic = load_structvector(db, config, DiscreteControlLogicV1)

logic_mapping = Dict{Tuple{NodeID, String}, String}()

for (node_id, truth_state, control_state_) in
Expand All @@ -567,7 +637,6 @@
end

logic_mapping = expand_logic_mapping(logic_mapping)
look_ahead = coalesce.(condition.look_ahead, 0.0)

record = (
time = Float64[],
Expand All @@ -577,11 +646,12 @@
)

return DiscreteControl(
NodeID.(NodeType.DiscreteControl, condition.node_id), # Not unique
NodeID.(condition.listen_node_type, condition.listen_node_id),
condition.variable,
node_id, # Not unique
listen_node_id,
variable,
weight,
look_ahead,
condition.greater_than,
greater_than,
condition_value,
control_state,
logic_mapping,
Expand Down
12 changes: 10 additions & 2 deletions core/src/schema.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# These schemas define the name of database tables and the configuration file structure
# The identifier is parsed as ribasim.nodetype.kind, no capitals or underscores are allowed.
@schema "ribasim.discretecontrol.variable" DiscreteControlVariable
@schema "ribasim.discretecontrol.condition" DiscreteControlCondition
@schema "ribasim.discretecontrol.logic" DiscreteControlLogic
@schema "ribasim.basin.static" BasinStatic
Expand Down Expand Up @@ -183,15 +184,22 @@ end
node_id::Int32
end

@version DiscreteControlConditionV1 begin
@version DiscreteControlVariableV1 begin
node_id::Int32
compound_variable_id::Int32
listen_node_type::String
listen_node_id::Int32
variable::String
greater_than::Float64
weight::Union{Missing, Float64}
look_ahead::Union{Missing, Float64}
end

@version DiscreteControlConditionV1 begin
node_id::Int32
compound_variable_id::Int32
greater_than::Float64
end

@version DiscreteControlLogicV1 begin
node_id::Int32
truth_state::String
Expand Down
Loading
Loading