diff --git a/core/src/callback.jl b/core/src/callback.jl index 831c216db..df7f7734c 100644 --- a/core/src/callback.jl +++ b/core/src/callback.jl @@ -56,13 +56,10 @@ function create_callbacks( push!(callbacks, export_cb) end - saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level) + discrete_control_cb = FunctionCallingCallback(apply_discrete_control!) + push!(callbacks, discrete_control_cb) - n_conditions = sum(length(vec) for vec in discrete_control.greater_than; init = 0) - if n_conditions > 0 - discrete_control_cb = FunctionCallingCallback(apply_discrete_control!) - push!(callbacks, discrete_control_cb) - end + saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level) callback = CallbackSet(callbacks...) return callback, saved @@ -197,77 +194,143 @@ function save_vertical_flux(u, t, integrator) return vertical_flux_mean end +""" +Apply the discrete control logic. There's somewhat of a complex structure: +- Each DiscreteControl node can have one or multiple compound variables it listens to +- A compound variable is defined as a linear combination of state/time derived parameters of the model +- Each compound variable has associated with it a sorted vector of greater_than values, which define an ordered + list of conditions of the form (compound variable value) => greater_than +- Thus, to find out which conditions are true, we only need to find the largest index in the greater than values + such that the above condition is true +- The truth value (true/false) of all these conditions for all variables of a DiscreteControl node are concatenated + (in preallocated memory) into what is called the nodes truth state. This concatenation happens in the order in which + the compound variables appear in discrete_control.compound_variables +- The DiscreteControl node maps this truth state via the logic mapping to a control state, which is a string +- The nodes that are controlled by this DiscreteControl node must have the same control state, for which they have + parameter values associated with that control state defined in their control_mapping +""" function apply_discrete_control!(u, t, integrator)::Nothing (; p) = integrator (; discrete_control) = p - condition_idx = 0 + (; node_id) = discrete_control + + # Loop over the discrete control nodes to determine their truth state + # and detect possible control state changes + for i in eachindex(node_id) + id = node_id[i] + truth_state = discrete_control.truth_state[i] + compound_variables = discrete_control.compound_variables[i] + + # Whether a change in truth state was detected, and thus whether + # a change in control state is possible + truth_state_change = false + + # As the truth state of this node is being updated for the different variables + # it listens to, this is the first index of the truth values for the current variable + truth_value_variable_idx = 1 + + # Loop over the variables listened to by this discrete control node + for compound_variable in compound_variables + + # Compute the value of the current variable + value = 0.0 + for subvariable in compound_variable.subvariables + value += subvariable.weight * get_value(p, subvariable, t) + end + + # The thresholds the value of this variable is being compared with + greater_thans = compound_variable.greater_than + n_greater_than = length(greater_thans) + + # Find the largest index i within the greater thans for this variable + # such that value >= greater_than and shift towards the index in the truth state + largest_true_index = + truth_value_variable_idx - 1 + searchsortedlast(greater_thans, value) + + # Update the truth values in the truth states for the current discrete control node + # corresponding to the conditions on the current variable + for truth_value_idx in + truth_value_variable_idx:(truth_value_variable_idx + n_greater_than - 1) + new_truth_state = (truth_value_idx <= largest_true_index) + # If no truth state change was detected yet, check whether there is a change + # at this position + if !truth_state_change + truth_state_change = (new_truth_state != truth_state[truth_value_idx]) + end + truth_state[truth_value_idx] = new_truth_state + end - discrete_control_condition!(u, t, integrator) + truth_value_variable_idx += n_greater_than + end + + # If no truth state change whas detected for this node, no control + # state change is possible either + if !((t == 0) || truth_state_change) + continue + end - # For every compound variable see whether it changes a control state - for compound_variable_idx in eachindex(discrete_control.node_id) - discrete_control_affect!(integrator, compound_variable_idx) + set_new_control_state!(integrator, id, truth_state) end + return nothing end -""" -Update discrete control condition truths. -""" -function discrete_control_condition!(u, t, integrator) +function set_new_control_state!( + integrator, + discrete_control_id::NodeID, + truth_state::Vector{Bool}, +)::Nothing (; p) = integrator (; discrete_control) = p - # Loop over compound variables - for ( - listen_node_ids, - variables, - weights, - greater_thans, - look_aheads, - condition_values, - ) in zip( - discrete_control.listen_node_id, - discrete_control.variable, - discrete_control.weight, - discrete_control.greater_than, - discrete_control.look_ahead, - discrete_control.condition_value, + # Get the control state corresponding to the new truth state, + # if one is defined + control_state_new = + get(discrete_control.logic_mapping, (discrete_control_id, truth_state), nothing) + isnothing(control_state_new) && error( + lazy"No control state specified for $discrete_control_node_id for truth state $truth_state.", ) - value = 0.0 - for (listen_node_id, variable, weight, look_ahead) in - zip(listen_node_ids, variables, weights, look_aheads) - value += weight * get_value(p, listen_node_id, variable, look_ahead, u, t) + + # Check the new control state against the current control state + # If there is a change, update parameters and the discrete control record + control_state_now, _ = discrete_control.control_state[discrete_control_id] + if control_state_now != control_state_new + record = discrete_control.record + + push!(record.time, integrator.t) + push!(record.control_node_id, Int32(discrete_control_id)) + push!(record.truth_state, convert_truth_state(truth_state)) + push!(record.control_state, control_state_new) + + # Loop over nodes which are under control of this control node + for target_node_id in + outneighbor_labels_type(p.graph, discrete_control_id, EdgeType.control) + set_control_params!(p, target_node_id, control_state_new) end - condition_values .= false - condition_values[1:searchsortedlast(greater_thans, value)] .= true + discrete_control.control_state[discrete_control_id] = + (control_state_new, integrator.t) end + return nothing end """ Get a value for a condition. Currently supports getting levels from basins and flows from flow boundaries. """ -function get_value( - p::Parameters, - node_id::NodeID, - variable::String, - Δt::Float64, - u::AbstractVector{Float64}, - t::Float64, -) +function get_value(p::Parameters, subvariable::NamedTuple, t::Float64) (; basin, flow_boundary, level_boundary) = p + (; listen_node_id, look_ahead, variable) = subvariable if variable == "level" - if node_id.type == NodeType.Basin - has_index, basin_idx = id_index(basin.node_id, node_id) + if listen_node_id.type == NodeType.Basin + has_index, basin_idx = id_index(basin.node_id, listen_node_id) if !has_index - error("Discrete control listen node $node_id does not exist.") + error("Discrete control listen node $listen_node_id does not exist.") end - _, level = get_area_and_level(basin, basin_idx, u[basin_idx]) - elseif node_id.type == NodeType.LevelBoundary - level_boundary_idx = findsorted(level_boundary.node_id, node_id) - level = level_boundary.level[level_boundary_idx](t + Δt) + level = get_tmp(basin.current_level, 0)[basin_idx] + elseif listen_node_id.type == NodeType.LevelBoundary + level_boundary_idx = findsorted(level_boundary.node_id, listen_node_id) + level = level_boundary.level[level_boundary_idx](t + look_ahead) else error( "Level condition node '$node_id' is neither a basin nor a level boundary.", @@ -276,11 +339,11 @@ function get_value( value = level elseif variable == "flow_rate" - if node_id.type == NodeType.FlowBoundary - flow_boundary_idx = findsorted(flow_boundary.node_id, node_id) - value = flow_boundary.flow_rate[flow_boundary_idx](t + Δt) + if listen_node_id.type == NodeType.FlowBoundary + flow_boundary_idx = findsorted(flow_boundary.node_id, listen_node_id) + value = flow_boundary.flow_rate[flow_boundary_idx](t + look_ahead) else - error("Flow condition node $node_id is not a flow boundary.") + error("Flow condition node $listen_node_id is not a flow boundary.") end else @@ -290,61 +353,6 @@ function get_value( return value end -""" -Change parameters based on the control logic. -""" -function discrete_control_affect!(integrator, compound_variable_idx) - p = integrator.p - (; discrete_control, graph) = p - - # Get the discrete_control node to which this compound variable belongs - discrete_control_node_id = discrete_control.node_id[compound_variable_idx] - - # Get the indices of all conditions that this control node listens to - where_node_id = searchsorted(discrete_control.node_id, discrete_control_node_id) - - # Get the truth state for this discrete_control node - truth_values = cat( - [ - [ifelse(b, "T", "F") for b in discrete_control.condition_value[i]] for - i in where_node_id - ]...; - dims = 1, - ) - truth_state = join(truth_values, "") - - # What the local control state should be - control_state_new = - if haskey(discrete_control.logic_mapping, (discrete_control_node_id, truth_state)) - discrete_control.logic_mapping[(discrete_control_node_id, truth_state)] - else - error( - "No control state specified for $discrete_control_node_id for truth state $truth_state.", - ) - end - - control_state_now, _ = discrete_control.control_state[discrete_control_node_id] - if control_state_now != control_state_new - # Store control action in record - record = discrete_control.record - - push!(record.time, integrator.t) - push!(record.control_node_id, Int32(discrete_control_node_id)) - push!(record.truth_state, truth_state) - push!(record.control_state, control_state_new) - - # Loop over nodes which are under control of this control node - for target_node_id in - outneighbor_labels_type(graph, discrete_control_node_id, EdgeType.control) - set_control_params!(p, target_node_id, control_state_new) - end - - discrete_control.control_state[discrete_control_node_id] = - (control_state_new, integrator.t) - end - return nothing -end - function get_allocation_model(p::Parameters, subnetwork_id::Int32)::AllocationModel (; allocation) = p (; subnetwork_ids, allocation_models) = allocation @@ -404,22 +412,57 @@ function set_fractional_flow_in_allocation!( return nothing end -function set_control_params!(p::Parameters, node_id::NodeID, control_state::String) - node = getfield(p, p.graph[node_id].type) - idx = searchsortedfirst(node.node_id, node_id) - new_state = node.control_mapping[(node_id, control_state)] +function discrete_control_parameter_update!( + fractional_flow::FractionalFlow, + node_id::NodeID, + control_state::String, + p::Parameters, +)::Nothing + parameter_update = fractional_flow.control_mapping[(node_id, control_state)] + (; node_idx, fraction) = parameter_update + fractional_flow.fraction[node_idx] = fraction + + if is_active(p.allocation) + set_fractional_flow_in_allocation!(p, fractional_flow.node_id[node_idx], fraction) + end + return nothing +end +function discrete_control_parameter_update!( + node::AbstractParameterNode, + node_id::NodeID, + control_state::String, +)::Nothing + new_state = node.control_mapping[(node_id, control_state)] + (; node_idx) = new_state for (field, value) in zip(keys(new_state), new_state) - if !ismissing(value) - vec = get_tmp(getfield(node, field), 0) - vec[idx] = value + if field == :node_idx + continue end + vec = get_tmp(getfield(node, field), 0) + vec[node_idx] = value + end +end - # Set new fractional flow fractions in allocation problem - if is_active(p.allocation) && node isa FractionalFlow && field == :fraction - set_fractional_flow_in_allocation!(p, node_id, value) - end +function set_control_params!(p::Parameters, node_id::NodeID, control_state::String)::Nothing + + # Check node type here to avoid runtime dispatch on the node type + if node_id.type == NodeType.Pump + discrete_control_parameter_update!(p.pump, node_id, control_state) + elseif node_id.type == NodeType.Outlet + discrete_control_parameter_update!(p.outlet, node_id, control_state) + elseif node_id.type == NodeType.TabulatedRatingCurve + discrete_control_parameter_update!(p.tabulated_rating_curve, node_id, control_state) + elseif node_id.type == NodeType.FractionalFlow + discrete_control_parameter_update!(p.fractional_flow, node_id, control_state, p) + elseif node_id.type == NodeType.PidControl + discrete_control_parameter_update!(p.pid_control, node_id, control_state) + elseif node_id.type == NodeType.LinearResistance + discrete_control_parameter_update!(p.linear_resistance, node_id, control_state) + elseif node_id.type == NodeType.ManningResistance + discrete_control_parameter_update!(p.manning_resistance, node_id, control_state) end + return nothing end function update_subgrid_level!(integrator)::Nothing @@ -531,7 +574,7 @@ end "Load updates from 'TabulatedRatingCurve / time' into the parameters" function update_tabulated_rating_curve!(integrator)::Nothing - (; node_id, tables, time) = integrator.p.tabulated_rating_curve + (; node_id, table, time) = integrator.p.tabulated_rating_curve t = datetime_since(integrator.t, integrator.p.starttime) # get groups of consecutive node_id for the current timestamp @@ -544,7 +587,7 @@ function update_tabulated_rating_curve!(integrator)::Nothing level = [row.level for row in group] flow_rate = [row.flow_rate for row in group] i = searchsortedfirst(node_id, NodeID(NodeType.TabulatedRatingCurve, id)) - tables[i] = LinearInterpolation(flow_rate, level; extrapolate = true) + table[i] = LinearInterpolation(flow_rate, level; extrapolate = true) end return nothing end diff --git a/core/src/graph.jl b/core/src/graph.jl index 8351dd92d..1645fab44 100644 --- a/core/src/graph.jl +++ b/core/src/graph.jl @@ -105,9 +105,11 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra if config.solver.autodiff flow = DiffCache(flow, chunk_sizes) end + flow_edges = EdgeMetadata[] graph_data = (; node_ids, edges_source, + flow_edges, flow_dict, flow, flow_prev, diff --git a/core/src/parameter.jl b/core/src/parameter.jl index b6c32a5bd..0e71f75c5 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -36,8 +36,6 @@ end Base.to_index(id::NodeID) = Int(id.value) const ScalarInterpolation = LinearInterpolation{Vector{Float64}, Vector{Float64}, Float64} -const VectorInterpolation = - LinearInterpolation{Vector{Vector{Float64}}, Vector{Float64}, Vector{Float64}} """ Store information for a subnetwork used for allocation. @@ -209,7 +207,7 @@ inflow_edge: incoming flow edge metadata outflow_edges: outgoing flow edges metadata The ID of the source node is always the ID of the TabulatedRatingCurve node active: whether this node is active and thus contributes flows -tables: The current Q(h) relationships +table: The current Q(h) relationships time: The time table used for updating the tables control_mapping: dictionary from (node_id, control_state) to Q(h) and/or active state """ @@ -218,9 +216,12 @@ struct TabulatedRatingCurve{C} <: AbstractParameterNode inflow_edge::Vector{EdgeMetadata} outflow_edges::Vector{Vector{EdgeMetadata}} active::BitVector - tables::Vector{ScalarInterpolation} + table::Vector{ScalarInterpolation} time::StructVector{TabulatedRatingCurveTimeV1, C, Int} - control_mapping::Dict{Tuple{NodeID, String}, NamedTuple} + control_mapping::Dict{ + Tuple{NodeID, String}, + @NamedTuple{node_idx::Int, active::Bool, table::ScalarInterpolation} + } end """ @@ -241,7 +242,10 @@ struct LinearResistance <: AbstractParameterNode active::BitVector resistance::Vector{Float64} max_flow_rate::Vector{Float64} - control_mapping::Dict{Tuple{NodeID, String}, NamedTuple} + control_mapping::Dict{ + Tuple{NodeID, String}, + @NamedTuple{node_idx::Int, active::Bool, resistance::Float64} + } end """ @@ -293,7 +297,10 @@ struct ManningResistance <: AbstractParameterNode profile_slope::Vector{Float64} upstream_bottom::Vector{Float64} downstream_bottom::Vector{Float64} - control_mapping::Dict{Tuple{NodeID, String}, NamedTuple} + control_mapping::Dict{ + Tuple{NodeID, String}, + @NamedTuple{node_idx::Int, active::Bool, manning_n::Float64} + } end """ @@ -310,7 +317,10 @@ struct FractionalFlow <: AbstractParameterNode inflow_edge::Vector{EdgeMetadata} outflow_edge::Vector{EdgeMetadata} fraction::Vector{Float64} - control_mapping::Dict{Tuple{NodeID, String}, NamedTuple} + control_mapping::Dict{ + Tuple{NodeID, String}, + @NamedTuple{node_idx::Int, fraction::Float64} + } end """ @@ -326,11 +336,13 @@ end """ node_id: node ID of the FlowBoundary node +outflow_ids: The downsteam nodes of this FlowBoundary node active: whether this node is active and thus contributes flow flow_rate: target flow rate """ struct FlowBoundary <: AbstractParameterNode node_id::Vector{NodeID} + outflow_ids::Vector{Vector{NodeID}} active::BitVector flow_rate::Vector{ScalarInterpolation} end @@ -356,7 +368,10 @@ struct Pump{T} <: AbstractParameterNode flow_rate::T min_flow_rate::Vector{Float64} max_flow_rate::Vector{Float64} - control_mapping::Dict{Tuple{NodeID, String}, NamedTuple} + control_mapping::Dict{ + Tuple{NodeID, String}, + @NamedTuple{node_idx::Int, active::Bool, flow_rate::Float64} + } is_pid_controlled::BitVector function Pump( @@ -410,7 +425,10 @@ struct Outlet{T} <: AbstractParameterNode min_flow_rate::Vector{Float64} max_flow_rate::Vector{Float64} min_crest_level::Vector{Float64} - control_mapping::Dict{Tuple{NodeID, String}, NamedTuple} + control_mapping::Dict{ + Tuple{NodeID, String}, + @NamedTuple{node_idx::Int, active::Bool, flow_rate::Float64} + } is_pid_controlled::BitVector function Outlet( @@ -452,30 +470,41 @@ struct Terminal <: AbstractParameterNode end """ -node_id: node ID of the DiscreteControl node per compound variable (can contain repeats) -listen_node_id: the IDs of the nodes being condition on per compound variable -variable: the names of the variables in the condition per compound variable -weight: the weight of the variables in the condition per compound variable -look_ahead: the look ahead of variables in the condition in seconds per compound_variable -greater_than: The threshold values per compound variable -condition_value: The current truth value of each condition per compound_variable per greater_than +The data for a single compound variable +node_id:: The ID of the DiscreteControl that listens to this variable +subvariables: data for one single subvariable +greater_than: the thresholds this compound variable will be + compared against +""" +struct CompoundVariable + node_id::NodeID + subvariables::Vector{ + @NamedTuple{ + listen_node_id::NodeID, + variable::String, + weight::Float64, + look_ahead::Float64, + } + } + greater_than::Vector{Float64} +end + +""" +node_id: node ID of the DiscreteControl (if it has at least one condition defined on it) +compound_variables: The compound variables the DiscreteControl node listens to +truth_state: Memory allocated for storing the truth state control_state: Dictionary: node ID => (control state, control state start) logic_mapping: Dictionary: (control node ID, truth state) => control state record: Namedtuple with discrete control information for results """ struct DiscreteControl <: AbstractParameterNode node_id::Vector{NodeID} - # Definition of compound variables - listen_node_id::Vector{Vector{NodeID}} - variable::Vector{Vector{String}} - weight::Vector{Vector{Float64}} - look_ahead::Vector{Vector{Float64}} - # Definition of conditions (one or more greater_than per compound variable) - greater_than::Vector{Vector{Float64}} - condition_value::Vector{BitVector} + compound_variables::Vector{Vector{CompoundVariable}} + # truth_state per discrete control node + truth_state::Vector{Vector{Bool}} # Definition of logic control_state::Dict{NodeID, Tuple{String, Float64}} - logic_mapping::Dict{Tuple{NodeID, String}, String} + logic_mapping::Dict{Tuple{NodeID, Vector{Bool}}, String} record::@NamedTuple{ time::Vector{Float64}, control_node_id::Vector{Int32}, @@ -500,9 +529,21 @@ struct PidControl{T} <: AbstractParameterNode active::BitVector listen_node_id::Vector{NodeID} target::Vector{ScalarInterpolation} - pid_params::Vector{VectorInterpolation} + proportional::Vector{ScalarInterpolation} + integral::Vector{ScalarInterpolation} + derivative::Vector{ScalarInterpolation} error::T - control_mapping::Dict{Tuple{NodeID, String}, NamedTuple} + control_mapping::Dict{ + Tuple{NodeID, String}, + @NamedTuple{ + node_idx::Int, + active::Bool, + target::ScalarInterpolation, + proportional::ScalarInterpolation, + integral::ScalarInterpolation, + derivative::ScalarInterpolation, + } + } end """ @@ -587,6 +628,12 @@ struct LevelDemand <: AbstractDemandNode priority::Vector{Int32} end +""" +node_id: node ID of the FlowDemand node +demand_itp: The time interpolation of the demand of the node +demand: The current demand of the node +priority: The priority of the demand of the node +""" struct FlowDemand <: AbstractDemandNode node_id::Vector{NodeID} demand_itp::Vector{ScalarInterpolation} @@ -602,27 +649,45 @@ struct Subgrid level::Vector{Float64} end +""" +The metadata of the graph (the fields of the NamedTuple) can be accessed + e.g. using graph[].flow. +node_ids: mapping subnetwork ID -> node IDs in that subnetwork +edges_source: mapping subnetwork ID -> metadata of allocation + source edges in that subnetwork +flow_edges: The metadata of all flow edges +flow dict: mapping (source ID, destination ID) -> index in the flow vector + of the flow over that edge +flow: Flow per flow edge in the order prescribed by flow_dict +flow_prev: The flow vector of the previous timestep, used for integration +flow_integrated: Flow integrated over time, used for mean flow computation + over saveat intervals +saveat: The time interval between saves of output data (storage, flow, ...) +""" +const ModelGraph{T} = MetaGraph{ + Int64, + DiGraph{Int64}, + NodeID, + NodeMetadata, + EdgeMetadata, + @NamedTuple{ + node_ids::Dict{Int32, Set{NodeID}}, + edges_source::Dict{Int32, Set{EdgeMetadata}}, + flow_edges::Vector{EdgeMetadata}, + flow_dict::Dict{Tuple{NodeID, NodeID}, Int32}, + flow::T, + flow_prev::Vector{Float64}, + flow_integrated::Vector{Float64}, + saveat::Float64, + }, + MetaGraphsNext.var"#11#13", + Float64, +} where {T} + # TODO Automatically add all nodetypes here struct Parameters{T, C1, C2, V1, V2, V3} starttime::DateTime - graph::MetaGraph{ - Int64, - DiGraph{Int64}, - NodeID, - NodeMetadata, - EdgeMetadata, - @NamedTuple{ - node_ids::Dict{Int32, Set{NodeID}}, - edges_source::Dict{Int32, Set{EdgeMetadata}}, - flow_dict::Dict{Tuple{NodeID, NodeID}, Int32}, - flow::T, - flow_prev::Vector{Float64}, - flow_integrated::Vector{Float64}, - saveat::Float64, - }, - MetaGraphsNext.var"#11#13", - Float64, - } + graph::ModelGraph{T} allocation::Allocation basin::Basin{T, C1, V1, V2, V3} linear_resistance::LinearResistance diff --git a/core/src/read.jl b/core/src/read.jl index 71072a2b8..97f5eb936 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -10,7 +10,7 @@ than one row in a table, as is the case for TabulatedRatingCurve. function parse_static_and_time( db::DB, config::Config, - nodetype::String; + node_type::Type; static::Union{StructVector, Nothing} = nothing, time::Union{StructVector, Nothing} = nothing, defaults::NamedTuple = (; active = true), @@ -32,7 +32,8 @@ function parse_static_and_time( # of the current type vals_out = [] - node_ids = NodeID.(nodetype, get_ids(db, nodetype)) + node_type_string = split(string(node_type), '.')[end] + node_ids = NodeID.(node_type_string, get_ids(db, node_type_string)) n_nodes = length(node_ids) # Initialize the vectors for the output @@ -59,9 +60,10 @@ function parse_static_and_time( push!(keys_out, :node_id) push!(vals_out, node_ids) - # The control mapping is a dictionary with keys (node_id, control_state) to a named tuple of + # The control mapping is a dictionary with keys (node_id, control_state) to a named tuple of parameter values # parameter values to be assigned to the node with this node_id in the case of this control_state - control_mapping = Dict{Tuple{NodeID, String}, NamedTuple}() + control_state_type = get_control_state_type(node_type) + control_mapping = Dict{Tuple{NodeID, String}, control_state_type}() push!(keys_out, :control_mapping) push!(vals_out, control_mapping) @@ -78,7 +80,7 @@ function parse_static_and_time( static_node_id_vec = NodeID[] static_node_ids = Set{NodeID}() else - static_node_id_vec = NodeID.(nodetype, static.node_id) + static_node_id_vec = NodeID.(node_type_string, static.node_id) static_node_ids = Set(static_node_id_vec) end @@ -87,7 +89,7 @@ function parse_static_and_time( time_node_id_vec = NodeID[] time_node_ids = Set{NodeID}() else - time_node_id_vec = NodeID.(nodetype, time.node_id) + time_node_id_vec = NodeID.(node_type_string, time.node_id) time_node_ids = Set(time_node_id_vec) end @@ -127,8 +129,24 @@ function parse_static_and_time( end # Add the parameter values to the control mapping control_state_key = coalesce(control_state, "") - control_mapping[(node_id, control_state_key)] = - NamedTuple{Tuple(parameter_names)}(Tuple(parameter_values)) + controllable_mask = + collect(parameter_names .∈ Ref(fieldnames(control_state_type))) + if any(controllable_mask) + node_idx = searchsortedfirst(node_ids, node_id) + controllable_parameter_names = + collect(parameter_names[controllable_mask]) + controllable_parameter_values = + collect(parameter_values[controllable_mask]) + pushfirst!(controllable_parameter_names, :node_idx) + pushfirst!( + controllable_parameter_values, + searchsortedfirst(node_ids, node_id), + ) + control_mapping[(node_id, control_state_key)] = + NamedTuple{Tuple(controllable_parameter_names)}( + Tuple(controllable_parameter_values), + ) + end end elseif node_id in time_node_ids # TODO replace (time, node_id) order by (node_id, time) @@ -234,7 +252,7 @@ function LinearResistance(db::DB, config::Config, graph::MetaGraph)::LinearResis static = load_structvector(db, config, LinearResistanceStaticV1) defaults = (; max_flow_rate = Inf, active = true) parsed_parameters, valid = - parse_static_and_time(db, config, "LinearResistance"; static, defaults) + parse_static_and_time(db, config, LinearResistance; static, defaults) if !valid error( @@ -273,11 +291,14 @@ function TabulatedRatingCurve( end interpolations = ScalarInterpolation[] - control_mapping = Dict{Tuple{NodeID, String}, NamedTuple}() + control_mapping = Dict{ + Tuple{NodeID, String}, + @NamedTuple{node_idx::Int, active::Bool, table::ScalarInterpolation} + }() active = BitVector() errors = false - for node_id in node_ids + for (i, node_id) in enumerate(node_ids) if node_id in static_node_ids # Loop over all static rating curves (groups) with this node_id. # If it has a control_state add it to control_mapping. @@ -303,7 +324,7 @@ function TabulatedRatingCurve( control_mapping[( NodeID(NodeType.TabulatedRatingCurve, node_id), control_state, - )] = (; tables = interpolation, active = is_active) + )] = (; node_idx = i, active = is_active, table = interpolation) end end push!(interpolations, interpolation) @@ -328,7 +349,6 @@ function TabulatedRatingCurve( if errors error("Errors occurred when parsing TabulatedRatingCurve data.") end - return TabulatedRatingCurve( node_ids, inflow_edge.(Ref(graph), node_ids), @@ -347,8 +367,7 @@ function ManningResistance( basin::Basin, )::ManningResistance static = load_structvector(db, config, ManningResistanceStaticV1) - parsed_parameters, valid = - parse_static_and_time(db, config, "ManningResistance"; static) + parsed_parameters, valid = parse_static_and_time(db, config, ManningResistance; static) if !valid error("Errors occurred when parsing ManningResistance data.") @@ -375,7 +394,7 @@ end function FractionalFlow(db::DB, config::Config, graph::MetaGraph)::FractionalFlow static = load_structvector(db, config, FractionalFlowStaticV1) - parsed_parameters, valid = parse_static_and_time(db, config, "FractionalFlow"; static) + parsed_parameters, valid = parse_static_and_time(db, config, FractionalFlow; static) if !valid error("Errors occurred when parsing FractionalFlow data.") @@ -403,14 +422,8 @@ function LevelBoundary(db::DB, config::Config)::LevelBoundary end time_interpolatables = [:level] - parsed_parameters, valid = parse_static_and_time( - db, - config, - "LevelBoundary"; - static, - time, - time_interpolatables, - ) + parsed_parameters, valid = + parse_static_and_time(db, config, LevelBoundary; static, time, time_interpolatables) if !valid error("Errors occurred when parsing LevelBoundary data.") @@ -419,7 +432,7 @@ function LevelBoundary(db::DB, config::Config)::LevelBoundary return LevelBoundary(node_ids, parsed_parameters.active, parsed_parameters.level) end -function FlowBoundary(db::DB, config::Config)::FlowBoundary +function FlowBoundary(db::DB, config::Config, graph::MetaGraph)::FlowBoundary static = load_structvector(db, config, FlowBoundaryStaticV1) time = load_structvector(db, config, FlowBoundaryTimeV1) @@ -430,14 +443,8 @@ function FlowBoundary(db::DB, config::Config)::FlowBoundary end time_interpolatables = [:flow_rate] - parsed_parameters, valid = parse_static_and_time( - db, - config, - "FlowBoundary"; - static, - time, - time_interpolatables, - ) + parsed_parameters, valid = + parse_static_and_time(db, config, FlowBoundary; static, time, time_interpolatables) for itp in parsed_parameters.flow_rate if any(itp.u .< 0.0) @@ -452,13 +459,18 @@ function FlowBoundary(db::DB, config::Config)::FlowBoundary error("Errors occurred when parsing FlowBoundary data.") end - return FlowBoundary(node_ids, parsed_parameters.active, parsed_parameters.flow_rate) + return FlowBoundary( + node_ids, + [collect(outflow_ids(graph, id)) for id in node_ids], + parsed_parameters.active, + parsed_parameters.flow_rate, + ) end function Pump(db::DB, config::Config, graph::MetaGraph, chunk_sizes::Vector{Int})::Pump static = load_structvector(db, config, PumpStaticV1) defaults = (; min_flow_rate = 0.0, max_flow_rate = Inf, active = true) - parsed_parameters, valid = parse_static_and_time(db, config, "Pump"; static, defaults) + parsed_parameters, valid = parse_static_and_time(db, config, Pump; static, defaults) is_pid_controlled = falses(length(NodeID.(NodeType.Pump, parsed_parameters.node_id))) if !valid @@ -491,7 +503,7 @@ function Outlet(db::DB, config::Config, graph::MetaGraph, chunk_sizes::Vector{In static = load_structvector(db, config, OutletStaticV1) defaults = (; min_flow_rate = 0.0, max_flow_rate = Inf, min_crest_level = -Inf, active = true) - parsed_parameters, valid = parse_static_and_time(db, config, "Outlet"; static, defaults) + parsed_parameters, valid = parse_static_and_time(db, config, Outlet; static, defaults) is_pid_controlled = falses(length(NodeID.(NodeType.Outlet, parsed_parameters.node_id))) if !valid @@ -596,19 +608,16 @@ function Basin(db::DB, config::Config, graph::MetaGraph, chunk_sizes::Vector{Int end function parse_variables_and_conditions(compound_variable, condition) - node_id = NodeID[] - listen_node_id = Vector{NodeID}[] - variable = Vector{String}[] - weight = Vector{Float64}[] - look_ahead = Vector{Float64}[] - greater_than = Vector{Float64}[] - condition_value = BitVector[] + compound_variables = Vector{CompoundVariable}[] errors = false # Loop over unique discrete_control node IDs (on which at least one condition is defined) for id in unique(condition.node_id) condition_group_id = filter(row -> row.node_id == id, condition) variable_group_id = filter(row -> row.node_id == id, compound_variable) + + compound_variables_node = CompoundVariable[] + # Loop over compound variables for this node ID for compound_variable_id in unique(condition_group_id.compound_variable_id) condition_group_variable = filter( @@ -624,47 +633,39 @@ function parse_variables_and_conditions(compound_variable, condition) errors = true @error "compound_variable_id $compound_variable_id for $discrete_control_id in condition table but not in variable table" else - push!(node_id, discrete_control_id) - push!( - listen_node_id, - NodeID.( - variable_group_variable.listen_node_type, - variable_group_variable.listen_node_id, - ), - ) - push!(variable, variable_group_variable.variable) - push!(weight, coalesce.(variable_group_variable.weight, 1.0)) - push!(look_ahead, coalesce.(variable_group_variable.look_ahead, 0.0)) - push!(greater_than, condition_group_variable.greater_than) + greater_than = condition_group_variable.greater_than + + # Collect subvariable data for this compound variable in + # NamedTuples + subvariables = NamedTuple[] + for i in eachindex(variable_group_variable.variable) + listen_node_id = + NodeID.( + variable_group_variable.listen_node_type[i], + variable_group_variable.listen_node_id[i], + ) + variable = variable_group_variable.variable[i] + weight = coalesce.(variable_group_variable.weight[i], 1.0) + look_ahead = coalesce.(variable_group_variable.look_ahead[i], 0.0) + push!(subvariables, (; listen_node_id, variable, weight, look_ahead)) + end + push!( - condition_value, - BitVector(zeros(length(condition_group_variable.greater_than))), + compound_variables_node, + CompoundVariable(discrete_control_id, subvariables, greater_than), ) end end + push!(compound_variables, compound_variables_node) end - return node_id, - listen_node_id, - variable, - weight, - look_ahead, - greater_than, - condition_value, - !errors + return compound_variables, !errors end function DiscreteControl(db::DB, config::Config)::DiscreteControl condition = load_structvector(db, config, DiscreteControlConditionV1) compound_variable = load_structvector(db, config, DiscreteControlVariableV1) - node_id, - listen_node_id, - variable, - weight, - look_ahead, - greater_than, - condition_value, - valid = parse_variables_and_conditions(compound_variable, condition) + compound_variables, valid = parse_variables_and_conditions(compound_variable, condition) if !valid error("Problems encountered when parsing DiscreteControl variables and conditions.") @@ -698,14 +699,19 @@ function DiscreteControl(db::DB, config::Config)::DiscreteControl control_state = String[], ) + node_id = + [first(compound_variables_).node_id for compound_variables_ in compound_variables] + + truth_state = Vector{Bool}[] + for i in eachindex(node_id) + truth_state_length = sum(length(var.greater_than) for var in compound_variables[i]) + push!(truth_state, zeros(Bool, truth_state_length)) + end + return DiscreteControl( - node_id, # Not unique - listen_node_id, - variable, - weight, - look_ahead, - greater_than, - condition_value, + node_id, + compound_variables, + truth_state, control_state, logic_mapping, record, @@ -724,7 +730,7 @@ function PidControl(db::DB, config::Config, chunk_sizes::Vector{Int})::PidContro time_interpolatables = [:target, :proportional, :integral, :derivative] parsed_parameters, valid = - parse_static_and_time(db, config, "PidControl"; static, time, time_interpolatables) + parse_static_and_time(db, config, PidControl; static, time, time_interpolatables) if !valid error("Errors occurred when parsing PidControl data.") @@ -735,39 +741,14 @@ function PidControl(db::DB, config::Config, chunk_sizes::Vector{Int})::PidContro if config.solver.autodiff pid_error = DiffCache(pid_error, chunk_sizes) end - - # Combine PID parameters into one vector interpolation object - pid_parameters = VectorInterpolation[] - (; proportional, integral, derivative) = parsed_parameters - - for i in eachindex(node_ids) - times = proportional[i].t - K_p = proportional[i].u - K_i = integral[i].u - K_d = derivative[i].u - - itp = LinearInterpolation(collect.(zip(K_p, K_i, K_d)), times) - push!(pid_parameters, itp) - end - - for (key, params) in parsed_parameters.control_mapping - (; proportional, integral, derivative) = params - - times = params.proportional.t - K_p = proportional.u - K_i = integral.u - K_d = derivative.u - pid_params = LinearInterpolation(collect.(zip(K_p, K_i, K_d)), times) - parsed_parameters.control_mapping[key] = - (; params.target, params.active, pid_params) - end - return PidControl( node_ids, BitVector(parsed_parameters.active), NodeID.(parsed_parameters.listen_node_type, parsed_parameters.listen_node_id), parsed_parameters.target, - pid_parameters, + parsed_parameters.proportional, + parsed_parameters.integral, + parsed_parameters.derivative, pid_error, parsed_parameters.control_mapping, ) @@ -929,7 +910,7 @@ function LevelDemand(db::DB, config::Config)::LevelDemand parsed_parameters, valid = parse_static_and_time( db, config, - "LevelDemand"; + LevelDemand; static, time, time_interpolatables = [:min_level, :max_level], @@ -955,7 +936,7 @@ function FlowDemand(db::DB, config::Config)::FlowDemand parsed_parameters, valid = parse_static_and_time( db, config, - "FlowDemand"; + FlowDemand; static, time, time_interpolatables = [:demand], @@ -1121,7 +1102,7 @@ function Parameters(db::DB, config::Config)::Parameters tabulated_rating_curve = TabulatedRatingCurve(db, config, graph) fractional_flow = FractionalFlow(db, config, graph) level_boundary = LevelBoundary(db, config) - flow_boundary = FlowBoundary(db, config) + flow_boundary = FlowBoundary(db, config, graph) pump = Pump(db, config, graph, chunk_sizes) outlet = Outlet(db, config, graph, chunk_sizes) terminal = Terminal(db, config) diff --git a/core/src/solve.jl b/core/src/solve.jl index ea20cd74d..12ece26e1 100644 --- a/core/src/solve.jl +++ b/core/src/solve.jl @@ -120,7 +120,8 @@ function continuous_control!( max_flow_rate_pump = pump.max_flow_rate min_flow_rate_outlet = outlet.min_flow_rate max_flow_rate_outlet = outlet.max_flow_rate - (; node_id, active, target, pid_params, listen_node_id, error) = pid_control + (; node_id, active, target, proportional, integral, derivative, listen_node_id, error) = + pid_control (; current_area) = basin current_area = get_tmp(current_area, u) @@ -185,7 +186,9 @@ function continuous_control!( factor = factor_basin * factor_outlet flow_rate = 0.0 - K_p, K_i, K_d = pid_params[i](t) + K_p = proportional[i](t) + K_i = integral[i](t) + K_d = derivative[i](t) if !iszero(K_d) # dlevel/dstorage = 1/area @@ -389,7 +392,7 @@ function formulate_flow!( t::Number, )::Nothing (; graph) = p - (; node_id, active, tables, inflow_edge, outflow_edges) = tabulated_rating_curve + (; node_id, active, table, inflow_edge, outflow_edges) = tabulated_rating_curve for (i, id) in enumerate(node_id) upstream_edge = inflow_edge[i] @@ -400,7 +403,7 @@ function formulate_flow!( factor = low_storage_factor(storage, upstream_edge, upstream_basin_id, 10.0) q = factor * - tables[i](get_level(p, upstream_edge, upstream_basin_id, t; storage)[2]) + table[i](get_level(p, upstream_edge, upstream_basin_id, t; storage)[2]) else q = 0.0 end @@ -548,11 +551,11 @@ function formulate_flow!( t::Number, )::Nothing (; graph) = p - (; node_id, active, flow_rate) = flow_boundary + (; node_id, active, flow_rate, outflow_ids) = flow_boundary for (i, id) in enumerate(node_id) # Requirement: edge points away from the flow boundary - for outflow_id in outflow_ids(graph, id) + for outflow_id in outflow_ids[i] if !active[i] continue end @@ -665,11 +668,8 @@ function formulate_du!( # loop over basins # subtract all outgoing flows # add all ingoing flows - for edge_metadata in values(graph.edge_data) - (; type, edge, basin_idx_src, basin_idx_dst) = edge_metadata - if type !== EdgeType.flow - continue - end + for edge_metadata in values(graph[].flow_edges) + (; edge, basin_idx_src, basin_idx_dst) = edge_metadata from_id, to_id = edge if from_id.type == NodeType.Basin diff --git a/core/src/util.jl b/core/src/util.jl index 34bc01454..5cd7db424 100644 --- a/core/src/util.jl +++ b/core/src/util.jl @@ -405,8 +405,8 @@ all possible explicit truth states. """ function expand_logic_mapping( logic_mapping::Dict{Tuple{NodeID, String}, String}, -)::Dict{Tuple{NodeID, String}, String} - logic_mapping_expanded = Dict{Tuple{NodeID, String}, String}() +)::Dict{Tuple{NodeID, Vector{Bool}}, String} + logic_mapping_expanded = Dict{Tuple{NodeID, Vector{Bool}}, String}() for (node_id, truth_state) in keys(logic_mapping) pattern = r"^[TF\*]+$" @@ -418,23 +418,23 @@ function expand_logic_mapping( n_wildcards = count(==('*'), truth_state) substitutions = if n_wildcards > 0 - substitutions = Iterators.product(fill(['T', 'F'], n_wildcards)...) + substitutions = Iterators.product(fill([true, false], n_wildcards)...) else [nothing] end # Loop over all substitution sets for the wildcards for substitution in substitutions - truth_state_new = "" + truth_state_new = Bool[] s_index = 0 # If a wildcard is found replace it, otherwise take the old truth value for truth_value in truth_state - truth_state_new *= if truth_value == '*' + if truth_value == '*' s_index += 1 - substitution[s_index] + push!(truth_state_new, substitution[s_index]) else - truth_value + push!(truth_state_new, truth_value == 'T') end end @@ -737,6 +737,16 @@ function set_basin_idxs!(graph::MetaGraph, basin::Basin)::Nothing @set edge_metadata.basin_idx_dst = id_index(basin.node_id, id_dst)[2] graph[edge...] = edge_metadata end + + # Collect the flow edges. This significantly speeds up + # formulate_du! + append!( + graph[].flow_edges, + filter( + edge_metadata -> edge_metadata.type == EdgeType.flow, + collect(values(graph.edge_data)), + ), + ) return nothing end @@ -793,3 +803,24 @@ function set_initial_allocation_mean_flows!(integrator)::Nothing return nothing end + +""" +Convert a truth state in terms of a BitVector of Vector{Bool} into a string of 'T' and 'F' +""" +function convert_truth_state(boolean_vector)::String + String(UInt8.(ifelse.(boolean_vector, 'T', 'F'))) +end + +""" +Given the type of a node struct, e.g. Pump, get the type of the values of the +control_mapping dict if this field exists. +""" +function get_control_state_type(node_type::Type)::Type + control_mapping_index = findfirst(==(:control_mapping), fieldnames(node_type)) + if !isnothing(control_mapping_index) + control_mapping_type = fieldtypes(node_type)[control_mapping_index] + control_state_type = eltype(fieldtypes(control_mapping_type)[3]) + return control_state_type + end + return Nothing +end diff --git a/core/src/validation.jl b/core/src/validation.jl index 1b5ac809c..322fea30a 100644 --- a/core/src/validation.jl +++ b/core/src/validation.jl @@ -239,7 +239,7 @@ Test whether static or discrete controlled flow rates are indeed non-negative. function valid_flow_rates( node_id::Vector{NodeID}, flow_rate::Vector, - control_mapping::Dict{Tuple{NodeID, String}, NamedTuple}, + control_mapping::Dict, )::Bool errors = false @@ -247,10 +247,11 @@ function valid_flow_rates( # if their initial value is also invalid. ids_controlled = NodeID[] - for (key, control_values) in pairs(control_mapping) + for (key, parameter_update) in pairs(control_mapping) id_controlled = key[1] push!(ids_controlled, id_controlled) - flow_rate_ = get(control_values, :flow_rate, 1) + flow_rate_ = parameter_update.flow_rate + flow_rate_ = isnan(flow_rate_) ? 1.0 : flow_rate_ if flow_rate_ < 0.0 errors = true @@ -311,7 +312,7 @@ outneighbor, that the fractions leaving a node add up to ≈1 and that the fract function valid_fractional_flow( graph::MetaGraph, node_id::Vector{NodeID}, - control_mapping::Dict{Tuple{NodeID, String}, NamedTuple}, + control_mapping::Dict, )::Bool errors = false @@ -552,21 +553,23 @@ Check: """ function valid_discrete_control(p::Parameters, config::Config)::Bool (; discrete_control, graph) = p - (; node_id, logic_mapping, look_ahead, variable, listen_node_id, greater_than) = - discrete_control + (; node_id, logic_mapping) = discrete_control t_end = seconds_since(config.endtime, config.starttime) errors = false - for id in unique(node_id) + for (id, compound_variables) in zip(node_id, discrete_control.compound_variables) # The control states of this DiscreteControl node control_states_discrete_control = Set{String}() # The truth states of this DiscreteControl node with the wrong length - truth_states_wrong_length = String[] + truth_states_wrong_length = Vector{Bool}[] # The number of conditions of this DiscreteControl node - n_conditions = sum(length(greater_than[i]) for i in searchsorted(node_id, id)) + n_conditions = sum( + length(compound_variable.greater_than) for + compound_variable in compound_variables + ) for (key, control_state) in logic_mapping id_, truth_state = key @@ -582,7 +585,7 @@ function valid_discrete_control(p::Parameters, config::Config)::Bool if !isempty(truth_states_wrong_length) errors = true - @error "$id has $n_conditions condition(s), which is inconsistent with these truth state(s): $truth_states_wrong_length." + @error "$id has $n_conditions condition(s), which is inconsistent with these truth state(s): $(convert_truth_state.(truth_states_wrong_length))." end # Check whether these control states are defined for the @@ -612,30 +615,32 @@ function valid_discrete_control(p::Parameters, config::Config)::Bool errors = true end end - end - for (look_aheads, variables, listen_node_ids) in - zip(look_ahead, variable, listen_node_id) - for (Δt, var, node_id) in zip(look_aheads, variables, listen_node_ids) - if !iszero(Δt) - node_type = node_id.type - if node_type ∉ [NodeType.FlowBoundary, NodeType.LevelBoundary] - errors = true - @error "Look ahead supplied for non-timeseries listen variable '$var' from listen node $node_id." - else - if Δt < 0 + + # Validate look_ahead + for compound_variable in compound_variables + for subvariable in compound_variable.subvariables + if !iszero(subvariable.look_ahead) + node_type = subvariable.listen_node_id.type + if node_type ∉ [NodeType.FlowBoundary, NodeType.LevelBoundary] errors = true - @error "Negative look ahead supplied for listen variable '$var' from listen node $node_id." + @error "Look ahead supplied for non-timeseries listen variable '$(subvariable.variable)' from listen node $(subvariable.listen_node_id)." else - node = getfield(p, graph[node_id].type) - idx = if node_type == NodeType.Basin - id_index(node.node_id, node_id) - else - searchsortedfirst(node.node_id, node_id) - end - interpolation = getfield(node, Symbol(var))[idx] - if t_end + Δt > interpolation.t[end] + if subvariable.look_ahead < 0 errors = true - @error "Look ahead for listen variable '$var' from listen node $node_id goes past timeseries end during simulation." + @error "Negative look ahead supplied for listen variable '$(subvariable.variable)' from listen node $(subvariable.listen_node_id)." + else + node = getfield(p, graph[subvariable.listen_node_id].type) + idx = if node_type == NodeType.Basin + id_index(node.node_id, subvariable.listen_node_id) + else + searchsortedfirst(node.node_id, subvariable.listen_node_id) + end + interpolation = + getfield(node, Symbol(subvariable.variable))[idx] + if t_end + subvariable.look_ahead > interpolation.t[end] + errors = true + @error "Look ahead for listen variable '$(subvariable.variable)' from listen node $(subvariable.listen_node_id) goes past timeseries end during simulation." + end end end end diff --git a/core/test/control_test.jl b/core/test/control_test.jl index 1eabc4f1a..ec9a2dab5 100644 --- a/core/test/control_test.jl +++ b/core/test/control_test.jl @@ -17,13 +17,13 @@ @test pump_control_mapping[(NodeID(:Pump, 4), "off")].flow_rate == 0 @test pump_control_mapping[(NodeID(:Pump, 4), "on")].flow_rate == 1.0e-5 - logic_mapping::Dict{Tuple{NodeID, String}, String} = Dict( - (NodeID(:DiscreteControl, 5), "TT") => "on", - (NodeID(:DiscreteControl, 6), "F") => "active", - (NodeID(:DiscreteControl, 5), "TF") => "off", - (NodeID(:DiscreteControl, 5), "FF") => "on", - (NodeID(:DiscreteControl, 5), "FT") => "off", - (NodeID(:DiscreteControl, 6), "T") => "inactive", + logic_mapping::Dict{Tuple{NodeID, Vector{Bool}}, String} = Dict( + (NodeID(:DiscreteControl, 5), [true, true]) => "on", + (NodeID(:DiscreteControl, 6), [false]) => "active", + (NodeID(:DiscreteControl, 5), [true, false]) => "off", + (NodeID(:DiscreteControl, 5), [false, false]) => "on", + (NodeID(:DiscreteControl, 5), [false, true]) => "off", + (NodeID(:DiscreteControl, 6), [true]) => "inactive", ) @test discrete_control.logic_mapping == logic_mapping @@ -49,11 +49,11 @@ # Control times t_1 = discrete_control.record.time[3] t_1_index = findfirst(>=(t_1), t) - @test level[1, t_1_index] <= discrete_control.greater_than[1][1] + @test level[1, t_1_index] <= discrete_control.compound_variables[1][1].greater_than[1] t_2 = discrete_control.record.time[4] t_2_index = findfirst(>=(t_2), t) - @test level[2, t_2_index] >= discrete_control.greater_than[2][1] + @test level[2, t_2_index] >= discrete_control.compound_variables[1][2].greater_than[1] flow = get_tmp(graph[].flow, 0) @test all(iszero, flow) @@ -66,13 +66,13 @@ end p = model.integrator.p (; discrete_control, flow_boundary) = p - Δt = discrete_control.look_ahead[1][1] + Δt = discrete_control.compound_variables[1][1].subvariables[1].look_ahead t = Ribasim.tsaves(model) t_control = discrete_control.record.time[2] t_control_index = searchsortedfirst(t, t_control) - greater_than = discrete_control.greater_than[1][1] + greater_than = discrete_control.compound_variables[1][1].greater_than[1] flow_t_control = flow_boundary.flow_rate[1](t_control) flow_t_control_ahead = flow_boundary.flow_rate[1](t_control + Δt) @@ -90,13 +90,13 @@ end p = model.integrator.p (; discrete_control, level_boundary) = p - Δt = discrete_control.look_ahead[1][1] + Δt = discrete_control.compound_variables[1][1].subvariables[1].look_ahead t = Ribasim.tsaves(model) t_control = discrete_control.record.time[2] t_control_index = searchsortedfirst(t, t_control) - greater_than = discrete_control.greater_than[1][1] + greater_than = discrete_control.compound_variables[1][1].greater_than[1] level_t_control = level_boundary.level[1](t_control) level_t_control_ahead = level_boundary.level[1](t_control + Δt) @@ -118,7 +118,8 @@ end t_target_change = target_itp.t[2] idx_target_change = searchsortedlast(t, t_target_change) - K_p, K_i, _ = pid_control.pid_params[2](0) + K_p = pid_control.proportional[2](0) + K_i = pid_control.integral[2](0) level_demand = pid_control.target[2](0) A = basin.area[1][1] @@ -159,7 +160,7 @@ end t = Ribasim.datetime_since(discrete_control.record.time[2], model.config.starttime) @test Date(t) == Date("2020-03-16") # then the rating curve is updated to the "low" control_state - @test last(only(p.tabulated_rating_curve.tables).t) == 1.2 + @test last(only(p.tabulated_rating_curve.table).t) == 1.2 end @testitem "Set PID target with DiscreteControl" begin @@ -199,11 +200,22 @@ end @test ispath(toml_path) model = Ribasim.run(toml_path) (; discrete_control) = model.integrator.p - (; listen_node_id, variable, weight, record) = discrete_control + (; compound_variables, record) = discrete_control - @test listen_node_id == [[NodeID(:FlowBoundary, 2), NodeID(:FlowBoundary, 3)]] - @test variable == [["flow_rate", "flow_rate"]] - @test weight == [[0.5, 0.5]] + compound_variable = only(only(compound_variables)) + + @test compound_variable.subvariables[1] == (; + listen_node_id = NodeID(:FlowBoundary, 2), + variable = "flow_rate", + weight = 0.5, + look_ahead = 0.0, + ) + @test compound_variable.subvariables[2] == (; + listen_node_id = NodeID(:FlowBoundary, 3), + variable = "flow_rate", + weight = 0.5, + look_ahead = 0.0, + ) @test record.time ≈ [0.0, model.integrator.sol.t[end] / 2] @test record.truth_state == ["F", "T"] @test record.control_state == ["Off", "On"] diff --git a/core/test/equations_test.jl b/core/test/equations_test.jl index 95ca06d03..a89198565 100644 --- a/core/test/equations_test.jl +++ b/core/test/equations_test.jl @@ -133,7 +133,9 @@ end storage = Ribasim.get_storages_and_levels(model).storage[:] t = Ribasim.tsaves(model) SP = pid_control.target[1](0) - K_p, K_i, K_d = pid_control.pid_params[1](0) + K_p = pid_control.proportional[1](0) + K_i = pid_control.integral[1](0) + K_d = pid_control.derivative[1](0) storage_min = 50.005 level_min = basin.level[1][2] diff --git a/core/test/run_models_test.jl b/core/test/run_models_test.jl index 3efcd9995..d9e476890 100644 --- a/core/test/run_models_test.jl +++ b/core/test/run_models_test.jl @@ -249,7 +249,7 @@ end Sys.isapple() end -@testitem "allocation example model" begin +@testitem "Allocation example model" begin using SciMLBase: successful_retcode toml_path = @@ -301,7 +301,7 @@ end @test successful_retcode(model) @test model.integrator.sol.u[end] ≈ Float32[7.783636, 726.16394] skip = Sys.isapple() # the highest level in the dynamic table is updated to 1.2 from the callback - @test model.integrator.p.tabulated_rating_curve.tables[end].t[end] == 1.2 + @test model.integrator.p.tabulated_rating_curve.table[end].t[end] == 1.2 end @testitem "Profile" begin diff --git a/core/test/utils_test.jl b/core/test/utils_test.jl index 50bfe9076..3c50b5edf 100644 --- a/core/test/utils_test.jl +++ b/core/test/utils_test.jl @@ -139,11 +139,11 @@ end logic_mapping[(NodeID(:DiscreteControl, 2), "FF")] = "bar" logic_mapping_expanded = Ribasim.expand_logic_mapping(logic_mapping) - @test logic_mapping_expanded[(NodeID(:DiscreteControl, 1), "TTT")] == "foo" - @test logic_mapping_expanded[(NodeID(:DiscreteControl, 1), "FTT")] == "foo" - @test logic_mapping_expanded[(NodeID(:DiscreteControl, 1), "TTF")] == "foo" - @test logic_mapping_expanded[(NodeID(:DiscreteControl, 1), "FTF")] == "foo" - @test logic_mapping_expanded[(NodeID(:DiscreteControl, 2), "FF")] == "bar" + @test logic_mapping_expanded[(NodeID(:DiscreteControl, 1), Bool[1, 1, 1])] == "foo" + @test logic_mapping_expanded[(NodeID(:DiscreteControl, 1), Bool[0, 1, 1])] == "foo" + @test logic_mapping_expanded[(NodeID(:DiscreteControl, 1), Bool[1, 1, 0])] == "foo" + @test logic_mapping_expanded[(NodeID(:DiscreteControl, 1), Bool[0, 1, 0])] == "foo" + @test logic_mapping_expanded[(NodeID(:DiscreteControl, 2), Bool[0, 0])] == "bar" @test length(logic_mapping_expanded) == 5 new_key = (NodeID(:DiscreteControl, 3), "duck") @@ -173,7 +173,7 @@ end new_key = (NodeID(:DiscreteControl, 1), "TTF") logic_mapping[new_key] = "bar" - @test_throws "AssertionError: Multiple control states found for DiscreteControl #1 for truth state `TTF`: [\"bar\", \"foo\"]." Ribasim.expand_logic_mapping( + @test_throws "AssertionError: Multiple control states found for DiscreteControl #1 for truth state `Bool[1, 1, 0]`: [\"bar\", \"foo\"]." Ribasim.expand_logic_mapping( logic_mapping, ) end diff --git a/core/test/validation_test.jl b/core/test/validation_test.jl index 2213e485b..0cf21f10a 100644 --- a/core/test/validation_test.jl +++ b/core/test/validation_test.jl @@ -287,7 +287,7 @@ end @testitem "Pump/outlet flow rate sign validation" begin using Logging - using Ribasim: NodeID + using Ribasim: NodeID, NodeType logger = TestLogger() @@ -322,7 +322,7 @@ end [NaN], [NaN], Dict{Tuple{NodeID, String}, NamedTuple}( - (NodeID(:Pump, 1), "foo") => (; flow_rate = -1.0), + (NodeID(:Pump, 1), "foo") => (; active = true, flow_rate = -1.0), ), [false], )