Skip to content

Commit

Permalink
Integrate vetrical fluxes
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Mar 21, 2024
1 parent 2a1e661 commit ee61d4b
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 44 deletions.
23 changes: 14 additions & 9 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,18 @@ Integrate flows over the last timestep
"""
function integrate_flows!(u, t, integrator)::Nothing
(; p, dt) = integrator
(; graph, user_demand) = p
(; graph, user_demand, basin) = p
(; flow, flow_dict, flow_prev, flow_integrated) = graph[]
(; vertical_flux, vertical_flux_prev, vertical_flux_integrated) = basin
flow = get_tmp(flow, 0)
vertical_flux = get_tmp(vertical_flux, 0)
if !isempty(flow_prev) && isnan(flow_prev[1])
# If flow_prev is not populated yet
copyto!(flow_prev, flow)
end

@. flow_integrated += 0.5 * (flow + flow_prev) * dt
@. vertical_flux_integrated += 0.5 * (vertical_flux + vertical_flux_prev) * dt

for (i, id) in enumerate(user_demand.node_id)
src_id = inflow_id(graph, id)
Expand All @@ -125,6 +128,7 @@ function integrate_flows!(u, t, integrator)::Nothing
end

copyto!(flow_prev, flow)
copyto!(vertical_flux_prev, vertical_flux)
return nothing
end

Expand Down Expand Up @@ -439,11 +443,11 @@ function save_flow(u, t, integrator)
end
end

mean_flow_all = vcat(flow_integrated)
mean_flow_all ./= Δt
flow_mean = copy(flow_integrated)
flow_mean ./= Δt
fill!(flow_integrated, 0.0)

return mean_flow_all
return flow_mean
end

function update_subgrid_level!(integrator)::Nothing
Expand All @@ -463,17 +467,18 @@ end
"Load updates from 'Basin / time' into the parameters"
function update_basin(integrator)::Nothing
(; basin) = integrator.p
(; node_id, time) = basin
(; node_id, time, vertical_flux) = basin
t = datetime_since(integrator.t, integrator.p.starttime)
vertical_flux = get_tmp(vertical_flux, integrator.u)

rows = searchsorted(time.time, t)
timeblock = view(time, rows)

table = (;
basin.precipitation,
basin.potential_evaporation,
basin.drainage,
basin.infiltration,
vertical_flux.precipitation,
vertical_flux.potential_evaporation,
vertical_flux.drainage,
vertical_flux.infiltration,
)

for row in timeblock
Expand Down
31 changes: 16 additions & 15 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ else
T = Vector{Float64}
end
"""
struct Basin{T, C} <: AbstractParameterNode
struct Basin{T, C, V1, V2} <: AbstractParameterNode
node_id::Indices{NodeID}
precipitation::Vector{Float64}
potential_evaporation::Vector{Float64}
drainage::Vector{Float64}
infiltration::Vector{Float64}
# Vertical fluxes
vertical_flux_from_input::V1
vertical_flux::V2
vertical_flux_prev::V1
vertical_flux_integrated::V1
# Cache this to avoid recomputation
current_level::T
current_area::T
Expand All @@ -171,26 +172,26 @@ struct Basin{T, C} <: AbstractParameterNode

function Basin(
node_id,
precipitation,
potential_evaporation,
drainage,
infiltration,
vertical_flux_from_input::V1,
vertical_flux::V2,
vertical_flux_prev,
vertical_flux_integrated,
current_level::T,
current_area::T,
area,
level,
storage,
demand,
time::StructVector{BasinTimeV1, C, Int},
) where {T, C}
) where {T, C, V1, V2}
is_valid = valid_profiles(node_id, level, area)
is_valid || error("Invalid Basin / profile table.")
return new{T, C}(
return new{T, C, V1, V2}(
node_id,
precipitation,
potential_evaporation,
drainage,
infiltration,
vertical_flux_from_input,
vertical_flux,
vertical_flux_prev,
vertical_flux_integrated,
current_level,
current_area,
area,
Expand Down
33 changes: 20 additions & 13 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,15 +488,10 @@ function Basin(db::DB, config::Config, chunk_sizes::Vector{Int})::Basin
current_level = zeros(n)
current_area = zeros(n)

if config.solver.autodiff
current_level = DiffCache(current_level, chunk_sizes)
current_area = DiffCache(current_area, chunk_sizes)
end

precipitation = zeros(length(node_id))
potential_evaporation = zeros(length(node_id))
drainage = zeros(length(node_id))
infiltration = zeros(length(node_id))
precipitation = zeros(n)
potential_evaporation = zeros(n)
drainage = zeros(n)
infiltration = zeros(n)
table = (; precipitation, potential_evaporation, drainage, infiltration)

area, level, storage = create_storage_tables(db, config)
Expand All @@ -509,14 +504,26 @@ function Basin(db::DB, config::Config, chunk_sizes::Vector{Int})::Basin
set_current_value!(table, node_id, time, config.starttime)
check_no_nans(table, "Basin")

vertical_flux_from_input =
ComponentVector(; precipitation, potential_evaporation, drainage, infiltration)
vertical_flux = zero(vertical_flux_from_input)
vertical_flux_prev = zero(vertical_flux_from_input)
vertical_flux_integrated = zero(vertical_flux_from_input)

if config.solver.autodiff
current_level = DiffCache(current_level, chunk_sizes)
current_area = DiffCache(current_area, chunk_sizes)
vertical_flux = DiffCache(vertical_flux, chunk_sizes)
end

demand = zeros(length(node_id))

return Basin(
Indices(NodeID.(NodeType.Basin, node_id)),
precipitation,
potential_evaporation,
drainage,
infiltration,
vertical_flux_from_input,
vertical_flux,
vertical_flux_prev,
vertical_flux_integrated,
current_level,
current_area,
area,
Expand Down
20 changes: 13 additions & 7 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function water_balance!(
set_current_basin_properties!(basin, storage)

# Basin forcings
formulate_basins!(du, basin, graph, storage)
formulate_basins!(du, basin, storage)

# First formulate intermediate flows
formulate_flows!(p, storage, t)
Expand Down Expand Up @@ -54,12 +54,13 @@ Currently at less than 0.1 m.
function formulate_basins!(
du::AbstractVector,
basin::Basin,
graph::MetaGraph,
storage::AbstractVector,
)::Nothing
(; node_id, current_level, current_area) = basin
(; node_id, current_level, current_area, vertical_flux_from_input, vertical_flux) =
basin
current_level = get_tmp(current_level, storage)
current_area = get_tmp(current_area, storage)
vertical_flux = get_tmp(vertical_flux, storage)

for (i, id) in enumerate(node_id)
# add all precipitation that falls within the profile
Expand All @@ -71,10 +72,15 @@ function formulate_basins!(
depth = max(level - bottom, 0.0)
factor = reduction_factor(depth, 0.1)

precipitation = fixed_area * basin.precipitation[i]
evaporation = area * factor * basin.potential_evaporation[i]
drainage = basin.drainage[i]
infiltration = factor * basin.infiltration[i]
precipitation = fixed_area * vertical_flux_from_input.precipitation[i]
evaporation = area * factor * vertical_flux_from_input.potential_evaporation[i]
drainage = vertical_flux_from_input.drainage[i]
infiltration = factor * vertical_flux_from_input.infiltration[i]

vertical_flux.precipitation[i] = precipitation
vertical_flux.potential_evaporation[i] = evaporation
vertical_flux.drainage[i] = drainage
vertical_flux.infiltration[i] = infiltration

influx = precipitation - evaporation + drainage - infiltration
du.storage[i] += influx
Expand Down

0 comments on commit ee61d4b

Please sign in to comment.