Skip to content

Commit

Permalink
Fix autodiff w.r.t. time
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Nov 29, 2023
1 parent a1b848f commit b4434fe
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 33 deletions.
1 change: 1 addition & 0 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ using MetaGraphsNext:
outneighbor_labels,
inneighbor_labels
using OrdinaryDiffEq
using OrdinaryDiffEq: OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
using PreallocationTools: DiffCache, FixedSizeDiffCache, get_tmp
using SciMLBase
using SparseArrays
Expand Down
43 changes: 28 additions & 15 deletions core/src/create.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ function FlowBoundary(db::DB, config::Config)::FlowBoundary
)
end

function Pump(db::DB, config::Config, chunk_size::Int)::Pump
function Pump(db::DB, config::Config, chunk_sizes::Vector{Int})::Pump
static = load_structvector(db, config, PumpStaticV1)
defaults = (; min_flow_rate = 0.0, max_flow_rate = Inf, active = true)
parsed_parameters, valid = parse_static_and_time(db, config, "Pump"; static, defaults)
Expand All @@ -417,7 +417,7 @@ function Pump(db::DB, config::Config, chunk_size::Int)::Pump

# If flow rate is set by PID control, it is part of the AD Jacobian computations
flow_rate = if config.solver.autodiff
DiffCache(parsed_parameters.flow_rate, chunk_size)
DiffCache(parsed_parameters.flow_rate, chunk_sizes)
else
parsed_parameters.flow_rate
end
Expand All @@ -433,7 +433,7 @@ function Pump(db::DB, config::Config, chunk_size::Int)::Pump
)
end

function Outlet(db::DB, config::Config, chunk_size::Int)::Outlet
function Outlet(db::DB, config::Config, chunk_sizes::Vector{Int})::Outlet
static = load_structvector(db, config, OutletStaticV1)
defaults =
(; min_flow_rate = 0.0, max_flow_rate = Inf, min_crest_level = -Inf, active = true)
Expand All @@ -446,7 +446,7 @@ function Outlet(db::DB, config::Config, chunk_size::Int)::Outlet

# If flow rate is set by PID control, it is part of the AD Jacobian computations
flow_rate = if config.solver.autodiff
DiffCache(parsed_parameters.flow_rate, chunk_size)
DiffCache(parsed_parameters.flow_rate, chunk_sizes)
else
parsed_parameters.flow_rate
end
Expand All @@ -468,15 +468,15 @@ function Terminal(db::DB, config::Config)::Terminal
return Terminal(NodeID.(static.node_id))
end

function Basin(db::DB, config::Config, chunk_size::Int)::Basin
function Basin(db::DB, config::Config, chunk_sizes::Vector{Int})::Basin
node_id = get_ids(db, "Basin")
n = length(node_id)
current_level = zeros(n)
current_area = zeros(n)

if config.solver.autodiff
current_level = DiffCache(current_level, chunk_size)
current_area = DiffCache(current_area, chunk_size)
current_level = DiffCache(current_level, chunk_sizes)
current_area = DiffCache(current_area, chunk_sizes)
end

precipitation = fill(NaN, length(node_id))
Expand Down Expand Up @@ -555,7 +555,7 @@ function DiscreteControl(db::DB, config::Config)::DiscreteControl
)
end

function PidControl(db::DB, config::Config, chunk_size::Int)::PidControl
function PidControl(db::DB, config::Config, chunk_sizes::Vector{Int})::PidControl
static = load_structvector(db, config, PidControlStaticV1)
time = load_structvector(db, config, PidControlTimeV1)

Expand All @@ -577,7 +577,7 @@ function PidControl(db::DB, config::Config, chunk_size::Int)::PidControl
pid_error = zeros(length(node_ids))

if config.solver.autodiff
pid_error = DiffCache(pid_error, chunk_size)
pid_error = DiffCache(pid_error, chunk_sizes)
end

# Combine PID parameters into one vector interpolation object
Expand Down Expand Up @@ -775,10 +775,23 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
return Subgrid(basin_ids, interpolations, fill(NaN, length(basin_ids)))
end

"""
Get the chunk sizes for DiffCache; differentiation w.r.t. u
and t (the latter only if a Rosenbrock algorithm is used).
"""
function get_chunk_sizes(config::Config, n_states::Int)::Vector{Int}
chunk_sizes = [pickchunksize(n_states)]
if Ribasim.config.algorithms[config.solver.algorithm] <:
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm
push!(chunk_sizes, 1)

Check warning on line 786 in core/src/create.jl

View check run for this annotation

Codecov / codecov/patch

core/src/create.jl#L786

Added line #L786 was not covered by tests
end
return chunk_sizes
end

function Parameters(db::DB, config::Config)::Parameters
n_states = length(get_ids(db, "Basin")) + length(get_ids(db, "PidControl"))
chunk_size = pickchunksize(n_states)
graph = create_graph(db, config, chunk_size)
chunk_sizes = get_chunk_sizes(config, n_states)
graph = create_graph(db, config, chunk_sizes)
allocation_models = Vector{AllocationModel}()

linear_resistance = LinearResistance(db, config)
Expand All @@ -787,14 +800,14 @@ function Parameters(db::DB, config::Config)::Parameters
fractional_flow = FractionalFlow(db, config)
level_boundary = LevelBoundary(db, config)
flow_boundary = FlowBoundary(db, config)
pump = Pump(db, config, chunk_size)
outlet = Outlet(db, config, chunk_size)
pump = Pump(db, config, chunk_sizes)
outlet = Outlet(db, config, chunk_sizes)
terminal = Terminal(db, config)
discrete_control = DiscreteControl(db, config)
pid_control = PidControl(db, config, chunk_size)
pid_control = PidControl(db, config, chunk_sizes)
user = User(db, config)

basin = Basin(db, config, chunk_size)
basin = Basin(db, config, chunk_sizes)
subgrid_level = Subgrid(db, config, basin)

# Set is_pid_controlled to true for those pumps and outlets that are PID controlled
Expand Down
28 changes: 14 additions & 14 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ function formulate_basins!(
return nothing
end

function set_error!(pid_control::PidControl, p::Parameters, u::ComponentVector, t::Float64)
function set_error!(pid_control::PidControl, p::Parameters, u::ComponentVector, t::Number)
(; basin) = p
(; listen_node_id, target, error) = pid_control
error = get_tmp(error, u)
Expand All @@ -588,7 +588,7 @@ function continuous_control!(
pid_control::PidControl,
p::Parameters,
integral_value::SubArray,
t::Float64,
t::Number,
)::Nothing
(; graph, pump, outlet, basin, fractional_flow) = p
min_flow_rate_pump = pump.min_flow_rate
Expand Down Expand Up @@ -751,7 +751,7 @@ function formulate_flow!(
user::User,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; graph, basin) = p
(; node_id, allocated, demand, active, return_factor, min_level) = user
Expand Down Expand Up @@ -803,7 +803,7 @@ function formulate_flow!(
linear_resistance::LinearResistance,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; graph) = p
(; node_id, active, resistance) = linear_resistance
Expand Down Expand Up @@ -831,7 +831,7 @@ function formulate_flow!(
tabulated_rating_curve::TabulatedRatingCurve,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; basin, graph) = p
(; node_id, active, tables) = tabulated_rating_curve
Expand Down Expand Up @@ -899,7 +899,7 @@ function formulate_flow!(
manning_resistance::ManningResistance,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; basin, graph) = p
(; node_id, active, length, manning_n, profile_width, profile_slope) =
Expand Down Expand Up @@ -954,7 +954,7 @@ function formulate_flow!(
fractional_flow::FractionalFlow,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; graph) = p
(; node_id, fraction) = fractional_flow
Expand All @@ -974,7 +974,7 @@ function formulate_flow!(
terminal::Terminal,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; graph) = p
(; node_id) = terminal
Expand All @@ -992,7 +992,7 @@ function formulate_flow!(
level_boundary::LevelBoundary,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; graph) = p
(; node_id) = level_boundary
Expand All @@ -1014,7 +1014,7 @@ function formulate_flow!(
flow_boundary::FlowBoundary,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; graph) = p
(; node_id, active, flow_rate) = flow_boundary
Expand All @@ -1039,7 +1039,7 @@ function formulate_flow!(
pump::Pump,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; graph, basin) = p
(; node_id, active, flow_rate, is_pid_controlled) = pump
Expand Down Expand Up @@ -1072,7 +1072,7 @@ function formulate_flow!(
outlet::Outlet,
p::Parameters,
storage::AbstractVector,
t::Float64,
t::Number,
)::Nothing
(; graph, basin) = p
(; node_id, active, flow_rate, is_pid_controlled, min_crest_level) = outlet
Expand Down Expand Up @@ -1136,7 +1136,7 @@ function formulate_du!(
return nothing
end

function formulate_flows!(p::Parameters, storage::AbstractVector, t::Float64)::Nothing
function formulate_flows!(p::Parameters, storage::AbstractVector, t::Number)::Nothing
(;
linear_resistance,
manning_resistance,
Expand Down Expand Up @@ -1171,7 +1171,7 @@ function water_balance!(
du::ComponentVector,
u::ComponentVector,
p::Parameters,
t::Float64,
t::Number,
)::Nothing
(; graph, basin, pid_control) = p

Expand Down
8 changes: 4 additions & 4 deletions core/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Return a directed metagraph with data of nodes (NodeMetadata):
and data of edges (EdgeMetadata):
[`EdgeMetadata`](@ref)
"""
function create_graph(db::DB, config::Config, chunk_size::Int)::MetaGraph
function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGraph
node_rows = execute(db, "select fid, type, allocation_network_id from Node")
edge_rows = execute(
db,
Expand Down Expand Up @@ -85,8 +85,8 @@ function create_graph(db::DB, config::Config, chunk_size::Int)::MetaGraph
flow = zeros(flow_counter)
flow_vertical = zeros(flow_vertical_counter)
if config.solver.autodiff
flow = DiffCache(flow, chunk_size)
flow_vertical = DiffCache(flow_vertical, chunk_size)
flow = DiffCache(flow, chunk_sizes)
flow_vertical = DiffCache(flow_vertical, chunk_sizes)
end
graph_data = (;
node_ids,
Expand Down Expand Up @@ -669,7 +669,7 @@ storage: tells ForwardDiff whether this call is for differentiation or not
function get_level(
p::Parameters,
node_id::NodeID,
t::Float64;
t::Number;
storage::Union{AbstractArray, Number} = 0,
)::Union{Real, Nothing}
(; basin, level_boundary) = p
Expand Down

0 comments on commit b4434fe

Please sign in to comment.