Skip to content

Commit

Permalink
Finally done
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Dec 13, 2024
1 parent 86a2d26 commit b5a7f5d
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 25 deletions.
9 changes: 2 additions & 7 deletions lib/BoundaryValueDiffEqCore/src/misc_utils.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
# Intermidiate solution evaluation
struct EvalSol{A <: BoundaryValueDiffEqAlgorithm}
@concrete struct EvalSol{iip}
u
t
alg::A
alg
k_discrete
end

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
nodual_value(x::AbstractArray{<:AbstractArray{<:Dual}}) = map(nodual_value, x)
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
recursive_flatten, recursive_flatten!, recursive_unflatten!,
__concrete_nonlinearsolve_algorithm, diff!,
__FastShortcutBVPCompatibleNonlinearPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg, nodual_value,
__FastShortcutBVPCompatibleNLLSPolyalg,
concrete_jacobian_algorithm, eval_bc_residual,
eval_bc_residual!, get_tmp, __maybe_matmul!,
__append_similar!, __extract_problem_details,
Expand Down
20 changes: 11 additions & 9 deletions lib/BoundaryValueDiffEqMIRK/src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,18 @@ end

# Intermidiate solution for evaluating boundry conditions
# basically simplified version of the interpolation for MIRK
function (s::EvalSol{A})(tval::Number) where {A <: AbstractMIRK}
(; u, t, alg, k_discrete) = s
function (s::EvalSol)(tval::Number)
(; t, u, alg, k_discrete) = s
stage = alg_stage(alg)
z = similar(u[1])
i = interval(t, tval)
dt = t[i + 1] - t[i]
τ = (tval - t[i]) / dt
# Quick handle for the case where tval is at the boundary
(tval == t[1]) && return first(u)
(tval == t[end]) && return last(u)
z = zero(last(u))
ii = interval(t, tval)
dt = t[ii + 1] - t[ii]
τ = (tval - t[ii]) / dt
w, _ = interp_weights(τ, alg)
z .= zero(z)
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
z .= z .* dt .+ u[i]
__maybe_matmul!(z, k_discrete[ii].du[:, 1:stage], w[1:stage])
z .= z .* dt .+ u[ii]
return z
end
12 changes: 6 additions & 6 deletions lib/BoundaryValueDiffEqMIRK/src/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ function __construct_nlproblem(
cache::MIRKCache{iip}, y::AbstractVector, y₀::AbstractVectorOfArray) where {iip}
pt = cache.problem_type

eval_sol = EvalSol(y₀.u, cache.mesh, cache.alg, cache.k_discrete)
eval_sol = EvalSol{iip}(y₀.u, cache.mesh, cache.alg, cache.k_discrete)

loss_bc = if iip
@closure (du, u, p) -> __mirk_loss_bc!(
Expand Down Expand Up @@ -243,15 +243,15 @@ end
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[2:end], cache, y_, u, p)
EvalSol.u[1:end] .= nodual_value(y_)
EvalSol.u[1:end] .= y_
EvalSol.k_discrete[1:end] .= cache.k_discrete
eval_bc_residual!(resids[1], pt, bc!, EvalSol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
end

@views function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2},
residual, mesh, cache, EvalSol) where {BC1, BC2}
residual, mesh, cache, _) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[2:end], cache, y_, u, p)
Expand All @@ -267,7 +267,7 @@ end
u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, EvalSol) where {BC}
y_ = recursive_unflatten!(y, u)
resid_co = Φ(cache, y_, u, p)
EvalSol.u[1:end] .= nodual_value(y_)
EvalSol.u[1:end] .= y_
EvalSol.k_discrete[1:end] .= cache.k_discrete
resid_bc = eval_bc_residual(pt, bc, EvalSol, p, mesh)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
Expand All @@ -285,14 +285,14 @@ end
@views function __mirk_loss_bc!(
resid, u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = EvalSol(y_, mesh, cache.alg, cache.k_discrete)
soly_ = EvalSol{true}(y_, mesh, cache.alg, cache.k_discrete)
eval_bc_residual!(resid, pt, bc!, soly_, p, mesh)
return nothing
end

@views function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = EvalSol(y_, mesh, cache.alg, cache.k_discrete)
soly_ = EvalSol{false}(y_, mesh, cache.alg, cache.k_discrete)
return eval_bc_residual(pt, bc!, soly_, p, mesh)
end

Expand Down
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqMIRK/test/mirk_basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ probArr = [BVProblem(odef1!, boundary!, u0, tspan, nlls = Val(false)),
TwoPointBVProblem(odef2, (boundary_two_point_a, boundary_two_point_b),
u0, tspan; bcresid_prototype, nlls = Val(false))]

testTol = 0.2
testTol = 0.4
affineTol = 1e-2
dts = 1 .// 2 .^ (3:-1:1)

Expand Down Expand Up @@ -108,7 +108,7 @@ end

@testset "Problem: $i" for i in (3, 4, 7, 8)
prob = probArr[i]
@testset "MIRK$order" for (i, order) in enumerate((2, 3, 4, 5, 6))
@testset "MIRK$order" for (_, order) in enumerate((2, 3, 4, 5, 6))
sim = test_convergence(
dts, prob, mirk_solver(Val(order)); abstol = 1e-8, reltol = 1e-8)
@test sim.𝒪est[:final]order atol=testTol
Expand Down

0 comments on commit b5a7f5d

Please sign in to comment.