Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use custom iterators to speed up inflow_ids and friends #830

Merged
merged 2 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ using Legolas: Legolas, @schema, @version, validate, SchemaVersion, declared
using Logging: current_logger, min_enabled_level, with_logger
using LoggingExtras: EarlyFilteredLogger, LevelOverrideLogger
using MetaGraphsNext:
MetaGraphsNext, MetaGraph, label_for, labels, outneighbor_labels, inneighbor_labels
MetaGraphsNext,
MetaGraph,
label_for,
code_for,
labels,
outneighbor_labels,
inneighbor_labels
using OrdinaryDiffEq
using PreallocationTools: DiffCache, FixedSizeDiffCache, get_tmp
using SciMLBase
Expand Down
4 changes: 2 additions & 2 deletions core/src/allocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ function allocation_graph_used_nodes!(p::Parameters, allocation_network_id::Int)
node_type = graph[node_id].type
if node_type in [:user, :basin]
push!(used_nodes, node_id)

elseif length(inoutflow_ids(graph, node_id)) > 2
elseif count(x -> true, inoutflow_ids(graph, node_id)) > 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is that preferable? I've assumed that length simply consumes unsized iterators, just like your code using count.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's actually not defined, a MethodError.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see :)

# use count since the length of the iterator is unknown
push!(used_nodes, node_id)
end
end
Expand Down
4 changes: 2 additions & 2 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,8 +515,8 @@ function valid_n_neighbors(node::AbstractParameterNode, graph::MetaGraph)::Bool
for id in node.node_id
for (bounds, edge_type) in
zip((bounds_flow, bounds_control), (EdgeType.flow, EdgeType.control))
n_inneighbors = length(inneighbor_labels_type(graph, id, edge_type))
n_outneighbors = length(outneighbor_labels_type(graph, id, edge_type))
n_inneighbors = count(x -> true, inneighbor_labels_type(graph, id, edge_type))
n_outneighbors = count(x -> true, outneighbor_labels_type(graph, id, edge_type))

if n_inneighbors < bounds.in_min
@error "Nodes of type $node_type must have at least $(bounds.in_min) $edge_type inneighbor(s) (got $n_inneighbors for node $id)."
Expand Down
85 changes: 67 additions & 18 deletions core/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,61 @@
return graph
end

abstract type AbstractNeighbors end

"""
Iterate over incoming neighbors of a given label in a MetaGraph, only for edges of edge_type
"""
struct InNeighbors{T} <: AbstractNeighbors
graph::T
label::NodeID
edge_type::EdgeType.T
end

"""
Iterate over outgoing neighbors of a given label in a MetaGraph, only for edges of edge_type
"""
struct OutNeighbors{T} <: AbstractNeighbors
graph::T
label::NodeID
edge_type::EdgeType.T
end

Base.IteratorSize(::Type{<:AbstractNeighbors}) = Base.SizeUnknown()
Base.eltype(::Type{<:AbstractNeighbors}) = NodeID

Check warning on line 98 in core/src/utils.jl

View check run for this annotation

Codecov / codecov/patch

core/src/utils.jl#L97-L98

Added lines #L97 - L98 were not covered by tests

function Base.iterate(iter::InNeighbors, state = 1)
(; graph, label, edge_type) = iter
code = code_for(graph, label)
local label_in

Check warning on line 103 in core/src/utils.jl

View check run for this annotation

Codecov / codecov/patch

core/src/utils.jl#L103

Added line #L103 was not covered by tests
while true
x = iterate(inneighbors(graph, code), state)
x === nothing && return nothing
code_in, state = x
label_in = label_for(graph, code_in)
if graph[label_in, label].type == edge_type
break
end
end
return label_in, state
end

function Base.iterate(iter::OutNeighbors, state = 1)
(; graph, label, edge_type) = iter
code = code_for(graph, label)
local label_out

Check warning on line 119 in core/src/utils.jl

View check run for this annotation

Codecov / codecov/patch

core/src/utils.jl#L119

Added line #L119 was not covered by tests
while true
x = iterate(outneighbors(graph, code), state)
x === nothing && return nothing
code_out, state = x
label_out = label_for(graph, code_out)
if graph[label, label_out].type == edge_type
break
end
end
return label_out, state
end

"""
Get the inneighbor node IDs of the given node ID (label)
over the given edge type in the graph.
Expand All @@ -82,11 +137,8 @@
graph::MetaGraph,
label::NodeID,
edge_type::EdgeType.T,
)::Vector{NodeID}
return [
label_in for label_in in inneighbor_labels(graph, label) if
graph[label_in, label].type == edge_type
]
)::InNeighbors
return InNeighbors(graph, label, edge_type)
end

"""
Expand All @@ -97,11 +149,8 @@
graph::MetaGraph,
label::NodeID,
edge_type::EdgeType.T,
)::Vector{NodeID}
return [
label_out for label_out in outneighbor_labels(graph, label) if
graph[label, label_out].type == edge_type
]
)::OutNeighbors
return OutNeighbors(graph, label, edge_type)
end

"""
Expand All @@ -112,31 +161,31 @@
graph::MetaGraph,
label::NodeID,
edge_type::EdgeType.T,
)::Vector{NodeID}
return [
outneighbor_labels_type(graph, label, edge_type)...,
inneighbor_labels_type(graph, label, edge_type)...,
]
)::Iterators.Flatten
return Iterators.flatten((
outneighbor_labels_type(graph, label, edge_type),
inneighbor_labels_type(graph, label, edge_type),
))
end

"""
Get the outneighbors over flow edges.
"""
function outflow_ids(graph::MetaGraph, id::NodeID)::Vector{NodeID}
function outflow_ids(graph::MetaGraph, id::NodeID)::OutNeighbors
return outneighbor_labels_type(graph, id, EdgeType.flow)
end

"""
Get the inneighbors over flow edges.
"""
function inflow_ids(graph::MetaGraph, id::NodeID)::Vector{NodeID}
function inflow_ids(graph::MetaGraph, id::NodeID)::InNeighbors
return inneighbor_labels_type(graph, id, EdgeType.flow)
end

"""
Get the in- and outneighbors over flow edges.
"""
function inoutflow_ids(graph::MetaGraph, id::NodeID)::Vector{NodeID}
function inoutflow_ids(graph::MetaGraph, id::NodeID)::Iterators.Flatten
return all_neighbor_labels_type(graph, id, EdgeType.flow)
end

Expand Down