From 43c946c71c9943ab8e0cdc2d60413765c782dd87 Mon Sep 17 00:00:00 2001 From: ErikQQY <2283984853@qq.com> Date: Thu, 12 Oct 2023 13:34:01 +0800 Subject: [PATCH] Fix interpolant evaluation error Signed-off-by: ErikQQY <2283984853@qq.com> --- src/interpolation.jl | 69 ++++++++++++-------------------------- test/interpolation_test.jl | 33 ++++++++++++++++++ test/runtests.jl | 6 ++++ 3 files changed, 60 insertions(+), 48 deletions(-) create mode 100644 test/interpolation_test.jl diff --git a/src/interpolation.jl b/src/interpolation.jl index 9b4600b8..a32f526f 100644 --- a/src/interpolation.jl +++ b/src/interpolation.jl @@ -4,6 +4,10 @@ struct MIRKInterpolation{T1, T2} <: AbstractDiffEqInterpolation cache end +function DiffEqBase.interp_summary(interp::MIRKInterpolation) + return "MIRK Order $(interp.cache.order) Interpolation" +end + function (id::MIRKInterpolation)(tvals, idxs, deriv, p, continuity::Symbol = :left) interpolation(tvals, id, idxs, deriv, p, continuity) end @@ -12,15 +16,11 @@ function (id::MIRKInterpolation)(val, tvals, idxs, deriv, p, continuity::Symbol interpolation!(val, tvals, id, idxs, deriv, p, continuity) end -@inline function interpolation(tvals, - id::I, - idxs, - deriv::D, - p, +# FIXME: Fix the interpolation outside the tspan + +@inline function interpolation(tvals, id::I, idxs, deriv::D, p, continuity::Symbol = :left) where {I, D} - t = id.t - u = id.u - cache = id.cache + @unpack t, u, cache = id tdir = sign(t[end] - t[1]) idx = sortperm(tvals, rev = tdir < 0) @@ -33,56 +33,29 @@ end end for j in idx - tval = tvals[j] - i = interval(t, tval) - dt = t[i + 1] - t[i] - θ = (tval - t[i]) / dt - weights, _ = interp_weights(θ, cache.alg) - z = zeros(cache.M) - sum_stages!(z, cache, weights, i) - vals[j] = copy(z) + z = similar(cache.fᵢ₂_cache) + interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt) + vals[j] = z end - DiffEqArray(vals, tvals) + return DiffEqArray(vals, tvals) end -@inline function interpolation!(vals, - tvals, - id::I, - idxs, - deriv::D, - p, +@inline function interpolation!(vals, tvals, id::I, idxs, deriv::D, p, continuity::Symbol = :left) where {I, D} - t = id.t - cache = id.cache + @unpack t, cache = id tdir = sign(t[end] - t[1]) idx = sortperm(tvals, rev = tdir < 0) for j in idx - tval = tvals[j] - i = interval(t, tval) - dt = t[i] - t[i - 1] - θ = (tval - t[i]) / dt - weights, _ = interp_weights(θ, cache.alg) - z = zeros(cache.M) - sum_stages!(z, cache, weights, i) - vals[j] = copy(z) + z = similar(cache.fᵢ₂_cache) + interp_eval!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt) + vals[j] = z end end -@inline function interpolation(tval::Number, - id::I, - idxs, - deriv::D, - p, +@inline function interpolation(tval::Number, id::I, idxs, deriv::D, p, continuity::Symbol = :left) where {I, D} - t = id.t - cache = id.cache - i = interval(t, tval) - dt = t[i] - t[i - 1] - θ = (tval - t[i]) / dt - weights, _ = interp_weights(θ, cache.alg) - z = zeros(cache.M) - sum_stages!(z, cache, weights, i) - val = copy(z) - val + z = similar(id.cache.fᵢ₂_cache) + interp_eval!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt) + return z end diff --git a/test/interpolation_test.jl b/test/interpolation_test.jl new file mode 100644 index 00000000..a1836c64 --- /dev/null +++ b/test/interpolation_test.jl @@ -0,0 +1,33 @@ +using BoundaryValueDiffEq, DiffEqBase, DiffEqDevTools, LinearAlgebra, Test + +λ = 1 +function prob_bvp_linear_analytic(u, λ, t) + a = 1 / sqrt(λ) + [(exp(-a * t) - exp((t - 2) * a)) / (1 - exp(-2 * a)), + (-a * exp(-t * a) - a * exp((t - 2) * a)) / (1 - exp(-2 * a))] +end +function prob_bvp_linear_f!(du, u, p, t) + du[1] = u[2] + du[2] = 1 / p * u[1] +end +function prob_bvp_linear_bc!(res, u, p, t) + res[1] = u[1][1] - 1 + res[2] = u[end][1] +end +prob_bvp_linear_function = ODEFunction(prob_bvp_linear_f!, analytic = prob_bvp_linear_analytic) +prob_bvp_linear_tspan = (0.0, 1.0) +prob_bvp_linear = BVProblem(prob_bvp_linear_function, prob_bvp_linear_bc!, + [1.0, 0.0], prob_bvp_linear_tspan, λ) +testTol = 1e-6 + +for order in (2, 3, 4, 5, 6) + s = Symbol("MIRK$(order)") + @eval mirk_solver(::Val{$order}) = $(s)() +end + +@testset "Interpolation" begin + @testset "MIRK$order" for order in (2, 3, 4, 5, 6) + @time sol = solve(prob_bvp_linear, mirk_solver(Val(order)); dt = 0.001) + @test sol(0.001) ≈ [0.998687464, -1.312035941] atol=testTol + end +end diff --git a/test/runtests.jl b/test/runtests.jl index af8b3bbd..fbf6fd64 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,4 +33,10 @@ using Test, SafeTestsets include("non_vector_inputs.jl") end end + + @time @testset "Interpolation Tests" begin + @time @safetestset "MIRK Interpolation Test" begin + include("interpolation_test.jl") + end + end end