Skip to content

Commit

Permalink
Precalculate resistance neighbors (#1436)
Browse files Browse the repository at this point in the history
On top of #1435.
This significantly speeds up the running time at the cost of using
slightly more storage.
  • Loading branch information
visr authored Apr 29, 2024
1 parent be32a44 commit 40cf6fa
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 55 deletions.
4 changes: 2 additions & 2 deletions core/src/allocation_init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ function add_constraints_conservation_node!(

# No flow conservation on nodes with FractionalFlow outneighbors
has_fractional_flow_outneighbors = any(
outneighbor_id.type == NodeType.FractionalFlow for
outneighbor_id in outflow_ids(graph, node_id)
outflow_id.type == NodeType.FractionalFlow for
outflow_id in outflow_ids(graph, node_id)
)

if is_source_sink | has_fractional_flow_outneighbors
Expand Down
8 changes: 4 additions & 4 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,16 @@ function save_flow(u, t, integrator)
outflow_mean = zeros(length(node_id))

for (i, basin_id) in enumerate(node_id)
for in_id in inflow_ids(graph, basin_id)
q = flow_mean[flow_dict[in_id, basin_id]]
for inflow_id in inflow_ids(graph, basin_id)
q = flow_mean[flow_dict[inflow_id, basin_id]]
if q > 0
inflow_mean[i] += q
else
outflow_mean[i] -= q
end
end
for out_id in outflow_ids(graph, basin_id)
q = flow_mean[flow_dict[basin_id, out_id]]
for outflow_id in outflow_ids(graph, basin_id)
q = flow_mean[flow_dict[basin_id, outflow_id]]
if q > 0
outflow_mean[i] += q
else
Expand Down
20 changes: 12 additions & 8 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,18 @@ struct TabulatedRatingCurve{C} <: AbstractParameterNode
end

"""
Requirements:
* from: must be (Basin,) node
* to: must be (Basin,) node
node_id: node ID of the LinearResistance node
inflow_id: node ID across the incoming flow edge
outflow_id: node ID across the outgoing flow edge
active: whether this node is active and thus contributes flows
resistance: the resistance to flow; `Q_unlimited = Δh/resistance`
max_flow_rate: the maximum flow rate allowed through the node; `Q = clamp(Q_unlimited, -max_flow_rate, max_flow_rate)`
control_mapping: dictionary from (node_id, control_state) to resistance and/or active state
"""
struct LinearResistance <: AbstractParameterNode
node_id::Vector{NodeID}
inflow_id::Vector{NodeID}
outflow_id::Vector{NodeID}
active::BitVector
resistance::Vector{Float64}
max_flow_rate::Vector{Float64}
Expand All @@ -263,8 +262,11 @@ end
"""
This is a simple Manning-Gauckler reach connection.
* Length describes the reach length.
* roughness describes Manning's n in (SI units).
node_id: node ID of the ManningResistance node
inflow_id: node ID across the incoming flow edge
outflow_id: node ID across the outgoing flow edge
length: reach length
manning_n: roughness; Manning's n in (SI units).
The profile is described by a trapezoid:
Expand All @@ -288,13 +290,15 @@ Requirements:
* from: must be (Basin,) node
* to: must be (Basin,) node
* length > 0
* roughess > 0
* manning_n > 0
* profile_width >= 0
* profile_slope >= 0
* (profile_width == 0) xor (profile_slope == 0)
"""
struct ManningResistance <: AbstractParameterNode
node_id::Vector{NodeID}
inflow_id::Vector{NodeID}
outflow_id::Vector{NodeID}
active::BitVector
length::Vector{Float64}
manning_n::Vector{Float64}
Expand Down
20 changes: 14 additions & 6 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ function initialize_allocation!(p::Parameters, config::Config)::Nothing
return nothing
end

function LinearResistance(db::DB, config::Config)::LinearResistance
function LinearResistance(db::DB, config::Config, graph::MetaGraph)::LinearResistance
static = load_structvector(db, config, LinearResistanceStaticV1)
defaults = (; max_flow_rate = Inf, active = true)
parsed_parameters, valid =
Expand All @@ -242,8 +242,12 @@ function LinearResistance(db::DB, config::Config)::LinearResistance
)
end

node_id = NodeID.(NodeType.LinearResistance, parsed_parameters.node_id)

return LinearResistance(
NodeID.(NodeType.LinearResistance, parsed_parameters.node_id),
node_id,
inflow_id.(Ref(graph), node_id),
outflow_id.(Ref(graph), node_id),
BitVector(parsed_parameters.active),
parsed_parameters.resistance,
parsed_parameters.max_flow_rate,
Expand Down Expand Up @@ -321,7 +325,7 @@ function TabulatedRatingCurve(db::DB, config::Config)::TabulatedRatingCurve
return TabulatedRatingCurve(node_ids, active, interpolations, time, control_mapping)
end

function ManningResistance(db::DB, config::Config)::ManningResistance
function ManningResistance(db::DB, config::Config, graph::MetaGraph)::ManningResistance
static = load_structvector(db, config, ManningResistanceStaticV1)
parsed_parameters, valid =
parse_static_and_time(db, config, "ManningResistance"; static)
Expand All @@ -330,8 +334,12 @@ function ManningResistance(db::DB, config::Config)::ManningResistance
error("Errors occurred when parsing ManningResistance data.")
end

node_id = NodeID.(NodeType.ManningResistance, parsed_parameters.node_id)

return ManningResistance(
NodeID.(NodeType.ManningResistance, parsed_parameters.node_id),
node_id,
inflow_id.(Ref(graph), node_id),
outflow_id.(Ref(graph), node_id),
BitVector(parsed_parameters.active),
parsed_parameters.length,
parsed_parameters.manning_n,
Expand Down Expand Up @@ -1029,8 +1037,8 @@ function Parameters(db::DB, config::Config)::Parameters
error("Invalid edge(s) found.")
end

linear_resistance = LinearResistance(db, config)
manning_resistance = ManningResistance(db, config)
linear_resistance = LinearResistance(db, config, graph)
manning_resistance = ManningResistance(db, config, graph)
tabulated_rating_curve = TabulatedRatingCurve(db, config)
fractional_flow = FractionalFlow(db, config)
level_boundary = LevelBoundary(db, config)
Expand Down
44 changes: 22 additions & 22 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,24 +336,24 @@ function formulate_flow!(
(; graph) = p
(; node_id, active, resistance, max_flow_rate) = linear_resistance
for (i, id) in enumerate(node_id)
basin_a_id = inflow_id(graph, id)
basin_b_id = outflow_id(graph, id)
inflow_id = linear_resistance.inflow_id[i]
outflow_id = linear_resistance.outflow_id[i]

if active[i]
h_a = get_level(p, basin_a_id, t; storage)
h_b = get_level(p, basin_b_id, t; storage)
h_a = get_level(p, inflow_id, t; storage)
h_b = get_level(p, outflow_id, t; storage)
q_unlimited = (h_a - h_b) / resistance[i]
q = clamp(q_unlimited, -max_flow_rate[i], max_flow_rate[i])

# add reduction_factor on highest level
if q > 0
q *= low_storage_factor(storage, p.basin.node_id, basin_a_id, 10.0)
q *= low_storage_factor(storage, p.basin.node_id, inflow_id, 10.0)
else
q *= low_storage_factor(storage, p.basin.node_id, basin_b_id, 10.0)
q *= low_storage_factor(storage, p.basin.node_id, outflow_id, 10.0)
end

set_flow!(graph, basin_a_id, id, q)
set_flow!(graph, id, basin_b_id, q)
set_flow!(graph, inflow_id, id, q)
set_flow!(graph, id, outflow_id, q)
end
end
return nothing
Expand Down Expand Up @@ -438,17 +438,17 @@ function formulate_flow!(
(; node_id, active, length, manning_n, profile_width, profile_slope) =
manning_resistance
for (i, id) in enumerate(node_id)
basin_a_id = inflow_id(graph, id)
basin_b_id = outflow_id(graph, id)
inflow_id = manning_resistance.inflow_id[i]
outflow_id = manning_resistance.outflow_id[i]

if !active[i]
continue
end

h_a = get_level(p, basin_a_id, t; storage)
h_b = get_level(p, basin_b_id, t; storage)
bottom_a = basin_bottom(basin, basin_a_id)
bottom_b = basin_bottom(basin, basin_b_id)
h_a = get_level(p, inflow_id, t; storage)
h_b = get_level(p, outflow_id, t; storage)
bottom_a = basin_bottom(basin, inflow_id)
bottom_b = basin_bottom(basin, outflow_id)
slope = profile_slope[i]
width = profile_width[i]
n = manning_n[i]
Expand Down Expand Up @@ -478,8 +478,8 @@ function formulate_flow!(

q = q_sign * A / n * R_h^(2 / 3) * sqrt(Δh / L * 2 / π * atan(k * Δh) + eps)

set_flow!(graph, basin_a_id, id, q)
set_flow!(graph, id, basin_b_id, q)
set_flow!(graph, inflow_id, id, q)
set_flow!(graph, id, outflow_id, q)
end
return nothing
end
Expand Down Expand Up @@ -515,15 +515,15 @@ function formulate_flow!(

for (i, id) in enumerate(node_id)
# Requirement: edge points away from the flow boundary
for dst_id in outflow_ids(graph, id)
for outflow_id in outflow_ids(graph, id)
if !active[i]
continue
end

rate = flow_rate[i](t)

# Adding water is always possible
set_flow!(graph, id, dst_id, rate)
set_flow!(graph, id, outflow_id, rate)
end
end
end
Expand Down Expand Up @@ -605,11 +605,11 @@ function formulate_du!(
# subtract all outgoing flows
# add all ingoing flows
for (i, basin_id) in enumerate(basin.node_id)
for in_id in inflow_ids(graph, basin_id)
du[i] += get_flow(graph, in_id, basin_id, storage)
for inflow_id in inflow_ids(graph, basin_id)
du[i] += get_flow(graph, inflow_id, basin_id, storage)
end
for out_id in outflow_ids(graph, basin_id)
du[i] -= get_flow(graph, basin_id, out_id, storage)
for outflow_id in outflow_ids(graph, basin_id)
du[i] -= get_flow(graph, basin_id, outflow_id, storage)
end
end
return nothing
Expand Down
10 changes: 5 additions & 5 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,15 @@ function get_fractional_flow_connected_basins(

has_fractional_flow_outneighbors = false

for first_outneighbor_id in outflow_ids(graph, node_id)
if first_outneighbor_id in fractional_flow.node_id
for first_outflow_id in outflow_ids(graph, node_id)
if first_outflow_id in fractional_flow.node_id
has_fractional_flow_outneighbors = true
second_outneighbor_id = outflow_id(graph, first_outneighbor_id)
has_index, basin_idx = id_index(basin.node_id, second_outneighbor_id)
second_outflow_id = outflow_id(graph, first_outflow_id)
has_index, basin_idx = id_index(basin.node_id, second_outflow_id)
if has_index
push!(
fractional_flow_idxs,
searchsortedfirst(fractional_flow.node_id, first_outneighbor_id),
searchsortedfirst(fractional_flow.node_id, first_outflow_id),
)
push!(basin_idxs, basin_idx)
end
Expand Down
10 changes: 4 additions & 6 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,19 +307,17 @@ function valid_fractional_flow(
control_states = Set{String}([key[2] for key in keys(control_mapping)])

for src_id in src_ids
src_outneighbor_ids = Set(outflow_ids(graph, src_id))
if src_outneighbor_ids node_id_set
src_outflow_ids = Set(outflow_ids(graph, src_id))
if src_outflow_ids node_id_set
errors = true
@error(
"$src_id combines fractional flow outneighbors with other outneigbor types."
)
@error("$src_id has outflow to FractionalFlow and other node types.")
end

# Each control state (including missing) must sum to 1
for control_state in control_states
fraction_sum = 0.0

for ff_id in intersect(src_outneighbor_ids, node_id_set)
for ff_id in intersect(src_outflow_ids, node_id_set)
parameter_values = get(control_mapping, (ff_id, control_state), nothing)
if parameter_values === nothing
continue
Expand Down
2 changes: 1 addition & 1 deletion core/test/validation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ end
@test length(logger.logs) == 4
@test logger.logs[1].level == Error
@test logger.logs[1].message ==
"TabulatedRatingCurve #7 combines fractional flow outneighbors with other outneigbor types."
"TabulatedRatingCurve #7 has outflow to FractionalFlow and other node types."
@test logger.logs[2].level == Error
@test logger.logs[2].message ==
"Fractional flow nodes must have non-negative fractions."
Expand Down
2 changes: 1 addition & 1 deletion python/ribasim_testmodels/ribasim_testmodels/invalid.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def invalid_fractional_flow_model() -> Model:
model.basin[1],
model.tabulated_rating_curve[7],
)
# Invalid: TabulatedRatingCurve #7 combines FractionalFlow outneighbors with other outneigbor types.
# Invalid: TabulatedRatingCurve #7 has outflow to FractionalFlow and other node types.
model.edge.add(
model.tabulated_rating_curve[7],
model.basin[2],
Expand Down

0 comments on commit 40cf6fa

Please sign in to comment.