Skip to content

Commit

Permalink
Avoid lookups in get_level for UserDemand, ManningResistance
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed May 8, 2024
1 parent 30cbbf4 commit acf434f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
4 changes: 2 additions & 2 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,8 @@ function Parameters(db::DB, config::Config)::Parameters
end

basin = Basin(db, config, graph, chunk_sizes)
set_basin_idxs!(graph, basin)

linear_resistance = LinearResistance(db, config, graph)
manning_resistance = ManningResistance(db, config, graph, basin)
tabulated_rating_curve = TabulatedRatingCurve(db, config, graph)
Expand All @@ -1098,8 +1100,6 @@ function Parameters(db::DB, config::Config)::Parameters

subgrid_level = Subgrid(db, config, basin)

set_basin_idxs!(graph, basin)

p = Parameters(
config.starttime,
graph,
Expand Down
12 changes: 6 additions & 6 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function water_balance!(
formulate_flows!(p, storage, t)

# Now formulate du
formulate_du!(du, graph, basin, storage)
formulate_du!(du, graph, storage)

# PID control (changes the du of PID controlled basins)
continuous_control!(u, du, pid_control, p, integral, t)
Expand Down Expand Up @@ -321,7 +321,7 @@ function formulate_flow!(

# Smoothly let abstraction go to 0 as the source basin
# level reaches its minimum level
_, source_level = get_level(p, inflow_id, t; storage)
source_level = get_level(inflow_edge.basin_idxs[1], basin; storage)
Δsource_level = source_level - min_level
factor_level = reduction_factor(Δsource_level, 0.1)
q *= factor_level
Expand Down Expand Up @@ -471,8 +471,8 @@ function formulate_flow!(
continue
end

_, h_a = get_level(p, inflow_id, t; storage)
_, h_b = get_level(p, outflow_id, t; storage)
h_a = get_level(inflow_edge.basin_idxs[1], basin; storage)
h_b = get_level(outflow_edge.basin_idxs[2], basin; storage)
bottom_a = upstream_bottom[i]
bottom_b = downstream_bottom[i]
slope = profile_slope[i]
Expand Down Expand Up @@ -650,7 +650,6 @@ end
function formulate_du!(
du::ComponentVector,
graph::MetaGraph,
basin::Basin,
storage::AbstractVector,
)::Nothing
# loop over basins
Expand All @@ -661,12 +660,13 @@ function formulate_du!(
if type !== EdgeType.flow
continue
end
q = get_flow(graph, edge_metadata, storage)
from_id, to_id = edge

if from_id.type == NodeType.Basin
q = get_flow(graph, edge_metadata, storage)
du[basin_idxs[1]] -= q
elseif to_id.type == NodeType.Basin
q = get_flow(graph, edge_metadata, storage)
du[basin_idxs[2]] += q
end
end
Expand Down
16 changes: 12 additions & 4 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ function get_level(
(; basin, level_boundary) = p
if node_id.type == NodeType.Basin
_, i = id_index(basin.node_id, node_id)
current_level = get_tmp(basin.current_level, storage)
return true, current_level[i]
return true, get_level(i, basin; storage)
elseif node_id.type == NodeType.LevelBoundary
i = findsorted(level_boundary.node_id, node_id)
return true, level_boundary.level[i](t)
Expand All @@ -390,6 +389,14 @@ function get_level(
end
end

function get_level(
i::Integer,
basin::Basin;
storage::Union{AbstractArray, Number} = 0,
)::Number
return get_tmp(basin.current_level, storage)[i]
end

"Get the index of an ID in a set of indices."
function id_index(ids::Indices{NodeID}, id::NodeID)::Tuple{Bool, Int}
# We avoid creating Dictionary here since it converts the values to a Vector,
Expand Down Expand Up @@ -741,12 +748,13 @@ outflow_edges(graph, node_id)::Vector{EdgeMetadata} =
[graph[node_id, outflow_id] for outflow_id in outflow_ids(graph, node_id)]

function set_basin_idxs!(graph::MetaGraph, basin::Basin)::Nothing
for edge_metadata in values(graph.edge_data)
(; edge) = edge_metadata
for (edge, edge_metadata) in graph.edge_data
id_src, id_dst = edge
edge_metadata = @set edge_metadata.basin_idxs =
(id_index(basin.node_id, id_src)[2], id_index(basin.node_id, id_dst)[2])
graph[edge...] = edge_metadata
if edge[2] == NodeID(:ManningResistance, 8)
end
end
return nothing
end

0 comments on commit acf434f

Please sign in to comment.