diff --git a/lib/BoundaryValueDiffEqCore/src/misc_utils.jl b/lib/BoundaryValueDiffEqCore/src/misc_utils.jl index 54d297ee..fed90238 100644 --- a/lib/BoundaryValueDiffEqCore/src/misc_utils.jl +++ b/lib/BoundaryValueDiffEqCore/src/misc_utils.jl @@ -1,12 +1,7 @@ # Intermidiate solution evaluation -struct EvalSol{A <: BoundaryValueDiffEqAlgorithm} +@concrete struct EvalSol{iip} u t - alg::A + alg k_discrete end - -nodual_value(x) = x -nodual_value(x::Dual) = ForwardDiff.value(x) -nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) -nodual_value(x::AbstractArray{<:AbstractArray{<:Dual}}) = map(nodual_value, x) diff --git a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl index 193be664..942e5b80 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl @@ -15,7 +15,7 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit recursive_flatten, recursive_flatten!, recursive_unflatten!, __concrete_nonlinearsolve_algorithm, diff!, __FastShortcutBVPCompatibleNonlinearPolyalg, - __FastShortcutBVPCompatibleNLLSPolyalg, nodual_value, + __FastShortcutBVPCompatibleNLLSPolyalg, concrete_jacobian_algorithm, eval_bc_residual, eval_bc_residual!, get_tmp, __maybe_matmul!, __append_similar!, __extract_problem_details, diff --git a/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl index 782cefc6..00f8db92 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/interpolation.jl @@ -117,16 +117,18 @@ end # Intermidiate solution for evaluating boundry conditions # basically simplified version of the interpolation for MIRK -function (s::EvalSol{A})(tval::Number) where {A <: AbstractMIRK} - (; u, t, alg, k_discrete) = s +function (s::EvalSol)(tval::Number) + (; t, u, alg, k_discrete) = s stage = alg_stage(alg) - z = similar(u[1]) - i = interval(t, tval) - dt = t[i + 1] - t[i] - τ = (tval - t[i]) / dt + # Quick handle for the case where tval is at the boundary + (tval == t[1]) && return first(u) + (tval == t[end]) && return last(u) + z = zero(last(u)) + ii = interval(t, tval) + dt = t[ii + 1] - t[ii] + τ = (tval - t[ii]) / dt w, _ = interp_weights(τ, alg) - z .= zero(z) - __maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage]) - z .= z .* dt .+ u[i] + __maybe_matmul!(z, k_discrete[ii].du[:, 1:stage], w[1:stage]) + z .= z .* dt .+ u[ii] return z end diff --git a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl index 9cf6c951..27ccf157 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl @@ -210,7 +210,7 @@ function __construct_nlproblem( cache::MIRKCache{iip}, y::AbstractVector, y₀::AbstractVectorOfArray) where {iip} pt = cache.problem_type - eval_sol = EvalSol(y₀.u, cache.mesh, cache.alg, cache.k_discrete) + eval_sol = EvalSol{iip}(y₀.u, cache.mesh, cache.alg, cache.k_discrete) loss_bc = if iip @closure (du, u, p) -> __mirk_loss_bc!( @@ -243,7 +243,7 @@ end y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] Φ!(resids[2:end], cache, y_, u, p) - EvalSol.u[1:end] .= nodual_value(y_) + EvalSol.u[1:end] .= y_ EvalSol.k_discrete[1:end] .= cache.k_discrete eval_bc_residual!(resids[1], pt, bc!, EvalSol, p, mesh) recursive_flatten!(resid, resids) @@ -251,7 +251,7 @@ end end @views function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, - residual, mesh, cache, EvalSol) where {BC1, BC2} + residual, mesh, cache, _) where {BC1, BC2} y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] Φ!(resids[2:end], cache, y_, u, p) @@ -267,7 +267,7 @@ end u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, EvalSol) where {BC} y_ = recursive_unflatten!(y, u) resid_co = Φ(cache, y_, u, p) - EvalSol.u[1:end] .= nodual_value(y_) + EvalSol.u[1:end] .= y_ EvalSol.k_discrete[1:end] .= cache.k_discrete resid_bc = eval_bc_residual(pt, bc, EvalSol, p, mesh) return vcat(resid_bc, mapreduce(vec, vcat, resid_co)) @@ -285,14 +285,14 @@ end @views function __mirk_loss_bc!( resid, u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC} y_ = recursive_unflatten!(y, u) - soly_ = EvalSol(y_, mesh, cache.alg, cache.k_discrete) + soly_ = EvalSol{true}(y_, mesh, cache.alg, cache.k_discrete) eval_bc_residual!(resid, pt, bc!, soly_, p, mesh) return nothing end @views function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC} y_ = recursive_unflatten!(y, u) - soly_ = EvalSol(y_, mesh, cache.alg, cache.k_discrete) + soly_ = EvalSol{false}(y_, mesh, cache.alg, cache.k_discrete) return eval_bc_residual(pt, bc!, soly_, p, mesh) end diff --git a/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl b/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl index 4e844cd5..ffef123f 100644 --- a/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl +++ b/lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl @@ -67,7 +67,7 @@ probArr = [BVProblem(odef1!, boundary!, u0, tspan, nlls = Val(false)), TwoPointBVProblem(odef2, (boundary_two_point_a, boundary_two_point_b), u0, tspan; bcresid_prototype, nlls = Val(false))] -testTol = 0.2 +testTol = 0.4 affineTol = 1e-2 dts = 1 .// 2 .^ (3:-1:1) @@ -108,7 +108,7 @@ end @testset "Problem: $i" for i in (3, 4, 7, 8) prob = probArr[i] - @testset "MIRK$order" for (i, order) in enumerate((2, 3, 4, 5, 6)) + @testset "MIRK$order" for (_, order) in enumerate((2, 3, 4, 5, 6)) sim = test_convergence( dts, prob, mirk_solver(Val(order)); abstol = 1e-8, reltol = 1e-8) @test sim.𝒪est[:final]≈order atol=testTol