Skip to content

Commit

Permalink
Finalize integration and saving
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Mar 25, 2024
1 parent 0b6ecd8 commit 560f695
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 21 deletions.
56 changes: 38 additions & 18 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ function create_callbacks(
push!(callbacks, allocation_cb)
end

# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat
saved_vertical_flux = SavedValues(Float64, typeof(basin.vertical_flux_integrated))
save_vertical_flux_cb =
SavingCallback(save_vertical_flux, saved_vertical_flux; saveat, save_start = false)
push!(callbacks, save_vertical_flux_cb)

# save the flows over time, as a Vector of the nonzeros(flow)
saved_flow = SavedValues(Float64, Vector{Float64})
save_flow_cb = SavingCallback(
save_flow,
saved_flow;
# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat,
save_start = false,
)
save_flow_cb = SavingCallback(save_flow, saved_flow; saveat, save_start = false)
push!(callbacks, save_flow_cb)

# interpolate the levels
Expand All @@ -85,7 +86,7 @@ function create_callbacks(
push!(callbacks, export_cb)
end

saved = SavedResults(saved_flow, saved_subgrid_level)
saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level)

n_conditions = length(discrete_control.node_id)
if n_conditions > 0
Expand Down Expand Up @@ -421,14 +422,10 @@ function set_control_params!(p::Parameters, node_id::NodeID, control_state::Stri
end
end

"Compute the average flows over the last saveat interval and write
them to SavedValues"
function save_flow(u, t, integrator)
(; dt, p) = integrator
(; graph) = p
(; flow_integrated, saveat) = graph[]

Δt = if iszero(saveat)
function get_Δt(integrator, graph)::Float64
(; t, dt) = integrator
(; saveat) = graph[]
if iszero(saveat)
dt
elseif isinf(saveat)
t
Expand All @@ -442,14 +439,36 @@ function save_flow(u, t, integrator)
iszero(rem) ? saveat : rem
end
end
end

"Compute the average flows over the last saveat interval and write
them to SavedValues"
function save_flow(u, t, integrator)
(; p) = integrator
(; graph) = p
(; flow_integrated) = graph[]

Δt = get_Δt(integrator, graph)
flow_mean = copy(flow_integrated)
flow_mean ./= Δt
fill!(flow_integrated, 0.0)

return flow_mean
end

function save_vertical_flux(u, t, integrator)
(; p) = integrator
(; basin, graph) = p
(; vertical_flux_integrated) = basin

Δt = get_Δt(integrator, graph)
vertical_flux_mean = copy(vertical_flux_integrated)
vertical_flux_mean ./= Δt
fill!(vertical_flux_mean, 0.0)

return vertical_flux_mean
end

function update_subgrid_level!(integrator)::Nothing
basin_level = get_tmp(integrator.p.basin.current_level, 0)
subgrid = integrator.p.subgrid
Expand All @@ -467,7 +486,7 @@ end
"Load updates from 'Basin / time' into the parameters"
function update_basin(integrator)::Nothing
(; basin) = integrator.p
(; node_id, time, vertical_flux) = basin
(; node_id, time, vertical_flux, vertical_flux_prev) = basin
t = datetime_since(integrator.t, integrator.p.starttime)
vertical_flux = get_tmp(vertical_flux, integrator.u)

Expand All @@ -487,6 +506,7 @@ function update_basin(integrator)::Nothing
set_table_row!(table, row, i)
end

copyto!(vertical_flux_prev, vertical_flux)
return nothing
end

Expand Down
3 changes: 2 additions & 1 deletion core/src/model.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
struct SavedResults
struct SavedResults{V1 <: ComponentVector{Float64}}
flow::SavedValues{Float64, Vector{Float64}}
vertical_flux::SavedValues{Float64, V1}
subgrid_level::SavedValues{Float64, Vector{Float64}}
end

Expand Down
45 changes: 43 additions & 2 deletions core/src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,57 @@ function basin_table(
node_id::Vector{Int32},
storage::Vector{Float64},
level::Vector{Float64},
precipitation::Vector{Float64},
evaporation::Vector{Float64},
drainage::Vector{Float64},
infiltration::Vector{Float64},
}
(; saved) = model
(; vertical_flux) = saved

data = get_storages_and_levels(model)
storage = vec(data.storage)
level = vec(data.level)

nbasin = length(data.node_id)
ntsteps = length(data.time)
nrows = nbasin * ntsteps

precipitation = zeros(nrows)
evaporation = zeros(nrows)
drainage = zeros(nrows)
infiltration = zeros(nrows)

idx_row = 0

for vec in vertical_flux.saveval
for (precipitation_, evaporation_, drainage_, infiltration_) in zip(
vec.precipitation,
vec.potential_evaporation,
vec.drainage,
vec.infiltration,
)
idx_row += 1
precipitation[idx_row] = precipitation_
evaporation[idx_row] = evaporation_
drainage[idx_row] = drainage_
infiltration[idx_row] = infiltration_
end
end

time = repeat(data.time; inner = nbasin)
node_id = repeat(Int32.(data.node_id); outer = ntsteps)

return (; time, node_id, storage = vec(data.storage), level = vec(data.level))
return (;
time,
node_id,
storage,
level,
precipitation,
evaporation,
drainage,
infiltration,
)
end

"Create a flow result table from the saved data"
Expand All @@ -112,7 +154,6 @@ function flow_table(
(; graph) = integrator.p
(; flow_dict) = graph[]

# self-loops have no edge ID
from_node_type = String[]
from_node_id = Int32[]
to_node_type = String[]
Expand Down

0 comments on commit 560f695

Please sign in to comment.