Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use solution object in all solvers #261

Merged
merged 11 commits into from
Dec 21, 2024
18 changes: 2 additions & 16 deletions benchmark/simple_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ function bc_pendulum!(residual, u, p, t)
return nothing
end

function bc_pendulum_mirk!(residual, u, p, t)
residual[1] = u[:, end ÷ 2][1] + π / 2
residual[2] = u[:, end][1] - π / 2
return nothing
end

function simple_pendulum(u, p, t)
g, L, θ, dθ = 9.81, 1.0, u[1], u[2]
return [dθ, -(g / L) * sin(θ)]
Expand All @@ -34,16 +28,8 @@ function bc_pendulum(u, p, t)
return [u((t0 + t1) / 2)[1] + π / 2, u(t1)[1] - π / 2]
end

function bc_pendulum_mirk(u, p, t)
return [u[:, end ÷ 2][1] + π / 2, u[:, end][1] - π / 2]
end

const prob_oop = BVProblem{false}(simple_pendulum, bc_pendulum, [π / 2, π / 2], tspan)
const prob_iip = BVProblem{true}(simple_pendulum!, bc_pendulum!, [π / 2, π / 2], tspan)
const prob_oop_mirk = BVProblem{false}(
simple_pendulum, bc_pendulum_mirk, [π / 2, π / 2], tspan)
const prob_iip_mirk = BVProblem{true}(
simple_pendulum!, bc_pendulum_mirk!, [π / 2, π / 2], tspan)

end

Expand Down Expand Up @@ -77,7 +63,7 @@ function create_simple_pendulum_benchmark()
for alg in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
if @isdefined(alg)
iip_suite["$alg()"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_iip_mirk, $alg(), dt = 0.05)
$SimplePendulumBenchmark.prob_iip, $alg(), dt = 0.05)
end
end

Expand All @@ -102,7 +88,7 @@ function create_simple_pendulum_benchmark()
for alg in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
if @isdefined(alg)
oop_suite["$alg()"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_oop_mirk, $alg(), dt = 0.05)
$SimplePendulumBenchmark.prob_oop, $alg(), dt = 0.05)
end
end

Expand Down
13 changes: 10 additions & 3 deletions lib/BoundaryValueDiffEqCore/src/misc_utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Intermidiate solution evaluation
@concrete struct EvalSol{iip}
@concrete struct EvalSol{C}
u
t
alg
k_discrete
cache::C
end

Base.size(e::EvalSol) = (size(e.u[1])..., length(e.u))
Base.size(e::EvalSol, i) = size(e)[i]

Base.axes(e::EvalSol) = Base.OneTo.(size(e))
Base.axes(e::EvalSol, d::Int) = Base.OneTo.(size(e)[d])

Base.getindex(e::EvalSol, args...) = Base.getindex(VectorOfArray(e.u), args...)
8 changes: 5 additions & 3 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ end
__vec_f(u, p, t, f, u_size) = vec(f(reshape(u, u_size), p, t))

function __vec_bc!(resid, sol, p, t, bc!, resid_size, u_size)
bc!(reshape(resid, resid_size), __restructure_sol(sol, u_size), p, t)
bc!(reshape(resid, resid_size), sol, p, t)
return nothing
end

Expand All @@ -232,17 +232,19 @@ function __vec_bc!(resid, sol, p, bc!, resid_size, u_size)
return nothing
end

__vec_bc(sol, p, t, bc, u_size) = vec(bc(__restructure_sol(sol, u_size), p, t))
__vec_bc(sol, p, t, bc, u_size) = vec(bc(sol, p, t))
__vec_bc(sol, p, bc, u_size) = vec(bc(reshape(sol, u_size), p))

@inline __get_non_sparse_ad(ad::AbstractADType) = ad
@inline __get_non_sparse_ad(ad::AutoSparse) = ADTypes.dense_ad(ad)

# Restructure Solution
function __restructure_sol(sol::AbstractVectorOfArray, u_size)
(size(first(sol)) == u_size) && return sol
return VectorOfArray(map(Base.Fix2(reshape, u_size), sol))
end
function __restructure_sol(sol::Vector{<:AbstractArray}, u_size)
function __restructure_sol(sol::AbstractArray{<:AbstractArray}, u_size)
(size(first(sol)) == u_size) && return sol
return map(Base.Fix2(reshape, u_size), sol)
end

Expand Down
25 changes: 13 additions & 12 deletions lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
recursive_flatten, recursive_flatten!, recursive_unflatten!,
__concrete_nonlinearsolve_algorithm, diff!,
__FastShortcutBVPCompatibleNonlinearPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg, EvalSol,
concrete_jacobian_algorithm, eval_bc_residual,
eval_bc_residual!, get_tmp, __maybe_matmul!,
__append_similar!, __extract_problem_details,
__initial_guess, __maybe_allocate_diffcache,
__get_bcresid_prototype, __similar, __vec, __vec_f,
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__internal_nlsolve_problem, MaybeDiffCache, __extract_mesh,
__extract_u0, __has_initial_guess, __initial_guess_length,
__restructure_sol, __get_bcresid_prototype, __similar,
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
recursive_flatten_twopoint!, __internal_nlsolve_problem,
MaybeDiffCache, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix
Expand All @@ -33,7 +34,7 @@ import ArrayInterface: matrix_colors, parameterless_type, undefmatrix, fast_scal
import ConcreteStructs: @concrete
import DiffEqBase: solve
import FastClosures: @closure
import ForwardDiff: ForwardDiff, pickchunksize
import ForwardDiff: ForwardDiff, pickchunksize, Dual
import Logging
import RecursiveArrayTools: ArrayPartition, DiffEqArray
import SciMLBase: AbstractDiffEqInterpolation, StandardBVProblem, __solve, _unwrap_val
Expand All @@ -60,11 +61,11 @@ include("sparse_jacobians.jl")
f1 = (u, p, t) -> [u[2], 0]

function bc1!(residual, u, p, t)
residual[1] = u[:, 1][1] - 5
residual[2] = u[:, end][1]
residual[1] = u(0.0)[1] - 5
residual[2] = u(5.0)[1]
end

bc1 = (u, p, t) -> [u[:, 1][1] - 5, u[:, end][1]]
bc1 = (u, p, t) -> [u(0.0)[1] - 5, u(5.0)[1]]

bc1_a! = (residual, ua, p) -> (residual[1] = ua[1] - 5)
bc1_b! = (residual, ub, p) -> (residual[1] = ub[1])
Expand Down Expand Up @@ -143,14 +144,14 @@ include("sparse_jacobians.jl")
f1_nlls = (u, p, t) -> [u[2], -u[1]]

bc1_nlls! = (resid, sol, p, t) -> begin
solₜ₁ = sol[:, 1]
solₜ₂ = sol[:, end]
solₜ₁ = sol(0.0)
solₜ₂ = sol(100.0)
resid[1] = solₜ₁[1]
resid[2] = solₜ₂[1] - 1
resid[3] = solₜ₂[2] + 1.729109
return nothing
end
bc1_nlls = (sol, p, t) -> [sol[:, 1][1], sol[:, end][1] - 1, sol[:, end][2] + 1.729109]
bc1_nlls = (sol, p, t) -> [sol(0.0)[1], sol(100.0)[1] - 1, sol(100.0)[2] + 1.729109]

bc1_nlls_a! = (resid, ua, p) -> (resid[1] = ua[1])
bc1_nlls_b! = (resid, ub, p) -> (resid[1] = ub[1] - 1;
Expand Down
6 changes: 3 additions & 3 deletions lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,10 @@ function apply_q_prime(τ, h, coeffs)
return sum(i * coeffs[i] * (τ * h)^(i - 1) for i in axes(coeffs, 1))
end

function eval_q(y_i, τ, h, A, K)
function eval_q(y_i::AbstractArray{T}, τ, h, A, K) where {T}
M = size(K, 1)
q = zeros(M)
q′ = zeros(M)
q = zeros(T, M)
q′ = zeros(T, M)
for i in 1:M
ki = @view K[i, :]
coeffs = get_q_coeffs(A, ki, h)
Expand Down
40 changes: 22 additions & 18 deletions lib/BoundaryValueDiffEqFIRK/src/firk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ end

function __perform_firk_iteration(cache::Union{FIRKCacheExpand, FIRKCacheNested}, abstol,
adaptive::Bool; nlsolve_kwargs = (;), kwargs...)
nlprob = __construct_nlproblem(cache, vec(cache.y₀))
nlprob = __construct_nlproblem(cache, vec(cache.y₀), copy(cache.y₀))
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
sol_nlprob = __solve(
nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
Expand Down Expand Up @@ -402,9 +402,11 @@ end

# Constructing the Nonlinear Problem
function __construct_nlproblem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{iip}},
y::AbstractVector) where {iip}
y::AbstractVector, y₀::AbstractVectorOfArray) where {iip}
pt = cache.problem_type

eval_sol = EvalSol(__restructure_sol(y₀.u, cache.in_size), cache.mesh, cache)

loss_bc = if iip
@closure (du, u, p) -> __firk_loss_bc!(
du, u, p, pt, cache.bc, cache.y, cache.mesh, cache)
Expand All @@ -422,9 +424,10 @@ function __construct_nlproblem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpan

loss = if iip
@closure (du, u, p) -> __firk_loss!(
du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache)
du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, eval_sol)
else
@closure (u, p) -> __firk_loss(u, p, cache.y, pt, cache.bc, cache.mesh, cache)
@closure (u, p) -> __firk_loss(
u, p, cache.y, pt, cache.bc, cache.mesh, cache, eval_sol)
end

return __construct_nlproblem(cache, y, loss_bc, loss_collocation, loss, pt)
Expand Down Expand Up @@ -658,19 +661,19 @@ function __construct_nlproblem(
return __internal_nlsolve_problem(cache.prob, resid_prototype, y, nlf, y, cache.p)
end

@views function __firk_loss!(
resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, cache) where {BC}
@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC,
residual, mesh, cache, eval_sol) where {BC}
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
soly_ = VectorOfArray(y_)
eval_bc_residual!(resids[1], pt, bc!, soly_, p, mesh)
Φ!(resids[2:end], cache, y_, u, p)
eval_sol.u[1:end] .= y_
eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh)
recursive_flatten!(resid, resids)
return nothing
end

@views function __firk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2},
residual, mesh, cache) where {BC1, BC2}
residual, mesh, cache, _) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resids = [get_tmp(r, u) for r in residual]
Expand All @@ -682,16 +685,17 @@ end
return nothing
end

@views function __firk_loss(u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache) where {BC}
@views function __firk_loss(
u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, eval_sol) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resid_bc = eval_bc_residual(pt, bc, soly_, p, mesh)
eval_sol.u[1:end] .= y_
resid_bc = eval_bc_residual(pt, bc, eval_sol, p, mesh)
resid_co = Φ(cache, y_, u, p)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
end

@views function __firk_loss(
u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2}, mesh, cache) where {BC1, BC2}
@views function __firk_loss(u, p, y, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2},
mesh, cache, _) where {BC1, BC2}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
resid_bca, resid_bcb = eval_bc_residual(pt, bc, soly_, p, mesh)
Expand All @@ -702,16 +706,16 @@ end
@views function __firk_loss_bc!(resid, u, p, pt, bc!::BC, y, mesh,
cache::Union{FIRKCacheNested, FIRKCacheExpand}) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
eval_bc_residual!(resid, pt, bc!, soly_, p, mesh)
eval_sol = EvalSol(__restructure_sol(y_, cache.in_size), mesh, cache)
eval_bc_residual!(resid, pt, bc!, eval_sol, p, mesh)
return nothing
end

@views function __firk_loss_bc(u, p, pt, bc!::BC, y, mesh,
cache::Union{FIRKCacheNested, FIRKCacheExpand}) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = VectorOfArray(y_)
return eval_bc_residual(pt, bc!, soly_, p, mesh)
eval_sol = EvalSol(__restructure_sol(y_, cache.in_size), mesh, cache)
return eval_bc_residual(pt, bc!, eval_sol, p, mesh)
end

@views function __firk_loss_collocation!(resid, u, p, y, mesh, residual, cache)
Expand Down
94 changes: 94 additions & 0 deletions lib/BoundaryValueDiffEqFIRK/src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,97 @@ end
cache.mesh, u, cache)
@inline __build_interpolation(cache::FIRKCacheNested, u::AbstractVector) = FIRKNestedInterpolation(
cache.mesh, u, cache)

# Intermidiate solution for evaluating boundry conditions
# basically simplified version of the interpolation for FIRK
# Expanded FIRK
function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheExpand}
(; t, u, cache) = s
(; f, alg, ITU, p) = cache
(; q_coeff) = ITU
stage = alg_stage(alg)
# Quick handle for the case where tval is at the boundary
(tval == t[1]) && return first(u)
(tval == t[end]) && return last(u)
K = __similar(first(u), length(first(u)), stage)
j = interval(t, tval)
ctr_y = (j - 1) * (stage + 1) + 1

yᵢ = u[ctr_y]
yᵢ₊₁ = u[ctr_y + stage + 1]

if SciMLBase.isinplace(cache.prob)
dyᵢ = similar(yᵢ)
dyᵢ₊₁ = similar(yᵢ₊₁)

f(dyᵢ, yᵢ, p, t[j])
f(dyᵢ₊₁, yᵢ₊₁, p, t[j + 1])
else
dyᵢ = f(yᵢ, p, t[j])
dyᵢ₊₁ = f(yᵢ₊₁, p, t[j + 1])
end

# Load interpolation residual
for jj in 1:stage
K[:, jj] = u[ctr_y + jj]
end
h = t[j + 1] - t[j]
τ = tval - t[j]

z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)

z = similar(yᵢ)

S_interpolate!(z, τ, S_coeffs)
return z
end

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

# Nested FIRK
function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheNested}
(; t, u, cache) = s
(; f, nest_prob, nest_tol, alg, mesh_dt, p, ITU) = cache
(; q_coeff) = ITU
stage = alg_stage(alg)
# Quick handle for the case where tval is at the boundary
(tval == t[1]) && return first(u)
(tval == t[end]) && return last(u)
j = interval(t, tval)
h = t[j + 1] - t[j]
τ = tval - t[j]

nest_nlsolve_alg = __concrete_nonlinearsolve_algorithm(nest_prob, alg.nlsolve)
nestprob_p = zeros(cache.M + 2)

yᵢ = u[j]
yᵢ₊₁ = u[j + 1]

if SciMLBase.isinplace(cache.prob)
dyᵢ = similar(yᵢ)
dyᵢ₊₁ = similar(yᵢ₊₁)

f(dyᵢ, yᵢ, p, t[j])
f(dyᵢ₊₁, yᵢ₊₁, p, t[j + 1])
else
dyᵢ = f(yᵢ, p, t[j])
dyᵢ₊₁ = f(yᵢ₊₁, p, t[j + 1])
end

nestprob_p[1] = t[j]
nestprob_p[2] = mesh_dt[j]
nestprob_p[3:end] .= nodual_value(yᵢ)

_nestprob = remake(nest_prob, p = nestprob_p)
nestsol = __solve(_nestprob, nest_nlsolve_alg; abstol = nest_tol)
K = nestsol.u

z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
z = similar(yᵢ)
S_interpolate!(z, τ, S_coeffs)
return z
end
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqFIRK/test/expanded/ensemble_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
end

function bc!(residual, u, p, t)
residual[1] = u[:, 1][1] - 1.0
residual[2] = u[:, end][1]
residual[1] = u(0.0)[1] - 1.0
residual[2] = u(1.0)[1]
end

prob_func(prob, i, repeat) = remake(prob, p = [rand()])
Expand Down
Loading
Loading