diff --git a/core/src/Ribasim.jl b/core/src/Ribasim.jl index 0827e4a87..956cb7f71 100644 --- a/core/src/Ribasim.jl +++ b/core/src/Ribasim.jl @@ -57,6 +57,7 @@ using MetaGraphsNext: outneighbor_labels, inneighbor_labels using OrdinaryDiffEq +using OrdinaryDiffEq: OrdinaryDiffEqRosenbrockAdaptiveAlgorithm using PreallocationTools: DiffCache, FixedSizeDiffCache, get_tmp using SciMLBase using SparseArrays diff --git a/core/src/create.jl b/core/src/create.jl index c038d2f01..47620ee24 100644 --- a/core/src/create.jl +++ b/core/src/create.jl @@ -405,7 +405,7 @@ function FlowBoundary(db::DB, config::Config)::FlowBoundary ) end -function Pump(db::DB, config::Config, chunk_size::Int)::Pump +function Pump(db::DB, config::Config, chunk_sizes::Vector{Int})::Pump static = load_structvector(db, config, PumpStaticV1) defaults = (; min_flow_rate = 0.0, max_flow_rate = Inf, active = true) parsed_parameters, valid = parse_static_and_time(db, config, "Pump"; static, defaults) @@ -417,7 +417,7 @@ function Pump(db::DB, config::Config, chunk_size::Int)::Pump # If flow rate is set by PID control, it is part of the AD Jacobian computations flow_rate = if config.solver.autodiff - DiffCache(parsed_parameters.flow_rate, chunk_size) + DiffCache(parsed_parameters.flow_rate, chunk_sizes) else parsed_parameters.flow_rate end @@ -433,7 +433,7 @@ function Pump(db::DB, config::Config, chunk_size::Int)::Pump ) end -function Outlet(db::DB, config::Config, chunk_size::Int)::Outlet +function Outlet(db::DB, config::Config, chunk_sizes::Vector{Int})::Outlet static = load_structvector(db, config, OutletStaticV1) defaults = (; min_flow_rate = 0.0, max_flow_rate = Inf, min_crest_level = -Inf, active = true) @@ -446,7 +446,7 @@ function Outlet(db::DB, config::Config, chunk_size::Int)::Outlet # If flow rate is set by PID control, it is part of the AD Jacobian computations flow_rate = if config.solver.autodiff - DiffCache(parsed_parameters.flow_rate, chunk_size) + DiffCache(parsed_parameters.flow_rate, chunk_sizes) else parsed_parameters.flow_rate end @@ -468,15 +468,15 @@ function Terminal(db::DB, config::Config)::Terminal return Terminal(NodeID.(static.node_id)) end -function Basin(db::DB, config::Config, chunk_size::Int)::Basin +function Basin(db::DB, config::Config, chunk_sizes::Vector{Int})::Basin node_id = get_ids(db, "Basin") n = length(node_id) current_level = zeros(n) current_area = zeros(n) if config.solver.autodiff - current_level = DiffCache(current_level, chunk_size) - current_area = DiffCache(current_area, chunk_size) + current_level = DiffCache(current_level, chunk_sizes) + current_area = DiffCache(current_area, chunk_sizes) end precipitation = fill(NaN, length(node_id)) @@ -555,7 +555,7 @@ function DiscreteControl(db::DB, config::Config)::DiscreteControl ) end -function PidControl(db::DB, config::Config, chunk_size::Int)::PidControl +function PidControl(db::DB, config::Config, chunk_sizes::Vector{Int})::PidControl static = load_structvector(db, config, PidControlStaticV1) time = load_structvector(db, config, PidControlTimeV1) @@ -577,7 +577,7 @@ function PidControl(db::DB, config::Config, chunk_size::Int)::PidControl pid_error = zeros(length(node_ids)) if config.solver.autodiff - pid_error = DiffCache(pid_error, chunk_size) + pid_error = DiffCache(pid_error, chunk_sizes) end # Combine PID parameters into one vector interpolation object @@ -775,10 +775,23 @@ function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid return Subgrid(basin_ids, interpolations, fill(NaN, length(basin_ids))) end +""" +Get the chunk sizes for DiffCache; differentiation w.r.t. u +and t (the latter only if a Rosenbrock algorithm is used). +""" +function get_chunk_sizes(config::Config, n_states::Int)::Vector{Int} + chunk_sizes = [pickchunksize(n_states)] + if Ribasim.config.algorithms[config.solver.algorithm] <: + OrdinaryDiffEqRosenbrockAdaptiveAlgorithm + push!(chunk_sizes, 1) + end + return chunk_sizes +end + function Parameters(db::DB, config::Config)::Parameters n_states = length(get_ids(db, "Basin")) + length(get_ids(db, "PidControl")) - chunk_size = pickchunksize(n_states) - graph = create_graph(db, config, chunk_size) + chunk_sizes = get_chunk_sizes(config, n_states) + graph = create_graph(db, config, chunk_sizes) allocation_models = Vector{AllocationModel}() linear_resistance = LinearResistance(db, config) @@ -787,14 +800,14 @@ function Parameters(db::DB, config::Config)::Parameters fractional_flow = FractionalFlow(db, config) level_boundary = LevelBoundary(db, config) flow_boundary = FlowBoundary(db, config) - pump = Pump(db, config, chunk_size) - outlet = Outlet(db, config, chunk_size) + pump = Pump(db, config, chunk_sizes) + outlet = Outlet(db, config, chunk_sizes) terminal = Terminal(db, config) discrete_control = DiscreteControl(db, config) - pid_control = PidControl(db, config, chunk_size) + pid_control = PidControl(db, config, chunk_sizes) user = User(db, config) - basin = Basin(db, config, chunk_size) + basin = Basin(db, config, chunk_sizes) subgrid_level = Subgrid(db, config, basin) # Set is_pid_controlled to true for those pumps and outlets that are PID controlled diff --git a/core/src/solve.jl b/core/src/solve.jl index f773fabb5..b9b0b19ae 100644 --- a/core/src/solve.jl +++ b/core/src/solve.jl @@ -568,7 +568,7 @@ function formulate_basins!( return nothing end -function set_error!(pid_control::PidControl, p::Parameters, u::ComponentVector, t::Float64) +function set_error!(pid_control::PidControl, p::Parameters, u::ComponentVector, t::Number) (; basin) = p (; listen_node_id, target, error) = pid_control error = get_tmp(error, u) @@ -588,7 +588,7 @@ function continuous_control!( pid_control::PidControl, p::Parameters, integral_value::SubArray, - t::Float64, + t::Number, )::Nothing (; graph, pump, outlet, basin, fractional_flow) = p min_flow_rate_pump = pump.min_flow_rate @@ -751,7 +751,7 @@ function formulate_flow!( user::User, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; graph, basin) = p (; node_id, allocated, demand, active, return_factor, min_level) = user @@ -803,7 +803,7 @@ function formulate_flow!( linear_resistance::LinearResistance, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; graph) = p (; node_id, active, resistance) = linear_resistance @@ -831,7 +831,7 @@ function formulate_flow!( tabulated_rating_curve::TabulatedRatingCurve, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; basin, graph) = p (; node_id, active, tables) = tabulated_rating_curve @@ -899,7 +899,7 @@ function formulate_flow!( manning_resistance::ManningResistance, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; basin, graph) = p (; node_id, active, length, manning_n, profile_width, profile_slope) = @@ -954,7 +954,7 @@ function formulate_flow!( fractional_flow::FractionalFlow, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; graph) = p (; node_id, fraction) = fractional_flow @@ -974,7 +974,7 @@ function formulate_flow!( terminal::Terminal, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; graph) = p (; node_id) = terminal @@ -992,7 +992,7 @@ function formulate_flow!( level_boundary::LevelBoundary, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; graph) = p (; node_id) = level_boundary @@ -1014,7 +1014,7 @@ function formulate_flow!( flow_boundary::FlowBoundary, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; graph) = p (; node_id, active, flow_rate) = flow_boundary @@ -1039,7 +1039,7 @@ function formulate_flow!( pump::Pump, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; graph, basin) = p (; node_id, active, flow_rate, is_pid_controlled) = pump @@ -1072,7 +1072,7 @@ function formulate_flow!( outlet::Outlet, p::Parameters, storage::AbstractVector, - t::Float64, + t::Number, )::Nothing (; graph, basin) = p (; node_id, active, flow_rate, is_pid_controlled, min_crest_level) = outlet @@ -1136,7 +1136,7 @@ function formulate_du!( return nothing end -function formulate_flows!(p::Parameters, storage::AbstractVector, t::Float64)::Nothing +function formulate_flows!(p::Parameters, storage::AbstractVector, t::Number)::Nothing (; linear_resistance, manning_resistance, @@ -1171,7 +1171,7 @@ function water_balance!( du::ComponentVector, u::ComponentVector, p::Parameters, - t::Float64, + t::Number, )::Nothing (; graph, basin, pid_control) = p diff --git a/core/src/utils.jl b/core/src/utils.jl index 4e2a3ce88..fcd7c1f0f 100644 --- a/core/src/utils.jl +++ b/core/src/utils.jl @@ -19,7 +19,7 @@ Return a directed metagraph with data of nodes (NodeMetadata): and data of edges (EdgeMetadata): [`EdgeMetadata`](@ref) """ -function create_graph(db::DB, config::Config, chunk_size::Int)::MetaGraph +function create_graph(db::DB, config::Config, chunk_sizes::Vector{Int})::MetaGraph node_rows = execute(db, "select fid, type, allocation_network_id from Node") edge_rows = execute( db, @@ -85,8 +85,8 @@ function create_graph(db::DB, config::Config, chunk_size::Int)::MetaGraph flow = zeros(flow_counter) flow_vertical = zeros(flow_vertical_counter) if config.solver.autodiff - flow = DiffCache(flow, chunk_size) - flow_vertical = DiffCache(flow_vertical, chunk_size) + flow = DiffCache(flow, chunk_sizes) + flow_vertical = DiffCache(flow_vertical, chunk_sizes) end graph_data = (; node_ids, @@ -669,7 +669,7 @@ storage: tells ForwardDiff whether this call is for differentiation or not function get_level( p::Parameters, node_id::NodeID, - t::Float64; + t::Number; storage::Union{AbstractArray, Number} = 0, )::Union{Real, Nothing} (; basin, level_boundary) = p