Skip to content

Commit

Permalink
Split vertical fluxes on basin (#1300)
Browse files Browse the repository at this point in the history
Fixes #661.

---------

Co-authored-by: Martijn Visser <[email protected]>
  • Loading branch information
SouthEndMusic and visr authored Mar 28, 2024
1 parent c56a89a commit db46d1e
Show file tree
Hide file tree
Showing 15 changed files with 360 additions and 246 deletions.
6 changes: 5 additions & 1 deletion core/src/allocation_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,13 @@ function get_basin_data(
node_id::NodeID,
)
(; graph, basin, level_demand) = p
(; vertical_flux) = basin
(; Δt_allocation) = allocation_model
@assert node_id.type == NodeType.Basin
influx = get_flow(graph, node_id, 0.0)
vertical_flux = get_tmp(vertical_flux, 0)
_, basin_idx = id_index(basin.node_id, node_id)
# NOTE: Instantaneous
influx = get_influx(basin, node_id)
_, basin_idx = id_index(basin.node_id, node_id)
storage_basin = u.storage[basin_idx]
control_inneighbors = inneighbor_labels_type(graph, node_id, EdgeType.control)
Expand Down
8 changes: 6 additions & 2 deletions core/src/bmi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ function BMI.get_value_ptr(model::Model, name::AbstractString)
elseif name == "basin.level"
get_tmp(model.integrator.p.basin.current_level, 0)
elseif name == "basin.infiltration"
model.integrator.p.basin.infiltration
get_tmp(model.integrator.p.basin.vertical_flux, 0).infiltration
elseif name == "basin.drainage"
model.integrator.p.basin.drainage
get_tmp(model.integrator.p.basin.vertical_flux, 0).drainage
elseif name == "basin.infiltration_integrated"
model.integrator.p.basin.vertical_flux_bmi.infiltration
elseif name == "basin.drainage_integrated"
model.integrator.p.basin.vertical_flux_bmi.drainage
elseif name == "basin.subgrid_level"
model.integrator.p.subgrid.level
elseif name == "user_demand.demand"
Expand Down
126 changes: 63 additions & 63 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ function create_callbacks(
(; starttime, basin, tabulated_rating_curve, discrete_control) = parameters
callbacks = SciMLBase.DECallback[]

integrating_flows_cb = FunctionCallingCallback(integrate_flows!; func_start = false)
push!(callbacks, integrating_flows_cb)

tstops = get_tstops(basin.time.time, starttime)
basin_cb = PresetTimeCallback(tstops, update_basin; save_positions = (false, false))
push!(callbacks, basin_cb)

integrating_flows_cb = FunctionCallingCallback(integrate_flows!; func_start = false)
push!(callbacks, integrating_flows_cb)

tstops = get_tstops(tabulated_rating_curve.time.time, starttime)
tabulated_rating_curve_cb = PresetTimeCallback(
tstops,
Expand All @@ -61,16 +61,17 @@ function create_callbacks(
push!(callbacks, allocation_cb)
end

# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat
saved_vertical_flux = SavedValues(Float64, typeof(basin.vertical_flux_integrated))
save_vertical_flux_cb =
SavingCallback(save_vertical_flux, saved_vertical_flux; saveat, save_start = false)
push!(callbacks, save_vertical_flux_cb)

# save the flows over time, as a Vector of the nonzeros(flow)
saved_flow = SavedValues(Float64, Vector{Float64})
save_flow_cb = SavingCallback(
save_flow,
saved_flow;
# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat,
save_start = false,
)
save_flow_cb = SavingCallback(save_flow, saved_flow; saveat, save_start = false)
push!(callbacks, save_flow_cb)

# interpolate the levels
Expand All @@ -85,7 +86,7 @@ function create_callbacks(
push!(callbacks, export_cb)
end

saved = SavedResults(saved_flow, saved_subgrid_level)
saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level)

n_conditions = length(discrete_control.node_id)
if n_conditions > 0
Expand All @@ -108,27 +109,20 @@ Integrate flows over the last timestep
"""
function integrate_flows!(u, t, integrator)::Nothing
(; p, dt) = integrator
(; graph, user_demand) = p
(;
flow,
flow_dict,
flow_vertical,
flow_prev,
flow_vertical_prev,
flow_integrated,
flow_vertical_integrated,
) = graph[]
(; graph, user_demand, basin) = p
(; flow, flow_dict, flow_prev, flow_integrated) = graph[]
(; vertical_flux, vertical_flux_prev, vertical_flux_integrated, vertical_flux_bmi) =
basin
flow = get_tmp(flow, 0)
flow_vertical = get_tmp(flow_vertical, 0)

vertical_flux = get_tmp(vertical_flux, 0)
if !isempty(flow_prev) && isnan(flow_prev[1])
# If flow_prev is not populated yet
copyto!(flow_prev, flow)
copyto!(flow_vertical_prev, flow_vertical)
end

@. flow_integrated += 0.5 * (flow + flow_prev) * dt
@. flow_vertical_integrated += 0.5 * (flow_vertical + flow_vertical_prev) * dt
@. vertical_flux_integrated += 0.5 * (vertical_flux + vertical_flux_prev) * dt
@. vertical_flux_bmi += 0.5 * (vertical_flux + vertical_flux_prev) * dt

for (i, id) in enumerate(user_demand.node_id)
src_id = inflow_id(graph, id)
Expand All @@ -137,10 +131,37 @@ function integrate_flows!(u, t, integrator)::Nothing
end

copyto!(flow_prev, flow)
copyto!(flow_vertical_prev, flow_vertical)
copyto!(vertical_flux_prev, vertical_flux)
return nothing
end

"Compute the average flows over the last saveat interval and write
them to SavedValues"
function save_flow(u, t, integrator)
(; flow_integrated) = integrator.p.graph[]

Δt = get_Δt(integrator)
flow_mean = copy(flow_integrated)
flow_mean ./= Δt
fill!(flow_integrated, 0.0)

return flow_mean
end

"Compute the average vertical fluxes over the last saveat interval and write
them to SavedValues"
function save_vertical_flux(u, t, integrator)
(; basin) = integrator.p
(; vertical_flux_integrated) = basin

Δt = get_Δt(integrator)
vertical_flux_mean = copy(vertical_flux_integrated)
vertical_flux_mean ./= Δt
fill!(vertical_flux_integrated, 0.0)

return vertical_flux_mean
end

"""
Listens for changes in condition truths.
"""
Expand Down Expand Up @@ -430,36 +451,6 @@ function set_control_params!(p::Parameters, node_id::NodeID, control_state::Stri
end
end

"Compute the average flows over the last saveat interval and write
them to SavedValues"
function save_flow(u, t, integrator)
(; dt, p) = integrator
(; graph) = p
(; flow_integrated, flow_vertical_integrated, saveat) = graph[]

Δt = if iszero(saveat)
dt
elseif isinf(saveat)
t
else
t_end = integrator.sol.prob.tspan[2]
if t_end - t > saveat
saveat
else
# The last interval might be shorter than saveat
rem = t % saveat
iszero(rem) ? saveat : rem
end
end

mean_flow_all = vcat(flow_vertical_integrated, flow_integrated)
mean_flow_all ./= Δt
fill!(flow_vertical_integrated, 0.0)
fill!(flow_integrated, 0.0)

return mean_flow_all
end

function update_subgrid_level!(integrator)::Nothing
basin_level = get_tmp(integrator.p.basin.current_level, 0)
subgrid = integrator.p.subgrid
Expand All @@ -476,18 +467,21 @@ end

"Load updates from 'Basin / time' into the parameters"
function update_basin(integrator)::Nothing
(; basin) = integrator.p
(; node_id, time) = basin
(; p, u) = integrator
(; basin) = p
(; storage) = u
(; node_id, time, vertical_flux_from_input, vertical_flux, vertical_flux_prev) = basin
t = datetime_since(integrator.t, integrator.p.starttime)
vertical_flux = get_tmp(vertical_flux, integrator.u)

rows = searchsorted(time.time, t)
timeblock = view(time, rows)

table = (;
basin.precipitation,
basin.potential_evaporation,
basin.drainage,
basin.infiltration,
vertical_flux_from_input.precipitation,
vertical_flux_from_input.potential_evaporation,
vertical_flux_from_input.drainage,
vertical_flux_from_input.infiltration,
)

for row in timeblock
Expand All @@ -496,6 +490,12 @@ function update_basin(integrator)::Nothing
set_table_row!(table, row, i)
end

for (i, id) in enumerate(basin.node_id)
update_vertical_flux!(basin, storage, i)
end

# Forget about vertical fluxes to handle discontinuous forcing from basin_update
copyto!(vertical_flux_prev, vertical_flux)
return nothing
end

Expand Down
42 changes: 0 additions & 42 deletions core/src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
flow_counter = 0
# Dictionary from flow edge to index in flow vector
flow_dict = Dict{Tuple{NodeID, NodeID}, Int}()
# The number of nodes with vertical flow (interaction with outside of model)
flow_vertical_counter = 0
# Dictionary from node ID to index in vertical flow vector
flow_vertical_dict = Dict{NodeID, Int}()
graph = MetaGraph(
DiGraph();
label_type = NodeID,
Expand All @@ -49,10 +45,6 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
end
graph[node_id] =
NodeMetadata(Symbol(snake_case(row.node_type)), allocation_network_id)
if row.node_type in nonconservative_nodetypes
flow_vertical_counter += 1
flow_vertical_dict[node_id] = flow_vertical_counter
end
end

errors = false
Expand Down Expand Up @@ -105,12 +97,8 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
flow = zeros(flow_counter)
flow_prev = fill(NaN, flow_counter)
flow_integrated = zeros(flow_counter)
flow_vertical = zeros(flow_vertical_counter)
flow_vertical_prev = fill(NaN, flow_vertical_counter)
flow_vertical_integrated = zeros(flow_vertical_counter)
if config.solver.autodiff
flow = DiffCache(flow, chunk_sizes)
flow_vertical = DiffCache(flow_vertical, chunk_sizes)
end
graph_data = (;
node_ids,
Expand All @@ -120,10 +108,6 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
flow,
flow_prev,
flow_integrated,
flow_vertical_dict,
flow_vertical,
flow_vertical_prev,
flow_vertical_integrated,
config.solver.saveat,
)
graph = @set graph.graph_data = graph_data
Expand Down Expand Up @@ -195,15 +179,6 @@ function set_flow!(graph::MetaGraph, id_src::NodeID, id_dst::NodeID, q::Number):
return nothing
end

"""
Set the given flow q on the horizontal (self-loop) edge from id to id.
"""
function set_flow!(graph::MetaGraph, id::NodeID, q::Number)::Nothing
(; flow_vertical_dict, flow_vertical) = graph[]
get_tmp(flow_vertical, q)[flow_vertical_dict[id]] = q
return nothing
end

"""
Add the given flow q to the existing flow over the edge between the given nodes.
"""
Expand All @@ -213,15 +188,6 @@ function add_flow!(graph::MetaGraph, id_src::NodeID, id_dst::NodeID, q::Number):
return nothing
end

"""
Add the given flow q to the flow over the edge on the horizontal (self-loop) edge from id to id.
"""
function add_flow!(graph::MetaGraph, id::NodeID, q::Number)::Nothing
(; flow_vertical_dict, flow_vertical) = graph[]
get_tmp(flow_vertical, q)[flow_vertical_dict[id]] += q
return nothing
end

"""
Get the flow over the given edge (val is needed for get_tmp from ForwardDiff.jl).
"""
Expand All @@ -230,14 +196,6 @@ function get_flow(graph::MetaGraph, id_src::NodeID, id_dst::NodeID, val)::Number
return get_tmp(flow, val)[flow_dict[id_src, id_dst]]
end

"""
Get the flow over the given horizontal (selfloop) edge (val is needed for get_tmp from ForwardDiff.jl).
"""
function get_flow(graph::MetaGraph, id::NodeID, val)::Number
(; flow_vertical_dict, flow_vertical) = graph[]
return get_tmp(flow_vertical, val)[flow_vertical_dict[id]]
end

"""
Get the inneighbor node IDs of the given node ID (label)
over the given edge type in the graph.
Expand Down
3 changes: 2 additions & 1 deletion core/src/model.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
struct SavedResults
struct SavedResults{V1 <: ComponentVector{Float64}}
flow::SavedValues{Float64, Vector{Float64}}
vertical_flux::SavedValues{Float64, V1}
subgrid_level::SavedValues{Float64, Vector{Float64}}
end

Expand Down
Loading

0 comments on commit db46d1e

Please sign in to comment.