Skip to content

Commit

Permalink
Merge pull request #112 from ErikQQY/qqy/interp
Browse files Browse the repository at this point in the history
Add interpolations for MIRK methods
  • Loading branch information
ChrisRackauckas authored Sep 23, 2023
2 parents a615b8a + b43f1d7 commit 186339d
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 15 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand All @@ -28,6 +29,7 @@ ConcreteStructs = "0.2"
DiffEqBase = "6.94.2"
ForwardDiff = "0.10"
NonlinearSolve = "1"
RecursiveArrayTools = "2.38.10"
Reexport = "0.2, 1.0"
SciMLBase = "1.70"
Setfield = "1"
Expand Down
8 changes: 6 additions & 2 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
module BoundaryValueDiffEq

using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays
@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools
using Adapt, LinearAlgebra, PreallocationTools, Reexport, Setfield, SparseArrays, SciMLBase,
RecursiveArrayTools
@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase

import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: pickchunksize
import RecursiveArrayTools: DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation
import SparseDiffTools: AbstractSparseADType
import TruncatedStacktraces: @truncate_stacktrace
import UnPack: @unpack
Expand All @@ -22,6 +25,7 @@ include("collocation.jl")
include("nlprob.jl")
include("solve.jl")
include("adaptivity.jl")
include("interpolation.jl")

export Shooting
export MIRK2, MIRK3, MIRK4, MIRK5, MIRK6
Expand Down
16 changes: 15 additions & 1 deletion src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ end
w′[(stage + 1):s_star],
true,
true)
z .= z .* dt .+ cache.y₀[i]
z .= z .* dt[1] .+ cache.y₀[i]

return z, z′
end
Expand Down Expand Up @@ -386,3 +386,17 @@ for order in (2, 3, 4, 5, 6)
end
end
end

function sol_eval(cache::MIRKCache{T}, t::T) where {T}
@unpack M, mesh, mesh_dt, alg, k_discrete, k_interp, y = cache

@assert mesh[1] t mesh[end]
i = interval(mesh, t)
dt = mesh_dt[i]
τ = (t - mesh[i]) / dt
weights, weights_prime = interp_weights(τ, alg)
z = zeros(M)
z_prime = zeros(M)
sum_stages!(z, z_prime, cache, weights, weights_prime, i, mesh_dt)
return z
end
22 changes: 11 additions & 11 deletions src/cache.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
@concrete struct MIRKCache{T}
order::Int # The order of MIRK method
stage::Int # The state of MIRK method
M::Int
M::Int # The number of equations
in_size
f! # FIXME: After supporting OOP functions
bc! # FIXME: After supporting OOP functions
prob
problem_type
p
alg
TU
ITU
prob # BVProblem
problem_type # StandardBVProblem
p # Parameters
alg # MIRK methods
TU # MIRK Tableau
ITU # MIRK Interpolation Tableau
# Everything below gets resized in adaptive methods
mesh
mesh_dt
k_discrete
k_interp
mesh # Discrete mesh
mesh_dt # Step size
k_discrete # Stage information associated with the discrete Runge-Kutta method
k_interp # Stage information associated with the discrete Runge-Kutta method
y
y₀
residual
Expand Down
88 changes: 88 additions & 0 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
struct MIRKInterpolation{T1, T2} <: AbstractDiffEqInterpolation
t::T1
u::T2
cache
end

function (id::MIRKInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left)
interpolation(tvals, id, idxs, deriv, p, continuity)
end

function (id::MIRKInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol = :left)
interpolation!(val, tvals, id, idxs, deriv, p, continuity)
end

@inline function interpolation(tvals,
id::I,
idxs,
deriv::D,
p,
continuity::Symbol = :left) where {I, D}
t = id.t
u = id.u
cache = id.cache
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)

if typeof(idxs) <: Number
vals = Vector{eltype(first(u))}(undef, length(tvals))
elseif typeof(idxs) <: AbstractVector
vals = Vector{Vector{eltype(first(u))}}(undef, length(tvals))
else
vals = Vector{eltype(u)}(undef, length(tvals))
end

for j in idx
tval = tvals[j]
i = interval(t, tval)
dt = t[i + 1] - t[i]
θ = (tval - t[i]) / dt
weights, _ = interp_weights(θ, cache.alg)
z = zeros(cache.M)
sum_stages!(z, cache, weights, i)
vals[j] = copy(z)
end
DiffEqArray(vals, tvals)
end

@inline function interpolation!(vals,
tvals,
id::I,
idxs,
deriv::D,
p,
continuity::Symbol = :left) where {I, D}
t = id.t
cache = id.cache
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)

for j in idx
tval = tvals[j]
i = interval(t, tval)
dt = t[i] - t[i - 1]
θ = (tval - t[i]) / dt
weights, _ = interp_weights(θ, cache.alg)
z = zeros(cache.M)
sum_stages!(z, cache, weights, i)
vals[j] = copy(z)
end
end

@inline function interpolation(tval::Number,
id::I,
idxs,
deriv::D,
p,
continuity::Symbol = :left) where {I, D}
t = id.t
cache = id.cache
i = interval(t, tval)
dt = t[i] - t[i - 1]
θ = (tval - t[i]) / dt
weights, _ = interp_weights(θ, cache.alg)
z = zeros(cache.M)
sum_stages!(z, cache, weights, i)
val = copy(z)
val
end
3 changes: 2 additions & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ function SciMLBase.solve!(cache::MIRKCache)
end
end

u = [reshape(y, cache.in_size) for y in cache.y₀]
return DiffEqBase.build_solution(prob, alg, mesh,
[reshape(y, cache.in_size) for y in cache.y₀]; retcode = info)
u; interp = MIRKInterpolation(mesh, u, cache), retcode = info)
end

0 comments on commit 186339d

Please sign in to comment.