Skip to content

Commit

Permalink
Write output table with solver stats per saveat (#1677)
Browse files Browse the repository at this point in the history
Fixes #1674.

Not sure how to test this?
  • Loading branch information
SouthEndMusic authored Aug 6, 2024
1 parent 89e1f99 commit a58f317
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 2 deletions.
25 changes: 23 additions & 2 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
Create the different callbacks that are used to store results
and feed the simulation with new data. The different callbacks
Expand Down Expand Up @@ -44,6 +43,12 @@ function create_callbacks(
save_flow_cb = SavingCallback(save_flow, saved_flow; saveat, save_start = false)
push!(callbacks, save_flow_cb)

# save solver stats
saved_solver_stats = SavedValues(Float64, SolverStats)
solver_stats_cb =
SavingCallback(save_solver_stats, saved_solver_stats; saveat, save_start = true)
push!(callbacks, solver_stats_cb)

# interpolate the levels
saved_subgrid_level = SavedValues(Float64, Vector{Float64})
if config.results.subgrid
Expand All @@ -59,12 +64,28 @@ function create_callbacks(
discrete_control_cb = FunctionCallingCallback(apply_discrete_control!)
push!(callbacks, discrete_control_cb)

saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level)
saved = SavedResults(
saved_flow,
saved_vertical_flux,
saved_subgrid_level,
saved_solver_stats,
)
callback = CallbackSet(callbacks...)

return callback, saved
end

function save_solver_stats(u, t, integrator)
(; stats) = integrator.sol
(;
time = t,
rhs_calls = stats.nf,
linear_solves = stats.nsolve,
accepted_timesteps = stats.naccept,
rejected_timesteps = stats.nreject,
)
end

function check_negative_storage(u, t, integrator)::Nothing
(; basin) = integrator.p
(; node_id) = basin
Expand Down
1 change: 1 addition & 0 deletions core/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ struct SavedResults{V1 <: ComponentVector{Float64}}
flow::SavedValues{Float64, SavedFlow}
vertical_flux::SavedValues{Float64, V1}
subgrid_level::SavedValues{Float64, Vector{Float64}}
solver_stats::SavedValues{Float64, SolverStats}
end

"""
Expand Down
8 changes: 8 additions & 0 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
const SolverStats = @NamedTuple{
time::Float64,
rhs_calls::Int,
linear_solves::Int,
accepted_timesteps::Int,
rejected_timesteps::Int,
}

# EdgeType.flow and NodeType.FlowBoundary
@enumx EdgeType flow control none
@eval @enumx NodeType $(config.nodetypes...)
Expand Down
28 changes: 28 additions & 0 deletions core/src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ function write_results(model::Model)::Model
path = results_path(config, RESULTS_FILENAME.subgrid_level)
write_arrow(path, table, compress; remove_empty_table)

# solver stats
table = solver_stats_table(model)
path = results_path(config, RESULTS_FILENAME.solver_stats)
write_arrow(path, table, compress; remove_empty_table)

@debug "Wrote results."
return model
end
Expand All @@ -56,6 +61,7 @@ const RESULTS_FILENAME = (
allocation = "allocation.arrow",
allocation_flow = "allocation_flow.arrow",
subgrid_level = "subgrid_level.arrow",
solver_stats = "solver_stats.arrow",
)

"Get the storage and level of all basins as matrices of nbasin × ntime"
Expand Down Expand Up @@ -183,6 +189,28 @@ function basin_table(
)
end

function solver_stats_table(
model::Model,
)::@NamedTuple{
time::Vector{DateTime},
rhs_calls::Vector{Int},
linear_solves::Vector{Int},
accepted_timesteps::Vector{Int},
rejected_timesteps::Vector{Int},
}
solver_stats = StructVector(model.saved.solver_stats.saveval)
(;
time = datetime_since.(
solver_stats.time[1:(end - 1)],
model.integrator.p.starttime,
),
rhs_calls = diff(solver_stats.rhs_calls),
linear_solves = diff(solver_stats.linear_solves),
accepted_timesteps = diff(solver_stats.accepted_timesteps),
rejected_timesteps = diff(solver_stats.rejected_timesteps),
)
end

"Create a flow result table from the saved data"
function flow_table(
model::Model,
Expand Down
6 changes: 6 additions & 0 deletions core/test/run_models_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
flow_bytes = read(normpath(dirname(toml_path), "results/flow.arrow"))
basin_bytes = read(normpath(dirname(toml_path), "results/basin.arrow"))
subgrid_bytes = read(normpath(dirname(toml_path), "results/subgrid_level.arrow"))
solver_stats_bytes = read(normpath(dirname(toml_path), "results/solver_stats.arrow"))

flow = Arrow.Table(flow_bytes)
basin = Arrow.Table(basin_bytes)
subgrid = Arrow.Table(subgrid_bytes)
solver_stats = Arrow.Table(solver_stats_bytes)

@testset "Schema" begin
@test Tables.schema(flow) == Tables.Schema(
Expand Down Expand Up @@ -83,6 +85,10 @@
(:time, :subgrid_id, :subgrid_level),
(DateTime, Int32, Float64),
)
@test Tables.schema(solver_stats) == Tables.Schema(
(:time, :rhs_calls, :linear_solves, :accepted_timesteps, :rejected_timesteps),
(DateTime, Int, Int, Int, Int),
)
end

@testset "Results size" begin
Expand Down
12 changes: 12 additions & 0 deletions docs/reference/usage.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,15 @@ column | type
time | DateTime
subgrid_id | Int32
subgrid_level | Float64

## Solver statistics - `solver_stats.arrow`

This result file contains statistics about the solver, which can give an insight into how well the solver is performing over time. The data is solved by `saveat` (see [configuration file](#configuration-file)). `water_balance` refers to the right-hand-side function of the system of differential equations solved by the Ribasim core.

column | type
--------------------| -----
time | DateTime
water_balance_calls | Int
linear_solves | Int
accepted_timesteps | Int
rejected_timesteps | Int

0 comments on commit a58f317

Please sign in to comment.