Skip to content

Commit

Permalink
Mean output flows (#1159)
Browse files Browse the repository at this point in the history
Fixes #935.

---------

Co-authored-by: Martijn Visser <[email protected]>
  • Loading branch information
SouthEndMusic and visr authored Feb 28, 2024
1 parent ffb0bf9 commit 51ed6ed
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 97 deletions.
82 changes: 72 additions & 10 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function set_initial_discrete_controlled_parameters!(
storage0::Vector{Float64},
)::Nothing
(; p) = integrator
(; basin, discrete_control) = p
(; discrete_control) = p

n_conditions = length(discrete_control.condition_value)
condition_diffs = zeros(Float64, n_conditions)
Expand All @@ -30,7 +30,7 @@ Returns the CallbackSet and the SavedValues for flow.
"""
function create_callbacks(
parameters::Parameters,
config::Config;
config::Config,
saveat,
)::Tuple{CallbackSet, SavedResults}
(; starttime, basin, tabulated_rating_curve, discrete_control) = parameters
Expand All @@ -40,6 +40,9 @@ function create_callbacks(
basin_cb = PresetTimeCallback(tstops, update_basin)
push!(callbacks, basin_cb)

integrating_flows_cb = FunctionCallingCallback(integrate_flows!; func_start = false)
push!(callbacks, integrating_flows_cb)

tstops = get_tstops(tabulated_rating_curve.time.time, starttime)
tabulated_rating_curve_cb = PresetTimeCallback(tstops, update_tabulated_rating_curve!)
push!(callbacks, tabulated_rating_curve_cb)
Expand All @@ -55,7 +58,14 @@ function create_callbacks(

# save the flows over time, as a Vector of the nonzeros(flow)
saved_flow = SavedValues(Float64, Vector{Float64})
save_flow_cb = SavingCallback(save_flow, saved_flow; saveat, save_start = false)
save_flow_cb = SavingCallback(
save_flow,
saved_flow;
# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat,
save_start = false,
)
push!(callbacks, save_flow_cb)

# interpolate the levels
Expand Down Expand Up @@ -87,6 +97,37 @@ function create_callbacks(
return callback, saved
end

"""
Integrate flows over the last timestep
"""
function integrate_flows!(u, t, integrator)::Nothing
(; p, dt) = integrator
(; graph) = p
(;
flow,
flow_vertical,
flow_prev,
flow_vertical_prev,
flow_integrated,
flow_vertical_integrated,
) = graph[]
flow = get_tmp(flow, 0)
flow_vertical = get_tmp(flow_vertical, 0)

if !isempty(flow_prev) && isnan(flow_prev[1])
# If flow_prev is not populated yet
copyto!(flow_prev, flow)
copyto!(flow_vertical_prev, flow_vertical)
end

@. flow_integrated += 0.5 * (flow + flow_prev) * dt
@. flow_vertical_integrated += 0.5 * (flow_vertical + flow_vertical_prev) * dt

copyto!(flow_prev, flow)
copyto!(flow_vertical_prev, flow_vertical)
return nothing
end

"""
Listens for changes in condition truths.
"""
Expand Down Expand Up @@ -272,8 +313,7 @@ function discrete_control_affect!(

# What the local control state is
# TODO: Check time elapsed since control change
control_state_now, control_state_start =
discrete_control.control_state[discrete_control_node_id]
control_state_now, _ = discrete_control.control_state[discrete_control_node_id]

control_state_change = false

Expand Down Expand Up @@ -377,12 +417,34 @@ function set_control_params!(p::Parameters, node_id::NodeID, control_state::Stri
end
end

"Copy the current flow to the SavedValues"
"Compute the average flows over the last saveat interval and write
them to SavedValues"
function save_flow(u, t, integrator)
vcat(
get_tmp(integrator.p.graph[].flow_vertical, 0.0),
get_tmp(integrator.p.graph[].flow, 0.0),
)
(; dt, p) = integrator
(; graph) = p
(; flow_integrated, flow_vertical_integrated, saveat) = graph[]

Δt = if iszero(saveat)
dt
elseif isinf(saveat)
t
else
t_end = integrator.sol.prob.tspan[2]
if t_end - t > saveat
saveat
else
# The last interval might be shorter than saveat
rem = t % saveat
iszero(rem) ? saveat : rem
end
end

mean_flow_all = vcat(flow_vertical_integrated, flow_integrated)
mean_flow_all ./= Δt
fill!(flow_vertical_integrated, 0.0)
fill!(flow_integrated, 0.0)

return mean_flow_all
end

function update_subgrid_level!(integrator)::Nothing
Expand Down
15 changes: 11 additions & 4 deletions core/src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,19 +233,26 @@ end

"Convert the saveat Float64 from our Config to SciML's saveat"
function convert_saveat(saveat::Float64, t_end::Float64)::Union{Float64, Vector{Float64}}
errors = false
if iszero(saveat)
# every step
Float64[]
saveat = Float64[]
elseif saveat == Inf
# only the start and end
[0.0, t_end]
saveat = [0.0, t_end]
elseif isfinite(saveat)
# every saveat seconds
saveat
if saveat !== round(saveat)
errors = true
@error "A finite saveat must be an integer number of seconds." saveat
end
else
errors = true
@error "Invalid saveat" saveat
error("Invalid saveat")
end

errors && error("Invalid saveat")
return saveat
end

"Convert the dt from our Config to SciML stepsize control arguments"
Expand Down
15 changes: 13 additions & 2 deletions core/src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ and data of edges (EdgeMetadata):
[`EdgeMetadata`](@ref)
"""
function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGraph
node_rows = execute(db, "SELECT node_id, node_type, subnetwork_id FROM Node ORDER BY fid")
node_rows =
execute(db, "SELECT node_id, node_type, subnetwork_id FROM Node ORDER BY fid")
edge_rows = execute(
db,
"SELECT fid, from_node_type, from_node_id, to_node_type, to_node_id, edge_type, subnetwork_id FROM Edge ORDER BY fid",
Expand Down Expand Up @@ -44,7 +45,8 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
end
push!(node_ids[allocation_network_id], node_id)
end
graph[node_id] = NodeMetadata(Symbol(snake_case(row.node_type)), allocation_network_id)
graph[node_id] =
NodeMetadata(Symbol(snake_case(row.node_type)), allocation_network_id)
if row.node_type in nonconservative_nodetypes
flow_vertical_counter += 1
flow_vertical_dict[node_id] = flow_vertical_counter
Expand Down Expand Up @@ -90,7 +92,11 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
end

flow = zeros(flow_counter)
flow_prev = fill(NaN, flow_counter)
flow_integrated = zeros(flow_counter)
flow_vertical = zeros(flow_vertical_counter)
flow_vertical_prev = fill(NaN, flow_vertical_counter)
flow_vertical_integrated = zeros(flow_vertical_counter)
if config.solver.autodiff
flow = DiffCache(flow, chunk_sizes)
flow_vertical = DiffCache(flow_vertical, chunk_sizes)
Expand All @@ -101,8 +107,13 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra
edges_source,
flow_dict,
flow,
flow_prev,
flow_integrated,
flow_vertical_dict,
flow_vertical,
flow_vertical_prev,
flow_vertical_integrated,
config.solver.saveat,
)
graph = @set graph.graph_data = graph_data

Expand Down
17 changes: 10 additions & 7 deletions core/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ function Model(config::Config)::Model
# tell the solver to stop when new data comes in
# TODO add all time tables here
time_flow_boundary = load_structvector(db, config, FlowBoundaryTimeV1)
tstops_flow_boundary = get_tstops(time_flow_boundary.time, config.starttime)
tstops = Vector{Float64}[]
push!(tstops, get_tstops(time_flow_boundary.time, config.starttime))
time_user_demand = load_structvector(db, config, UserDemandTimeV1)
tstops_user_demand = get_tstops(time_user_demand.time, config.starttime)
tstops = sort(unique(vcat(tstops_flow_boundary, tstops_user_demand)))
push!(tstops, get_tstops(time_user_demand.time, config.starttime))

# use state
state = load_structvector(db, config, BasinStateV1)
Expand All @@ -105,7 +105,10 @@ function Model(config::Config)::Model
@assert eps(t_end) < 3600 "Simulation time too long"
t0 = zero(t_end)
timespan = (t0, t_end)

saveat = convert_saveat(config.solver.saveat, t_end)
saveat isa Float64 && push!(tstops, range(0, t_end; step = saveat))
tstops = sort(unique(vcat(tstops...)))
adaptive, dt = convert_dt(config.solver.dt)

jac_prototype = config.solver.sparse ? get_jac_prototype(parameters) : nothing
Expand All @@ -116,7 +119,7 @@ function Model(config::Config)::Model
end
@debug "Setup ODEProblem."

callback, saved = create_callbacks(parameters, config; saveat)
callback, saved = create_callbacks(parameters, config, saveat)
@debug "Created callbacks."

# Initialize the integrator, providing all solver options as described in
Expand Down Expand Up @@ -154,17 +157,17 @@ function Model(config::Config)::Model
end

"Get all saved times in seconds since start"
timesteps(model::Model)::Vector{Float64} = model.integrator.sol.t
tsaves(model::Model)::Vector{Float64} = model.integrator.sol.t

"Get all saved times as a Vector{DateTime}"
function datetimes(model::Model)::Vector{DateTime}
return datetime_since.(timesteps(model), model.config.starttime)
return datetime_since.(tsaves(model), model.config.starttime)
end

function Base.show(io::IO, model::Model)
(; config, integrator) = model
t = datetime_since(integrator.t, config.starttime)
nsaved = length(timesteps(model))
nsaved = length(tsaves(model))
println(io, "Model(ts: $nsaved, t: $t)")
end

Expand Down
5 changes: 5 additions & 0 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,13 @@ struct Parameters{T, C1, C2}
edges_source::Dict{Int, Set{EdgeMetadata}},
flow_dict::Dict{Tuple{NodeID, NodeID}, Int},
flow::T,
flow_prev::Vector{Float64},
flow_integrated::Vector{Float64},
flow_vertical_dict::Dict{NodeID, Int},
flow_vertical::T,
flow_vertical_prev::Vector{Float64},
flow_vertical_integrated::Vector{Float64},
saveat::Float64,
},
MetaGraphsNext.var"#11#13",
Float64,
Expand Down
14 changes: 0 additions & 14 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,6 @@ function formulate_du!(
basin::Basin,
storage::AbstractVector,
)::Nothing
(; flow_vertical) = graph[]
flow_vertical = get_tmp(flow_vertical, storage)
# loop over basins
# subtract all outgoing flows
# add all ingoing flows
Expand Down Expand Up @@ -665,15 +663,3 @@ function formulate_flows!(p::Parameters, storage::AbstractVector, t::Number)::No
formulate_flow!(level_boundary, p, storage, t)
formulate_flow!(terminal, p, storage, t)
end

function track_waterbalance!(u, t, integrator)::Nothing
(; p, tprev, uprev) = integrator
dt = t - tprev
du = u - uprev
p.storage_diff .+= du
p.precipitation.total .+= p.precipitation.value .* dt
p.evaporation.total .+= p.evaporation.value .* dt
p.infiltration.total .+= p.infiltration.value .* dt
p.drainage.total .+= p.drainage.value .* dt
return nothing
end
9 changes: 7 additions & 2 deletions core/src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function get_storages_and_levels(
(; sol, p) = integrator

node_id = p.basin.node_id.values::Vector{NodeID}
tsteps = datetime_since.(timesteps(model), config.starttime)
tsteps = datetime_since.(tsaves(model), config.starttime)

storage = hcat([collect(u_.storage) for u_ in sol.u]...)
level = zero(storage)
Expand Down Expand Up @@ -148,7 +148,12 @@ function flow_table(
nflow = length(unique_edge_ids_flow)
ntsteps = length(t)

time = repeat(datetime_since.(t, config.starttime); inner = nflow)
# the timestamp should represent the start of the period, not the end
t_starts = circshift(t, 1)
if !isempty(t)
t_starts[1] = 0.0
end
time = repeat(datetime_since.(t_starts, config.starttime); inner = nflow)
edge_id = repeat(unique_edge_ids_flow; outer = ntsteps)
from_node_type = repeat(from_node_type; outer = ntsteps)
from_node_id = repeat(from_node_id; outer = ntsteps)
Expand Down
2 changes: 1 addition & 1 deletion core/test/allocation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ end
model = Ribasim.run(toml_path)

storage = Ribasim.get_storages_and_levels(model).storage[1, :]
t = Ribasim.timesteps(model)
t = Ribasim.tsaves(model)

p = model.integrator.p
(; user_demand, graph, allocation, basin, level_demand) = p
Expand Down
2 changes: 1 addition & 1 deletion core/test/config_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ end
@test convert_saveat(0.0, t_end) == Float64[]
@test convert_saveat(60.0, t_end) == 60.0
@test convert_saveat(Inf, t_end) == [0.0, t_end]
@test convert_saveat(Inf, t_end) == [0.0, t_end]
@test_throws ErrorException convert_saveat(-Inf, t_end)
@test_throws ErrorException convert_saveat(NaN, t_end)
@test_throws ErrorException convert_saveat(3.1415, t_end)

@test convert_dt(nothing) == (true, 0.0)
@test convert_dt(360.0) == (false, 360.0)
Expand Down
Loading

0 comments on commit 51ed6ed

Please sign in to comment.