Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save all results via saving callback (follow up) #1896

Merged
merged 8 commits into from
Oct 10, 2024
4 changes: 2 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.5"
manifest_format = "2.0"
project_hash = "0257f2772e4bfed0b6316b6871c4495978c173b4"
project_hash = "a2d1a982c3293971ae40e0a7bdfb40a85bd30ac1"

[[deps.ADTypes]]
git-tree-sha1 = "eea5d80188827b35333801ef97a40c2ed653b081"
Expand Down Expand Up @@ -1346,7 +1346,7 @@ uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.6.0"

[[deps.Ribasim]]
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqBase", "DiffEqCallbacks", "EnumX", "FiniteDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LineSearches", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEqBDF", "OrdinaryDiffEqCore", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqTsit5", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"]
deps = ["Accessors", "Arrow", "BasicModelInterface", "CodecZstd", "ComponentArrays", "Configurations", "DBInterface", "DataInterpolations", "DataStructures", "Dates", "DiffEqBase", "DiffEqCallbacks", "EnumX", "FiniteDiff", "Graphs", "HiGHS", "IterTools", "JuMP", "Legolas", "LineSearches", "LinearAlgebra", "LinearSolve", "Logging", "LoggingExtras", "MetaGraphsNext", "OrdinaryDiffEqBDF", "OrdinaryDiffEqCore", "OrdinaryDiffEqLowOrderRK", "OrdinaryDiffEqNonlinearSolve", "OrdinaryDiffEqRosenbrock", "OrdinaryDiffEqSDIRK", "OrdinaryDiffEqTsit5", "PreallocationTools", "SQLite", "SciMLBase", "SparseArrays", "SparseConnectivityTracer", "StructArrays", "Tables", "TerminalLoggers", "TranscodingStreams"]
path = "core"
uuid = "aac5e3d9-0b8f-4d4f-8241-b1a7a9632635"
version = "2024.11.0"
Expand Down
2 changes: 2 additions & 0 deletions core/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Expand Down Expand Up @@ -77,6 +78,7 @@ IOCapture = "0.2"
IterTools = "1.4"
JuMP = "1.15"
Legolas = "0.5"
LinearAlgebra = "1"
LineSearches = "7"
LinearSolve = "2.24"
Logging = "<0.0.1, 1"
Expand Down
8 changes: 7 additions & 1 deletion core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ using SciMLBase:
# through operator overloading
using SparseConnectivityTracer: TracerSparsityDetector, jacobian_sparsity, GradientTracer

# For efficient sparse computations
using SparseArrays: SparseMatrixCSC, spzeros

# Linear algebra
using LinearAlgebra: mul!

# PreallocationTools is used because the RHS function (water_balance!) gets called with different input types
# for u, du:
# - Float64 for normal calls
Expand Down Expand Up @@ -92,7 +98,7 @@ using TerminalLoggers: TerminalLogger
# Convenience wrapper around arrays, divides vectors in
# separate sections which can be indexed individually.
# Used for e.g. Basin forcing and the state vector.
using ComponentArrays: ComponentVector, Axis
using ComponentArrays: ComponentVector, ComponentArray, Axis, getaxes

# Date and time handling; externally we use the proleptic Gregorian calendar,
# internally we use a Float64; seconds since the start of the simulation.
Expand Down
9 changes: 6 additions & 3 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,10 @@ function check_water_balance_error(
(; basin, water_balance_abstol, water_balance_reltol) = p
errors = false
current_storage = basin.current_storage[parent(u)]
formulate_storages!(current_storage, u, u, p, t)

# The initial storage is irrelevant for the storage rate and can only cause
# floating point truncation errors
formulate_storages!(current_storage, u, u, p, t; add_initial_storage = false)

for (
inflow_rate,
Expand All @@ -289,7 +292,7 @@ function check_water_balance_error(
saved_flow.flow.evaporation,
saved_flow.flow.infiltration,
current_storage,
basin.storage_prev_saveat,
basin.Δstorage_prev_saveat,
basin.node_id,
)
storage_rate = (s_now - s_prev) / Δt
Expand All @@ -310,7 +313,7 @@ function check_water_balance_error(
error("Too large water balance error(s) detected at t = $t")
end

@. basin.storage_prev_saveat = current_storage
@. basin.Δstorage_prev_saveat = current_storage
return nothing
end

Expand Down
1 change: 1 addition & 0 deletions core/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ function Model(config::Config)::Model
du0 = zero(u0)

parameters = set_state_flow_edges(parameters, u0)
parameters = build_flow_to_storage(parameters, u0)
parameters = @set parameters.u_prev_saveat = zero(u0)

# The Solver algorithm
Expand Down
5 changes: 4 additions & 1 deletion core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ end
vertical_flux::V = zeros(length(node_id))
# Initial_storage
storage0::Vector{Float64} = zeros(length(node_id))
storage_prev_saveat::Vector{Float64} = zeros(length(node_id))
# Storage at previous saveat without storage0
Δstorage_prev_saveat::Vector{Float64} = zeros(length(node_id))
# Analytically integrated forcings
cumulative_precipitation::Vector{Float64} = zeros(length(node_id))
cumulative_drainage::Vector{Float64} = zeros(length(node_id))
Expand Down Expand Up @@ -840,6 +841,8 @@ const ModelGraph = MetaGraph{
state_outflow_edge::C4 = ComponentVector()
all_nodes_active::Base.RefValue{Bool} = Ref(false)
tprev::Base.RefValue{Float64} = Ref(0.0)
# Sparse matrix for combining flows into storages
flow_to_storage::SparseMatrixCSC{Float64, Int64} = spzeros(1, 1)
# Water balance tolerances
water_balance_abstol::Float64
water_balance_reltol::Float64
Expand Down
1 change: 0 additions & 1 deletion core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,6 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
storage0 = get_storages_from_levels(basin, state.level)
@assert length(storage0) == n "Basin / state length differs from number of Basins"
basin.storage0 .= storage0
basin.storage_prev_saveat .= storage0
return basin
end

Expand Down
70 changes: 12 additions & 58 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,86 +118,40 @@ function formulate_storages!(
du::ComponentVector,
u::ComponentVector,
p::Parameters,
t::Number,
t::Number;
add_initial_storage::Bool = true,
)::Nothing
(;
basin,
flow_boundary,
tabulated_rating_curve,
pump,
outlet,
linear_resistance,
manning_resistance,
user_demand,
tprev,
) = p
# Current storage: initial conditdion +
(; basin, flow_boundary, tprev, flow_to_storage) = p
# Current storage: initial condition +
# total inflows and outflows since the start
# of the simulation
current_storage .= basin.storage0
formulate_storage!(current_storage, basin, du, u)
if add_initial_storage
current_storage .= basin.storage0
else
current_storage .= 0.0
end
mul!(current_storage, flow_to_storage, u, 1, 1)
formulate_storage!(current_storage, basin, du)
formulate_storage!(current_storage, tprev[], t, flow_boundary)
formulate_storage!(current_storage, t, u.tabulated_rating_curve, tabulated_rating_curve)
formulate_storage!(current_storage, t, u.pump, pump)
formulate_storage!(current_storage, t, u.outlet, outlet)
formulate_storage!(current_storage, t, u.linear_resistance, linear_resistance)
formulate_storage!(current_storage, t, u.manning_resistance, manning_resistance)
formulate_storage!(
current_storage,
t,
u.user_demand_inflow,
user_demand;
edge_volume_out = u.user_demand_outflow,
)
return nothing
end

"""
The storage contributions of the forcings that are part of the state.
The storage contributions of the forcings that are not part of the state.
"""
function formulate_storage!(
current_storage::AbstractVector,
basin::Basin,
du::ComponentVector,
u::ComponentVector,
)
(; current_cumulative_precipitation, current_cumulative_drainage) = basin

current_storage .-= u.evaporation
current_storage .-= u.infiltration

current_cumulative_precipitation = current_cumulative_precipitation[parent(du)]
current_cumulative_drainage = current_cumulative_drainage[parent(du)]
current_storage .+= current_cumulative_precipitation
current_storage .+= current_cumulative_drainage
end

"""
Formulate storage contributions of nodes.
"""
function formulate_storage!(
current_storage::AbstractVector,
t::Number,
edge_volume_in::AbstractVector,
node::AbstractParameterNode;
edge_volume_out = nothing,
)
edge_volume_out = isnothing(edge_volume_out) ? edge_volume_in : edge_volume_out

for (volume_in, volume_out, inflow_edge, outflow_edge) in
zip(edge_volume_in, edge_volume_out, node.inflow_edge, node.outflow_edge)
inflow_id = inflow_edge.edge[1]
if inflow_id.type == NodeType.Basin
current_storage[inflow_id.idx] -= volume_in
end

outflow_id = outflow_edge.edge[2]
if outflow_id.type == NodeType.Basin
current_storage[outflow_id.idx] += volume_out
end
end
end

"""
Formulate storage contributions of flow boundaries.
"""
Expand Down
55 changes: 53 additions & 2 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ function get_variable_ref(
variable::String;
listen::Bool = true,
)::Tuple{PreallocationRef, Bool}
(; basin, graph) = p
(; basin) = p
errors = false

# Only built here because it is needed to obtain indices
Expand Down Expand Up @@ -840,7 +840,8 @@ but the derivative is bounded at x = 0.
"""
function relaxed_root(x, threshold)
if abs(x) < threshold
1 / 4 * (x / sqrt(threshold)) * (5 - (x / threshold)^2)
x_scaled = x / threshold
sqrt(threshold) * x_scaled^3 * (9 - 5x_scaled^2) / 4
else
sign(x) * sqrt(abs(x))
end
Expand Down Expand Up @@ -953,6 +954,56 @@ function build_state_vector(p::Parameters)
)
end

function build_flow_to_storage(p::Parameters, u::ComponentVector)::Parameters
n_basins = length(p.basin.node_id)
n_states = length(u)
flow_to_storage = ComponentArray(
spzeros(n_basins, n_states),
(Axis(; basins = 1:n_basins), only(getaxes(u))),
)

for node_name in (
:tabulated_rating_curve,
:pump,
:outlet,
:linear_resistance,
:manning_resistance,
:user_demand,
)
node = getfield(p, node_name)

if node_name == :user_demand
flow_to_storage_node_inflow = view(flow_to_storage, :, :user_demand_inflow)
flow_to_storage_node_outflow = view(flow_to_storage, :, :user_demand_outflow)
else
flow_to_storage_node_inflow = view(flow_to_storage, :, node_name)
flow_to_storage_node_outflow = flow_to_storage_node_inflow
end

for (inflow_edge, outflow_edge) in zip(node.inflow_edge, node.outflow_edge)
inflow_id, node_id = inflow_edge.edge
if inflow_id.type == NodeType.Basin
flow_to_storage_node_inflow[inflow_id.idx, node_id.idx] = -1.0
end

outflow_id = outflow_edge.edge[2]
if outflow_id.type == NodeType.Basin
flow_to_storage_node_outflow[outflow_id.idx, node_id.idx] = 1.0
end
end
end

flow_to_storage_evaporation = view(flow_to_storage, :, :evaporation)
flow_to_storage_infiltration = view(flow_to_storage, :, :infiltration)

for i in 1:n_basins
flow_to_storage_evaporation[i, i] = -1.0
flow_to_storage_infiltration[i, i] = -1.0
end

@set p.flow_to_storage = parent(flow_to_storage)
end

"""
Create vectors state_inflow_edge and state_outflow_edge which give for each state
in the state vector in order the metadata of the edge that is associated with that state.
Expand Down
4 changes: 2 additions & 2 deletions core/test/run_models_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ end
@test successful_retcode(model)
@test length(model.integrator.sol) == 2 # start and end
@test model.integrator.p.basin.current_storage[Float64[]] ≈
Float32[803.7093, 803.68274, 495.241, 1318.3053] skip = Sys.isapple() atol = 1.5
Float32[828.5386, 801.88289, 492.290, 1318.3053] skip = Sys.isapple() atol = 1.5

@test length(logger.logs) > 10
@test logger.logs[1].level == Debug
Expand Down Expand Up @@ -242,7 +242,7 @@ end
precipitation = model.integrator.p.basin.vertical_flux.precipitation
@test length(precipitation) == 4
@test model.integrator.p.basin.current_storage[parent(du)] ≈
Float32[697.30591, 697.2799, 419.19034, 1334.3859] atol = 2.0 skip = Sys.isapple()
Float32[720.23611, 694.8785, 415.22371, 1334.3859] atol = 2.0 skip = Sys.isapple()
end

@testitem "Allocation example model" begin
Expand Down
2 changes: 2 additions & 0 deletions core/test/utils_test.jl
SouthEndMusic marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ end
t0 = 0.0
u0 = Ribasim.build_state_vector(p)
du0 = copy(u0)
p = Ribasim.build_flow_to_storage(p, u0)
jac_prototype = Ribasim.get_jac_prototype(du0, u0, p, t0)

# rows, cols, _ = findnz(jac_prototype)
Expand All @@ -195,6 +196,7 @@ end
close(db)
u0 = Ribasim.build_state_vector(p)
du0 = copy(u0)
p = Ribasim.build_flow_to_storage(p, u0)
jac_prototype = Ribasim.get_jac_prototype(du0, u0, p, t0)

#! format: off
Expand Down
8 changes: 5 additions & 3 deletions docs/reference/node/manning-resistance.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,10 @@ where $s$ is a relaxed square root function:

$$
s(x; x_0)
=
\begin{align}
\begin{cases}
\frac{x}{4\sqrt{x_0}}\left(5-\left(\frac{x}{x_0}\right)^2\right) &\text{ if } |x| < x_0 \\
\frac{\sqrt{x_0}}{4}\left(\frac{x}{p}\right)^3\left(9 - 5\left(\frac{x}{p}\right)^2\right) &\text{ if } |x| < x_0 \\
\textrm{sign}(x)\sqrt{|x|} &\text{ if } |x| \ge x_0
\end{cases}
\end{align}
Expand All @@ -151,7 +152,8 @@ import matplotlib.pyplot as plt

def s(x, threshold):
if np.abs(x) < threshold:
return 1/4 * (x / np.sqrt(threshold)) * (5.0 - (x / threshold)**2)
x_scaled = x / threshold
return np.sqrt(threshold) * x_scaled**3 * (9 - 5*x_scaled**2) / 4
else:
return np.sign(x)*np.sqrt(np.abs(x))

Expand All @@ -165,7 +167,7 @@ y_s = [s(x_, threshold) for x_ in x]

ax.plot(x, y_o, ls = ":", label = r"sign$(x)\sqrt{|x|}$")
ax.plot(x, y_s, color = "C0", label = r"$s\left(x; 10^{-3}\right)$")
ax.legend()
ax.legend();
```

:::{.callout-note}
Expand Down