Skip to content

Commit

Permalink
Call valid_n_neighbors earlier
Browse files Browse the repository at this point in the history
Because after Parameter construction is too late now, since we cannot put wrong neighbors in the structs anymore. So instead do the validation on the graph.
  • Loading branch information
visr committed May 3, 2024
1 parent 8a8ab10 commit 43a2378
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 31 deletions.
3 changes: 2 additions & 1 deletion core/src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ using OrdinaryDiffEq:
Tsit5

export Config, Solver, Results, Logging, Toml
export algorithm, snake_case, input_path, results_path, convert_saveat, convert_dt
export algorithm,
snake_case, input_path, results_path, convert_saveat, convert_dt, nodetypes

const schemas =
getfield.(
Expand Down
7 changes: 3 additions & 4 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,9 @@ function Parameters(db::DB, config::Config)::Parameters
if !valid_edges(graph)
error("Invalid edge(s) found.")
end
if !valid_n_neighbors(graph)
error("Invalid number of connections for certain node types.")
end

linear_resistance = LinearResistance(db, config, graph)
manning_resistance = ManningResistance(db, config, graph)
Expand Down Expand Up @@ -1098,10 +1101,6 @@ function Parameters(db::DB, config::Config)::Parameters

set_is_pid_controlled!(p)

if !valid_n_neighbors(p)
error("Invalid number of connections for certain node types.")
end

# Allocation data structures
if config.allocation.use_allocation
initialize_allocation!(p, config)
Expand Down
6 changes: 0 additions & 6 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -461,12 +461,6 @@ function expand_logic_mapping(
return logic_mapping_expanded
end

"""Get all node fieldnames of the parameter object."""
nodefields(p::Parameters) = (
name for
name in fieldnames(typeof(p)) if fieldtype(typeof(p), name) <: AbstractParameterNode
)

"""
Get the node type specific indices of the fractional flows and basins,
that are consecutively connected to a node of given id.
Expand Down
33 changes: 16 additions & 17 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,50 +429,49 @@ end
Test for each node given its node type whether it has an allowed
number of flow/control inneighbors and outneighbors
"""
function valid_n_neighbors(p::Parameters)::Bool
(; graph) = p

function valid_n_neighbors(graph::MetaGraph)::Bool
errors = false

for nodefield in nodefields(p)
errors |= !valid_n_neighbors(getfield(p, nodefield), graph)
for nodetype in nodetypes
errors |= !valid_n_neighbors(nodetype, graph)
end

return !errors
end

function valid_n_neighbors(node::AbstractParameterNode, graph::MetaGraph)::Bool
node_type = typeof(node)
node_name = nameof(node_type)

function valid_n_neighbors(node_name::Symbol, graph::MetaGraph)::Bool
node_type = NodeType.T(node_name)
bounds_flow = n_neighbor_bounds_flow(node_name)
bounds_control = n_neighbor_bounds_control(node_name)

errors = false

for id in node.node_id
# return !errors
for node_id in labels(graph)
node_id.type == node_type || continue
for (bounds, edge_type) in
zip((bounds_flow, bounds_control), (EdgeType.flow, EdgeType.control))
n_inneighbors = count(x -> true, inneighbor_labels_type(graph, id, edge_type))
n_outneighbors = count(x -> true, outneighbor_labels_type(graph, id, edge_type))
n_inneighbors =
count(x -> true, inneighbor_labels_type(graph, node_id, edge_type))
n_outneighbors =
count(x -> true, outneighbor_labels_type(graph, node_id, edge_type))

if n_inneighbors < bounds.in_min
@error "$id must have at least $(bounds.in_min) $edge_type inneighbor(s) (got $n_inneighbors)."
@error "$node_id must have at least $(bounds.in_min) $edge_type inneighbor(s) (got $n_inneighbors)."
errors = true
end

if n_inneighbors > bounds.in_max
@error "$id can have at most $(bounds.in_max) $edge_type inneighbor(s) (got $n_inneighbors)."
@error "$node_id can have at most $(bounds.in_max) $edge_type inneighbor(s) (got $n_inneighbors)."
errors = true
end

if n_outneighbors < bounds.out_min
@error "$id must have at least $(bounds.out_min) $edge_type outneighbor(s) (got $n_outneighbors)."
@error "$node_id must have at least $(bounds.out_min) $edge_type outneighbor(s) (got $n_outneighbors)."
errors = true
end

if n_outneighbors > bounds.out_max
@error "$id can have at most $(bounds.out_max) $edge_type outneighbor(s) (got $n_outneighbors)."
@error "$node_id can have at most $(bounds.out_max) $edge_type outneighbor(s) (got $n_outneighbors)."
errors = true
end
end
Expand Down
38 changes: 35 additions & 3 deletions core/test/utils_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,46 @@ end
@test ispath(toml_path)
model = Ribasim.Model(toml_path)
(; p) = model.integrator
constraining_types = (:pump, :outlet, :linear_resistance)
constraining_types = (:Pump, :Outlet, :LinearResistance)

for type in Ribasim.nodefields(p)
node = getfield(p, type)
for type in Ribasim.nodetypes
type == :Terminal && continue # has no parameter field
node = getfield(p, snake_case(type))
if type in constraining_types
@test Ribasim.is_flow_constraining(node)
else
@test !Ribasim.is_flow_constraining(node)
end
end
end

@testitem "Node types" begin
using Ribasim: nodetypes, NodeType, Parameters, AbstractParameterNode

@test Set(nodetypes) == Set([
:Terminal,
:PidControl,
:LevelBoundary,
:Pump,
:UserDemand,
:TabulatedRatingCurve,
:FlowDemand,
:FlowBoundary,
:Basin,
:ManningResistance,
:LevelDemand,
:DiscreteControl,
:Outlet,
:LinearResistance,
:FractionalFlow,
])
for nodetype in nodetypes
NodeType.T(nodetype)
if nodetype != :Terminal
# It has a struct which is added to Parameters
T = getproperty(Ribasim, nodetype)
@test T <: AbstractParameterNode
@test hasfield(Parameters, snake_case(nodetype))
end
end
end

0 comments on commit 43a2378

Please sign in to comment.