Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Apr 16, 2024
1 parent 17d2a15 commit 75b56a0
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 63 deletions.
106 changes: 76 additions & 30 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@ function set_initial_discrete_controlled_parameters!(
(; p) = integrator
(; discrete_control) = p

n_conditions = length(discrete_control.condition_value)
n_conditions = sum(length(vec) for vec in discrete_control.condition_value)
condition_diffs = zeros(Float64, n_conditions)
discrete_control_condition(condition_diffs, storage0, integrator.t, integrator)
discrete_control.condition_value .= (condition_diffs .> 0.0)

idx_start = 1
for (i, vec) in enumerate(discrete_control.condition_value)
l = length(vec)
idx_end = idx_start + l - 1
discrete_control.condition_value[i] .= (condition_diffs[idx_start:idx_end] .> 0)
idx_start += l
end

# For every discrete_control node find a condition_idx it listens to
for discrete_control_node_id in unique(discrete_control.node_id)
Expand Down Expand Up @@ -88,7 +95,7 @@ function create_callbacks(

saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level)

n_conditions = length(discrete_control.node_id)
n_conditions = sum(length(vec) for vec in discrete_control.greater_than; init = 0)
if n_conditions > 0
discrete_control_cb = VectorContinuousCallback(
discrete_control_condition,
Expand Down Expand Up @@ -193,23 +200,25 @@ Listens for changes in condition truths.
function discrete_control_condition(out, u, t, integrator)
(; p) = integrator
(; discrete_control) = p

for (i, (listen_node_ids, variables, weights, greater_than, look_aheads)) in enumerate(
zip(
discrete_control.listen_node_id,
discrete_control.variable,
discrete_control.weight,
discrete_control.greater_than,
discrete_control.look_ahead,
),
condition_idx = 0

for (listen_node_ids, variables, weights, greater_thans, look_aheads) in zip(
discrete_control.listen_node_id,
discrete_control.variable,
discrete_control.weight,
discrete_control.greater_than,
discrete_control.look_ahead,
)
value = 0.0
for (listen_node_id, variable, weight, look_ahead) in
zip(listen_node_ids, variables, weights, look_aheads)
value += weight * get_value(p, listen_node_id, variable, look_ahead, u, t)
end
diff = value - greater_than
out[i] = diff
for greater_than in greater_thans
condition_idx += 1
diff = value - greater_than
out[condition_idx] = diff
end
end
end

Expand Down Expand Up @@ -259,6 +268,22 @@ function get_value(
return value
end

function get_discrete_control_indices(discrete_control::DiscreteControl, condition_idx::Int)
(; greater_than) = discrete_control
condition_idx_now = 1

for (compound_variable_idx, vec) in enumerate(greater_than)
l = length(vec)

if condition_idx_now + l > condition_idx
greater_than_idx = condition_idx - condition_idx_now + 1
return compound_variable_idx, greater_than_idx
end

condition_idx_now += l
end
end

"""
An upcrossing means that a condition (always greater than) becomes true.
"""
Expand All @@ -267,7 +292,9 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx)
(; discrete_control, basin) = p
(; variable, condition_value, listen_node_id) = discrete_control

condition_value[condition_idx] = true
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
condition_value[compound_variable_idx][greater_than_idx] = true

control_state_change = discrete_control_affect!(integrator, condition_idx, true)

Expand All @@ -277,23 +304,24 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx)
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
listen_node_ids = discrete_control.listen_node_id[condition_idx]
listen_node_ids = discrete_control.listen_node_id[compound_variable_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

# NOTE: The above no longer works when listen feature ids can be something other than node ids
# I think the more durable option is to give all possible condition types a different variable string,
# e.g. basin.level and level_boundary.level
if variable[condition_idx][1] == "level" && control_state_change && is_basin
if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
_, condition_basin_idx = id_index(basin.node_id, listen_node_id[condition_idx][1])
_, condition_basin_idx =
id_index(basin.node_id, listen_node_id[compound_variable_idx][1])

if du[condition_basin_idx] < 0.0
condition_value[condition_idx] = false
condition_value[compound_variable_idx][greater_than_idx] = false
discrete_control_affect!(integrator, condition_idx, false)
end
end
Expand All @@ -307,7 +335,9 @@ function discrete_control_affect_downcrossing!(integrator, condition_idx)
(; discrete_control, basin) = p
(; variable, condition_value, listen_node_id) = discrete_control

condition_value[condition_idx] = false
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
condition_value[compound_variable_idx][greater_than_idx] = false

control_state_change = discrete_control_affect!(integrator, condition_idx, false)

Expand All @@ -317,21 +347,23 @@ function discrete_control_affect_downcrossing!(integrator, condition_idx)
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
listen_node_ids = discrete_control.listen_node_id[condition_idx]
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
listen_node_ids = discrete_control.listen_node_id[compound_variable_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

if variable[condition_idx][1] == "level" && control_state_change && is_basin
if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
has_index, condition_basin_idx =
id_index(basin.node_id, listen_node_id[condition_idx][1])
id_index(basin.node_id, listen_node_id[compound_variable_idx][1])

if has_index && du[condition_basin_idx] > 0.0
condition_value[condition_idx] = true
condition_value[compound_variable_idx][greater_than_idx] = true
discrete_control_affect!(integrator, condition_idx, true)
end
end
Expand All @@ -349,20 +381,34 @@ function discrete_control_affect!(
(; discrete_control, graph) = p

# Get the discrete_control node that listens to this condition
discrete_control_node_id = discrete_control.node_id[condition_idx]

compound_variable_idx, _ = get_discrete_control_indices(discrete_control, condition_idx)
discrete_control_node_id = discrete_control.node_id[compound_variable_idx]

# Get the indices of all conditions that this control node listens to
condition_ids = discrete_control.node_id .== discrete_control_node_id
where_node_id = searchsorted(discrete_control.node_id, discrete_control_node_id)

# Get the truth state for this discrete_control node
truth_values = [ifelse(b, "T", "F") for b in discrete_control.condition_value]
truth_state = join(truth_values[condition_ids], "")
truth_values = cat(
[
[ifelse(b, "T", "F") for b in discrete_control.condition_value[i]] for
i in where_node_id
]...;
dims = 1,
)
truth_state = join(truth_values, "")

# Get the truth specific about the latest crossing
if !ismissing(upcrossing)
truth_values[condition_idx] = upcrossing ? "U" : "D"
truth_value_idx =
condition_idx - sum(
length(vec) for
vec in discrete_control.condition_value[1:(where_node_id.start - 1)];
init = 0,
)
truth_values[truth_value_idx] = upcrossing ? "U" : "D"
end
truth_state_crossing_specific = join(truth_values[condition_ids], "")
truth_state_crossing_specific = join(truth_values, "")

# What the local control state should be
control_state_new =
Expand Down
7 changes: 5 additions & 2 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -457,12 +457,15 @@ record: Namedtuple with discrete control information for results
"""
struct DiscreteControl <: AbstractParameterNode
node_id::Vector{NodeID}
# Definition of compound variables
listen_node_id::Vector{Vector{NodeID}}
variable::Vector{Vector{String}}
weight::Vector{Vector{Float64}}
look_ahead::Vector{Vector{Float64}}
greater_than::Vector{Float64}
condition_value::Vector{Bool}
# Definition of conditions (one or more greater_than per compound variable)
greater_than::Vector{Vector{Float64}}
condition_value::Vector{BitVector}
# Definition of logic
control_state::Dict{NodeID, Tuple{String, Float64}}
logic_mapping::Dict{Tuple{NodeID, String}, String}
record::@NamedTuple{
Expand Down
80 changes: 63 additions & 17 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,33 +542,79 @@ function Basin(db::DB, config::Config, chunk_sizes::Vector{Int})::Basin
)
end

function DiscreteControl(db::DB, config::Config)::DiscreteControl
compound_variable = load_structvector(db, config, DiscreteControlVariableV1)
condition = load_structvector(db, config, DiscreteControlConditionV1)

function parse_variables_and_conditions(compound_variable, condition)
node_id = NodeID[]
listen_node_id = Vector{NodeID}[]
variable = Vector{String}[]
weight = Vector{Float64}[]
look_ahead = Vector{Float64}[]
greater_than = Vector{Float64}[]
condition_value = BitVector[]
errors = false

for id in unique(condition.node_id)
group_id = filter(row -> row.node_id == id, compound_variable)
for group_variable in
StructVector.(IterTools.groupby(row -> row.compound_variable_id, group_id))
first_row = first(group_variable)
push!(node_id, NodeID(NodeType.DiscreteControl, first_row.node_id))
push!(
listen_node_id,
NodeID.(group_variable.listen_node_type, group_variable.listen_node_id),
condition_group_id = filter(row -> row.node_id == id, condition)
variable_group_id = filter(row -> row.node_id == id, compound_variable)
for compound_variable_id in unique(condition_group_id.compound_variable_id)
condition_group_variable = filter(
row -> row.compound_variable_id == compound_variable_id,
condition_group_id,
)
push!(variable, group_variable.variable)
push!(weight, coalesce.(group_variable.weight, 1.0))
push!(look_ahead, coalesce.(group_variable.look_ahead, 0.0))
variable_group_variable = filter(
row -> row.compound_variable_id == compound_variable_id,
variable_group_id,
)
discrete_control_id = NodeID(NodeType.DiscreteControl, id)
if isempty(variable_group_variable)
errors = true
@error "compound_variable_id $compound_variable_id for $discrete_control_id in condition table but not in variable table"
else
push!(node_id, discrete_control_id)
push!(
listen_node_id,
NodeID.(
variable_group_variable.listen_node_type,
variable_group_variable.listen_node_id,
),
)
push!(variable, variable_group_variable.variable)
push!(weight, coalesce.(variable_group_variable.weight, 1.0))
push!(look_ahead, coalesce.(variable_group_variable.look_ahead, 0.0))
push!(greater_than, condition_group_variable.greater_than)
push!(
condition_value,
BitVector(zeros(length(condition_group_variable.greater_than))),
)
end
end
end
return node_id,
listen_node_id,
variable,
weight,
look_ahead,
greater_than,
condition_value,
!errors
end

function DiscreteControl(db::DB, config::Config)::DiscreteControl
condition = load_structvector(db, config, DiscreteControlConditionV1)
compound_variable = load_structvector(db, config, DiscreteControlVariableV1)

node_id,
listen_node_id,
variable,
weight,
look_ahead,
greater_than,
condition_value,
valid = parse_variables_and_conditions(compound_variable, condition)

if !valid
error("Problems encountered when parsing DiscreteControl variables and conditions.")
end

condition_value = fill(false, length(condition.node_id))
control_state::Dict{NodeID, Tuple{String, Float64}} = Dict()

rows = execute(db, "SELECT from_node_id, edge_type FROM Edge ORDER BY fid")
Expand Down Expand Up @@ -603,7 +649,7 @@ function DiscreteControl(db::DB, config::Config)::DiscreteControl
variable,
weight,
look_ahead,
condition.greater_than,
greater_than,
condition_value,
control_state,
logic_mapping,
Expand Down
12 changes: 5 additions & 7 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,9 @@ sort_by_id_state_level(row) = (row.node_id, row.control_state, row.level)
sort_by_priority(row) = (row.node_id, row.priority)
sort_by_priority_time(row) = (row.node_id, row.priority, row.time)
sort_by_subgrid_level(row) = (row.subgrid_id, row.basin_level)
sort_by_condition(row) =
(row.node_id, row.listen_node_type, row.listen_node_id, row.variable, row.greater_than)
sort_by_variable(row) =
(row.node_id, row.listen_node_type, row.listen_node_id, row.variable)
sort_by_id_greater_than(row) = (row.node_id, row.greater_than)
sort_by_truth_state(row) = (row.node_id, row.truth_state)
sort_by_condition(row) = (row.node_id, row.compound_variable_id, row.greater_than)

# get the right sort by function given the Schema, with sort_by_id as the default
sort_by_function(table::StructVector{<:Legolas.AbstractRecord}) = sort_by_id
Expand All @@ -117,7 +114,7 @@ sort_by_function(table::StructVector{UserDemandStaticV1}) = sort_by_priority
sort_by_function(table::StructVector{UserDemandTimeV1}) = sort_by_priority_time
sort_by_function(table::StructVector{BasinSubgridV1}) = sort_by_subgrid_level
sort_by_function(table::StructVector{DiscreteControlVariableV1}) = sort_by_variable
sort_by_function(table::StructVector{DiscreteControlConditionV1}) = sort_by_id_greater_than
sort_by_function(table::StructVector{DiscreteControlConditionV1}) = sort_by_condition

const TimeSchemas = Union{
BasinTimeV1,
Expand Down Expand Up @@ -499,7 +496,8 @@ Check:
"""
function valid_discrete_control(p::Parameters, config::Config)::Bool
(; discrete_control, graph) = p
(; node_id, logic_mapping, look_ahead, variable, listen_node_id) = discrete_control
(; node_id, logic_mapping, look_ahead, variable, listen_node_id, greater_than) =
discrete_control

t_end = seconds_since(config.endtime, config.starttime)
errors = false
Expand All @@ -512,7 +510,7 @@ function valid_discrete_control(p::Parameters, config::Config)::Bool
truth_states_wrong_length = String[]

# The number of conditions of this DiscreteControl node
n_conditions = length(searchsorted(node_id, id))
n_conditions = sum(length(greater_than[i]) for i in searchsorted(node_id, id))

for (key, control_state) in logic_mapping
id_, truth_state = key
Expand Down
Loading

0 comments on commit 75b56a0

Please sign in to comment.