Skip to content

Commit

Permalink
Discrete control with discrete callback (#1393)
Browse files Browse the repository at this point in the history
Fixes #455, fixes
#495. The latter is made
redundant by this PR since U/D truth values are removed.
  • Loading branch information
SouthEndMusic authored Apr 17, 2024
1 parent a40821d commit 169466e
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 319 deletions.
212 changes: 34 additions & 178 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,3 @@
"""
Set parameters of nodes that are controlled by DiscreteControl to the
values corresponding to the initial state of the model.
"""
function set_initial_discrete_controlled_parameters!(
integrator,
storage0::Vector{Float64},
)::Nothing
(; p) = integrator
(; discrete_control) = p

n_conditions = sum(length(vec) for vec in discrete_control.condition_value; init = 0)
condition_diffs = zeros(Float64, n_conditions)
discrete_control_condition(condition_diffs, storage0, integrator.t, integrator)

# Set the discrete control value (bool) per compound variable
idx_start = 1
for (compound_variable_idx, vec) in enumerate(discrete_control.condition_value)
l = length(vec)
idx_end = idx_start + l - 1
discrete_control.condition_value[compound_variable_idx] .=
(condition_diffs[idx_start:idx_end] .> 0)
idx_start += l
end

# For every discrete_control node find a condition_idx it listens to
for discrete_control_node_id in unique(discrete_control.node_id)
condition_idx =
searchsortedfirst(discrete_control.node_id, discrete_control_node_id)
discrete_control_affect!(integrator, condition_idx, missing)
end
end

"""
Create the different callbacks that are used to store results
Expand Down Expand Up @@ -89,13 +57,7 @@ function create_callbacks(

n_conditions = sum(length(vec) for vec in discrete_control.greater_than; init = 0)
if n_conditions > 0
discrete_control_cb = VectorContinuousCallback(
discrete_control_condition,
discrete_control_affect_upcrossing!,
discrete_control_affect_downcrossing!,
n_conditions;
save_positions = (false, false),
)
discrete_control_cb = FunctionCallingCallback(apply_discrete_control!)
push!(callbacks, discrete_control_cb)
end
callback = CallbackSet(callbacks...)
Expand Down Expand Up @@ -186,33 +148,50 @@ function save_vertical_flux(u, t, integrator)
return vertical_flux_mean
end

function apply_discrete_control!(u, t, integrator)::Nothing
(; p) = integrator
(; discrete_control) = p
condition_idx = 0

discrete_control_condition!(u, t, integrator)

# For every compound variable see whether it changes a control state
for compound_variable_idx in eachindex(discrete_control.node_id)
discrete_control_affect!(integrator, compound_variable_idx)
end
end

"""
Listens for changes in condition truths.
Update discrete control condition truths.
"""
function discrete_control_condition(out, u, t, integrator)
function discrete_control_condition!(u, t, integrator)
(; p) = integrator
(; discrete_control) = p
condition_idx = 0

# Loop over compound variables
for (listen_node_ids, variables, weights, greater_thans, look_aheads) in zip(
for (
listen_node_ids,
variables,
weights,
greater_thans,
look_aheads,
condition_values,
) in zip(
discrete_control.listen_node_id,
discrete_control.variable,
discrete_control.weight,
discrete_control.greater_than,
discrete_control.look_ahead,
discrete_control.condition_value,
)
value = 0.0
for (listen_node_id, variable, weight, look_ahead) in
zip(listen_node_ids, variables, weights, look_aheads)
value += weight * get_value(p, listen_node_id, variable, look_ahead, u, t)
end
# Loop over greater_than values for this compound_variable
for greater_than in greater_thans
condition_idx += 1
diff = value - greater_than
out[condition_idx] = diff
end

condition_values .= false
condition_values[1:searchsortedlast(greater_thans, value)] .= true
end
end

Expand Down Expand Up @@ -262,105 +241,14 @@ function get_value(
return value
end

"""
An upcrossing means that a condition (always greater than) becomes true.
"""
function discrete_control_affect_upcrossing!(integrator, condition_idx)
(; p, u, t) = integrator
(; discrete_control, basin) = p
(; variable, condition_value, listen_node_id) = discrete_control

compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
condition_value[compound_variable_idx][greater_than_idx] = true

control_state_change = discrete_control_affect!(integrator, condition_idx, true)

# Check whether the control state change changed the direction of the crossing
# NOTE: This works for level conditions, but not for flow conditions on an
# arbitrary edge. That is because parameter changes do not change the instantaneous level,
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
listen_node_ids = discrete_control.listen_node_id[compound_variable_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

# NOTE: The above no longer works when listen feature ids can be something other than node ids
# I think the more durable option is to give all possible condition types a different variable string,
# e.g. basin.level and level_boundary.level
if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
_, condition_basin_idx =
id_index(basin.node_id, listen_node_id[compound_variable_idx][1])

if du[condition_basin_idx] < 0.0
condition_value[compound_variable_idx][greater_than_idx] = false
discrete_control_affect!(integrator, condition_idx, false)
end
end
end

"""
An downcrossing means that a condition (always greater than) becomes false.
"""
function discrete_control_affect_downcrossing!(integrator, condition_idx)
(; p, u, t) = integrator
(; discrete_control, basin) = p
(; variable, condition_value, listen_node_id) = discrete_control

compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
condition_value[compound_variable_idx][greater_than_idx] = false

control_state_change = discrete_control_affect!(integrator, condition_idx, false)

# Check whether the control state change changed the direction of the crossing
# NOTE: This works for level conditions, but not for flow conditions on an
# arbitrary edge. That is because parameter changes do not change the instantaneous level,
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
listen_node_ids = discrete_control.listen_node_id[compound_variable_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
has_index, condition_basin_idx =
id_index(basin.node_id, listen_node_id[compound_variable_idx][1])

if has_index && du[condition_basin_idx] > 0.0
condition_value[compound_variable_idx][greater_than_idx] = true
discrete_control_affect!(integrator, condition_idx, true)
end
end
end

"""
Change parameters based on the control logic.
"""
function discrete_control_affect!(
integrator,
condition_idx::Int,
upcrossing::Union{Bool, Missing},
)::Bool
function discrete_control_affect!(integrator, compound_variable_idx)
p = integrator.p
(; discrete_control, graph) = p

# Get the discrete_control node that listens to this condition

compound_variable_idx, _ = get_discrete_control_indices(discrete_control, condition_idx)
# Get the discrete_control node to which this compound variable belongs
discrete_control_node_id = discrete_control.node_id[compound_variable_idx]

# Get the indices of all conditions that this control node listens to
Expand All @@ -376,56 +264,24 @@ function discrete_control_affect!(
)
truth_state = join(truth_values, "")

# Get the truth specific about the latest crossing
if !ismissing(upcrossing)
truth_value_idx =
condition_idx - sum(
length(vec) for
vec in discrete_control.condition_value[1:(where_node_id.start - 1)];
init = 0,
)
truth_values[truth_value_idx] = upcrossing ? "U" : "D"
end
truth_state_crossing_specific = join(truth_values, "")

# What the local control state should be
control_state_new =
if haskey(
discrete_control.logic_mapping,
(discrete_control_node_id, truth_state_crossing_specific),
)
truth_state_used = truth_state_crossing_specific
discrete_control.logic_mapping[(
discrete_control_node_id,
truth_state_crossing_specific,
)]
elseif haskey(
discrete_control.logic_mapping,
(discrete_control_node_id, truth_state),
)
truth_state_used = truth_state
if haskey(discrete_control.logic_mapping, (discrete_control_node_id, truth_state))
discrete_control.logic_mapping[(discrete_control_node_id, truth_state)]
else
error(
"Control state specified for neither $truth_state_crossing_specific nor $truth_state for $discrete_control_node_id.",
"No control state specified for $discrete_control_node_id for truth state $truth_state.",
)
end

# What the local control state is
# TODO: Check time elapsed since control change
control_state_now, _ = discrete_control.control_state[discrete_control_node_id]

control_state_change = false

if control_state_now != control_state_new
control_state_change = true

# Store control action in record
record = discrete_control.record

push!(record.time, integrator.t)
push!(record.control_node_id, Int32(discrete_control_node_id))
push!(record.truth_state, truth_state_used)
push!(record.truth_state, truth_state)
push!(record.control_state, control_state_new)

# Loop over nodes which are under control of this control node
Expand All @@ -437,7 +293,7 @@ function discrete_control_affect!(
discrete_control.control_state[discrete_control_node_id] =
(control_state_new, integrator.t)
end
return control_state_change
return nothing
end

function get_allocation_model(p::Parameters, allocation_network_id::Int32)::AllocationModel
Expand Down
2 changes: 0 additions & 2 deletions core/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ function Model(config::Config)::Model
@show Ribasim.to
end

set_initial_discrete_controlled_parameters!(integrator, storage)

return Model(integrator, config, saved)
end

Expand Down
2 changes: 1 addition & 1 deletion core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ function expand_logic_mapping(
logic_mapping_expanded = Dict{Tuple{NodeID, String}, String}()

for (node_id, truth_state) in keys(logic_mapping)
pattern = r"^[TFUD\*]+$"
pattern = r"^[TF\*]+$"
if !occursin(pattern, truth_state)
error("Truth state \'$truth_state\' contains illegal characters or is empty.")
end
Expand Down
1 change: 0 additions & 1 deletion core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,6 @@ function valid_discrete_control(p::Parameters, config::Config)::Bool

if !isempty(undefined_control_states)
undefined_list = collect(undefined_control_states)
node_type = typeof(node).name.name
@error "These control states from $id are not defined for controlled $id_outneighbor: $undefined_list."
errors = true
end
Expand Down
32 changes: 1 addition & 31 deletions core/test/control_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,41 +145,11 @@ end
@test discrete_control.record.control_state == ["high", "low"]
@test discrete_control.record.time[1] == 0.0
t = Ribasim.datetime_since(discrete_control.record.time[2], model.config.starttime)
@test Date(t) == Date("2020-03-15")
@test Date(t) == Date("2020-03-16")
# then the rating curve is updated to the "low" control_state
@test only(p.tabulated_rating_curve.tables).t[2] == 1.2
end

@testitem "Setpoint with bounds control" begin
toml_path = normpath(
@__DIR__,
"../../generated_testmodels/level_setpoint_with_minmax/ribasim.toml",
)
@test ispath(toml_path)
model = Ribasim.run(toml_path)
p = model.integrator.p
(; discrete_control) = p
(; record, greater_than) = discrete_control
level = Ribasim.get_storages_and_levels(model).level[1, :]
t = Ribasim.tsaves(model)

t_none_1 = discrete_control.record.time[2]
t_in = discrete_control.record.time[3]
t_none_2 = discrete_control.record.time[4]

level_min = greater_than[1][1]
setpoint = greater_than[1][2]

t_1_none_index = findfirst(>=(t_none_1), t)
t_in_index = findfirst(>=(t_in), t)
t_2_none_index = findfirst(>=(t_none_2), t)

@test record.control_state == ["out", "none", "in", "none"]
@test level[t_1_none_index] <= setpoint
@test level[t_in_index] >= level_min
@test level[t_2_none_index] <= setpoint
end

@testitem "Set PID target with DiscreteControl" begin
using Ribasim: NodeID

Expand Down
Loading

0 comments on commit 169466e

Please sign in to comment.