From 75b56a077f379e1510fbf9223f1ffa80e6f717e0 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Tue, 16 Apr 2024 11:49:38 +0200 Subject: [PATCH] Fix bugs --- core/src/callback.jl | 106 +++++++++++++----- core/src/parameter.jl | 7 +- core/src/read.jl | 80 ++++++++++--- core/src/validation.jl | 12 +- core/test/control_test.jl | 12 +- python/ribasim/ribasim/config.py | 3 +- .../ribasim_testmodels/discrete_control.py | 1 + 7 files changed, 158 insertions(+), 63 deletions(-) diff --git a/core/src/callback.jl b/core/src/callback.jl index 7e6daec53..9dbb48acd 100644 --- a/core/src/callback.jl +++ b/core/src/callback.jl @@ -9,10 +9,17 @@ function set_initial_discrete_controlled_parameters!( (; p) = integrator (; discrete_control) = p - n_conditions = length(discrete_control.condition_value) + n_conditions = sum(length(vec) for vec in discrete_control.condition_value) condition_diffs = zeros(Float64, n_conditions) discrete_control_condition(condition_diffs, storage0, integrator.t, integrator) - discrete_control.condition_value .= (condition_diffs .> 0.0) + + idx_start = 1 + for (i, vec) in enumerate(discrete_control.condition_value) + l = length(vec) + idx_end = idx_start + l - 1 + discrete_control.condition_value[i] .= (condition_diffs[idx_start:idx_end] .> 0) + idx_start += l + end # For every discrete_control node find a condition_idx it listens to for discrete_control_node_id in unique(discrete_control.node_id) @@ -88,7 +95,7 @@ function create_callbacks( saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level) - n_conditions = length(discrete_control.node_id) + n_conditions = sum(length(vec) for vec in discrete_control.greater_than; init = 0) if n_conditions > 0 discrete_control_cb = VectorContinuousCallback( discrete_control_condition, @@ -193,23 +200,25 @@ Listens for changes in condition truths. function discrete_control_condition(out, u, t, integrator) (; p) = integrator (; discrete_control) = p - - for (i, (listen_node_ids, variables, weights, greater_than, look_aheads)) in enumerate( - zip( - discrete_control.listen_node_id, - discrete_control.variable, - discrete_control.weight, - discrete_control.greater_than, - discrete_control.look_ahead, - ), + condition_idx = 0 + + for (listen_node_ids, variables, weights, greater_thans, look_aheads) in zip( + discrete_control.listen_node_id, + discrete_control.variable, + discrete_control.weight, + discrete_control.greater_than, + discrete_control.look_ahead, ) value = 0.0 for (listen_node_id, variable, weight, look_ahead) in zip(listen_node_ids, variables, weights, look_aheads) value += weight * get_value(p, listen_node_id, variable, look_ahead, u, t) end - diff = value - greater_than - out[i] = diff + for greater_than in greater_thans + condition_idx += 1 + diff = value - greater_than + out[condition_idx] = diff + end end end @@ -259,6 +268,22 @@ function get_value( return value end +function get_discrete_control_indices(discrete_control::DiscreteControl, condition_idx::Int) + (; greater_than) = discrete_control + condition_idx_now = 1 + + for (compound_variable_idx, vec) in enumerate(greater_than) + l = length(vec) + + if condition_idx_now + l > condition_idx + greater_than_idx = condition_idx - condition_idx_now + 1 + return compound_variable_idx, greater_than_idx + end + + condition_idx_now += l + end +end + """ An upcrossing means that a condition (always greater than) becomes true. """ @@ -267,7 +292,9 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx) (; discrete_control, basin) = p (; variable, condition_value, listen_node_id) = discrete_control - condition_value[condition_idx] = true + compound_variable_idx, greater_than_idx = + get_discrete_control_indices(discrete_control, condition_idx) + condition_value[compound_variable_idx][greater_than_idx] = true control_state_change = discrete_control_affect!(integrator, condition_idx, true) @@ -277,7 +304,7 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx) # only possibly the du. Parameter changes can change the flow on an edge discontinuously, # giving the possibility of logical paradoxes where certain parameter changes immediately # undo the truth state that caused that parameter change. - listen_node_ids = discrete_control.listen_node_id[condition_idx] + listen_node_ids = discrete_control.listen_node_id[compound_variable_idx] is_basin = length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] : false @@ -285,15 +312,16 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx) # NOTE: The above no longer works when listen feature ids can be something other than node ids # I think the more durable option is to give all possible condition types a different variable string, # e.g. basin.level and level_boundary.level - if variable[condition_idx][1] == "level" && control_state_change && is_basin + if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin # Calling water_balance is expensive, but it is a sure way of getting # du for the basin of this level condition du = zero(u) water_balance!(du, u, p, t) - _, condition_basin_idx = id_index(basin.node_id, listen_node_id[condition_idx][1]) + _, condition_basin_idx = + id_index(basin.node_id, listen_node_id[compound_variable_idx][1]) if du[condition_basin_idx] < 0.0 - condition_value[condition_idx] = false + condition_value[compound_variable_idx][greater_than_idx] = false discrete_control_affect!(integrator, condition_idx, false) end end @@ -307,7 +335,9 @@ function discrete_control_affect_downcrossing!(integrator, condition_idx) (; discrete_control, basin) = p (; variable, condition_value, listen_node_id) = discrete_control - condition_value[condition_idx] = false + compound_variable_idx, greater_than_idx = + get_discrete_control_indices(discrete_control, condition_idx) + condition_value[compound_variable_idx][greater_than_idx] = false control_state_change = discrete_control_affect!(integrator, condition_idx, false) @@ -317,21 +347,23 @@ function discrete_control_affect_downcrossing!(integrator, condition_idx) # only possibly the du. Parameter changes can change the flow on an edge discontinuously, # giving the possibility of logical paradoxes where certain parameter changes immediately # undo the truth state that caused that parameter change. - listen_node_ids = discrete_control.listen_node_id[condition_idx] + compound_variable_idx, greater_than_idx = + get_discrete_control_indices(discrete_control, condition_idx) + listen_node_ids = discrete_control.listen_node_id[compound_variable_idx] is_basin = length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] : false - if variable[condition_idx][1] == "level" && control_state_change && is_basin + if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin # Calling water_balance is expensive, but it is a sure way of getting # du for the basin of this level condition du = zero(u) water_balance!(du, u, p, t) has_index, condition_basin_idx = - id_index(basin.node_id, listen_node_id[condition_idx][1]) + id_index(basin.node_id, listen_node_id[compound_variable_idx][1]) if has_index && du[condition_basin_idx] > 0.0 - condition_value[condition_idx] = true + condition_value[compound_variable_idx][greater_than_idx] = true discrete_control_affect!(integrator, condition_idx, true) end end @@ -349,20 +381,34 @@ function discrete_control_affect!( (; discrete_control, graph) = p # Get the discrete_control node that listens to this condition - discrete_control_node_id = discrete_control.node_id[condition_idx] + + compound_variable_idx, _ = get_discrete_control_indices(discrete_control, condition_idx) + discrete_control_node_id = discrete_control.node_id[compound_variable_idx] # Get the indices of all conditions that this control node listens to - condition_ids = discrete_control.node_id .== discrete_control_node_id + where_node_id = searchsorted(discrete_control.node_id, discrete_control_node_id) # Get the truth state for this discrete_control node - truth_values = [ifelse(b, "T", "F") for b in discrete_control.condition_value] - truth_state = join(truth_values[condition_ids], "") + truth_values = cat( + [ + [ifelse(b, "T", "F") for b in discrete_control.condition_value[i]] for + i in where_node_id + ]...; + dims = 1, + ) + truth_state = join(truth_values, "") # Get the truth specific about the latest crossing if !ismissing(upcrossing) - truth_values[condition_idx] = upcrossing ? "U" : "D" + truth_value_idx = + condition_idx - sum( + length(vec) for + vec in discrete_control.condition_value[1:(where_node_id.start - 1)]; + init = 0, + ) + truth_values[truth_value_idx] = upcrossing ? "U" : "D" end - truth_state_crossing_specific = join(truth_values[condition_ids], "") + truth_state_crossing_specific = join(truth_values, "") # What the local control state should be control_state_new = diff --git a/core/src/parameter.jl b/core/src/parameter.jl index 231f0a9c4..e898280ff 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -457,12 +457,15 @@ record: Namedtuple with discrete control information for results """ struct DiscreteControl <: AbstractParameterNode node_id::Vector{NodeID} + # Definition of compound variables listen_node_id::Vector{Vector{NodeID}} variable::Vector{Vector{String}} weight::Vector{Vector{Float64}} look_ahead::Vector{Vector{Float64}} - greater_than::Vector{Float64} - condition_value::Vector{Bool} + # Definition of conditions (one or more greater_than per compound variable) + greater_than::Vector{Vector{Float64}} + condition_value::Vector{BitVector} + # Definition of logic control_state::Dict{NodeID, Tuple{String, Float64}} logic_mapping::Dict{Tuple{NodeID, String}, String} record::@NamedTuple{ diff --git a/core/src/read.jl b/core/src/read.jl index a0aec5fda..62a7ddc89 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -542,33 +542,79 @@ function Basin(db::DB, config::Config, chunk_sizes::Vector{Int})::Basin ) end -function DiscreteControl(db::DB, config::Config)::DiscreteControl - compound_variable = load_structvector(db, config, DiscreteControlVariableV1) - condition = load_structvector(db, config, DiscreteControlConditionV1) - +function parse_variables_and_conditions(compound_variable, condition) node_id = NodeID[] listen_node_id = Vector{NodeID}[] variable = Vector{String}[] weight = Vector{Float64}[] look_ahead = Vector{Float64}[] + greater_than = Vector{Float64}[] + condition_value = BitVector[] + errors = false for id in unique(condition.node_id) - group_id = filter(row -> row.node_id == id, compound_variable) - for group_variable in - StructVector.(IterTools.groupby(row -> row.compound_variable_id, group_id)) - first_row = first(group_variable) - push!(node_id, NodeID(NodeType.DiscreteControl, first_row.node_id)) - push!( - listen_node_id, - NodeID.(group_variable.listen_node_type, group_variable.listen_node_id), + condition_group_id = filter(row -> row.node_id == id, condition) + variable_group_id = filter(row -> row.node_id == id, compound_variable) + for compound_variable_id in unique(condition_group_id.compound_variable_id) + condition_group_variable = filter( + row -> row.compound_variable_id == compound_variable_id, + condition_group_id, ) - push!(variable, group_variable.variable) - push!(weight, coalesce.(group_variable.weight, 1.0)) - push!(look_ahead, coalesce.(group_variable.look_ahead, 0.0)) + variable_group_variable = filter( + row -> row.compound_variable_id == compound_variable_id, + variable_group_id, + ) + discrete_control_id = NodeID(NodeType.DiscreteControl, id) + if isempty(variable_group_variable) + errors = true + @error "compound_variable_id $compound_variable_id for $discrete_control_id in condition table but not in variable table" + else + push!(node_id, discrete_control_id) + push!( + listen_node_id, + NodeID.( + variable_group_variable.listen_node_type, + variable_group_variable.listen_node_id, + ), + ) + push!(variable, variable_group_variable.variable) + push!(weight, coalesce.(variable_group_variable.weight, 1.0)) + push!(look_ahead, coalesce.(variable_group_variable.look_ahead, 0.0)) + push!(greater_than, condition_group_variable.greater_than) + push!( + condition_value, + BitVector(zeros(length(condition_group_variable.greater_than))), + ) + end end end + return node_id, + listen_node_id, + variable, + weight, + look_ahead, + greater_than, + condition_value, + !errors +end + +function DiscreteControl(db::DB, config::Config)::DiscreteControl + condition = load_structvector(db, config, DiscreteControlConditionV1) + compound_variable = load_structvector(db, config, DiscreteControlVariableV1) + + node_id, + listen_node_id, + variable, + weight, + look_ahead, + greater_than, + condition_value, + valid = parse_variables_and_conditions(compound_variable, condition) + + if !valid + error("Problems encountered when parsing DiscreteControl variables and conditions.") + end - condition_value = fill(false, length(condition.node_id)) control_state::Dict{NodeID, Tuple{String, Float64}} = Dict() rows = execute(db, "SELECT from_node_id, edge_type FROM Edge ORDER BY fid") @@ -603,7 +649,7 @@ function DiscreteControl(db::DB, config::Config)::DiscreteControl variable, weight, look_ahead, - condition.greater_than, + greater_than, condition_value, control_state, logic_mapping, diff --git a/core/src/validation.jl b/core/src/validation.jl index 580b4d5ba..9d6a621d2 100644 --- a/core/src/validation.jl +++ b/core/src/validation.jl @@ -102,12 +102,9 @@ sort_by_id_state_level(row) = (row.node_id, row.control_state, row.level) sort_by_priority(row) = (row.node_id, row.priority) sort_by_priority_time(row) = (row.node_id, row.priority, row.time) sort_by_subgrid_level(row) = (row.subgrid_id, row.basin_level) -sort_by_condition(row) = - (row.node_id, row.listen_node_type, row.listen_node_id, row.variable, row.greater_than) sort_by_variable(row) = (row.node_id, row.listen_node_type, row.listen_node_id, row.variable) -sort_by_id_greater_than(row) = (row.node_id, row.greater_than) -sort_by_truth_state(row) = (row.node_id, row.truth_state) +sort_by_condition(row) = (row.node_id, row.compound_variable_id, row.greater_than) # get the right sort by function given the Schema, with sort_by_id as the default sort_by_function(table::StructVector{<:Legolas.AbstractRecord}) = sort_by_id @@ -117,7 +114,7 @@ sort_by_function(table::StructVector{UserDemandStaticV1}) = sort_by_priority sort_by_function(table::StructVector{UserDemandTimeV1}) = sort_by_priority_time sort_by_function(table::StructVector{BasinSubgridV1}) = sort_by_subgrid_level sort_by_function(table::StructVector{DiscreteControlVariableV1}) = sort_by_variable -sort_by_function(table::StructVector{DiscreteControlConditionV1}) = sort_by_id_greater_than +sort_by_function(table::StructVector{DiscreteControlConditionV1}) = sort_by_condition const TimeSchemas = Union{ BasinTimeV1, @@ -499,7 +496,8 @@ Check: """ function valid_discrete_control(p::Parameters, config::Config)::Bool (; discrete_control, graph) = p - (; node_id, logic_mapping, look_ahead, variable, listen_node_id) = discrete_control + (; node_id, logic_mapping, look_ahead, variable, listen_node_id, greater_than) = + discrete_control t_end = seconds_since(config.endtime, config.starttime) errors = false @@ -512,7 +510,7 @@ function valid_discrete_control(p::Parameters, config::Config)::Bool truth_states_wrong_length = String[] # The number of conditions of this DiscreteControl node - n_conditions = length(searchsorted(node_id, id)) + n_conditions = sum(length(greater_than[i]) for i in searchsorted(node_id, id)) for (key, control_state) in logic_mapping id_, truth_state = key diff --git a/core/test/control_test.jl b/core/test/control_test.jl index e0913ba1d..21bcc653c 100644 --- a/core/test/control_test.jl +++ b/core/test/control_test.jl @@ -37,11 +37,11 @@ # Control times t_1 = discrete_control.record.time[3] t_1_index = findfirst(>=(t_1), t) - @test level[1, t_1_index] <= discrete_control.greater_than[1] + @test level[1, t_1_index] <= discrete_control.greater_than[1][1] t_2 = discrete_control.record.time[4] t_2_index = findfirst(>=(t_2), t) - @test level[2, t_2_index] >= discrete_control.greater_than[2] + @test level[2, t_2_index] >= discrete_control.greater_than[2][1] flow = get_tmp(graph[].flow, 0) @test all(iszero, flow) @@ -60,7 +60,7 @@ end t_control = discrete_control.record.time[2] t_control_index = searchsortedfirst(t, t_control) - greater_than = discrete_control.greater_than[1] + greater_than = discrete_control.greater_than[1][1] flow_t_control = flow_boundary.flow_rate[1](t_control) flow_t_control_ahead = flow_boundary.flow_rate[1](t_control + Δt) @@ -84,7 +84,7 @@ end t_control = discrete_control.record.time[2] t_control_index = searchsortedfirst(t, t_control) - greater_than = discrete_control.greater_than[1] + greater_than = discrete_control.greater_than[1][1] level_t_control = level_boundary.level[1](t_control) level_t_control_ahead = level_boundary.level[1](t_control + Δt) @@ -167,8 +167,8 @@ end t_in = discrete_control.record.time[3] t_none_2 = discrete_control.record.time[4] - level_min = greater_than[1] - setpoint = greater_than[2] + level_min = greater_than[1][1] + setpoint = greater_than[1][2] t_1_none_index = findfirst(>=(t_none_1), t) t_in_index = findfirst(>=(t_in), t) diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py index d837fc8fc..30408fe05 100644 --- a/python/ribasim/ribasim/config.py +++ b/python/ribasim/ribasim/config.py @@ -304,13 +304,14 @@ class DiscreteControl(MultiNodeModel): json_schema_extra={ "sort_keys": [ "node_id", + "compound_variable_id", "greater_than", ] }, ) logic: TableModel[DiscreteControlLogicSchema] = Field( default_factory=TableModel[DiscreteControlLogicSchema], - json_schema_extra={"sort_keys": ["node_id"]}, + json_schema_extra={"sort_keys": ["node_id", "truth_state"]}, ) diff --git a/python/ribasim_testmodels/ribasim_testmodels/discrete_control.py b/python/ribasim_testmodels/ribasim_testmodels/discrete_control.py index d8d911674..650e73876 100644 --- a/python/ribasim_testmodels/ribasim_testmodels/discrete_control.py +++ b/python/ribasim_testmodels/ribasim_testmodels/discrete_control.py @@ -299,6 +299,7 @@ def tabulated_rating_curve_control_model() -> Model: ), discrete_control.Condition( greater_than=[0.5], + compound_variable_id=1, ), discrete_control.Logic( truth_state=["T", "F"], control_state=["low", "high"]