diff --git a/Project.toml b/Project.toml index 71418a8b..239fa0e5 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -28,6 +29,7 @@ ConcreteStructs = "0.2" DiffEqBase = "6.94.2" ForwardDiff = "0.10" NonlinearSolve = "1" +RecursiveArrayTools = "2.38.10" Reexport = "0.2, 1.0" SciMLBase = "1.70" Setfield = "1" diff --git a/src/BoundaryValueDiffEq.jl b/src/BoundaryValueDiffEq.jl index 400e7d64..619a0e3b 100644 --- a/src/BoundaryValueDiffEq.jl +++ b/src/BoundaryValueDiffEq.jl @@ -1,13 +1,16 @@ module BoundaryValueDiffEq -using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays -@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools +using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays, SciMLBase, + RecursiveArrayTools +@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase import ADTypes: AbstractADType import ArrayInterface: matrix_colors, parameterless_type import ConcreteStructs: @concrete import DiffEqBase: solve import ForwardDiff: pickchunksize +import RecursiveArrayTools: DiffEqArray +import SciMLBase: AbstractDiffEqInterpolation import SparseDiffTools: AbstractSparseADType import TruncatedStacktraces: @truncate_stacktrace import UnPack: @unpack @@ -22,6 +25,7 @@ include("collocation.jl") include("nlprob.jl") include("solve.jl") include("adaptivity.jl") +include("interpolation.jl") export Shooting export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6 diff --git a/src/adaptivity.jl b/src/adaptivity.jl index 1616d11b..cec9c7b8 100644 --- a/src/adaptivity.jl +++ b/src/adaptivity.jl @@ -253,7 +253,7 @@ end w′[(stage + 1):s_star], true, true) - z .= z .* dt .+ cache.y₀[i] + z .= z .* dt[1] .+ cache.y₀[i] return z, z′ end @@ -386,3 +386,17 @@ for order in (2, 3, 4, 5, 6) end end end + +function sol_eval(cache::MIRKCache{T}, t::T) where {T} + @unpack M, mesh, mesh_dt, alg, k_discrete, k_interp, y = cache + + @assert mesh[1] ≤ t ≤ mesh[end] + i = interval(mesh, t) + dt = mesh_dt[i] + τ = (t - mesh[i]) / dt + weights, weights_prime = interp_weights(τ, alg) + z = zeros(M) + z_prime = zeros(M) + sum_stages!(z, z_prime, cache, weights, weights_prime, i, mesh_dt) + return z +end diff --git a/src/cache.jl b/src/cache.jl index 3a8ade35..c4a77cef 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -1,21 +1,21 @@ @concrete struct MIRKCache{T} order::Int # The order of MIRK method stage::Int # The state of MIRK method - M::Int + M::Int # The number of equations in_size f! # FIXME: After supporting OOP functions bc! # FIXME: After supporting OOP functions - prob - problem_type - p - alg - TU - ITU + prob # BVProblem + problem_type # StandardBVProblem + p # Parameters + alg # MIRK methods + TU # MIRK Tableau + ITU # MIRK Interpolation Tableau # Everything below gets resized in adaptive methods - mesh - mesh_dt - k_discrete - k_interp + mesh # Discrete mesh + mesh_dt # Step size + k_discrete # Stage information associated with the discrete Runge-Kutta method + k_interp # Stage information associated with the discrete Runge-Kutta method y y₀ residual diff --git a/src/interpolation.jl b/src/interpolation.jl new file mode 100644 index 00000000..9b4600b8 --- /dev/null +++ b/src/interpolation.jl @@ -0,0 +1,88 @@ +struct MIRKInterpolation{T1, T2} <: AbstractDiffEqInterpolation + t::T1 + u::T2 + cache +end + +function (id::MIRKInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left) + interpolation(tvals, id, idxs, deriv, p, continuity) +end + +function (id::MIRKInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol = :left) + interpolation!(val, tvals, id, idxs, deriv, p, continuity) +end + +@inline function interpolation(tvals, + id::I, + idxs, + deriv::D, + p, + continuity::Symbol = :left) where {I, D} + t = id.t + u = id.u + cache = id.cache + tdir = sign(t[end] - t[1]) + idx = sortperm(tvals, rev = tdir < 0) + + if typeof(idxs) <: Number + vals = Vector{eltype(first(u))}(undef, length(tvals)) + elseif typeof(idxs) <: AbstractVector + vals = Vector{Vector{eltype(first(u))}}(undef, length(tvals)) + else + vals = Vector{eltype(u)}(undef, length(tvals)) + end + + for j in idx + tval = tvals[j] + i = interval(t, tval) + dt = t[i + 1] - t[i] + θ = (tval - t[i]) / dt + weights, _ = interp_weights(θ, cache.alg) + z = zeros(cache.M) + sum_stages!(z, cache, weights, i) + vals[j] = copy(z) + end + DiffEqArray(vals, tvals) +end + +@inline function interpolation!(vals, + tvals, + id::I, + idxs, + deriv::D, + p, + continuity::Symbol = :left) where {I, D} + t = id.t + cache = id.cache + tdir = sign(t[end] - t[1]) + idx = sortperm(tvals, rev = tdir < 0) + + for j in idx + tval = tvals[j] + i = interval(t, tval) + dt = t[i] - t[i - 1] + θ = (tval - t[i]) / dt + weights, _ = interp_weights(θ, cache.alg) + z = zeros(cache.M) + sum_stages!(z, cache, weights, i) + vals[j] = copy(z) + end +end + +@inline function interpolation(tval::Number, + id::I, + idxs, + deriv::D, + p, + continuity::Symbol = :left) where {I, D} + t = id.t + cache = id.cache + i = interval(t, tval) + dt = t[i] - t[i - 1] + θ = (tval - t[i]) / dt + weights, _ = interp_weights(θ, cache.alg) + z = zeros(cache.M) + sum_stages!(z, cache, weights, i) + val = copy(z) + val +end diff --git a/src/solve.jl b/src/solve.jl index dcb14182..ab827be0 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -156,6 +156,7 @@ function SciMLBase.solve!(cache::MIRKCache) end end + u = [reshape(y, cache.in_size) for y in cache.y₀] return DiffEqBase.build_solution(prob, alg, mesh, - [reshape(y, cache.in_size) for y in cache.y₀]; retcode = info) + u; interp = MIRKInterpolation(mesh, u, cache), retcode = info) end