diff --git a/core/src/graph.jl b/core/src/graph.jl index dfd093452..c6f0e6d29 100644 --- a/core/src/graph.jl +++ b/core/src/graph.jl @@ -71,6 +71,7 @@ function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGra edge_type, subnetwork_id, (id_src, id_dst), + (0, 0), ) if haskey(graph, id_src, id_dst) errors = true diff --git a/core/src/parameter.jl b/core/src/parameter.jl index 33ef2f03a..4b03b690d 100644 --- a/core/src/parameter.jl +++ b/core/src/parameter.jl @@ -121,6 +121,7 @@ type: type of the edge subnetwork_id_source: ID of subnetwork where this edge is a source (0 if not a source) edge: (from node ID, to node ID) +basin_idxs: Basin indices of source and destination nodes (0 if not a basin) """ struct EdgeMetadata id::Int32 @@ -128,6 +129,7 @@ struct EdgeMetadata type::EdgeType.T subnetwork_id_source::Int32 edge::Tuple{NodeID, NodeID} + basin_idxs::Tuple{Int32, Int32} end abstract type AbstractParameterNode end diff --git a/core/src/read.jl b/core/src/read.jl index 5208c16d9..79904426b 100644 --- a/core/src/read.jl +++ b/core/src/read.jl @@ -1098,6 +1098,8 @@ function Parameters(db::DB, config::Config)::Parameters subgrid_level = Subgrid(db, config, basin) + set_basin_idxs!(graph, basin) + p = Parameters( config.starttime, graph, diff --git a/core/src/solve.jl b/core/src/solve.jl index 1f1bae89a..8abf0a1ac 100644 --- a/core/src/solve.jl +++ b/core/src/solve.jl @@ -656,12 +656,18 @@ function formulate_du!( # loop over basins # subtract all outgoing flows # add all ingoing flows - for (i, basin_id) in enumerate(basin.node_id) - for inflow_id in basin.inflow_ids[i] - du[i] += get_flow(graph, inflow_id, basin_id, storage) + for edge_metadata in values(graph.edge_data) + (; type, edge, basin_idxs) = edge_metadata + if type !== EdgeType.flow + continue end - for outflow_id in basin.outflow_ids[i] - du[i] -= get_flow(graph, basin_id, outflow_id, storage) + q = get_flow(graph, edge_metadata, storage) + from_id, to_id = edge + + if from_id.type == NodeType.Basin + du[basin_idxs[1]] -= q + elseif to_id.type == NodeType.Basin + du[basin_idxs[2]] += q end end return nothing diff --git a/core/src/util.jl b/core/src/util.jl index d963f161e..0e57d5ed9 100644 --- a/core/src/util.jl +++ b/core/src/util.jl @@ -739,3 +739,14 @@ inflow_edge(graph, node_id)::EdgeMetadata = graph[inflow_id(graph, node_id), nod outflow_edge(graph, node_id)::EdgeMetadata = graph[node_id, outflow_id(graph, node_id)] outflow_edges(graph, node_id)::Vector{EdgeMetadata} = [graph[node_id, outflow_id] for outflow_id in outflow_ids(graph, node_id)] + +function set_basin_idxs!(graph::MetaGraph, basin::Basin)::Nothing + for edge_metadata in values(graph.edge_data) + (; edge) = edge_metadata + id_src, id_dst = edge + edge_metadata = @set edge_metadata.basin_idxs = + (id_index(basin.node_id, id_src)[2], id_index(basin.node_id, id_dst)[2]) + graph[edge...] = edge_metadata + end + return nothing +end