Skip to content

Commit

Permalink
Fix non-vector input case
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Dec 16, 2024
1 parent 5c9d67e commit d766831
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 39 deletions.
18 changes: 2 additions & 16 deletions benchmark/simple_pendulum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ function bc_pendulum!(residual, u, p, t)
return nothing
end

function bc_pendulum_mirk!(residual, u, p, t)
residual[1] = u[:, end ÷ 2][1] + π / 2
residual[2] = u[:, end][1] - π / 2
return nothing
end

function simple_pendulum(u, p, t)
g, L, θ, dθ = 9.81, 1.0, u[1], u[2]
return [dθ, -(g / L) * sin(θ)]
Expand All @@ -34,16 +28,8 @@ function bc_pendulum(u, p, t)
return [u((t0 + t1) / 2)[1] + π / 2, u(t1)[1] - π / 2]
end

function bc_pendulum_mirk(u, p, t)
return [u[:, end ÷ 2][1] + π / 2, u[:, end][1] - π / 2]
end

const prob_oop = BVProblem{false}(simple_pendulum, bc_pendulum, [π / 2, π / 2], tspan)
const prob_iip = BVProblem{true}(simple_pendulum!, bc_pendulum!, [π / 2, π / 2], tspan)
const prob_oop_mirk = BVProblem{false}(
simple_pendulum, bc_pendulum_mirk, [π / 2, π / 2], tspan)
const prob_iip_mirk = BVProblem{true}(
simple_pendulum!, bc_pendulum_mirk!, [π / 2, π / 2], tspan)

end

Expand Down Expand Up @@ -77,7 +63,7 @@ function create_simple_pendulum_benchmark()
for alg in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
if @isdefined(alg)
iip_suite["$alg()"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_iip_mirk, $alg(), dt = 0.05)
$SimplePendulumBenchmark.prob_iip, $alg(), dt = 0.05)
end
end

Expand All @@ -102,7 +88,7 @@ function create_simple_pendulum_benchmark()
for alg in (MIRK2, MIRK3, MIRK4, MIRK5, MIRK6)
if @isdefined(alg)
oop_suite["$alg()"] = @benchmarkable solve(
$SimplePendulumBenchmark.prob_oop_mirk, $alg(), dt = 0.05)
$SimplePendulumBenchmark.prob_oop, $alg(), dt = 0.05)
end
end

Expand Down
4 changes: 2 additions & 2 deletions lib/BoundaryValueDiffEqCore/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ end
__vec_f(u, p, t, f, u_size) = vec(f(reshape(u, u_size), p, t))

function __vec_bc!(resid, sol, p, t, bc!, resid_size, u_size)
bc!(reshape(resid, resid_size), __restructure_sol(sol, u_size), p, t)
bc!(reshape(resid, resid_size), sol, p, t)
return nothing
end

Expand All @@ -232,7 +232,7 @@ function __vec_bc!(resid, sol, p, bc!, resid_size, u_size)
return nothing
end

__vec_bc(sol, p, t, bc, u_size) = vec(bc(__restructure_sol(sol, u_size), p, t))
__vec_bc(sol, p, t, bc, u_size) = vec(bc(sol, p, t))
__vec_bc(sol, p, bc, u_size) = vec(bc(reshape(sol, u_size), p))

@inline __get_non_sparse_ad(ad::AbstractADType) = ad
Expand Down
9 changes: 5 additions & 4 deletions lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
eval_bc_residual!, get_tmp, __maybe_matmul!,
__append_similar!, __extract_problem_details,
__initial_guess, __maybe_allocate_diffcache,
__get_bcresid_prototype, __similar, __vec, __vec_f,
__vec_f!, __vec_bc, __vec_bc!, recursive_flatten_twopoint!,
__internal_nlsolve_problem, MaybeDiffCache, __extract_mesh,
__extract_u0, __has_initial_guess, __initial_guess_length,
__restructure_sol, __get_bcresid_prototype, __similar,
__vec, __vec_f, __vec_f!, __vec_bc, __vec_bc!,
recursive_flatten_twopoint!, __internal_nlsolve_problem,
MaybeDiffCache, __extract_mesh, __extract_u0,
__has_initial_guess, __initial_guess_length,
__initial_guess_on_mesh, __flatten_initial_guess,
__build_solution, __Fix3, __sparse_jacobian_cache,
__sparsity_detection_alg, _sparse_like, ColoredMatrix
Expand Down
6 changes: 3 additions & 3 deletions lib/BoundaryValueDiffEqFIRK/src/firk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ function __construct_nlproblem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpan
y::AbstractVector, y₀::AbstractVectorOfArray) where {iip}
pt = cache.problem_type

eval_sol = EvalSol(y₀.u, cache.mesh, cache)
eval_sol = EvalSol(__restructure_sol(y₀.u, cache.in_size), cache.mesh, cache)

loss_bc = if iip
@closure (du, u, p) -> __firk_loss_bc!(
Expand Down Expand Up @@ -707,15 +707,15 @@ end
@views function __firk_loss_bc!(resid, u, p, pt, bc!::BC, y, mesh,
cache::Union{FIRKCacheNested, FIRKCacheExpand}, eval_sol) where {BC}
y_ = recursive_unflatten!(y, u)
eval_sol = EvalSol(y_, mesh, cache)
eval_sol = EvalSol(__restructure_sol(y_, cache.in_size), mesh, cache)
eval_bc_residual!(resid, pt, bc!, eval_sol, p, mesh)
return nothing
end

@views function __firk_loss_bc(u, p, pt, bc!::BC, y, mesh,
cache::Union{FIRKCacheNested, FIRKCacheExpand}, eval_sol) where {BC}
y_ = recursive_unflatten!(y, u)
eval_sol = EvalSol(y_, mesh, cache)
eval_sol = EvalSol(__restructure_sol(y_, cache.in_size), mesh, cache)
return eval_bc_residual(pt, bc!, eval_sol, p, mesh)
end

Expand Down
2 changes: 1 addition & 1 deletion lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, BVPJacobianAlgorit
recursive_flatten, recursive_flatten!, recursive_unflatten!,
__concrete_nonlinearsolve_algorithm, diff!,
__FastShortcutBVPCompatibleNonlinearPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg,
__FastShortcutBVPCompatibleNLLSPolyalg, __restructure_sol,
concrete_jacobian_algorithm, eval_bc_residual,
eval_bc_residual!, get_tmp, __maybe_matmul!,
__append_similar!, __extract_problem_details,
Expand Down
10 changes: 5 additions & 5 deletions lib/BoundaryValueDiffEqMIRK/src/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ function __construct_nlproblem(
cache::MIRKCache{iip}, y::AbstractVector, y₀::AbstractVectorOfArray) where {iip}
pt = cache.problem_type

eval_sol = EvalSol(y₀.u, cache.mesh, cache)
eval_sol = EvalSol(__restructure_sol(y₀.u, cache.in_size), cache.mesh, cache)

loss_bc = if iip
@closure (du, u, p) -> __mirk_loss_bc!(
Expand Down Expand Up @@ -243,7 +243,7 @@ end
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[2:end], cache, y_, u, p)
EvalSol.u[1:end] .= y_
EvalSol.u[1:end] .= __restructure_sol(y_, cache.in_size)
EvalSol.cache.k_discrete[1:end] .= cache.k_discrete
eval_bc_residual!(resids[1], pt, bc!, EvalSol, p, mesh)
recursive_flatten!(resid, resids)
Expand All @@ -267,7 +267,7 @@ end
u, p, y, pt::StandardBVProblem, bc::BC, mesh, cache, EvalSol) where {BC}
y_ = recursive_unflatten!(y, u)
resid_co = Φ(cache, y_, u, p)
EvalSol.u[1:end] .= y_
EvalSol.u[1:end] .= __restructure_sol(y_, cache.in_size)
EvalSol.cache.k_discrete[1:end] .= cache.k_discrete
resid_bc = eval_bc_residual(pt, bc, EvalSol, p, mesh)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
Expand All @@ -285,14 +285,14 @@ end
@views function __mirk_loss_bc!(
resid, u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = EvalSol(y_, mesh, cache)
soly_ = EvalSol(__restructure_sol(y_, cache.in_size), mesh, cache)
eval_bc_residual!(resid, pt, bc!, soly_, p, mesh)
return nothing
end

@views function __mirk_loss_bc(u, p, pt, bc!::BC, y, mesh, cache::MIRKCache) where {BC}
y_ = recursive_unflatten!(y, u)
soly_ = EvalSol(y_, mesh, cache)
soly_ = EvalSol(__restructure_sol(y_, cache.in_size), mesh, cache)
return eval_bc_residual(pt, bc!, soly_, p, mesh)
end

Expand Down
12 changes: 8 additions & 4 deletions lib/BoundaryValueDiffEqMIRKN/src/mirkn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,10 @@ end
y_ = recursive_unflatten!(y, u)
resids = [get_tmp(r, u) for r in residual]
Φ!(resids[3:end], cache, y_, u, p)
soly_ = EvalSol(y_[1:length(cache.mesh)], cache.mesh, cache)
dsoly_ = EvalSol(y_[(length(cache.mesh) + 1):end], cache.mesh, cache)
soly_ = EvalSol(
__restructure_sol(y_[1:length(cache.mesh)], cache.in_size), cache.mesh, cache)
dsoly_ = EvalSol(__restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size),
cache.mesh, cache)
eval_bc_residual!(resids[1:2], pt, bc, soly_, dsoly_, p, mesh)
recursive_flatten!(resid, resids)
return nothing
Expand All @@ -172,8 +174,10 @@ end
bc::BC, mesh, cache::MIRKNCache) where {BC}
y_ = recursive_unflatten!(y, u)
resid_co = Φ(cache, y_, u, p)
soly_ = EvalSol(y_[1:length(cache.mesh)], cache.mesh, cache)
dsoly_ = EvalSol(y_[(length(cache.mesh) + 1):end], cache.mesh, cache)
soly_ = EvalSol(
__restructure_sol(y_[1:length(cache.mesh)], cache.in_size), cache.mesh, cache)
dsoly_ = EvalSol(__restructure_sol(y_[(length(cache.mesh) + 1):end], cache.in_size),
cache.mesh, cache)
resid_bc = eval_bc_residual(pt, bc, soly_, dsoly_, p, mesh)
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
end
Expand Down
2 changes: 1 addition & 1 deletion test/misc/manifolds_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
if alg isa Shooting || alg isa MultipleShooting
sol = solve(bvp, alg)
else
sol = solve(bvp, alg; dt)
sol = solve(bvp, alg; dt, abstol = 1e-8)
end
@test SciMLBase.successful_retcode(sol)
resid = zeros(4)
Expand Down
6 changes: 3 additions & 3 deletions test/misc/non_vector_input_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
end

function boundary!(residual, u, p, t)
residual[1, 1] = u.u[1][1, 1] - 5
residual[1, 2] = u.u[end][1, 1]
residual[1, 1] = u(0.0)[1, 1] - 5
residual[1, 2] = u(5.0)[1, 1]
end

function boundary_a!(resida, ua, p)
Expand All @@ -28,7 +28,7 @@
end

function boundary(u, p, t)
return [u.u[1][1, 1] - 5 u.u[end][1, 1]]
return [u(0.0)[1, 1] - 5 u(5.0)[1, 1]]
end

boundary_a = (ua, p) -> [ua[1, 1] - 5]
Expand Down

0 comments on commit d766831

Please sign in to comment.