Skip to content

Commit

Permalink
Fix problem where edge_data dict of graph makes node ids in flow tabl…
Browse files Browse the repository at this point in the history
…e non-deterministic
  • Loading branch information
SouthEndMusic committed Sep 18, 2024
1 parent 6d08366 commit a622e46
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 37 deletions.
50 changes: 23 additions & 27 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,44 +263,41 @@ function save_flow(u, t, integrator)
precipitation,
drainage,
)
check_water_balance_error(p, saved_flow, t, Δt, u)
check_water_balance_error(integrator, saved_flow, Δt)
return saved_flow
end

function check_water_balance_error(
p::Parameters,
integrator::DEIntegrator,
saved_flow::SavedFlow,
t::Float64,
Δt::Float64,
u::ComponentVector,
)::Nothing
(; u, p, t) = integrator
(; basin, water_balance_abstol, water_balance_reltol) = p
errors = false
current_storage = basin.current_storage[parent(u)]
formulate_storages!(current_storage, u, u, p, t)

for (
i,
(
inflow_rate,
outflow_rate,
precipitation,
drainage,
evaporation,
infiltration,
s_now,
s_prev,
),
) in enumerate(
zip(
saved_flow.inflow,
saved_flow.outflow,
saved_flow.precipitation,
saved_flow.drainage,
saved_flow.flow.evaporation,
saved_flow.flow.infiltration,
current_storage,
basin.storage_prev_saveat,
),
inflow_rate,
outflow_rate,
precipitation,
drainage,
evaporation,
infiltration,
s_now,
s_prev,
id,
) in zip(
saved_flow.inflow,
saved_flow.outflow,
saved_flow.precipitation,
saved_flow.drainage,
saved_flow.flow.evaporation,
saved_flow.flow.infiltration,
current_storage,
basin.storage_prev_saveat,
basin.node_id,
)
storage_rate = (s_now - s_prev) / Δt
total_in = inflow_rate + precipitation + drainage
Expand All @@ -312,7 +309,6 @@ function check_water_balance_error(
if abs(balance_error) > water_balance_abstol &&
abs(relative_error) > water_balance_reltol
errors = true
id = id_from_state_index(p, saved_flow.flow, i)
@error "Too large water balance error" id balance_error relative_error
end
end
Expand Down
4 changes: 2 additions & 2 deletions core/src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ const nodetypes = collect(keys(nodekinds))
force_dtmin::Bool = false
abstol::Float64 = 1e-6
reltol::Float64 = 1e-5
water_balance_abstol::Float64 = 1e-6
water_balance_reltol::Float64 = 1e-6
water_balance_abstol::Float64 = 1e-3
water_balance_reltol::Float64 = 1e-2
maxiters::Int = 1e9
sparse::Bool = true
autodiff::Bool = true
Expand Down
11 changes: 6 additions & 5 deletions core/src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ function create_graph(db::DB, config::Config)::MetaGraph
node_ids = Dict{Int32, Set{NodeID}}()
# Source edges per subnetwork
edges_source = Dict{Int32, Set{EdgeMetadata}}()
# The flow counter gives a unique consecutive id to the
# flow edges to index the flow vectors
flow_counter = 0
# The metadata of the flow edges in the order in which they are in the input
# and will be in the output
flow_edges = EdgeMetadata[]
# Dictionary from flow edge to index in flow vector
graph = MetaGraph(
DiGraph();
Expand Down Expand Up @@ -79,11 +79,13 @@ function create_graph(db::DB, config::Config)::MetaGraph
end
edge_metadata = EdgeMetadata(;
id = edge_id,
flow_idx = edge_type == EdgeType.flow ? flow_counter + 1 : 0,
type = edge_type,
subnetwork_id_source = subnetwork_id,
edge = (id_src, id_dst),
)
if edge_type == EdgeType.flow
push!(flow_edges, edge_metadata)
end
if haskey(graph, id_src, id_dst)
errors = true
@error "Duplicate edge" id_src id_dst
Expand All @@ -104,7 +106,6 @@ function create_graph(db::DB, config::Config)::MetaGraph
error("Incomplete connectivity in subnetwork")
end

flow_edges = [edge for edge in values(graph.edge_data) if edge.type == EdgeType.flow]
graph_data = (; node_ids, edges_source, flow_edges, config.solver.saveat)
graph = @set graph.graph_data = graph_data

Expand Down
2 changes: 0 additions & 2 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,13 @@ end
"""
Type for storing metadata of edges in the graph:
id: ID of the edge (only used for labeling flow output)
flow_idx: Index in the vector of flows
type: type of the edge
subnetwork_id_source: ID of subnetwork where this edge is a source
(0 if not a source)
edge: (from node ID, to node ID)
"""
@kwdef struct EdgeMetadata
id::Int32
flow_idx::Int
type::EdgeType.T
subnetwork_id_source::Int32
edge::Tuple{NodeID, NodeID}
Expand Down
5 changes: 4 additions & 1 deletion core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,6 @@ function set_state_flow_edges(p::Parameters, u0::ComponentVector)::Parameters
state_outflow_edges = Vector{EdgeMetadata}[]

placeholder_edge = EdgeMetadata(
0,
0,
EdgeType.flow,
0,
Expand Down Expand Up @@ -1042,6 +1041,10 @@ function id_from_state_index(
component = :basin
end

@show global_idx
@show NT
@show local_idx

getfield(p, component).node_id[local_idx]
end

Expand Down

0 comments on commit a622e46

Please sign in to comment.