Skip to content

Commit

Permalink
Merge pull request #116 from avik-pal/ap/tp_newapi
Browse files Browse the repository at this point in the history
Split up the 2Point BVP
  • Loading branch information
ChrisRackauckas authored Oct 7, 2023
2 parents 1f99b3b + 49314ba commit 5eede34
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 77 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "BoundaryValueDiffEq"
uuid = "764a87c0-6b3e-53db-9096-fe964310641d"
version = "5.0.0"
version = "5.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -39,7 +39,7 @@ ODEInterface = "0.5"
PreallocationTools = "0.4"
RecursiveArrayTools = "2.38.10"
Reexport = "0.2, 1.0"
SciMLBase = "2"
SciMLBase = "2.2"
Setfield = "1"
SparseDiffTools = "2.6"
TruncatedStacktraces = "1"
Expand Down
9 changes: 7 additions & 2 deletions ext/BoundaryValueDiffEqODEInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3,
alg.max_num_subintervals)

bvp2m_f(t, u, du) = prob.f(du, u, prob.p, t)
bvp2m_bc(ya, yb, bca, bcb) = prob.bc((bca, bcb), (ya, yb), prob.p)
function bvp2m_bc(ya, yb, bca, bcb)
prob.f.bc[1](bca, ya, prob.p)
prob.f.bc[2](bcb, yb, prob.p)
return nothing
end

opt = OptionsODE(OPT_RTOL => reltol, OPT_METHODCHOICE => alg.method_choice,
OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output,
Expand Down Expand Up @@ -76,7 +80,8 @@ function SciMLBase.__solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol
function bc!(ya, yb, r)
ra = first(prob.f.bcresid_prototype.x)
rb = last(prob.f.bcresid_prototype.x)
prob.bc((ra, rb), (ya, yb), prob.p)
prob.f.bc[1](ra, ya, prob.p)
prob.f.bc[2](rb, yb, prob.p)
r[1:length(ra)] .= ra
r[(length(ra) + 1):(length(ra) + length(rb))] .= rb
return r
Expand Down
5 changes: 2 additions & 3 deletions src/nlprob.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ function construct_nlproblem(cache::MIRKCache{iip}, y::AbstractVector) where {ii
function loss_collocation_internal(u::AbstractVector, p = cache.p)
y_ = recursive_unflatten!(cache.y, u)
resids = Φ(cache, y_, u, p)
xxx = mapreduce(vec, vcat, resids)
return xxx
return mapreduce(vec, vcat, resids)
end
end

Expand Down Expand Up @@ -211,7 +210,7 @@ function generate_nlprob(cache::MIRKCache{iip}, y, loss_bc, loss_collocation, lo

if !iip && cache.prob.f.bcresid_prototype === nothing
y_ = recursive_unflatten!(cache.y, y)
resid_ = cache.bc((y_[1], y_[end]), cache.p)
resid_ = cache.bc[1](y_[1], cache.p), cache.bc[2](y_[end], cache.p)
resid = ArrayPartition(ArrayPartition(resid_), similar(y, cache.M * (N - 1)))
else
resid = ArrayPartition(cache.prob.f.bcresid_prototype,
Expand Down
53 changes: 31 additions & 22 deletions src/solve/mirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,51 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,

# Transform the functions to handle non-vector inputs
f, bc = if X isa AbstractVector
prob.f, prob.bc
prob.f, prob.f.bc
elseif iip
function vecf!(du, u, p, t)
du_ = reshape(du, size(X))
x_ = reshape(u, size(X))
prob.f(du_, x_, p, t)
return du
end
function vecbc!(resid, sol, p, t)
resid_ = reshape(resid, resid₁_size)
sol_ = map(s -> reshape(s, size(X)), sol)
prob.bc(resid_, sol_, p, t)
return resid
end
function vecbc!((resida, residb), (ua, ub), p)
resida_ = reshape(resida, resid₁_size[1])
residb_ = reshape(residb, resid₁_size[2])
ua_ = reshape(ua, size(X))
ub_ = reshape(ub, size(X))
prob.bc((resida_, residb_), (ua_, ub_), p)
return (resida, residb)
vecbc! = if !(prob.problem_type isa TwoPointBVProblem)
function __vecbc!(resid, sol, p, t)
resid_ = reshape(resid, resid₁_size)
sol_ = map(s -> reshape(s, size(X)), sol)
prob.f.bc(resid_, sol_, p, t)
return resid
end
else
function __vecbc_a!(resida, ua, p)
resida_ = reshape(resida, resid₁_size[1])
ua_ = reshape(ua, size(X))
prob.f.bc[1](resida_, ua_, p)
return nothing
end
function __vecbc_b!(residb, ub, p)
residb_ = reshape(residb, resid₁_size[2])
ub_ = reshape(ub, size(X))
prob.f.bc[2](residb_, ub_, p)
return nothing
end
(__vecbc_a!, __vecbc_b!)
end
vecf!, vecbc!
else
function vecf(u, p, t)
x_ = reshape(u, size(X))
return vec(prob.f(x_, p, t))
end
function vecbc(sol, p, t)
sol_ = map(s -> reshape(s, size(X)), sol)
return vec(prob.bc(sol_, p, t))
end
function vecbc((ua, ub), p)
ua_ = reshape(ua, size(X))
ub_ = reshape(ub, size(X))
return vec.(prob.bc((ua_, ub_), p))
vecbc = if !(prob.problem_type isa TwoPointBVProblem)
function __vecbc(sol, p, t)
sol_ = map(s -> reshape(s, size(X)), sol)
return vec(prob.f.bc(sol_, p, t))
end
else
__vecbc_a(ua, p) = vec(prob.f.bc[1](reshape(ua, size(X)), p))
__vecbc_b(ub, p) = vec(prob.f.bc[2](reshape(ub, size(X)), p))
(__vecbc_a, __vecbc_b)
end
vecf, vecbc
end
Expand Down
2 changes: 1 addition & 1 deletion src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# TODO: Support Non-Vector Inputs
function SciMLBase.__solve(prob::BVProblem, alg::Shooting; kwargs...)
iip = isinplace(prob)
bc = prob.bc
bc = prob.f.bc
u0 = deepcopy(prob.u0)
loss_fn = if iip
function loss!(resid, u0, p)
Expand Down
10 changes: 6 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,20 @@ end
## Easier to dispatch
eval_bc_residual(pt, bc, sol, p) = eval_bc_residual(pt, bc, sol, p, sol.t)
eval_bc_residual(_, bc, sol, p, t) = bc(sol, p, t)
function eval_bc_residual(::TwoPointBVProblem, bc, sol, p, t)
function eval_bc_residual(::TwoPointBVProblem, (bca, bcb), sol, p, t)
ua = sol isa AbstractVector ? sol[1] : sol(first(t))
ub = sol isa AbstractVector ? sol[end] : sol(last(t))
resid₀, resid₁ = bc((ua, ub), p)
resid₀ = bca(ua, p)
resid₁ = bcb(ub, p)
return ArrayPartition(resid₀, resid₁)
end

eval_bc_residual!(resid, pt, bc!, sol, p) = eval_bc_residual!(resid, pt, bc!, sol, p, sol.t)
eval_bc_residual!(resid, _, bc!, sol, p, t) = bc!(resid, sol, p, t)
@views function eval_bc_residual!(resid, ::TwoPointBVProblem, bc!, sol, p, t)
@views function eval_bc_residual!(resid, ::TwoPointBVProblem, (bca!, bcb!), sol, p, t)
ua = sol isa AbstractVector ? sol[1] : sol(first(t))
ub = sol isa AbstractVector ? sol[end] : sol(last(t))
bc!((resid.x[1], resid.x[2]), (ua, ub), p)
bca!(resid.x[1], ua, p)
bcb!(resid.x[2], ub, p)
return resid
end
22 changes: 15 additions & 7 deletions test/mirk_convergence_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ function boundary!(residual, u, p, t)
end
boundary(u, p, t) = [u[1][1] - 5, u[end][1]]

function boundary_two_point!((resida, residb), (ua, ub), p)
function boundary_two_point_a!(resida, ua, p)
resida[1] = ua[1] - 5
end
function boundary_two_point_b!(residb, ub, p)
residb[1] = ub[1]
end
boundary_two_point((ua, ub), p) = [ua[1] - 5, ub[1]]

boundary_two_point_a(ua, p) = [ua[1] - 5]
boundary_two_point_b(ub, p) = [ub[1]]

# Not able to change the initial condition.
# Hard coded solution.
Expand Down Expand Up @@ -57,10 +61,14 @@ probArr = [
BVProblem(odef1, boundary, u0, tspan),
BVProblem(odef2!, boundary!, u0, tspan),
BVProblem(odef2, boundary, u0, tspan),
TwoPointBVProblem(odef1!, boundary_two_point!, u0, tspan; bcresid_prototype),
TwoPointBVProblem(odef1, boundary_two_point, u0, tspan; bcresid_prototype),
TwoPointBVProblem(odef2!, boundary_two_point!, u0, tspan; bcresid_prototype),
TwoPointBVProblem(odef2, boundary_two_point, u0, tspan; bcresid_prototype),
TwoPointBVProblem(odef1!, (boundary_two_point_a!, boundary_two_point_b!), u0, tspan;
bcresid_prototype),
TwoPointBVProblem(odef1, (boundary_two_point_a, boundary_two_point_b), u0, tspan;
bcresid_prototype),
TwoPointBVProblem(odef2!, (boundary_two_point_a!, boundary_two_point_b!), u0, tspan;
bcresid_prototype),
TwoPointBVProblem(odef2, (boundary_two_point_a, boundary_two_point_b), u0, tspan;
bcresid_prototype),
];

testTol = 0.2
Expand All @@ -73,7 +81,7 @@ dts = 1 .// 2 .^ (3:-1:1)
@testset "Problem: $i" for i in (1, 2, 5, 6)
prob = probArr[i]
@testset "MIRK$order" for order in (2, 3, 4, 5, 6)
@time sol = solve(prob, mirk_solver(Val(order)), dt = 0.2)
@time sol = solve(prob, mirk_solver(Val(order)); dt = 0.2)
@test norm(diff(first.(sol.u)) .+ 0.2, Inf) + abs(sol[1][1] - 5) < affineTol
end
end
Expand Down
13 changes: 7 additions & 6 deletions test/non_vector_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,28 @@ function boundary!(residual, u, p, t)
residual[1, 2] = u[end][1, 1]
end

function boundary!((resida, residb), (ua, ub), p)
function boundary_a!(resida, ua, p)
resida[1, 1] = ua[1, 1] - 5
end
function boundary_b!(residb, ub, p)
residb[1, 1] = ub[1, 1]
end

function boundary(u, p, t)
return [u[1][1, 1] - 5 u[end][1, 1]]
end

function boundary((ua, ub), p)
return (reshape([ua[1, 1] - 5], (1, 1)), reshape([ub[1, 1]], (1, 1)))
end
boundary_a = (ua, p) -> [ua[1, 1] - 5]
boundary_b = (ub, p) -> [ub[1, 1]]

tspan = (0.0, 5.0)
u0 = [5.0 -3.5]
probs = [
BVProblem(f1!, boundary!, u0, tspan),
TwoPointBVProblem(f1!, boundary!, u0, tspan;
TwoPointBVProblem(f1!, (boundary_a!, boundary_b!), u0, tspan;
bcresid_prototype = (Array{Float64}(undef, 1, 1), Array{Float64}(undef, 1, 1))),
BVProblem(f1, boundary, u0, tspan),
TwoPointBVProblem(f1, boundary, u0, tspan),
TwoPointBVProblem(f1, (boundary_a, boundary_b), u0, tspan),
];

@testset "Affineness" begin
Expand Down
13 changes: 9 additions & 4 deletions test/odeinterface_ex7.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ function ex7_f!(du, u, p, t)
return nothing
end

function ex7_2pbc!((resa, resb), (ua, ub), p)
function ex7_2pbc1!(resa, ua, p)
resa[1] = ua[1] - 1
return nothing
end

function ex7_2pbc2!(resb, ub, p)
resb[1] = ub[1] - 1
return nothing
end
Expand All @@ -19,15 +23,16 @@ u0 = [0.5, 1.0]
p = [0.1]
tspan = (-π / 2, π / 2)

tpprob = TwoPointBVProblem(ex7_f!, ex7_2pbc!, u0, tspan, p;
tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), u0, tspan, p;
bcresid_prototype = (zeros(1), zeros(1)))

@info "BVPM2"

sol_bvpm2 = solve(tpprob, BVPM2(); dt = π / 20)
@test SciMLBase.successful_retcode(sol_bvpm2)
resid_f = (Array{Float64, 1}(undef, 1), Array{Float64, 1}(undef, 1))
ex7_2pbc!(resid_f, (sol_bvpm2(tspan[1]), sol_bvpm2(tspan[2])), nothing)
ex7_2pbc1!(resid_f[1], sol_bvpm2(tspan[1]), nothing)
ex7_2pbc2!(resid_f[2], sol_bvpm2(tspan[2]), nothing)
@test norm(resid_f) < 1e-6

function ex7_f2!(du, u, p, t)
Expand All @@ -40,7 +45,7 @@ end
@info "BVPSOL"

initial_u0 = [sol_bvpm2(t) .+ rand() for t in tspan[1]:/ 20):tspan[2]]
tpprob = TwoPointBVProblem(ex7_f2!, ex7_2pbc!, initial_u0, tspan;
tpprob = TwoPointBVProblem(ex7_f2!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan;
bcresid_prototype = (zeros(1), zeros(1)))

# Just test that it runs. BVPSOL only works with linearly separable BCs.
Expand Down
14 changes: 9 additions & 5 deletions test/orbital.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,20 @@ function bc!_generator(resid, sol, init_val)
resid[6] = sol(t1)[3] - init_val[6]
end

function bc!_generator_2p((resid0, resid1), (ua, ub), init_val)
function bc!_generator_2p_a(resid0, ua, init_val)
resid0[1] = ua[1] - init_val[1]
resid0[2] = ua[2] - init_val[2]
resid0[3] = ua[3] - init_val[3]
end
function bc!_generator_2p_b(resid1, ub, init_val)
resid1[1] = ub[1] - init_val[4]
resid1[2] = ub[2] - init_val[5]
resid1[3] = ub[3] - init_val[6]
end

cur_bc! = (resid, sol, p, t) -> bc!_generator(resid, sol, init_val)
cur_bc_2point! = (resid, sol, p) -> bc!_generator_2p(resid, sol, init_val)
cur_bc_2point_a! = (resid, sol, p) -> bc!_generator_2p_a(resid, sol, init_val)
cur_bc_2point_b! = (resid, sol, p) -> bc!_generator_2p_b(resid, sol, init_val)
resid_f = Array{Float64}(undef, 6)
resid_f_2p = (Array{Float64, 1}(undef, 3), Array{Float64, 1}(undef, 3))

Expand All @@ -78,14 +81,15 @@ for autodiff in (AutoForwardDiff(), AutoFiniteDiff(; fdtype = Val(:central)),
end

### Using the TwoPoint BVP Structure
bvp = TwoPointBVProblem(orbital!, cur_bc_2point!, y0, tspan;
bvp = TwoPointBVProblem(orbital!, (cur_bc_2point_a!, cur_bc_2point_b!), y0, tspan;
bcresid_prototype = (Array{Float64}(undef, 3), Array{Float64}(undef, 3)))
for autodiff in (AutoForwardDiff(), AutoFiniteDiff(; fdtype = Val(:central)),
AutoSparseForwardDiff(), AutoFiniteDiff(; fdtype = Val(:forward)),
AutoSparseFiniteDiff())
nlsolve = NewtonRaphson(; autodiff)
@time sol = solve(bvp, Shooting(DP5(); nlsolve); force_dtmin = true, abstol = 1e-13,
reltol = 1e-13)
cur_bc_2point!(resid_f_2p, (sol(t0), sol(t1)), nothing)
@test norm(vcat(resid_f_2p...), Inf) < TestTol
cur_bc_2point_a!(resid_f_2p[1], sol(t0), nothing)
cur_bc_2point_b!(resid_f_2p[2], sol(t1), nothing)
@test norm(reduce(vcat, resid_f_2p), Inf) < TestTol
end
Loading

0 comments on commit 5eede34

Please sign in to comment.