Skip to content

Commit

Permalink
Fix NLLS for Shooting
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 22, 2023
1 parent 30e7fdf commit 2c81ce1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
7 changes: 6 additions & 1 deletion src/solve/single_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ function __solve(prob::BVProblem, alg_::Shooting; odesolve_kwargs = (;),

# Construct the residual function
actual_ode_kwargs = (; kwargs..., verbose, odesolve_kwargs...)
ode_kwargs = (; save_everystep = false, actual_ode_kwargs...)
# For TwoPointBVPs we don't need to save every step
if prob.problem_type isa TwoPointBVProblem
ode_kwargs = (; save_everystep = false, actual_ode_kwargs...)
else
ode_kwargs = (; actual_ode_kwargs...)
end
internal_prob = ODEProblem{iip}(prob.f, u0, prob.tspan, prob.p)
ode_cache_loss_fn = SciMLBase.__init(internal_prob, alg.ode_alg; ode_kwargs...)

Expand Down
33 changes: 16 additions & 17 deletions test/shooting/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ using BoundaryValueDiffEq, LinearAlgebra, OrdinaryDiffEq, Test
for solver in SOLVERS
@info "Testing $solver"
sol = @time solve(bvp1, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-3))
@test norm(bc1(sol, nothing, sol.t), Inf) < 1e-4
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6))
@test norm(sol.resid, Inf) < 0.005
end

# IIP MP-BVP
Expand All @@ -56,40 +56,39 @@ using BoundaryValueDiffEq, LinearAlgebra, OrdinaryDiffEq, Test
return nothing
end

bvp2 = BVProblem(BVPFunction{true}(f1!, bc1!; bcresid_prototype = zeros(4)), u0, tspan)
bvp2 = BVProblem(BVPFunction{true}(f1!, bc1!; bcresid_prototype = zeros(4)), u0, tspan;
nlls = Val(true))

for solver in SOLVERS
sol = @time solve(bvp2, solver; verbose = false)
resid_f = Array{Float64}(undef, 4)
bc1!(resid_f, sol, nothing, sol.t)
@test norm(resid_f, Inf) < 1e-4
@info "Testing $solver"
sol = @time solve(bvp2, solver; verbose = false, abstol = 1e-6, reltol = 1e-6,
odesolve_kwargs = (; abstol = 1e-6, reltol = 1e-6))
@test norm(sol.resid, Inf) < 0.005
end

# OOP TP-BVP
bc1a(ua, p) = [ua[1]]
bc1b(ub, p) = [ub[1] - 1, ub[2] + 1.729109]

bvp3 = TwoPointBVProblem(BVPFunction{false}(f1, (bc1a, bc1b); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan)
bvp3 = BVProblem(BVPFunction{false}(f1, (bc1a, bc1b); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan; nlls = Val(true))

for solver in SOLVERS
@info "Testing $solver"
sol = @time solve(bvp3, solver; verbose = false)
@test norm(vcat(bc1a(sol(0.0), nothing), bc1b(sol(100.0), nothing)), Inf) < 1e-4
@test norm(sol.resid, Inf) < 1e-4
end

# IIP TP-BVP
bc1a!(resid, ua, p) = (resid[1] = ua[1])
bc1b!(resid, ub, p) = (resid[1] = ub[1] - 1; resid[2] = ub[2] + 1.729109)

bvp4 = TwoPointBVProblem(BVPFunction{true}(f1!, (bc1a!, bc1b!); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan)
bvp4 = BVProblem(BVPFunction{true}(f1!, (bc1a!, bc1b!); twopoint = Val(true),
bcresid_prototype = (zeros(1), zeros(2))), u0, tspan; nlls = Val(true))

for solver in SOLVERS
@info "Testing $solver"
sol = @time solve(bvp4, solver; verbose = false)
resida = Array{Float64}(undef, 1)
residb = Array{Float64}(undef, 2)
bc1a!(resida, sol(0.0), nothing)
bc1b!(residb, sol(100.0), nothing)
@test norm(vcat(resida, residb), Inf) < 1e-4
@test norm(sol.resid, Inf) < 1e-4
end
end

0 comments on commit 2c81ce1

Please sign in to comment.