Skip to content

Commit

Permalink
Boundary conditions should always use solution object
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Dec 7, 2024
1 parent 5ecacf6 commit 9f18a3b
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 67 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ function simplependulum!(du, u, p, t)
du[2] = -9.81 * sin(θ)
end
function bc!(residual, u, p, t)
residual[1] = u[:, end ÷ 2][1] + pi / 2
residual[2] = u[:, end][1] - pi / 2
residual[1] = u(pi / 4)[1] + pi / 2
residual[2] = u(pi / 2)[1] - pi / 2
end
prob = BVProblem(simplependulum!, bc!, [pi / 2, pi / 2], tspan)
sol = solve(prob, MIRK4(), dt = 0.05)
Expand Down
3 changes: 2 additions & 1 deletion lib/BoundaryValueDiffEqCore/src/BoundaryValueDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import ADTypes: AbstractADType
import ArrayInterface: matrix_colors, parameterless_type, fast_scalar_indexing
import ConcreteStructs: @concrete
import DiffEqBase: solve
import ForwardDiff: ForwardDiff, pickchunksize
import ForwardDiff: ForwardDiff, pickchunksize, Dual
import Logging
using NonlinearSolveFirstOrder: NonlinearSolvePolyAlgorithm
import LineSearch: BackTracking
Expand All @@ -28,6 +28,7 @@ include("algorithms.jl")
include("alg_utils.jl")
include("default_nlsolve.jl")
include("sparse_jacobians.jl")
include("misc_utils.jl")

function __solve(prob::BVProblem, alg::BoundaryValueDiffEqAlgorithm, args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
Expand Down
12 changes: 12 additions & 0 deletions lib/BoundaryValueDiffEqCore/src/misc_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Intermidiate solution evaluation
struct EvalSol{A <: BoundaryValueDiffEqAlgorithm}
u
t
alg::A
k_discrete
end

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
nodual_value(x::AbstractArray{<:AbstractArray{<:Dual}}) = map(nodual_value, x)
2 changes: 0 additions & 2 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,5 +368,3 @@ end
end

@inline (f::__Fix3{F})(a, b) where {F} = f.f(a, b, f.x)

# convert every vector of vector to AbstractVectorOfArray, especially if them come from get_tmp of PreallocationTools.jl
18 changes: 9 additions & 9 deletions lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
recursive_flatten, recursive_flatten!, recursive_unflatten!,
__concrete_nonlinearsolve_algorithm, diff!,
__FastShortcutBVPCompatibleNonlinearPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg, nodual_value,
concrete_jacobian_algorithm, eval_bc_residual,
eval_bc_residual!, get_tmp, __maybe_matmul!,
__append_similar!, __extract_problem_details,
__initial_guess, __maybe_allocate_diffcache,
__get_bcresid_prototype, __similar, __vec, __vec_f,
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__internal_nlsolve_problem, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
__has_initial_guess, __initial_guess_length, EvalSol,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix
Expand All @@ -33,7 +33,7 @@ import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scal
import ConcreteStructs: @concrete
import DiffEqBase: solve
import FastClosures: @closure
import ForwardDiff: ForwardDiff, pickchunksize
import ForwardDiff: ForwardDiff, pickchunksize, Dual
import Logging
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
Expand All @@ -58,11 +58,11 @@ include("sparse_jacobians.jl")
f1 = (u, p, t) -> [u[2], 0]

function bc1!(residual, u, p, t)
residual[1] = u[:, 1][1] - 5
residual[2] = u[:, end][1]
residual[1] = u(0.0)[1] - 5
residual[2] = u(5.0)[1]
end

bc1 = (u, p, t) -> [u[:, 1][1] - 5, u[:, end][1]]
bc1 = (u, p, t) -> [u(0.0)[1] - 5, u(5.0)[1]]

bc1_a! = (residual, ua, p) -> (residual[1] = ua[1] - 5)
bc1_b! = (residual, ub, p) -> (residual[1] = ub[1])
Expand Down Expand Up @@ -103,14 +103,14 @@ include("sparse_jacobians.jl")
f1_nlls = (u, p, t) -> [u[2], -u[1]]

bc1_nlls! = (resid, sol, p, t) -> begin
solₜ₁ = sol[:, 1]
solₜ₂ = sol[:, end]
solₜ₁ = sol(0.0)
solₜ₂ = sol(100.0)
resid[1] = solₜ₁[1]
resid[2] = solₜ₂[1] - 1
resid[3] = solₜ₂[2] + 1.729109
return nothing
end
bc1_nlls = (sol, p, t) -> [sol[:, 1][1], sol[:, end][1] - 1, sol[:, end][2] + 1.729109]
bc1_nlls = (sol, p, t) -> [sol(0.0)[1], sol(100.0)[1] - 1, sol(100.0)[2] + 1.729109]

bc1_nlls_a! = (resid, ua, p) -> (resid[1] = ua[1])
bc1_nlls_b! = (resid, ub, p) -> (resid[1] = ub[1] - 1;
Expand Down
69 changes: 59 additions & 10 deletions lib/BoundaryValueDiffEqMIRK/src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ end
@inline function interpolation(tvals, id::MIRKInterpolation, idxs, deriv::D,
p, continuity::Symbol = :left) where {D}
(; t, u, cache) = id
(; mesh, mesh_dt) = cache
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)

Expand All @@ -34,7 +35,7 @@ end

for j in idx
z = similar(cache.fᵢ₂_cache)
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
interpolant!(z, id, cache, tvals[j], mesh, mesh_dt, deriv)
vals[j] = idxs !== nothing ? z[idxs] : z
end
return DiffEqArray(vals, tvals)
Expand All @@ -43,41 +44,89 @@ end
@inline function interpolation!(vals, tvals, id::MIRKInterpolation, idxs,
deriv::D, p, continuity::Symbol = :left) where {D}
(; t, cache) = id
(; mesh, mesh_dt) = cache
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)

for j in idx
z = similar(cache.fᵢ₂_cache)
interpolant!(z, id.cache, tvals[j], id.cache.mesh, id.cache.mesh_dt, deriv)
z = similar(id.u[1])
interpolant!(z, id, cache, tvals[j], mesh, mesh_dt, deriv)
vals[j] = z
end
end

@inline function interpolation(tval::Number, id::MIRKInterpolation, idxs,
deriv::D, p, continuity::Symbol = :left) where {D}
z = similar(id.cache.fᵢ₂_cache)
interpolant!(z, id.cache, tval, id.cache.mesh, id.cache.mesh_dt, deriv)
z = similar(id.u[1])
interpolant!(z, id, id.cache, tval, id.cache.mesh, id.cache.mesh_dt, deriv)
return idxs !== nothing ? z[idxs] : z
end

@inline function interpolant!(
z::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{0}})
z::AbstractArray, id, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{0}})
i = interval(mesh, t)
dt = mesh_dt[i]
τ = (t - mesh[i]) / dt
w, w′ = interp_weights(τ, cache.alg)
sum_stages!(z, cache, w, i)
sum_stages!(z, id, cache, w, i)
end

@inline function interpolant!(
dz::AbstractArray, cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{1}})
@inline function interpolant!(dz::AbstractArray, id::MIRKInterpolation,
cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{1}})
i = interval(mesh, t)
dt = mesh_dt[i]
τ = (t - mesh[i]) / dt
w, w′ = interp_weights(τ, cache.alg)
z = similar(dz)
sum_stages!(z, dz, cache, w, w′, i)
sum_stages!(z, dz, id, cache, w, w′, i)
end

function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
cache::MIRKCache, w, i::Int, dt = cache.mesh_dt[i])
(; stage, k_discrete, k_interp) = cache
(; s_star) = cache.ITU
z .= zero(z)
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
__maybe_matmul!(
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true)
z .= z .* dt .+ id.u[i]

return z
end

@views function sum_stages!(z, z′, id::MIRKInterpolation, cache::MIRKCache,
w, w′, i::Int, dt = cache.mesh_dt[i])
(; stage, k_discrete, k_interp) = cache
(; s_star) = cache.ITU

z .= zero(z)
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
__maybe_matmul!(
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true)
z′ .= zero(z′)
__maybe_matmul!(z′, k_discrete[i].du[:, 1:stage], w′[1:stage])
__maybe_matmul!(
z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true)
z .= z .* dt[1] .+ id.u[i]

return z, z′
end

@inline __build_interpolation(cache::MIRKCache, u::AbstractVector) = MIRKInterpolation(
cache.mesh, u, cache)

# Intermidiate solution for evaluating boundry conditions
# basically simplified version of the interpolation for MIRK
function (s::EvalSol{A})(tval::Number) where {A <: AbstractMIRK}
(; u, t, alg, k_discrete) = s
stage = alg_stage(alg)
z = similar(u[1])
i = interval(t, tval)
dt = t[i + 1] - t[i]
τ = (tval - t[i]) / dt
w, _ = interp_weights(τ, alg)
z .= zero(z)
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
z .= z .* dt .+ u[i]
return z
end
51 changes: 30 additions & 21 deletions lib/BoundaryValueDiffEqMIRK/src/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ end

function __perform_mirk_iteration(
cache::MIRKCache, abstol, adaptive::Bool; nlsolve_kwargs = (;), kwargs...)
nlprob = __construct_nlproblem(cache, vec(cache.y₀))
nlprob = __construct_nlproblem(cache, vec(cache.y₀), copy(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(
nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
Expand Down Expand Up @@ -206,9 +206,12 @@ function __perform_mirk_iteration(
end

# Constructing the Nonlinear Problem
function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {iip}
function __construct_nlproblem(
cache::MIRKCache{iip}, y::AbstractVector, y₀::VectorOfArray) where {iip}
pt = cache.problem_type

eval_sol = EvalSol(y₀.u, cache.mesh, cache.alg, cache.k_discrete)

loss_bc = if iip
@closure (du, u, p) -> __mirk_loss_bc!(
du, u, p, pt, cache.bc, cache.y, cache.mesh, cache)
Expand All @@ -226,66 +229,72 @@ function __construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {

loss = if iip
@closure (du, u, p) -> __mirk_loss!(
du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache)
du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, eval_sol)
else
@closure (u, p) -> __mirk_loss(u, p, cache.y, pt, cache.bc, cache.mesh, cache)
@closure (u, p) -> __mirk_loss(
u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol)
end

return __construct_nlproblem(cache, y, loss_bc, loss_collocation, loss, pt)
end

@views function __mirk_loss!(
resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, cache) where {BC}
@views function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC,
residual, mesh, cache, EvalSol) where {BC}
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
soly_ = VectorOfArray(y_)
eval_bc_residual!(resids[1], pt, bc!, soly_, p, mesh)
Φ!(resids[2:end], cache, y_, u, p)
EvalSol.u[1:end] .= nodual_value(y_)
EvalSol.k_discrete[1:end] .= cache.k_discrete
eval_bc_residual!(resids[1], pt, bc!, EvalSol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
end

@views function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2},
residual, mesh, cache) where {BC1, BC2}
residual, mesh, cache, EvalSol) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[2:end], cache, y_, u, p)
EvalSol.u[1:end] .= nodual_value(y_)
EvalSol.k_discrete[1:end] .= cache.k_discrete
resida = resids[1][1:prod(cache.resid_size[1])]
residb = resids[1][(prod(cache.resid_size[1]) + 1):end]
eval_bc_residual!((resida, residb), pt, bc!, soly_, p, mesh)
Φ!(resids[2:end], cache, y_, u, p)
eval_bc_residual!((resida, residb), pt, bc!, EvalSol, p, mesh)
recursive_flatten_twopoint!(resid, resids, cache.resid_size)
return nothing
end

@views function __mirk_loss(u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache) where {BC}
@views function __mirk_loss(
u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, EvalSol) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resid_bc = eval_bc_residual(pt, bc, soly_, p, mesh)
resid_co = Φ(cache, y_, u, p)
EvalSol.u[1:end] .= nodual_value(y_)
EvalSol.k_discrete[1:end] .= cache.k_discrete
resid_bc = eval_bc_residual(pt, bc, EvalSol, p, mesh)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
end

@views function __mirk_loss(
u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2}, mesh, cache) where {BC1, BC2}
@views function __mirk_loss(u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2},
mesh, cache, EvalSol) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resid_bca, resid_bcb = eval_bc_residual(pt, bc, soly_, p, mesh)
resid_co = Φ(cache, y_, u, p)
EvalSol.u[1:end] .= nodual_value(y_)
EvalSol.k_discrete[1:end] .= cache.k_discrete
resid_bca, resid_bcb = eval_bc_residual(pt, bc, EvalSol, p, mesh)
return vcat(resid_bca, mapreduce(vec, vcat, resid_co), resid_bcb)
end

@views function __mirk_loss_bc!(
resid, u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
soly_ = EvalSol(y_, mesh, cache.alg, cache.k_discrete)
eval_bc_residual!(resid, pt, bc!, soly_, p, mesh)
return nothing
end

@views function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
soly_ = EvalSol(y_, mesh, cache.alg, cache.k_discrete)
return eval_bc_residual(pt, bc!, soly_, p, mesh)
end

Expand Down
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqMIRK/test/ensemble_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
end

function bc!(residual, u, p, t)
residual[1] = u[:, 1][1] - 1.0
residual[2] = u[:, end][1]
residual[1] = u(0.0)[1] - 1.0
residual[2] = u(1.0)[1]
end

prob_func(prob, i, repeat) = remake(prob, p = [rand()])
Expand Down
Loading

0 comments on commit 9f18a3b

Please sign in to comment.