From c65b3d02855af72060da733e0f2aff9cd1ce922d Mon Sep 17 00:00:00 2001 From: Pierre Martinon Date: Thu, 21 Nov 2024 18:55:18 +0100 Subject: [PATCH] gl2 seems ok --- src/irk.jl | 29 +++--------- src/problem.jl | 5 +- test/suite/test_discretization.jl | 79 ++++++++++++++++++++----------- test/suite/test_nlp.jl | 21 ++++---- test/suite/test_ocp.jl | 62 +++++++++++++++++++++--- 5 files changed, 126 insertions(+), 70 deletions(-) diff --git a/src/irk.jl b/src/irk.jl index 4b00f06..2c7b912 100644 --- a/src/irk.jl +++ b/src/irk.jl @@ -63,7 +63,7 @@ $(TYPEDSIGNATURES) Gauss Legendre 2 discretization, formulated as a generic IRK """ -struct GaussLegendre2 <: GenericIRK +struct Gauss_Legendre_2 <: GenericIRK stage::Int butcher_a::Matrix{Float64} @@ -72,7 +72,7 @@ struct GaussLegendre2 <: GenericIRK _step_block::Int info::String - function GaussLegendre2(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons) + function Gauss_Legendre_2(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons) stage = 2 @@ -140,17 +140,8 @@ $(TYPEDSIGNATURES) Retrieve stage variables at given time step/stage from the NLP variables. Convention: 1 <= i <= dim_NLP_steps+1, 1 <= j <= s -Scalar / Vector output +Vector output """ -#=function get_stagevars_at_time_step(xu, docp::DOCP{ <: GenericIRK, ScalVariable, <: ScalVect, <: ScalVect}, i, j) - if i == docp.dim_NLP_steps+1 - return xu[1] # unused but keep same type ! - else - offset = (i-1) * docp.discretization._step_block + docp.dim_NLP_x + docp.dim_NLP_u + (j-1)*docp.dim_NLP_x - return xu[offset + 1] - end -end -function get_stagevars_at_time_step(xu, docp::DOCP{ <: GenericIRK, VectVariable, <: ScalVect, <: ScalVect}, i, j)=# function get_stagevars_at_time_step(xu, docp::DOCP{ <: GenericIRK}, i, j) if i == docp.dim_NLP_steps+1 return @view xu[1:docp.dim_NLP_x] # unused but keep same type ! @@ -159,14 +150,6 @@ function get_stagevars_at_time_step(xu, docp::DOCP{ <: GenericIRK}, i, j) return @view xu[(offset + 1):(offset + docp.dim_NLP_x)] end end -#=function get_lagrange_stagevar_at_time_step(xu, docp::DOCP{ <: GenericIRK}, i, j) - if i == docp.dim_NLP_steps+1 - return xu[1] # unused but keep same type ! - else - offset = (i-1) * docp.discretization._step_block + docp.dim_NLP_x + docp.dim_NLP_u + (j-1)*docp.dim_NLP_x - return xu[offset + docp.dim_NLP_x] - end -end=# """ $(TYPEDSIGNATURES) @@ -205,7 +188,7 @@ function setWorkArray(docp::DOCP{ <: GenericIRK}, xu, time_grid, v) # use work array to store all dynamics + lagrange costs # + one state/stage variable (including lagrange part for setConstraintsBlock) - work = similar(xu, docp.dim_NLP_x * (docp.dim_NLP_steps) + docp.dim_NLP_x) + work = similar(xu, docp.dim_NLP_x * (docp.discretization.stage * docp.dim_NLP_steps + 1)) # loop over time steps ans stages for i = 1:docp.dim_NLP_steps @@ -222,7 +205,7 @@ function setWorkArray(docp::DOCP{ <: GenericIRK}, xu, time_grid, v) @. work[end-docp.dim_OCP_x+1:end] = xi for l = 1:docp.discretization.stage kil = get_stagevars_at_time_step(xu, docp, i, l) - @views @. work[end-docp.dim_OCP_x+1:end] = work[end-docp.dim_OCP_x+1:end] + hi * docp.discretization.butcher_a[j][l] * kil[1:docp.dim_OCP_x] + @views @. work[end-docp.dim_OCP_x+1:end] = work[end-docp.dim_OCP_x+1:end] + hi * docp.discretization.butcher_a[j,l] * kil[1:docp.dim_OCP_x] end if docp.dim_OCP_x == 1 xij = work[end] @@ -264,6 +247,7 @@ function setConstraintBlock!(docp::DOCP{ <: GenericIRK}, c, xu, v, time_grid, i, hi = tip1 - ti offset_dyn_i = (i-1) * docp.dim_NLP_x * docp.discretization.stage offset_x = length(work) - docp.dim_NLP_x + offset_stage_eq = docp.dim_NLP_x # work array for sum b_j k_i^j (w/ lagrange term) #@. work[offset_x+1:offset_x+docp.dim_NLP_x] = 0 known AD bug with :optimized backend: cannot affect constants @@ -271,7 +255,6 @@ function setConstraintBlock!(docp::DOCP{ <: GenericIRK}, c, xu, v, time_grid, i, # loop over stages for j=1:docp.discretization.stage - offset_stage_eq = docp.dim_NLP_x kij = get_stagevars_at_time_step(xu, docp, i, j) # update sum b_j k_i^j (w/ lagrange term) for state equation below diff --git a/src/problem.jl b/src/problem.jl index 76d9efe..9c55ec1 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -185,9 +185,10 @@ struct DOCP{T <: Discretization, X <: ScalVect, U <: ScalVect, V <: ScalVect} discretization, dim_NLP_variables, dim_NLP_constraints = CTDirect.Trapeze(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons) elseif disc_method == :midpoint_irk discretization, dim_NLP_variables, dim_NLP_constraints = CTDirect.Midpoint_IRK(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons) + elseif disc_method == :gauss_legendre_2 + discretization, dim_NLP_variables, dim_NLP_constraints = CTDirect.Gauss_Legendre_2(dim_NLP_steps, dim_NLP_x, dim_NLP_u, dim_NLP_v, dim_path_cons, dim_boundary_cons, dim_v_cons) else - error("Unknown discretization method: ", disc_method, " of type (should be Symbol) ", typeof(disc_method)) - # print list of available methods (as symbols) + error("Unknown discretization method: ", disc_method, "\nValid options are disc_method={:trapeze, :midpoint, :midpoint_irk, :gauss_legendre_2}\n", typeof(disc_method)) end # add initial condition for lagrange state diff --git a/test/suite/test_discretization.jl b/test/suite/test_discretization.jl index c922aef..979608c 100644 --- a/test/suite/test_discretization.jl +++ b/test/suite/test_discretization.jl @@ -62,44 +62,67 @@ end if !isdefined(Main, :goddard_all) include("../problems/goddard.jl") end -@testset verbose = true showtiming = true ":implicit_midpoint :simple_integrator" begin - ocp = simple_integrator().ocp - sol_t = direct_solve(ocp, display = false) - sol_m = direct_solve(ocp, display = false, disc_method = :midpoint) - @test sol_m.objective ≈ sol_t.objective rtol = 1e-2 + +@testset verbose = true showtiming = true ":trapeze :simple_integrator" begin + prob = simple_integrator() + sol = direct_solve(prob.ocp, display = false, disc_method = :trapeze) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end +@testset verbose = true showtiming = true ":trapeze :double_integrator" begin + prob = double_integrator_freet0tf() + sol = direct_solve(prob.ocp, display = false, disc_method = :trapeze) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end +@testset verbose = true showtiming = true ":trapeze :goddard" begin + prob = goddard_all() + sol = direct_solve(prob.ocp, display = false, disc_method = :trapeze) + @test sol.objective ≈ prob.obj rtol = 1e-2 end +@testset verbose = true showtiming = true ":implicit_midpoint :simple_integrator" begin + prob = simple_integrator() + sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint, grid_size=100) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end @testset verbose = true showtiming = true ":implicit_midpoint :double_integrator" begin - ocp = double_integrator_freet0tf().ocp - sol_t = direct_solve(ocp, display = false) - sol_m = direct_solve(ocp, display = false, disc_method = :midpoint) - @test sol_m.objective ≈ sol_t.objective rtol = 1e-2 + prob = double_integrator_freet0tf() + sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint, grid_size=100) + @test sol.objective ≈ prob.obj rtol = 1e-2 end - @testset verbose = true showtiming = true ":implicit_midpoint :goddard" begin - ocp = goddard_all().ocp - sol_t = direct_solve(ocp, display = false) - sol_m = direct_solve(ocp, display = false, disc_method = :midpoint) - @test sol_m.objective ≈ sol_t.objective rtol = 1e-2 + prob = goddard_all() + sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint, grid_size=100) + @test sol.objective ≈ prob.obj rtol = 1e-2 end @testset verbose = true showtiming = true ":midpoint_irk :simple_integrator" begin - ocp = simple_integrator().ocp - sol_t = direct_solve(ocp, display = false) - sol_m = direct_solve(ocp, display = false, disc_method = :midpoint_irk) - @test sol_m.objective ≈ sol_t.objective rtol = 1e-2 + prob = simple_integrator() + sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint_irk, grid_size=100) + @test sol.objective ≈ prob.obj rtol = 1e-2 end - @testset verbose = true showtiming = true ":midpoint_irk :double_integrator" begin - ocp = double_integrator_freet0tf().ocp - sol_t = direct_solve(ocp, display = false) - sol_m = direct_solve(ocp, display = false, disc_method = :midpoint_irk) - @test sol_m.objective ≈ sol_t.objective rtol = 1e-2 + prob = double_integrator_freet0tf() + sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint_irk, grid_size=100) + @test sol.objective ≈ prob.obj rtol = 1e-2 end - @testset verbose = true showtiming = true ":midpoint_irk :goddard" begin - ocp = goddard_all().ocp - sol_t = direct_solve(ocp, display = false) - sol_m = direct_solve(ocp, display = false, disc_method = :midpoint_irk) - @test sol_m.objective ≈ sol_t.objective rtol = 1e-2 + prob = goddard_all() + sol = direct_solve(prob.ocp, display = false, disc_method = :midpoint_irk, grid_size=100) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end + +@testset verbose = true showtiming = true ":gauss_legendre_2 :simple_integrator" begin + prob = simple_integrator() + sol = direct_solve(prob.ocp, display = false, disc_method = :gauss_legendre_2, grid_size=100) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end +@testset verbose = true showtiming = true ":gauss_legendre_2 :double_integrator" begin + prob = double_integrator_freet0tf() + sol = direct_solve(prob.ocp, display = false, disc_method = :gauss_legendre_2, grid_size=100) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end +@testset verbose = true showtiming = true ":gauss_legendre_2 :goddard" begin + prob = goddard_all() + sol = direct_solve(prob.ocp, display = false, disc_method = :gauss_legendre_2, grid_size=100) + @test sol.objective ≈ prob.obj rtol = 1e-2 end \ No newline at end of file diff --git a/test/suite/test_nlp.jl b/test/suite/test_nlp.jl index c531fa2..f9abdac 100644 --- a/test/suite/test_nlp.jl +++ b/test/suite/test_nlp.jl @@ -3,7 +3,8 @@ println("Test: nlp options") if !isdefined(Main, :simple_integrator) include("../problems/simple_integrator.jl") end -ocp = simple_integrator().ocp +prob = simple_integrator() +ocp = prob.ocp @testset verbose = true showtiming = true ":control_dim_2" begin @test is_solvable(ocp) @@ -16,11 +17,11 @@ end solver_backend = CTDirect.IpoptBackend() dsol = CTDirect.solve_docp(solver_backend, docp, nlp, display = false) sol = OptimalControlSolution(docp, dsol) - @test sol.objective ≈ 0.313 rtol = 1e-2 + @test sol.objective ≈ prob.obj rtol = 1e-2 sol = OptimalControlSolution(docp, primal = dsol.solution) - @test sol.objective ≈ 0.313 rtol = 1e-2 + @test sol.objective ≈ prob.obj rtol = 1e-2 sol = OptimalControlSolution(docp, primal = dsol.solution, dual = dsol.multipliers) - @test sol.objective ≈ 0.313 rtol = 1e-2 + @test sol.objective ≈ prob.obj rtol = 1e-2 end @testset verbose = true showtiming = true ":solve_docp :midpoint" begin @@ -28,11 +29,11 @@ end solver_backend = CTDirect.IpoptBackend() dsol = CTDirect.solve_docp(solver_backend, docp, nlp, display = false) sol = OptimalControlSolution(docp, dsol) - @test sol.objective ≈ 0.313 rtol = 1e-2 + @test sol.objective ≈ prob.obj rtol = 1e-2 sol = OptimalControlSolution(docp, primal = dsol.solution) - @test sol.objective ≈ 0.313 rtol = 1e-2 + @test sol.objective ≈ prob.obj rtol = 1e-2 sol = OptimalControlSolution(docp, primal = dsol.solution, dual = dsol.multipliers) - @test sol.objective ≈ 0.313 rtol = 1e-2 + @test sol.objective ≈ prob.obj rtol = 1e-2 end @testset verbose = true showtiming = true ":solve_docp :madnlp" begin @@ -40,11 +41,11 @@ end solver_backend = CTDirect.MadNLPBackend() dsol = CTDirect.solve_docp(solver_backend, docp, nlp, display = false) sol = OptimalControlSolution(docp, dsol) - @test sol.objective ≈ 0.313 rtol = 1e-2 + @test sol.objective ≈ prob.obj rtol = 1e-2 sol = OptimalControlSolution(docp, primal = dsol.solution) - @test sol.objective ≈ 0.313 rtol = 1e-2 + @test sol.objective ≈ prob.obj rtol = 1e-2 sol = OptimalControlSolution(docp, primal = dsol.solution, dual = dsol.multipliers) - @test sol.objective ≈ 0.313 rtol = 1e-2 + @test sol.objective ≈ prob.obj rtol = 1e-2 end # check solution building diff --git a/test/suite/test_ocp.jl b/test/suite/test_ocp.jl index a189e95..34d5177 100644 --- a/test/suite/test_ocp.jl +++ b/test/suite/test_ocp.jl @@ -1,20 +1,36 @@ println("Test: OCP definition") -# + beam +# beam +if !isdefined(Main, :beam) + include("../problems/beam.jl") +end +@testset verbose = true showtiming = true ":beam" begin + prob = beam() + sol = direct_solve(prob.ocp, display = false) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end # double integrator min tf if !isdefined(Main, :double_integrator_mintf) include("../problems/double_integrator.jl") end - @testset verbose = true showtiming = true ":double_integrator :min_tf" begin prob = double_integrator_mintf() sol = direct_solve(prob.ocp, display = false) @test sol.objective ≈ prob.obj rtol = 1e-2 end -# + fuller +# fuller +if !isdefined(Main, :fuller) + include("../problems/fuller.jl") +end +@testset verbose = true showtiming = true ":fuller" begin + prob = fuller() + sol = direct_solve(prob.ocp, display = false) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end +# goddard max rf if !isdefined(Main, :goddard) include("../problems/goddard.jl") end @@ -24,10 +40,42 @@ end @test sol.objective ≈ prob.obj rtol = 1e-2 end -# + jackson +# jackson +if !isdefined(Main, :jackson) + include("../problems/jackson.jl") +end +@testset verbose = true showtiming = true ":jackson" begin + prob = jackson() + sol = direct_solve(prob.ocp, display = false) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end -# + robbins +#= robbins +if !isdefined(Main, :robbins) + include("../problems/robbins.jl") +end +@testset verbose = true showtiming = true ":robbins" begin + prob = robbins() + sol = direct_solve(prob.ocp, display = false) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end=# -# + simple integrator +# simple integrator +if !isdefined(Main, :simple_integrator) + include("../problems/simple_integrator.jl") +end +@testset verbose = true showtiming = true ":simple_integrator" begin + prob = simple_integrator() + sol = direct_solve(prob.ocp, display = false) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end -# + vanderpol \ No newline at end of file +# vanderpol +if !isdefined(Main, :vanderpol) + include("../problems/vanderpol.jl") +end +@testset verbose = true showtiming = true ":vanderpol" begin + prob = vanderpol() + sol = direct_solve(prob.ocp, display = false) + @test sol.objective ≈ prob.obj rtol = 1e-2 +end \ No newline at end of file