diff --git a/Project.toml b/Project.toml index 7aa43f32..35027d33 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index 7612cf6e..0a3a4f2a 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -1,7 +1,7 @@ module BoundaryValueDiffEq using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays, SciMLBase, - RecursiveArrayTools, ForwardDiff + Static, RecursiveArrayTools, ForwardDiff @reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase import ADTypes: AbstractADType @@ -10,7 +10,7 @@ import ConcreteStructs: @concrete import DiffEqBase: solve import ForwardDiff: pickchunksize import RecursiveArrayTools: ArrayPartition, DiffEqArray -import SciMLBase: AbstractDiffEqInterpolation +import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem import RecursiveArrayTools: ArrayPartition import SparseDiffTools: AbstractSparseADType import TruncatedStacktraces: @truncate_stacktrace @@ -20,15 +20,16 @@ include("types.jl") include("utils.jl") include("algorithms.jl") include("alg_utils.jl") + include("mirk_tableaus.jl") -include("cache.jl") -include("collocation.jl") -include("sparse_jacobians.jl") include("solve/single_shooting.jl") include("solve/multiple_shooting.jl") include("solve/mirk.jl") +include("collocation.jl") +include("sparse_jacobians.jl") + include("adaptivity.jl") include("interpolation.jl") diff --git a/src/cache.jl b/src/cache.jl deleted file mode 100644 index 6d70d9e9..00000000 --- a/src/cache.jl +++ /dev/null @@ -1,67 +0,0 @@ -@concrete struct MIRKCache{iip, T} - order::Int # The order of MIRK method - stage::Int # The state of MIRK method - M::Int # The number of equations - in_size - f - bc - prob # BVProblem - problem_type # StandardBVProblem - p # Parameters - alg # MIRK methods - TU # MIRK Tableau - ITU # MIRK Interpolation Tableau - # Everything below gets resized in adaptive methods - mesh # Discrete mesh - mesh_dt # Step size - k_discrete # Stage information associated with the discrete Runge-Kutta method - k_interp # Stage information associated with the discrete Runge-Kutta method - y - y₀ - residual - # The following 2 caches are never resized - fᵢ_cache - fᵢ₂_cache - defect - new_stages - kwargs -end - -Base.eltype(::MIRKCache{iip, T}) where {iip, T} = T - -""" - expand_cache!(cache::MIRKCache) - -After redistributing or halving the mesh, this function expands the required vectors to -match the length of the new mesh. -""" -function expand_cache!(cache::MIRKCache) - Nₙ = length(cache.mesh) - __append_similar!(cache.k_discrete, Nₙ - 1, cache.M) - __append_similar!(cache.k_interp, Nₙ - 1, cache.M) - __append_similar!(cache.y, Nₙ, cache.M) - __append_similar!(cache.y₀, Nₙ, cache.M) - __append_similar!(cache.residual, Nₙ, cache.M) - __append_similar!(cache.defect, Nₙ - 1, cache.M) - __append_similar!(cache.new_stages, Nₙ - 1, cache.M) - return cache -end - -__append_similar!(::Nothing, n, _) = nothing - -function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _) - N = n - length(x) - N == 0 && return x - N < 0 && throw(ArgumentError("Cannot append a negative number of elements")) - append!(x, [similar(first(x)) for _ in 1:N]) - return x -end - -function __append_similar!(x::AbstractVector{<:MaybeDiffCache}, n, M) - N = n - length(x) - N == 0 && return x - N < 0 && throw(ArgumentError("Cannot append a negative number of elements")) - chunksize = pickchunksize(M * (N + length(x))) - append!(x, [maybe_allocate_diffcache(first(x), chunksize) for _ in 1:N]) - return x -end diff --git a/src/collocation.jl b/src/collocation.jl index 8ab995f3..0bfeb3f3 100644 --- a/src/collocation.jl +++ b/src/collocation.jl @@ -1,9 +1,3 @@ -__initial_state_from_prob(prob::BVProblem, mesh) = __initial_state_from_prob(prob.u0, mesh) -__initial_state_from_prob(u0::AbstractArray, mesh) = [copy(vec(u0)) for _ in mesh] -function __initial_state_from_prob(u0::AbstractVector{<:AbstractVector}, _) - [copy(vec(u)) for u in u0] -end - function Φ!(residual, cache::MIRKCache, y, u, p = cache.p) return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, y, u, p, cache.mesh, cache.mesh_dt, cache.stage) diff --git a/src/solve/mirk.jl b/src/solve/mirk.jl index 5f7553a7..9242f60b 100644 --- a/src/solve/mirk.jl +++ b/src/solve/mirk.jl @@ -1,27 +1,47 @@ +@concrete struct MIRKCache{iip, T} + order::Int # The order of MIRK method + stage::Int # The state of MIRK method + M::Int # The number of equations + in_size + f + bc + prob # BVProblem + problem_type # StandardBVProblem + p # Parameters + alg # MIRK methods + TU # MIRK Tableau + ITU # MIRK Interpolation Tableau + bcresid_prototype + # Everything below gets resized in adaptive methods + mesh # Discrete mesh + mesh_dt # Step size + k_discrete # Stage information associated with the discrete Runge-Kutta method + k_interp # Stage information associated with the discrete Runge-Kutta method + y + y₀ + residual + # The following 2 caches are never resized + fᵢ_cache + fᵢ₂_cache + defect + new_stages + kwargs +end + +Base.eltype(::MIRKCache{iip, T}) where {iip, T} = T + function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, abstol = 1e-3, adaptive = true, kwargs...) - has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray} @set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg) iip = isinplace(prob) - (T, M, n) = if has_initial_guess - # If user provided a vector of initial guesses - _u0 = first(prob.u0) - eltype(_u0), length(_u0), (length(prob.u0) - 1) - else - dt ≤ 0 && throw(ArgumentError("dt must be positive")) - eltype(prob.u0), length(prob.u0), Int(cld((prob.tspan[2] - prob.tspan[1]), dt)) - end + has_initial_guess, T, M, n, X = __extract_problem_details(prob; dt, + check_positive_dt = true) chunksize = pickchunksize(M * (n + 1)) - if has_initial_guess - fᵢ_cache = maybe_allocate_diffcache(vec(similar(_u0)), chunksize, alg.jac_alg) - fᵢ₂_cache = vec(similar(_u0)) - else - fᵢ_cache = maybe_allocate_diffcache(vec(similar(prob.u0)), chunksize, alg.jac_alg) - fᵢ₂_cache = vec(similar(prob.u0)) - end - # Without this, boxing breaks type stability - X = has_initial_guess ? _u0 : prob.u0 + __alloc_diffcache = x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg) + + fᵢ_cache = __alloc_diffcache(similar(X)) + fᵢ₂_cache = vec(similar(X)) # NOTE: Assumes the user provided initial guess is on a uniform mesh mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1)) @@ -32,85 +52,51 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, # Don't flatten this here, since we need to expand it later if needed y₀ = __initial_state_from_prob(prob, mesh) - y = [maybe_allocate_diffcache(vec(copy(yᵢ)), chunksize, alg.jac_alg) for yᵢ in y₀] + y = __alloc_diffcache.(copy.(y₀)) TU, ITU = constructMIRK(alg, T) stage = alg_stage(alg) - k_discrete = [maybe_allocate_diffcache(similar(X, M, stage), chunksize, alg.jac_alg) + k_discrete = [__maybe_allocate_diffcache(similar(X, M, stage), chunksize, alg.jac_alg) for _ in 1:n] - k_interp = adaptive ? [similar(X, M, ITU.s_star - stage) for _ in 1:n] : - [similar(X, 0, 0) for _ in 1:n] + k_interp = [similar(X, ifelse(adaptive, M, 0), ifelse(adaptive, ITU.s_star - stage, 0)) + for _ in 1:n] - resid₁_size = if prob.f.bcresid_prototype === nothing - size(X) - elseif prob.f.bcresid_prototype isa ArrayPartition - size.(prob.f.bcresid_prototype.x) - else - size(prob.f.bcresid_prototype) - end + bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X) - if iip - if prob.f.bcresid_prototype === nothing - residual = [maybe_allocate_diffcache(vec(copy(yᵢ)), chunksize, alg.jac_alg) - for yᵢ in y₀] - else - residual = vcat([ - maybe_allocate_diffcache(vec(copy(prob.f.bcresid_prototype)), - chunksize, alg.jac_alg)], - [maybe_allocate_diffcache(vec(copy(yᵢ)), chunksize, alg.jac_alg) - for yᵢ in y₀[2:end]]) - end + residual = if iip + vcat([__alloc_diffcache(bcresid_prototype)], + __alloc_diffcache.(copy.(@view(y₀[2:end])))) else - residual = nothing + nothing end - defect = adaptive ? [similar(X, M) for _ in 1:n] : [similar(X, 0) for _ in 1:n] - - new_stages = adaptive ? [similar(X, M) for _ in 1:n] : [similar(X, 0) for _ in 1:n] + defect = [similar(X, ifelse(adaptive, M, 0)) for _ in 1:n] + new_stages = [similar(X, ifelse(adaptive, M, 0)) for _ in 1:n] # Transform the functions to handle non-vector inputs f, bc = if X isa AbstractVector prob.f, prob.f.bc elseif iip - function vecf!(du, u, p, t) - du_ = reshape(du, size(X)) - x_ = reshape(u, size(X)) - prob.f(du_, x_, p, t) - return du - end + vecf!(du, u, p, t) = prob.f(reshape(du, size(X)), reshape(u, size(X)), p, t) vecbc! = if !(prob.problem_type isa TwoPointBVProblem) function __vecbc!(resid, sol, p, t) - resid_ = reshape(resid, resid₁_size) - sol_ = map(s -> reshape(s, size(X)), sol) - prob.f.bc(resid_, sol_, p, t) - return resid + prob.f.bc(reshape(resid, resid₁_size), + map(Base.Fix2(reshape, size(X)), sol), p, t) end else function __vecbc_a!(resida, ua, p) - resida_ = reshape(resida, resid₁_size[1]) - ua_ = reshape(ua, size(X)) - prob.f.bc[1](resida_, ua_, p) - return nothing + prob.f.bc[1](reshape(resida, resid₁_size[1]), reshape(ua, size(X)), p) end function __vecbc_b!(residb, ub, p) - residb_ = reshape(residb, resid₁_size[2]) - ub_ = reshape(ub, size(X)) - prob.f.bc[2](residb_, ub_, p) - return nothing + prob.f.bc[2](reshape(residb, resid₁_size[2]), reshape(ub, size(X)), p) end (__vecbc_a!, __vecbc_b!) end vecf!, vecbc! else - function vecf(u, p, t) - x_ = reshape(u, size(X)) - return vec(prob.f(x_, p, t)) - end + vecf(u, p, t) = vec(prob.f(reshape(u, size(X)), p, t)) vecbc = if !(prob.problem_type isa TwoPointBVProblem) - function __vecbc(sol, p, t) - sol_ = map(s -> reshape(s, size(X)), sol) - return vec(prob.f.bc(sol_, p, t)) - end + __vecbc(sol, p, t) = vec(prob.f.bc(map(Base.Fix2(reshape, size(X)), sol), p, t)) else __vecbc_a(ua, p) = vec(prob.f.bc[1](reshape(ua, size(X)), p)) __vecbc_b(ub, p) = vec(prob.f.bc[2](reshape(ub, size(X)), p)) @@ -120,11 +106,29 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, end return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob, - prob.problem_type, prob.p, alg, TU, ITU, mesh, mesh_dt, k_discrete, k_interp, y, y₀, - residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages, + prob.problem_type, prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, + k_discrete, k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages, (; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs...)) end +""" + __expand_cache!(cache::MIRKCache) + +After redistributing or halving the mesh, this function expands the required vectors to +match the length of the new mesh. +""" +function __expand_cache!(cache::MIRKCache) + Nₙ = length(cache.mesh) + __append_similar!(cache.k_discrete, Nₙ - 1, cache.M) + __append_similar!(cache.k_interp, Nₙ - 1, cache.M) + __append_similar!(cache.y, Nₙ, cache.M) + __append_similar!(cache.y₀, Nₙ, cache.M) + __append_similar!(cache.residual, Nₙ, cache.M) + __append_similar!(cache.defect, Nₙ - 1, cache.M) + __append_similar!(cache.new_stages, Nₙ - 1, cache.M) + return cache +end + function __split_mirk_kwargs(; defect_threshold, MxNsub, abstol, dt, adaptive = true, kwargs...) return ((defect_threshold, MxNsub, abstol, adaptive, dt), @@ -139,7 +143,7 @@ function SciMLBase.solve!(cache::MIRKCache) defect_norm = 2 * abstol while SciMLBase.successful_retcode(info) && defect_norm > abstol - nlprob = construct_nlproblem(cache, recursive_flatten(y₀)) + nlprob = __construct_nlproblem(cache, recursive_flatten(y₀)) sol_nlprob = solve(nlprob, alg.nlsolve; abstol, kwargs...) recursive_unflatten!(cache.y₀, sol_nlprob.u) @@ -162,7 +166,7 @@ function SciMLBase.solve!(cache::MIRKCache) for (i, m) in enumerate(cache.mesh) interp_eval!(cache.y₀[i], cache, m, mesh, mesh_dt) end - expand_cache!(cache) + __expand_cache!(cache) end end else @@ -172,7 +176,7 @@ function SciMLBase.solve!(cache::MIRKCache) info = ReturnCode.Failure else half_mesh!(cache) - expand_cache!(cache) + __expand_cache!(cache) recursive_fill!(cache.y₀, 0) info = ReturnCode.Success # Force a restart defect_norm = 2 * abstol @@ -186,7 +190,7 @@ function SciMLBase.solve!(cache::MIRKCache) end # Constructing the Nonlinear Problem -function construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {iip} +function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {iip} loss_bc = if iip function loss_bc_internal!(resid::AbstractVector, u::AbstractVector, p = cache.p) y_ = recursive_unflatten!(cache.y, u) @@ -217,63 +221,47 @@ function construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {ii end end - loss = if !(cache.problem_type isa TwoPointBVProblem) - if iip - function loss_internal!(resid::AbstractVector, u::AbstractVector, p = cache.p) - y_ = recursive_unflatten!(cache.y, u) - resids = [get_tmp(r, u) for r in cache.residual] - eval_bc_residual!(resids[1], cache.problem_type, cache.bc, y_, p, - cache.mesh) - Φ!(resids[2:end], cache, y_, u, p) + loss = if iip + function loss_internal!(resid::AbstractVector, u::AbstractVector, p = cache.p) + y_ = recursive_unflatten!(cache.y, u) + resids = [get_tmp(r, u) for r in cache.residual] + eval_bc_residual!(resids[1], cache.problem_type, cache.bc, y_, p, + cache.mesh) + Φ!(resids[2:end], cache, y_, u, p) + if cache.problem_type isa TwoPointBVProblem + recursive_flatten_twopoint!(resid, resids) + else recursive_flatten!(resid, resids) - return resid - end - else - function loss_internal(u::AbstractVector, p = cache.p) - y_ = recursive_unflatten!(cache.y, u) - resid_bc = eval_bc_residual(cache.problem_type, cache.bc, y_, p, cache.mesh) - resid_co = Φ(cache, y_, u, p) - return vcat(resid_bc, mapreduce(vec, vcat, resid_co)) end + return resid end else - # Reordering for 2 point BVP - if iip - function loss_internal_2point!(resid::AbstractVector, u::AbstractVector, - p = cache.p) - y_ = recursive_unflatten!(cache.y, u) - resids = [get_tmp(r, u) for r in cache.residual] - eval_bc_residual!(resids[1], cache.problem_type, cache.bc, y_, p, - cache.mesh) - Φ!(resids[2:end], cache, y_, u, p) - recursive_flatten_twopoint!(resid, resids) - return resid - end - else - function loss_internal_2point(u::AbstractVector, p = cache.p) - y_ = recursive_unflatten!(cache.y, u) - resid_bc = eval_bc_residual(cache.problem_type, cache.bc, y_, p, cache.mesh) - resid_co = Φ(cache, y_, u, p) + function loss_internal(u::AbstractVector, p = cache.p) + y_ = recursive_unflatten!(cache.y, u) + resid_bc = eval_bc_residual(cache.problem_type, cache.bc, y_, p, cache.mesh) + resid_co = Φ(cache, y_, u, p) + if cache.problem_type isa TwoPointBVProblem return vcat(resid_bc.x[1], mapreduce(vec, vcat, resid_co), resid_bc.x[2]) + else + return vcat(resid_bc, mapreduce(vec, vcat, resid_co)) end end end - return generate_nlprob(cache, y, loss_bc, loss_collocation, loss, cache.problem_type) + return __construct_nlproblem(cache, y, loss_bc, loss_collocation, loss, + cache.problem_type) end -function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, loss, - _) where {iip} +function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, loss, + ::StandardBVProblem) where {iip} @unpack nlsolve, jac_alg = cache.alg N = length(cache.mesh) - resid_bc = cache.prob.f.bcresid_prototype === nothing ? similar(y, cache.M) : - cache.prob.f.bcresid_prototype + resid_bc = cache.bcresid_prototype resid_collocation = similar(y, cache.M * (N - 1)) sd_bc = jac_alg.bc_diffmode isa AbstractSparseADType ? SymbolicsSparsityDetection() : NoSparsityDetection() - cache_bc = __sparse_jacobian_cache(Val(iip), jac_alg.bc_diffmode, sd_bc, loss_bc, resid_bc, y) @@ -283,14 +271,11 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo else NoSparsityDetection() end - cache_collocation = __sparse_jacobian_cache(Val(iip), jac_alg.nonbc_diffmode, sd_collocation, loss_collocation, resid_collocation, y) jac_prototype = vcat(init_jacobian(cache_bc), init_jacobian(cache_collocation)) - # TODO: Pass `p` into `loss_bc` and `loss_collocation`. Currently leads to a Tag - # mismatch for ForwardDiff jac = if iip function jac_internal!(J, x, p) sparse_jacobian!(@view(J[1:(cache.M), :]), jac_alg.bc_diffmode, cache_bc, @@ -313,19 +298,12 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo return NonlinearProblem(NonlinearFunction{iip}(loss; jac, jac_prototype), y, cache.p) end -function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, loss, +function __construct_nlproblem(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, loss, ::TwoPointBVProblem) where {iip} @unpack nlsolve, jac_alg = cache.alg N = length(cache.mesh) - if !iip && cache.prob.f.bcresid_prototype === nothing - y_ = recursive_unflatten!(cache.y, y) - resid_ = ArrayPartition(cache.bc[1](y_[1], cache.p), cache.bc[2](y_[end], cache.p)) - resid = ArrayPartition(resid_, similar(y, cache.M * (N - 1))) - else - resid = ArrayPartition(cache.prob.f.bcresid_prototype, - similar(y, cache.M * (N - 1))) - end + resid = ArrayPartition(cache.bcresid_prototype, similar(y, cache.M * (N - 1))) sd = if jac_alg.diffmode isa AbstractSparseADType PrecomputedJacobianColorvec(__generate_sparse_jacobian_prototype(cache, @@ -333,13 +311,9 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo else NoSparsityDetection() end - diffcache = __sparse_jacobian_cache(Val(iip), jac_alg.diffmode, sd, loss, resid, y) - jac_prototype = init_jacobian(diffcache) - # TODO: Pass `p` into `loss_bc` and `loss_collocation`. Currently leads to a Tag - # mismatch for ForwardDiff jac = if iip function jac_internal!(J, x, p) sparse_jacobian!(J, jac_alg.diffmode, diffcache, loss, resid, x) diff --git a/src/solve/single_shooting.jl b/src/solve/single_shooting.jl index 2de87648..5a02f0f4 100644 --- a/src/solve/single_shooting.jl +++ b/src/solve/single_shooting.jl @@ -1,9 +1,8 @@ function SciMLBase.__solve(prob::BVProblem, alg::Shooting; odesolve_kwargs = (;), nlsolve_kwargs = (;), verbose = true, kwargs...) - has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray} - has_initial_guess && verbose && + ig, T, _, _, u0 = __extract_problem_details(prob; dt = 0.1) + known(ig) && verbose && @warn "Initial guess provided, but will be ignored for Shooting!" - u0 = has_initial_guess ? first(prob.u0) : prob.u0 iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0) resid_size = prob.f.bcresid_prototype === nothing ? u0_size : diff --git a/src/types.jl b/src/types.jl index bf3ff899..5232faa3 100644 --- a/src/types.jl +++ b/src/types.jl @@ -82,15 +82,11 @@ end du end -function maybe_allocate_diffcache(x, chunksize, jac_alg) - if __needs_diffcache(jac_alg) - return DiffCache(x, chunksize) - else - return FakeDiffCache(x) - end +function __maybe_allocate_diffcache(x, chunksize, jac_alg) + return __needs_diffcache(jac_alg) ? DiffCache(x, chunksize) : FakeDiffCache(x) end -maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(similar(x.du), chunksize) -maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(similar(x.du)) +__maybe_allocate_diffcache(x::DiffCache, chunksize) = DiffCache(similar(x.du), chunksize) +__maybe_allocate_diffcache(x::FakeDiffCache, _) = FakeDiffCache(similar(x.du)) PreallocationTools.get_tmp(dc::FakeDiffCache, _) = dc.du diff --git a/src/utils.jl b/src/utils.jl index 3082566d..6a99dbdb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -90,3 +90,68 @@ eval_bc_residual!(resid, _, bc!, sol, p, t) = bc!(resid, sol, p, t) bcb!(resid.x[2], ub, p) return resid end + +__append_similar!(::Nothing, n, _) = nothing + +function __append_similar!(x::AbstractVector{<:AbstractArray}, n, _) + N = n - length(x) + N == 0 && return x + N < 0 && throw(ArgumentError("Cannot append a negative number of elements")) + append!(x, [similar(first(x)) for _ in 1:N]) + return x +end + +function __append_similar!(x::AbstractVector{<:MaybeDiffCache}, n, M) + N = n - length(x) + N == 0 && return x + N < 0 && throw(ArgumentError("Cannot append a negative number of elements")) + chunksize = pickchunksize(M * (N + length(x))) + append!(x, [maybe_allocate_diffcache(first(x), chunksize) for _ in 1:N]) + return x +end + +## Problem with Initial Guess +function __extract_problem_details(prob; kwargs...) + return __extract_problem_details(prob, prob.u0; kwargs...) +end +function __extract_problem_details(prob, u0::AbstractVector{<:AbstractArray}; kwargs...) + # Problem has Initial Guess + _u0 = first(u0) + return True(), eltype(_u0), length(_u0), (length(u0) - 1), _u0 +end +function __extract_problem_details(prob, u0; dt = 0.0, check_positive_dt::Bool = false) + # Problem does not have Initial Guess + check_positive_dt && dt ≤ 0 && throw(ArgumentError("dt must be positive")) + t₀, t₁ = prob.tspan + return False(), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), prob.u0 +end + +__initial_state_from_prob(prob::BVProblem, mesh) = __initial_state_from_prob(prob.u0, mesh) +__initial_state_from_prob(u0::AbstractArray, mesh) = [copy(vec(u0)) for _ in mesh] +function __initial_state_from_prob(u0::AbstractVector{<:AbstractVector}, _) + return [copy(vec(u)) for u in u0] +end + +function __get_bcresid_prototype(::TwoPointBVProblem, prob, u) + prototype = if isinplace(prob) + prob.f.bcresid_prototype + elseif prob.f.bcresid_prototype === nothing + prob.f.bcresid_prototype + else + ArrayPartition(first(prob.f.bc)(u, prob.p), last(prob.f.bc)(u, prob.p)) + end + return prototype, size.(prototype.x) +end +function __get_bcresid_prototype(::StandardBVProblem, prob, u) + prototype = prob.f.bcresid_prototype !== nothing ? prob.f.bcresid_prototype : + fill!(similar(u), 0) + return prototype, size(prototype) +end + +function __fill_like(v, x, args...) + y = similar(x, args...) + fill!(y, v) + return y +end +__zeros_like(args...) = __fill_like(0, args...) +__ones_like(args...) = __fill_like(1, args...)