Skip to content

Commit

Permalink
gl2 seems ok
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreMartinon committed Nov 21, 2024
1 parent d1d59c4 commit c65b3d0
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 70 deletions.
29 changes: 6 additions & 23 deletions src/irk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ $(TYPEDSIGNATURES)
Gauss Legendre 2 discretization, formulated as a generic IRK
"""
struct GaussLegendre2 <: GenericIRK
struct Gauss_Legendre_2 <: GenericIRK

stage::Int
butcher_a::Matrix{Float64}
Expand All @@ -72,7 +72,7 @@ struct GaussLegendre2 <: GenericIRK
_step_block::Int
info::String

function GaussLegendre2(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons)
function Gauss_Legendre_2(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons)

stage = 2

Expand Down Expand Up @@ -140,17 +140,8 @@ $(TYPEDSIGNATURES)
Retrieve stage variables at given time step/stage from the NLP variables.
Convention: 1 <= i <= dim_NLP_steps+1, 1 <= j <= s
Scalar / Vector output
Vector output
"""
#=function get_stagevars_at_time_step(xu, docp::DOCP{ <: GenericIRK, ScalVariable, <: ScalVect, <: ScalVect}, i, j)
if i == docp.dim_NLP_steps+1
return xu[1] # unused but keep same type !
else
offset = (i-1) * docp.discretization._step_block + docp.dim_NLP_x + docp.dim_NLP_u + (j-1)*docp.dim_NLP_x
return xu[offset + 1]
end
end
function get_stagevars_at_time_step(xu, docp::DOCP{ <: GenericIRK, VectVariable, <: ScalVect, <: ScalVect}, i, j)=#
function get_stagevars_at_time_step(xu, docp::DOCP{ <: GenericIRK}, i, j)
if i == docp.dim_NLP_steps+1
return @view xu[1:docp.dim_NLP_x] # unused but keep same type !
Expand All @@ -159,14 +150,6 @@ function get_stagevars_at_time_step(xu, docp::DOCP{ <: GenericIRK}, i, j)
return @view xu[(offset + 1):(offset + docp.dim_NLP_x)]
end
end
#=function get_lagrange_stagevar_at_time_step(xu, docp::DOCP{ <: GenericIRK}, i, j)
if i == docp.dim_NLP_steps+1
return xu[1] # unused but keep same type !
else
offset = (i-1) * docp.discretization._step_block + docp.dim_NLP_x + docp.dim_NLP_u + (j-1)*docp.dim_NLP_x
return xu[offset + docp.dim_NLP_x]
end
end=#

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -205,7 +188,7 @@ function setWorkArray(docp::DOCP{ <: GenericIRK}, xu, time_grid, v)

# use work array to store all dynamics + lagrange costs
# + one state/stage variable (including lagrange part for setConstraintsBlock)
work = similar(xu, docp.dim_NLP_x * (docp.dim_NLP_steps) + docp.dim_NLP_x)
work = similar(xu, docp.dim_NLP_x * (docp.discretization.stage * docp.dim_NLP_steps + 1))

# loop over time steps ans stages
for i = 1:docp.dim_NLP_steps
Expand All @@ -222,7 +205,7 @@ function setWorkArray(docp::DOCP{ <: GenericIRK}, xu, time_grid, v)
@. work[end-docp.dim_OCP_x+1:end] = xi
for l = 1:docp.discretization.stage
kil = get_stagevars_at_time_step(xu, docp, i, l)
@views @. work[end-docp.dim_OCP_x+1:end] = work[end-docp.dim_OCP_x+1:end] + hi * docp.discretization.butcher_a[j][l] * kil[1:docp.dim_OCP_x]
@views @. work[end-docp.dim_OCP_x+1:end] = work[end-docp.dim_OCP_x+1:end] + hi * docp.discretization.butcher_a[j,l] * kil[1:docp.dim_OCP_x]
end
if docp.dim_OCP_x == 1
xij = work[end]
Expand Down Expand Up @@ -264,14 +247,14 @@ function setConstraintBlock!(docp::DOCP{ <: GenericIRK}, c, xu, v, time_grid, i,
hi = tip1 - ti
offset_dyn_i = (i-1) * docp.dim_NLP_x * docp.discretization.stage
offset_x = length(work) - docp.dim_NLP_x
offset_stage_eq = docp.dim_NLP_x

# work array for sum b_j k_i^j (w/ lagrange term)
#@. work[offset_x+1:offset_x+docp.dim_NLP_x] = 0 known AD bug with :optimized backend: cannot affect constants
@views @. work[offset_x+1:offset_x+docp.dim_NLP_x] = 0 * work[1:docp.dim_NLP_x]

# loop over stages
for j=1:docp.discretization.stage
offset_stage_eq = docp.dim_NLP_x
kij = get_stagevars_at_time_step(xu, docp, i, j)

# update sum b_j k_i^j (w/ lagrange term) for state equation below
Expand Down
5 changes: 3 additions & 2 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ struct DOCP{T <: Discretization, X <: ScalVect, U <: ScalVect, V <: ScalVect}
discretization, dim_NLP_variables, dim_NLP_constraints = CTDirect.Trapeze(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons)
elseif disc_method == :midpoint_irk
discretization, dim_NLP_variables, dim_NLP_constraints = CTDirect.Midpoint_IRK(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons)
elseif disc_method == :gauss_legendre_2
discretization, dim_NLP_variables, dim_NLP_constraints = CTDirect.Gauss_Legendre_2(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons)
else
error("Unknown discretization method: ", disc_method, " of type (should be Symbol) ", typeof(disc_method))
# print list of available methods (as symbols)
error("Unknown discretization method: ", disc_method, "\nValid options are disc_method={:trapeze, :midpoint, :midpoint_irk, :gauss_legendre_2}\n", typeof(disc_method))
end

# add initial condition for lagrange state
Expand Down
79 changes: 51 additions & 28 deletions test/suite/test_discretization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,44 +62,67 @@ end
if !isdefined(Main, :goddard_all)
include("../problems/goddard.jl")
end
@testset verbose = true showtiming = true ":implicit_midpoint :simple_integrator" begin
ocp = simple_integrator().ocp
sol_t = direct_solve(ocp, display = false)
sol_m = direct_solve(ocp, display = false, disc_method = :midpoint)
@test sol_m.objective sol_t.objective rtol = 1e-2

@testset verbose = true showtiming = true ":trapeze :simple_integrator" begin
prob = simple_integrator()
sol = direct_solve(prob.ocp, display = false, disc_method = :trapeze)
@test sol.objective prob.obj rtol = 1e-2
end
@testset verbose = true showtiming = true ":trapeze :double_integrator" begin
prob = double_integrator_freet0tf()
sol = direct_solve(prob.ocp, display = false, disc_method = :trapeze)
@test sol.objective prob.obj rtol = 1e-2
end
@testset verbose = true showtiming = true ":trapeze :goddard" begin
prob = goddard_all()
sol = direct_solve(prob.ocp, display = false, disc_method = :trapeze)
@test sol.objective prob.obj rtol = 1e-2
end

@testset verbose = true showtiming = true ":implicit_midpoint :simple_integrator" begin
prob = simple_integrator()
sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint, grid_size=100)
@test sol.objective prob.obj rtol = 1e-2
end
@testset verbose = true showtiming = true ":implicit_midpoint :double_integrator" begin
ocp = double_integrator_freet0tf().ocp
sol_t = direct_solve(ocp, display = false)
sol_m = direct_solve(ocp, display = false, disc_method = :midpoint)
@test sol_m.objective sol_t.objective rtol = 1e-2
prob = double_integrator_freet0tf()
sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint, grid_size=100)
@test sol.objective prob.obj rtol = 1e-2
end

@testset verbose = true showtiming = true ":implicit_midpoint :goddard" begin
ocp = goddard_all().ocp
sol_t = direct_solve(ocp, display = false)
sol_m = direct_solve(ocp, display = false, disc_method = :midpoint)
@test sol_m.objective sol_t.objective rtol = 1e-2
prob = goddard_all()
sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint, grid_size=100)
@test sol.objective prob.obj rtol = 1e-2
end

@testset verbose = true showtiming = true ":midpoint_irk :simple_integrator" begin
ocp = simple_integrator().ocp
sol_t = direct_solve(ocp, display = false)
sol_m = direct_solve(ocp, display = false, disc_method = :midpoint_irk)
@test sol_m.objective sol_t.objective rtol = 1e-2
prob = simple_integrator()
sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint_irk, grid_size=100)
@test sol.objective prob.obj rtol = 1e-2
end

@testset verbose = true showtiming = true ":midpoint_irk :double_integrator" begin
ocp = double_integrator_freet0tf().ocp
sol_t = direct_solve(ocp, display = false)
sol_m = direct_solve(ocp, display = false, disc_method = :midpoint_irk)
@test sol_m.objective sol_t.objective rtol = 1e-2
prob = double_integrator_freet0tf()
sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint_irk, grid_size=100)
@test sol.objective prob.obj rtol = 1e-2
end

@testset verbose = true showtiming = true ":midpoint_irk :goddard" begin
ocp = goddard_all().ocp
sol_t = direct_solve(ocp, display = false)
sol_m = direct_solve(ocp, display = false, disc_method = :midpoint_irk)
@test sol_m.objective sol_t.objective rtol = 1e-2
prob = goddard_all()
sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint_irk, grid_size=100)
@test sol.objective prob.obj rtol = 1e-2
end

@testset verbose = true showtiming = true ":gauss_legendre_2 :simple_integrator" begin
prob = simple_integrator()
sol = direct_solve(prob.ocp, display = false, disc_method = :gauss_legendre_2, grid_size=100)
@test sol.objective prob.obj rtol = 1e-2
end
@testset verbose = true showtiming = true ":gauss_legendre_2 :double_integrator" begin
prob = double_integrator_freet0tf()
sol = direct_solve(prob.ocp, display = false, disc_method = :gauss_legendre_2, grid_size=100)
@test sol.objective prob.obj rtol = 1e-2
end
@testset verbose = true showtiming = true ":gauss_legendre_2 :goddard" begin
prob = goddard_all()
sol = direct_solve(prob.ocp, display = false, disc_method = :gauss_legendre_2, grid_size=100)
@test sol.objective prob.obj rtol = 1e-2
end
21 changes: 11 additions & 10 deletions test/suite/test_nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ println("Test: nlp options")
if !isdefined(Main, :simple_integrator)
include("../problems/simple_integrator.jl")
end
ocp = simple_integrator().ocp
prob = simple_integrator()
ocp = prob.ocp

@testset verbose = true showtiming = true ":control_dim_2" begin
@test is_solvable(ocp)
Expand All @@ -16,35 +17,35 @@ end
solver_backend = CTDirect.IpoptBackend()
dsol = CTDirect.solve_docp(solver_backend, docp, nlp, display = false)
sol = OptimalControlSolution(docp, dsol)
@test sol.objective 0.313 rtol = 1e-2
@test sol.objective prob.obj rtol = 1e-2
sol = OptimalControlSolution(docp, primal = dsol.solution)
@test sol.objective 0.313 rtol = 1e-2
@test sol.objective prob.obj rtol = 1e-2
sol = OptimalControlSolution(docp, primal = dsol.solution, dual = dsol.multipliers)
@test sol.objective 0.313 rtol = 1e-2
@test sol.objective prob.obj rtol = 1e-2
end

@testset verbose = true showtiming = true ":solve_docp :midpoint" begin
docp, nlp = direct_transcription(ocp, disc_method = :midpoint)
solver_backend = CTDirect.IpoptBackend()
dsol = CTDirect.solve_docp(solver_backend, docp, nlp, display = false)
sol = OptimalControlSolution(docp, dsol)
@test sol.objective 0.313 rtol = 1e-2
@test sol.objective prob.obj rtol = 1e-2
sol = OptimalControlSolution(docp, primal = dsol.solution)
@test sol.objective 0.313 rtol = 1e-2
@test sol.objective prob.obj rtol = 1e-2
sol = OptimalControlSolution(docp, primal = dsol.solution, dual = dsol.multipliers)
@test sol.objective 0.313 rtol = 1e-2
@test sol.objective prob.obj rtol = 1e-2
end

@testset verbose = true showtiming = true ":solve_docp :madnlp" begin
docp, nlp = direct_transcription(ocp)
solver_backend = CTDirect.MadNLPBackend()
dsol = CTDirect.solve_docp(solver_backend, docp, nlp, display = false)
sol = OptimalControlSolution(docp, dsol)
@test sol.objective 0.313 rtol = 1e-2
@test sol.objective prob.obj rtol = 1e-2
sol = OptimalControlSolution(docp, primal = dsol.solution)
@test sol.objective 0.313 rtol = 1e-2
@test sol.objective prob.obj rtol = 1e-2
sol = OptimalControlSolution(docp, primal = dsol.solution, dual = dsol.multipliers)
@test sol.objective 0.313 rtol = 1e-2
@test sol.objective prob.obj rtol = 1e-2
end

# check solution building
Expand Down
62 changes: 55 additions & 7 deletions test/suite/test_ocp.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
println("Test: OCP definition")

# + beam
# beam
if !isdefined(Main, :beam)
include("../problems/beam.jl")
end
@testset verbose = true showtiming = true ":beam" begin
prob = beam()
sol = direct_solve(prob.ocp, display = false)
@test sol.objective prob.obj rtol = 1e-2
end

# double integrator min tf
if !isdefined(Main, :double_integrator_mintf)
include("../problems/double_integrator.jl")
end

@testset verbose = true showtiming = true ":double_integrator :min_tf" begin
prob = double_integrator_mintf()
sol = direct_solve(prob.ocp, display = false)
@test sol.objective prob.obj rtol = 1e-2
end

# + fuller
# fuller
if !isdefined(Main, :fuller)
include("../problems/fuller.jl")
end
@testset verbose = true showtiming = true ":fuller" begin
prob = fuller()
sol = direct_solve(prob.ocp, display = false)
@test sol.objective prob.obj rtol = 1e-2
end

# goddard max rf
if !isdefined(Main, :goddard)
include("../problems/goddard.jl")
end
Expand All @@ -24,10 +40,42 @@ end
@test sol.objective prob.obj rtol = 1e-2
end

# + jackson
# jackson
if !isdefined(Main, :jackson)
include("../problems/jackson.jl")
end
@testset verbose = true showtiming = true ":jackson" begin
prob = jackson()
sol = direct_solve(prob.ocp, display = false)
@test sol.objective prob.obj rtol = 1e-2
end

# + robbins
#= robbins
if !isdefined(Main, :robbins)
include("../problems/robbins.jl")
end
@testset verbose = true showtiming = true ":robbins" begin
prob = robbins()
sol = direct_solve(prob.ocp, display = false)
@test sol.objective ≈ prob.obj rtol = 1e-2
end=#

# + simple integrator
# simple integrator
if !isdefined(Main, :simple_integrator)
include("../problems/simple_integrator.jl")
end
@testset verbose = true showtiming = true ":simple_integrator" begin
prob = simple_integrator()
sol = direct_solve(prob.ocp, display = false)
@test sol.objective prob.obj rtol = 1e-2
end

# + vanderpol
# vanderpol
if !isdefined(Main, :vanderpol)
include("../problems/vanderpol.jl")
end
@testset verbose = true showtiming = true ":vanderpol" begin
prob = vanderpol()
sol = direct_solve(prob.ocp, display = false)
@test sol.objective prob.obj rtol = 1e-2
end

0 comments on commit c65b3d0

Please sign in to comment.