diff --git a/core/src/parameter.jl b/core/src/parameter.jl index 479a45442..e3742808e 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -162,6 +162,8 @@ end """ struct Basin{T, C, V1, V2, V3} <: AbstractParameterNode node_id::Indices{NodeID} + inflow_ids::Vector{Vector{NodeID}} + outflow_ids::Vector{Vector{NodeID}} # Vertical fluxes vertical_flux_from_input::V1 vertical_flux::V2 @@ -182,6 +184,8 @@ struct Basin{T, C, V1, V2, V3} <: AbstractParameterNode function Basin( node_id, + inflow_ids, + outflow_ids, vertical_flux_from_input::V1, vertical_flux::V2, vertical_flux_prev::V3, @@ -199,6 +203,8 @@ struct Basin{T, C, V1, V2, V3} <: AbstractParameterNode is_valid || error("Invalid Basin / profile table.") return new{T, C, V1, V2, V3}( node_id, + inflow_ids, + outflow_ids, vertical_flux_from_input, vertical_flux, vertical_flux_prev, diff --git a/core/src/read.jl b/core/src/read.jl index f2efc9050..00a17b895 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -490,7 +490,7 @@ function Terminal(db::DB, config::Config)::Terminal return Terminal(NodeID.(NodeType.Terminal, static.node_id)) end -function Basin(db::DB, config::Config, chunk_sizes::Vector{Int})::Basin +function Basin(db::DB, config::Config, graph::MetaGraph, chunk_sizes::Vector{Int})::Basin node_id = get_ids(db, "Basin") n = length(node_id) current_level = zeros(n) @@ -533,8 +533,12 @@ function Basin(db::DB, config::Config, chunk_sizes::Vector{Int})::Basin demand = zeros(length(node_id)) + node_id = NodeID.(NodeType.Basin, node_id) + return Basin( - Indices(NodeID.(NodeType.Basin, node_id)), + Indices(node_id), + [collect(inflow_ids(graph, id)) for id in node_id], + [collect(outflow_ids(graph, id)) for id in node_id], vertical_flux_from_input, vertical_flux, vertical_flux_prev, @@ -1053,7 +1057,7 @@ function Parameters(db::DB, config::Config)::Parameters level_demand = LevelDemand(db, config) flow_demand = FlowDemand(db, config) - basin = Basin(db, config, chunk_sizes) + basin = Basin(db, config, graph, chunk_sizes) subgrid_level = Subgrid(db, config, basin) p = Parameters( diff --git a/core/src/solve.jl b/core/src/solve.jl index 160d00718..96ec3b6f5 100644 --- a/core/src/solve.jl +++ b/core/src/solve.jl @@ -607,10 +607,10 @@ function formulate_du!( # subtract all outgoing flows # add all ingoing flows for (i, basin_id) in enumerate(basin.node_id) - for inflow_id in inflow_ids(graph, basin_id) + for inflow_id in basin.inflow_ids[i] du[i] += get_flow(graph, inflow_id, basin_id, storage) end - for outflow_id in outflow_ids(graph, basin_id) + for outflow_id in basin.outflow_ids[i] du[i] -= get_flow(graph, basin_id, outflow_id, storage) end end diff --git a/core/test/utils_test.jl b/core/test/utils_test.jl index e6dadb501..fab1b3ab1 100644 --- a/core/test/utils_test.jl +++ b/core/test/utils_test.jl @@ -38,6 +38,8 @@ end demand = zeros(2) basin = Ribasim.Basin( Indices(NodeID.(:Basin, [5, 7])), + [NodeID[]], + [NodeID[]], [2.0, 3.0], [2.0, 3.0], [2.0, 3.0], @@ -92,6 +94,8 @@ end demand = zeros(1) basin = Ribasim.Basin( Indices(NodeID.(:Basin, [1])), + [NodeID[]], + [NodeID[]], zeros(1), zeros(1), zeros(1),