diff --git a/core/src/callback.jl b/core/src/callback.jl index 05b49dda4..fb1918831 100644 --- a/core/src/callback.jl +++ b/core/src/callback.jl @@ -13,11 +13,13 @@ function set_initial_discrete_controlled_parameters!( condition_diffs = zeros(Float64, n_conditions) discrete_control_condition(condition_diffs, storage0, integrator.t, integrator) + # Set the discrete control value (bool) per compound variable idx_start = 1 - for (i, vec) in enumerate(discrete_control.condition_value) + for (compound_variable_idx, vec) in enumerate(discrete_control.condition_value) l = length(vec) idx_end = idx_start + l - 1 - discrete_control.condition_value[i] .= (condition_diffs[idx_start:idx_end] .> 0) + discrete_control.condition_value[compound_variable_idx] .= + (condition_diffs[idx_start:idx_end] .> 0) idx_start += l end @@ -202,6 +204,7 @@ function discrete_control_condition(out, u, t, integrator) (; discrete_control) = p condition_idx = 0 + # Loop over compound variables for (listen_node_ids, variables, weights, greater_thans, look_aheads) in zip( discrete_control.listen_node_id, discrete_control.variable, @@ -214,6 +217,7 @@ function discrete_control_condition(out, u, t, integrator) zip(listen_node_ids, variables, weights, look_aheads) value += weight * get_value(p, listen_node_id, variable, look_ahead, u, t) end + # Loop over greater_than values for this compound_variable for greater_than in greater_thans condition_idx += 1 diff = value - greater_than @@ -268,22 +272,6 @@ function get_value( return value end -function get_discrete_control_indices(discrete_control::DiscreteControl, condition_idx::Int) - (; greater_than) = discrete_control - condition_idx_now = 1 - - for (compound_variable_idx, vec) in enumerate(greater_than) - l = length(vec) - - if condition_idx_now + l > condition_idx - greater_than_idx = condition_idx - condition_idx_now + 1 - return compound_variable_idx, greater_than_idx - end - - condition_idx_now += l - end -end - """ An upcrossing means that a condition (always greater than) becomes true. """ diff --git a/core/src/parameter.jl b/core/src/parameter.jl index e898280ff..9ea874d74 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -443,14 +443,13 @@ struct Terminal <: AbstractParameterNode end """ -node_id: node ID of the DiscreteControl node; these are not unique but repeated - by the amount of conditions of this DiscreteControl node -listen_node_id: the IDs of the nodes being condition on -variable: the names of the variables in the condition -weight: the weight of the variables in the condition -look_ahead: the look ahead of variables in the condition in seconds -greater_than: The threshold value in the condition -condition_value: The current value of each condition +node_id: node ID of the DiscreteControl node per compound variable (can contain repeats) +listen_node_id: the IDs of the nodes being condition on per compound variable +variable: the names of the variables in the condition per compound variable +weight: the weight of the variables in the condition per compound variable +look_ahead: the look ahead of variables in the condition in seconds per compound_variable +greater_than: The threshold values per compound variable +condition_value: The current truth value of each condition per compound_variable per greater_than control_state: Dictionary: node ID => (control state, control state start) logic_mapping: Dictionary: (control node ID, truth state) => control state record: Namedtuple with discrete control information for results diff --git a/core/src/read.jl b/core/src/read.jl index 62a7ddc89..bd07fd17d 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -552,9 +552,11 @@ function parse_variables_and_conditions(compound_variable, condition) condition_value = BitVector[] errors = false + # Loop over unique discrete_control node IDs (on which at least one condition is defined) for id in unique(condition.node_id) condition_group_id = filter(row -> row.node_id == id, condition) variable_group_id = filter(row -> row.node_id == id, compound_variable) + # Loop over compound variables for this node ID for compound_variable_id in unique(condition_group_id.compound_variable_id) condition_group_variable = filter( row -> row.compound_variable_id == compound_variable_id, diff --git a/core/src/util.jl b/core/src/util.jl index 9f28670cf..d84567632 100644 --- a/core/src/util.jl +++ b/core/src/util.jl @@ -709,3 +709,19 @@ function get_influx(basin::Basin, basin_idx::Int)::Float64 return precipitation[basin_idx] - evaporation[basin_idx] + drainage[basin_idx] - infiltration[basin_idx] end + +function get_discrete_control_indices(discrete_control::DiscreteControl, condition_idx::Int) + (; greater_than) = discrete_control + condition_idx_now = 1 + + for (compound_variable_idx, vec) in enumerate(greater_than) + l = length(vec) + + if condition_idx_now + l > condition_idx + greater_than_idx = condition_idx - condition_idx_now + 1 + return compound_variable_idx, greater_than_idx + end + + condition_idx_now += l + end +end