diff --git a/core/src/Ribasim.jl b/core/src/Ribasim.jl index cb94448cb..0827e4a87 100644 --- a/core/src/Ribasim.jl +++ b/core/src/Ribasim.jl @@ -49,7 +49,13 @@ using Legolas: Legolas, @schema, @version, validate, SchemaVersion, declared using Logging: current_logger, min_enabled_level, with_logger using LoggingExtras: EarlyFilteredLogger, LevelOverrideLogger using MetaGraphsNext: - MetaGraphsNext, MetaGraph, label_for, labels, outneighbor_labels, inneighbor_labels + MetaGraphsNext, + MetaGraph, + label_for, + code_for, + labels, + outneighbor_labels, + inneighbor_labels using OrdinaryDiffEq using PreallocationTools: DiffCache, FixedSizeDiffCache, get_tmp using SciMLBase diff --git a/core/src/allocation.jl b/core/src/allocation.jl index ef7c52b4c..e1ec98e2c 100644 --- a/core/src/allocation.jl +++ b/core/src/allocation.jl @@ -13,8 +13,8 @@ function allocation_graph_used_nodes!(p::Parameters, allocation_network_id::Int) node_type = graph[node_id].type if node_type in [:user, :basin] push!(used_nodes, node_id) - - elseif length(inoutflow_ids(graph, node_id)) > 2 + elseif count(x -> true, inoutflow_ids(graph, node_id)) > 2 + # use count since the length of the iterator is unknown push!(used_nodes, node_id) end end diff --git a/core/src/solve.jl b/core/src/solve.jl index 06c05a957..ff61b7a47 100644 --- a/core/src/solve.jl +++ b/core/src/solve.jl @@ -515,8 +515,8 @@ function valid_n_neighbors(node::AbstractParameterNode, graph::MetaGraph)::Bool for id in node.node_id for (bounds, edge_type) in zip((bounds_flow, bounds_control), (EdgeType.flow, EdgeType.control)) - n_inneighbors = length(inneighbor_labels_type(graph, id, edge_type)) - n_outneighbors = length(outneighbor_labels_type(graph, id, edge_type)) + n_inneighbors = count(x -> true, inneighbor_labels_type(graph, id, edge_type)) + n_outneighbors = count(x -> true, outneighbor_labels_type(graph, id, edge_type)) if n_inneighbors < bounds.in_min @error "Nodes of type $node_type must have at least $(bounds.in_min) $edge_type inneighbor(s) (got $n_inneighbors for node $id)." diff --git a/core/src/utils.jl b/core/src/utils.jl index 2174f7b7d..c124b4b26 100644 --- a/core/src/utils.jl +++ b/core/src/utils.jl @@ -74,6 +74,61 @@ function create_graph(db::DB)::MetaGraph return graph end +abstract type AbstractNeighbors end + +""" +Iterate over incoming neighbors of a given label in a MetaGraph, only for edges of edge_type +""" +struct InNeighbors{T} <: AbstractNeighbors + graph::T + label::NodeID + edge_type::EdgeType.T +end + +""" +Iterate over outgoing neighbors of a given label in a MetaGraph, only for edges of edge_type +""" +struct OutNeighbors{T} <: AbstractNeighbors + graph::T + label::NodeID + edge_type::EdgeType.T +end + +Base.IteratorSize(::Type{<:AbstractNeighbors}) = Base.SizeUnknown() +Base.eltype(::Type{<:AbstractNeighbors}) = NodeID + +function Base.iterate(iter::InNeighbors, state = 1) + (; graph, label, edge_type) = iter + code = code_for(graph, label) + local label_in + while true + x = iterate(inneighbors(graph, code), state) + x === nothing && return nothing + code_in, state = x + label_in = label_for(graph, code_in) + if graph[label_in, label].type == edge_type + break + end + end + return label_in, state +end + +function Base.iterate(iter::OutNeighbors, state = 1) + (; graph, label, edge_type) = iter + code = code_for(graph, label) + local label_out + while true + x = iterate(outneighbors(graph, code), state) + x === nothing && return nothing + code_out, state = x + label_out = label_for(graph, code_out) + if graph[label, label_out].type == edge_type + break + end + end + return label_out, state +end + """ Get the inneighbor node IDs of the given node ID (label) over the given edge type in the graph. @@ -82,11 +137,8 @@ function inneighbor_labels_type( graph::MetaGraph, label::NodeID, edge_type::EdgeType.T, -)::Vector{NodeID} - return [ - label_in for label_in in inneighbor_labels(graph, label) if - graph[label_in, label].type == edge_type - ] +)::InNeighbors + return InNeighbors(graph, label, edge_type) end """ @@ -97,11 +149,8 @@ function outneighbor_labels_type( graph::MetaGraph, label::NodeID, edge_type::EdgeType.T, -)::Vector{NodeID} - return [ - label_out for label_out in outneighbor_labels(graph, label) if - graph[label, label_out].type == edge_type - ] +)::OutNeighbors + return OutNeighbors(graph, label, edge_type) end """ @@ -112,31 +161,31 @@ function all_neighbor_labels_type( graph::MetaGraph, label::NodeID, edge_type::EdgeType.T, -)::Vector{NodeID} - return [ - outneighbor_labels_type(graph, label, edge_type)..., - inneighbor_labels_type(graph, label, edge_type)..., - ] +)::Iterators.Flatten + return Iterators.flatten(( + outneighbor_labels_type(graph, label, edge_type), + inneighbor_labels_type(graph, label, edge_type), + )) end """ Get the outneighbors over flow edges. """ -function outflow_ids(graph::MetaGraph, id::NodeID)::Vector{NodeID} +function outflow_ids(graph::MetaGraph, id::NodeID)::OutNeighbors return outneighbor_labels_type(graph, id, EdgeType.flow) end """ Get the inneighbors over flow edges. """ -function inflow_ids(graph::MetaGraph, id::NodeID)::Vector{NodeID} +function inflow_ids(graph::MetaGraph, id::NodeID)::InNeighbors return inneighbor_labels_type(graph, id, EdgeType.flow) end """ Get the in- and outneighbors over flow edges. """ -function inoutflow_ids(graph::MetaGraph, id::NodeID)::Vector{NodeID} +function inoutflow_ids(graph::MetaGraph, id::NodeID)::Iterators.Flatten return all_neighbor_labels_type(graph, id, EdgeType.flow) end