Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split vertical fluxes on basin #1300

Merged
merged 20 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion core/src/allocation_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,15 @@ function get_basin_data(
node_id::NodeID,
)
(; graph, basin, level_demand) = p
(; vertical_flux) = basin
(; Δt_allocation) = allocation_model
@assert node_id.type == NodeType.Basin
influx = get_flow(graph, node_id, 0.0)
vertical_flux = get_tmp(vertical_flux, 0)
_, basin_idx = id_index(basin.node_id, node_id)
# NOTE: Instantaneous
influx =
SouthEndMusic marked this conversation as resolved.
Show resolved Hide resolved
vertical_flux.precipitation[basin_idx] - vertical_flux.evaporation[basin_idx] +
vertical_flux.drainage[basin_idx] - vertical_flux.infiltration[basin_idx]
_, basin_idx = id_index(basin.node_id, node_id)
storage_basin = u.storage[basin_idx]
control_inneighbors = inneighbor_labels_type(graph, node_id, EdgeType.control)
Expand Down
4 changes: 2 additions & 2 deletions core/src/bmi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ function BMI.get_value_ptr(model::Model, name::AbstractString)
elseif name == "basin.level"
get_tmp(model.integrator.p.basin.current_level, 0)
elseif name == "basin.infiltration"
model.integrator.p.basin.infiltration
get_tmp(model.integrator.p.basin.vertical_flux, 0).infiltration
elseif name == "basin.drainage"
model.integrator.p.basin.drainage
get_tmp(model.integrator.p.basin.vertical_flux, 0).drainage
elseif name == "basin.subgrid_level"
model.integrator.p.subgrid.level
elseif name == "user_demand.demand"
Expand Down
109 changes: 55 additions & 54 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ function create_callbacks(
(; starttime, basin, tabulated_rating_curve, discrete_control) = parameters
callbacks = SciMLBase.DECallback[]

integrating_flows_cb = FunctionCallingCallback(integrate_flows!; func_start = false)
push!(callbacks, integrating_flows_cb)

tstops = get_tstops(basin.time.time, starttime)
basin_cb = PresetTimeCallback(tstops, update_basin; save_positions = (false, false))
push!(callbacks, basin_cb)

integrating_flows_cb = FunctionCallingCallback(integrate_flows!; func_start = false)
push!(callbacks, integrating_flows_cb)

tstops = get_tstops(tabulated_rating_curve.time.time, starttime)
tabulated_rating_curve_cb = PresetTimeCallback(
tstops,
Expand All @@ -61,16 +61,17 @@ function create_callbacks(
push!(callbacks, allocation_cb)
end

# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat
saved_vertical_flux = SavedValues(Float64, typeof(basin.vertical_flux_integrated))
save_vertical_flux_cb =
SavingCallback(save_vertical_flux, saved_vertical_flux; saveat, save_start = false)
push!(callbacks, save_vertical_flux_cb)

# save the flows over time, as a Vector of the nonzeros(flow)
saved_flow = SavedValues(Float64, Vector{Float64})
save_flow_cb = SavingCallback(
save_flow,
saved_flow;
# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat,
save_start = false,
)
save_flow_cb = SavingCallback(save_flow, saved_flow; saveat, save_start = false)
push!(callbacks, save_flow_cb)

# interpolate the levels
Expand All @@ -85,7 +86,7 @@ function create_callbacks(
push!(callbacks, export_cb)
end

saved = SavedResults(saved_flow, saved_subgrid_level)
saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level)

n_conditions = length(discrete_control.node_id)
if n_conditions > 0
Expand All @@ -108,27 +109,18 @@ Integrate flows over the last timestep
"""
function integrate_flows!(u, t, integrator)::Nothing
(; p, dt) = integrator
(; graph, user_demand) = p
(;
flow,
flow_dict,
flow_vertical,
flow_prev,
flow_vertical_prev,
flow_integrated,
flow_vertical_integrated,
) = graph[]
(; graph, user_demand, basin) = p
(; flow, flow_dict, flow_prev, flow_integrated) = graph[]
(; vertical_flux, vertical_flux_prev, vertical_flux_integrated) = basin
flow = get_tmp(flow, 0)
flow_vertical = get_tmp(flow_vertical, 0)

vertical_flux = get_tmp(vertical_flux, 0)
if !isempty(flow_prev) && isnan(flow_prev[1])
# If flow_prev is not populated yet
copyto!(flow_prev, flow)
copyto!(flow_vertical_prev, flow_vertical)
end

@. flow_integrated += 0.5 * (flow + flow_prev) * dt
@. flow_vertical_integrated += 0.5 * (flow_vertical + flow_vertical_prev) * dt
@. vertical_flux_integrated += 0.5 * (vertical_flux + vertical_flux_prev) * dt

for (i, id) in enumerate(user_demand.node_id)
src_id = inflow_id(graph, id)
Expand All @@ -137,7 +129,7 @@ function integrate_flows!(u, t, integrator)::Nothing
end

copyto!(flow_prev, flow)
copyto!(flow_vertical_prev, flow_vertical)
copyto!(vertical_flux_prev, vertical_flux)
return nothing
end

Expand Down Expand Up @@ -433,31 +425,31 @@ end
"Compute the average flows over the last saveat interval and write
them to SavedValues"
function save_flow(u, t, integrator)
(; dt, p) = integrator
(; p) = integrator
(; graph) = p
(; flow_integrated, flow_vertical_integrated, saveat) = graph[]
(; flow_integrated) = graph[]

Δt = if iszero(saveat)
dt
elseif isinf(saveat)
t
else
t_end = integrator.sol.prob.tspan[2]
if t_end - t > saveat
saveat
else
# The last interval might be shorter than saveat
rem = t % saveat
iszero(rem) ? saveat : rem
end
end

mean_flow_all = vcat(flow_vertical_integrated, flow_integrated)
mean_flow_all ./= Δt
fill!(flow_vertical_integrated, 0.0)
Δt = get_Δt(integrator, graph)
flow_mean = copy(flow_integrated)
flow_mean ./= Δt
fill!(flow_integrated, 0.0)

return mean_flow_all
return flow_mean
end

"Compute the average vertical fluxes over the last saveat interval and write
them to SavedValues"
function save_vertical_flux(u, t, integrator)
(; p) = integrator
(; basin, graph) = p
(; vertical_flux_integrated) = basin

Δt = get_Δt(integrator, graph)
vertical_flux_mean = copy(vertical_flux_integrated)
vertical_flux_mean ./= Δt
fill!(vertical_flux_mean, 0.0)

return vertical_flux_mean
end

function update_subgrid_level!(integrator)::Nothing
Expand All @@ -476,18 +468,21 @@ end

"Load updates from 'Basin / time' into the parameters"
function update_basin(integrator)::Nothing
(; basin) = integrator.p
(; node_id, time) = basin
(; p, u) = integrator
(; basin) = p
(; storage) = u
(; node_id, time, vertical_flux_from_input, vertical_flux, vertical_flux_prev) = basin
t = datetime_since(integrator.t, integrator.p.starttime)
vertical_flux = get_tmp(vertical_flux, integrator.u)

rows = searchsorted(time.time, t)
timeblock = view(time, rows)

table = (;
basin.precipitation,
basin.potential_evaporation,
basin.drainage,
basin.infiltration,
vertical_flux_from_input.precipitation,
vertical_flux_from_input.potential_evaporation,
vertical_flux_from_input.drainage,
vertical_flux_from_input.infiltration,
)

for row in timeblock
Expand All @@ -496,6 +491,12 @@ function update_basin(integrator)::Nothing
set_table_row!(table, row, i)
end

for (i, id) in enumerate(basin.node_id)
update_vertical_flux!(basin, storage, i)
end
SouthEndMusic marked this conversation as resolved.
Show resolved Hide resolved

# Forget about vertical fluxes before basin update
SouthEndMusic marked this conversation as resolved.
Show resolved Hide resolved
copyto!(vertical_flux_prev, vertical_flux)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be most accurate, maybe an exception should be made for vertical fluxes which are not set by the Basin \ time table.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically only some of the vertical fluxes will be in Basin \ time. Could you make an issue about this? I don't fully oversee the error introduced in this case.

return nothing
end

Expand Down
42 changes: 0 additions & 42 deletions core/src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
flow_counter = 0
# Dictionary from flow edge to index in flow vector
flow_dict = Dict{Tuple{NodeID, NodeID}, Int}()
# The number of nodes with vertical flow (interaction with outside of model)
flow_vertical_counter = 0
# Dictionary from node ID to index in vertical flow vector
flow_vertical_dict = Dict{NodeID, Int}()
graph = MetaGraph(
DiGraph();
label_type = NodeID,
Expand All @@ -49,10 +45,6 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
end
graph[node_id] =
NodeMetadata(Symbol(snake_case(row.node_type)), allocation_network_id)
if row.node_type in nonconservative_nodetypes
flow_vertical_counter += 1
flow_vertical_dict[node_id] = flow_vertical_counter
end
end

errors = false
Expand Down Expand Up @@ -105,12 +97,8 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
flow = zeros(flow_counter)
flow_prev = fill(NaN, flow_counter)
flow_integrated = zeros(flow_counter)
flow_vertical = zeros(flow_vertical_counter)
flow_vertical_prev = fill(NaN, flow_vertical_counter)
flow_vertical_integrated = zeros(flow_vertical_counter)
if config.solver.autodiff
flow = DiffCache(flow, chunk_sizes)
flow_vertical = DiffCache(flow_vertical, chunk_sizes)
end
graph_data = (;
node_ids,
Expand All @@ -120,10 +108,6 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
flow,
flow_prev,
flow_integrated,
flow_vertical_dict,
flow_vertical,
flow_vertical_prev,
flow_vertical_integrated,
config.solver.saveat,
)
graph = @set graph.graph_data = graph_data
Expand Down Expand Up @@ -195,15 +179,6 @@ function set_flow!(graph::MetaGraph, id_src::NodeID, id_dst::NodeID, q::Number):
return nothing
end

"""
Set the given flow q on the horizontal (self-loop) edge from id to id.
"""
function set_flow!(graph::MetaGraph, id::NodeID, q::Number)::Nothing
(; flow_vertical_dict, flow_vertical) = graph[]
get_tmp(flow_vertical, q)[flow_vertical_dict[id]] = q
return nothing
end

"""
Add the given flow q to the existing flow over the edge between the given nodes.
"""
Expand All @@ -213,15 +188,6 @@ function add_flow!(graph::MetaGraph, id_src::NodeID, id_dst::NodeID, q::Number):
return nothing
end

"""
Add the given flow q to the flow over the edge on the horizontal (self-loop) edge from id to id.
"""
function add_flow!(graph::MetaGraph, id::NodeID, q::Number)::Nothing
(; flow_vertical_dict, flow_vertical) = graph[]
get_tmp(flow_vertical, q)[flow_vertical_dict[id]] += q
return nothing
end

"""
Get the flow over the given edge (val is needed for get_tmp from ForwardDiff.jl).
"""
Expand All @@ -230,14 +196,6 @@ function get_flow(graph::MetaGraph, id_src::NodeID, id_dst::NodeID, val)::Number
return get_tmp(flow, val)[flow_dict[id_src, id_dst]]
end

"""
Get the flow over the given horizontal (selfloop) edge (val is needed for get_tmp from ForwardDiff.jl).
"""
function get_flow(graph::MetaGraph, id::NodeID, val)::Number
(; flow_vertical_dict, flow_vertical) = graph[]
return get_tmp(flow_vertical, val)[flow_vertical_dict[id]]
end

"""
Get the inneighbor node IDs of the given node ID (label)
over the given edge type in the graph.
Expand Down
3 changes: 2 additions & 1 deletion core/src/model.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
struct SavedResults
struct SavedResults{V1 <: ComponentVector{Float64}}
flow::SavedValues{Float64, Vector{Float64}}
vertical_flux::SavedValues{Float64, V1}
subgrid_level::SavedValues{Float64, Vector{Float64}}
end

Expand Down
Loading
Loading