From b1b98274ac71f005adfa067b383c3123d811387f Mon Sep 17 00:00:00 2001 From: Pierre Martinon Date: Mon, 2 Dec 2024 11:00:05 +0100 Subject: [PATCH 1/3] extension for load/save functions (moved from CTDirect) --- Project.toml | 5 ++- ext/CTBaseLoadSave.jl | 64 ++++++++++++++++++++++++++++ src/optimal_control_solution-type.jl | 7 +++ test/test_solution.jl | 16 +++++++ 4 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 ext/CTBaseLoadSave.jl diff --git a/Project.toml b/Project.toml index 38155434..855a2343 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CTBase" uuid = "54762871-cc72-4466-b8e8-f6c8b58076cd" authors = ["Olivier Cots ", "Jean-Baptiste Caillau "] -version = "0.14.0" +version = "0.14.1" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -23,9 +23,12 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [weakdeps] Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" [extensions] CTBasePlots = "Plots" +CTBaseLoadSave = ["JLD2", "JSON3"] [compat] DataStructures = "0.18" diff --git a/ext/CTBaseLoadSave.jl b/ext/CTBaseLoadSave.jl new file mode 100644 index 00000000..01fe0192 --- /dev/null +++ b/ext/CTBaseLoadSave.jl @@ -0,0 +1,64 @@ +module CTBaseLoadSave + +using CTBase +using DocStringExtensions + +using JLD2 +using JSON3 + +""" +$(TYPEDSIGNATURES) + +Export OCP solution in JLD / JSON format +""" +function CTBase.export_ocp_solution(sol::OptimalControlSolution; filename_prefix = "solution", format = :JLD) + if format == :JLD + save_object(filename_prefix * ".jld2", sol) + elseif format == :JSON + blob = Dict( + "objective" => sol.objective, + "time_grid" => sol.time_grid, + "state" => state_discretized(sol), + "control" => control_discretized(sol), + "costate" => costate_discretized(sol)[1:(end - 1), :], + "variable" => sol.variable, + ) + open(filename_prefix * ".json", "w") do io + JSON3.pretty(io, blob) + end + else + error("Export_ocp_solution: unknow format (should be :JLD or :JSON): ", format) + end + return nothing +end + +""" +$(TYPEDSIGNATURES) + +Read OCP solution in JLD / JSON format +""" +function CTBase.import_ocp_solution(ocp::OptimalControlModel; filename_prefix = "solution", format = :JLD) + + if format == :JLD + return load_object(filename_prefix * ".jld2") + elseif format == :JSON + json_string = read(filename_prefix * ".json", String) + blob = JSON3.read(json_string) + + # NB. convert vect{vect} to matrix + return OptimalControlSolution( + ocp, + blob.time_grid, + stack(blob.state, dims = 1), + stack(blob.control, dims = 1), + blob.variable, + stack(blob.costate, dims = 1); + objective = blob.objective, + ) + else + error("Export_ocp_solution: unknow format (should be :JLD or :JSON): ", format) + end +end + + +end diff --git a/src/optimal_control_solution-type.jl b/src/optimal_control_solution-type.jl index 8710802c..56afd16c 100644 --- a/src/optimal_control_solution-type.jl +++ b/src/optimal_control_solution-type.jl @@ -79,3 +79,10 @@ $(TYPEDFIELDS) mult_control_box_lower::Union{Nothing, Function} = nothing mult_control_box_upper::Union{Nothing, Function} = nothing end + +export export_ocp_solution +export import_ocp_solution + +# placeholders (see extension CTBaseLoadSave) +function export_ocp_solution end +function import_ocp_solution end \ No newline at end of file diff --git a/test/test_solution.jl b/test/test_solution.jl index 0c8b9e80..3caf7b94 100644 --- a/test/test_solution.jl +++ b/test/test_solution.jl @@ -54,4 +54,20 @@ function test_solution() @test variable(sol) == v @test typeof(sol) == OptimalControlSolution @test_throws UndefKeywordError OptimalControlSolution(ocp; x, u, obj) + + + # test save / load solution in JLD2 format + @testset verbose = true showtiming = true ":save_load :JLD2" begin + export_ocp_solution(sol; filename_prefix = "solution_test") + sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test") + @test sol.objective == sol_reloaded.objective + end + + # test export / read solution in JSON format + @testset verbose = true showtiming = true ":export_read :JSON" begin + export_ocp_solution(sol; filename_prefix = "solution_test", format = :JSON) + sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test", format = :JSON) + @test sol.objective == sol_reloaded.objective + end + end From c2ce2b12e606c717f423c356767ccdd504e4fd9e Mon Sep 17 00:00:00 2001 From: Pierre Martinon Date: Mon, 2 Dec 2024 15:44:04 +0100 Subject: [PATCH 2/3] tests ok --- Project.toml | 6 +- src/optimal_control_solution-setters.jl | 136 ++++++++++++++++++++++++ src/optimal_control_solution-type.jl | 8 +- test/Project.toml | 2 + test/runtests.jl | 1 + test/solution_test.jld2 | Bin 0 -> 10653 bytes test/solution_test.json | 108 +++++++++++++++++++ test/test_solution.jl | 28 +++-- 8 files changed, 269 insertions(+), 20 deletions(-) create mode 100644 test/solution_test.jld2 create mode 100644 test/solution_test.json diff --git a/Project.toml b/Project.toml index 855a2343..71059320 100644 --- a/Project.toml +++ b/Project.toml @@ -22,13 +22,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [weakdeps] -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" [extensions] -CTBasePlots = "Plots" CTBaseLoadSave = ["JLD2", "JSON3"] +CTBasePlots = "Plots" [compat] DataStructures = "0.18" @@ -36,6 +36,8 @@ DifferentiationInterface = "0.5" DocStringExtensions = "0.9" ForwardDiff = "0.10" Interpolations = "0.15" +JLD2 = "0.5" +JSON3 = "1" MLStyle = "0.4" MacroTools = "0.5" Parameters = "0.12" diff --git a/src/optimal_control_solution-setters.jl b/src/optimal_control_solution-setters.jl index e6f44fc7..7cd6f6c6 100644 --- a/src/optimal_control_solution-setters.jl +++ b/src/optimal_control_solution-setters.jl @@ -224,6 +224,142 @@ function OptimalControlSolution( ) end + +""" +$(TYPEDSIGNATURES) + +Build OCP functional solution from discrete solution (given as raw variables and multipliers plus some optional infos) +""" +function OptimalControlSolution( + ocp::OptimalControlModel, + T, + X, + U, + v, + P; + objective = 0, + iterations = 0, + constraints_violation = 0, + message = "No msg", + stopping = nothing, + success = nothing, + constraints_types = (nothing, nothing, nothing, nothing, nothing), + constraints_mult = (nothing, nothing, nothing, nothing, nothing), + box_multipliers = (nothing, nothing, nothing, nothing, nothing, nothing), +) + dim_x = state_dimension(ocp) + dim_u = control_dimension(ocp) + dim_v = variable_dimension(ocp) + + # check that time grid is strictly increasing + # if not proceed with list of indexes as time grid + if !issorted(T, lt = <=) + println( + "WARNING: time grid at solution is not strictly increasing, replacing with list of indices...", + ) + println(T) + dim_NLP_steps = length(T) - 1 + T = LinRange(0, dim_NLP_steps, dim_NLP_steps + 1) + end + + # variables: remove additional state for lagrange cost + x = ctinterpolate(T, matrix2vec(X[:, 1:dim_x], 1)) + p = ctinterpolate(T[1:(end - 1)], matrix2vec(P[:, 1:dim_x], 1)) + u = ctinterpolate(T, matrix2vec(U[:, 1:dim_u], 1)) + + # force scalar output when dimension is 1 + fx = (dim_x == 1) ? deepcopy(t -> x(t)[1]) : deepcopy(t -> x(t)) + fu = (dim_u == 1) ? deepcopy(t -> u(t)[1]) : deepcopy(t -> u(t)) + fp = (dim_x == 1) ? deepcopy(t -> p(t)[1]) : deepcopy(t -> p(t)) + var = (dim_v == 1) ? v[1] : v + + # misc infos + infos = Dict{Symbol, Any}() + infos[:constraints_violation] = constraints_violation + + # nonlinear constraints and multipliers + control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[1], 1))(t) + mult_control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[1], 1))(t) + state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[2], 1))(t) + mult_state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[2], 1))(t) + mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[3], 1))(t) + mult_mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[3], 1))(t) + + # boundary and variable constraints + boundary_constraints = constraints_types[4] + mult_boundary_constraints = constraints_mult[4] + variable_constraints = constraints_types[5] + mult_variable_constraints = constraints_mult[5] + + # box constraints multipliers + mult_state_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[1][:, 1:dim_x], 1))(t) + mult_state_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[2][:, 1:dim_x], 1)) + mult_control_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[3][:, 1:dim_u], 1))(t) + mult_control_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[4][:, 1:dim_u], 1)) + mult_variable_box_lower, mult_variable_box_upper = box_multipliers[5], box_multipliers[6] + + # build and return solution + if is_variable_dependent(ocp) + return OptimalControlSolution( + ocp; + state = fx, + control = fu, + objective = objective, + costate = fp, + time_grid = T, + variable = var, + iterations = iterations, + stopping = stopping, + message = message, + success = success, + infos = infos, + control_constraints = control_constraints, + state_constraints = state_constraints, + mixed_constraints = mixed_constraints, + boundary_constraints = boundary_constraints, + variable_constraints = variable_constraints, + mult_control_constraints = mult_control_constraints, + mult_state_constraints = mult_state_constraints, + mult_mixed_constraints = mult_mixed_constraints, + mult_boundary_constraints = mult_boundary_constraints, + mult_variable_constraints = mult_variable_constraints, + mult_state_box_lower = mult_state_box_lower, + mult_state_box_upper = mult_state_box_upper, + mult_control_box_lower = mult_control_box_lower, + mult_control_box_upper = mult_control_box_upper, + mult_variable_box_lower = mult_variable_box_lower, + mult_variable_box_upper = mult_variable_box_upper, + ) + else + return OptimalControlSolution( + ocp; + state = fx, + control = fu, + objective = objective, + costate = fp, + time_grid = T, + iterations = iterations, + stopping = stopping, + message = message, + success = success, + infos = infos, + control_constraints = control_constraints, + state_constraints = state_constraints, + mixed_constraints = mixed_constraints, + boundary_constraints = boundary_constraints, + mult_control_constraints = mult_control_constraints, + mult_state_constraints = mult_state_constraints, + mult_mixed_constraints = mult_mixed_constraints, + mult_boundary_constraints = mult_boundary_constraints, + mult_state_box_lower = mult_state_box_lower, + mult_state_box_upper = mult_state_box_upper, + mult_control_box_lower = mult_control_box_lower, + mult_control_box_upper = mult_control_box_upper, + ) + end +end + + # setters #state!(sol::OptimalControlSolution, state::Function) = (sol.state = state; nothing) #control!(sol::OptimalControlSolution, control::Function) = (sol.control = control; nothing) diff --git a/src/optimal_control_solution-type.jl b/src/optimal_control_solution-type.jl index 56afd16c..da75c718 100644 --- a/src/optimal_control_solution-type.jl +++ b/src/optimal_control_solution-type.jl @@ -84,5 +84,9 @@ export export_ocp_solution export import_ocp_solution # placeholders (see extension CTBaseLoadSave) -function export_ocp_solution end -function import_ocp_solution end \ No newline at end of file +function export_ocp_solution(args...; kwargs...) + error("Requires JLD2 and JSON3 packages") +end +function import_ocp_solution(args...; kwargs...) + error("Requires JLD2 and JSON3 packages") +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 5bb19cbe..8d8b4e57 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,8 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 28f34bfd..dc140f93 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Aqua using CTBase using DifferentiationInterface: AutoForwardDiff using Plots +using JLD2, JSON3 using Test # functions and types that are not exported diff --git a/test/solution_test.jld2 b/test/solution_test.jld2 new file mode 100644 index 0000000000000000000000000000000000000000..0c2791d4a2bef8a40a3fa30011e8a5a4d10b36e7 GIT binary patch literal 10653 zcmeHNZ)_Ar6ra5-WqZArtAJEQT?C~PIIpFDiP3{rD5O9`i6Izr>|K|#aJPHj?UmN} zK`~MBFCR2f6oEt|{v%){A$~A?FxVJFP)Pt2{{$uSfe(<7NKA;n*?BW}x3_JUfFB?` zxx2UX-kbMk-kX`7{q6PkENfesRP)-Pva*oYRi#I@RArf&%c#}@WmwDQb<#rM-1wQn!@8#~&hoxm6>uv$tniU1(+~Aj^aV*eL$RvIi4;^+Cq}Ir;hV zHN8D+1LVcvqa@xA>Iz51vT6=w7EB@%=>iUwci0j_!b`xP zR7*mnifpAbFJX;lH4+R^1GLc2)WLU*R-iDW?3(|!@2=@%*UP)t^w}EJ2NVhQsRdfP z&72mcb9){2Laj4<_t;X9uUiU>Ai5@N=@~WMZ5me2Ob?jpf(7#*bb_cK=K9p$YSY@H z8$*;;g>v4@0_sh2^kbE(8__u(?J@Y$kW{0*3iU_M?Rbt}POjwdpnTcgURxK~>jZfX z%6UJAX(K2hP@n1h18zQ;DCqg;f zulW)1Bj88CkANQmKLUOP{0R6F@FUmN!9swJtkhWF@XII7B=MmBg55(|juO)_Z z`XG#r!w@#lMlfW-(+%B%=hH+{77?<|x=|)+6p|Ff3=)Ij#h|DeQ7U99Gm|w9&9L$b z>N8K8Dpc&g#GGu>DL8pcwX_m1afhVBN3TP{_RwF>B#O?AUL|Q+02teqMkGh#-N%TO+fncRu0e^D^HFJai);A60SPOgpLSUsfjTm zwoSO|oE%3dX^te)=5{Sd&P&3A*jwfaVSIWhFXvOwT<0~)=F`?iIi=;>)$089Z zE!!i%SvBtv+jX+iBv1&}NC4_l&0vTGgMl*$k;#upfEIR2D`v3fuOdBxG}6fUYmx3n z+ISb^pG7)rCadd6FCjgAH{)MHx?~osA4U2P(&Y$!zKrxV(wX-%{f$WXA>E`f{&_^! zn(k-yZAcFyy^OSV4wG+3`W4b^NaJ&v{CT9uk=8uG_8z68%1U+;&<)hslPE~ z+6u}f-Ub5L@4aEn6w&2*4GN3Ji!dOW0l;a)kTornwMNSz?f@uCn~fDmnETGM%m-NJ zzq^xJlVLuF}=XWdzk#-LFEfPI=;T7uJv z6CVXvYc! zchkF5o&7?4drr43&2W3AO<#oaX7zCGEVY02-NUBlM%WeOZud7&-K8&bxgY*gx763) zWAEZEYQ5uDZJ$3$`K9-PXs@0IXR?SS$5?6R_X`KVg`*x0{vv)Va`0REvx$S>#?Q76emlR3 zIQSj>e@Jepztf*A-p{tg-4hN`v4=P<4=5wRp#V7M0ODc*50 Wr{f>x$O0zy{OfW9Ir8zqm;V74w;Uz_ literal 0 HcmV?d00001 diff --git a/test/solution_test.json b/test/solution_test.json new file mode 100644 index 00000000..bf6c704a --- /dev/null +++ b/test/solution_test.json @@ -0,0 +1,108 @@ +{ + "time_grid": [ + 0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777778, + 0.8888888888888888, + 1 + ], + "objective": 1, + "control": [ + 0, + 0.2222222222222222, + 0.4444444444444444, + 0.6666666666666666, + 0.8888888888888888, + 1.1111111111111112, + 1.3333333333333333, + 1.5555555555555556, + 1.7777777777777777, + 2 + ], + "costate": [ + [ + 0, + -1 + ], + [ + 0.1111111111111111, + -0.8888888888888888 + ], + [ + 0.2222222222222222, + -0.7777777777777778 + ], + [ + 0.3333333333333333, + -0.6666666666666667 + ], + [ + 0.4444444444444444, + -0.5555555555555556 + ], + [ + 0.5555555555555556, + -0.4444444444444444 + ], + [ + 0.6666666666666666, + -0.33333333333333337 + ], + [ + 0.7777777777777778, + -0.2222222222222222 + ], + [ + 0.8888888888888888, + -0.11111111111111116 + ] + ], + "variable": null, + "state": [ + [ + 0, + 1 + ], + [ + 0.1111111111111111, + 1.1111111111111112 + ], + [ + 0.2222222222222222, + 1.2222222222222223 + ], + [ + 0.3333333333333333, + 1.3333333333333333 + ], + [ + 0.4444444444444444, + 1.4444444444444444 + ], + [ + 0.5555555555555556, + 1.5555555555555556 + ], + [ + 0.6666666666666666, + 1.6666666666666665 + ], + [ + 0.7777777777777778, + 1.7777777777777777 + ], + [ + 0.8888888888888888, + 1.8888888888888888 + ], + [ + 1, + 2 + ] + ] +} \ No newline at end of file diff --git a/test/test_solution.jl b/test/test_solution.jl index 3caf7b94..0993ed27 100644 --- a/test/test_solution.jl +++ b/test/test_solution.jl @@ -12,9 +12,9 @@ function test_solution() end times = range(0, 1, 10) - x = t -> t + x = t -> [t, t+1] u = t -> 2t - p = t -> t + p = t -> [t, t-1] obj = 1 sol = OptimalControlSolution( ocp; @@ -33,6 +33,12 @@ function test_solution() @test all(control_discretized(sol) .== u.(times)) @test all(costate_discretized(sol) .== p.(times)) + # test export / read solution in JSON format (NB. requires time grid in solution !) + println(sol.time_grid) + export_ocp_solution(sol; filename_prefix = "solution_test", format = :JSON) + sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test", format = :JSON) + @test sol.objective == sol_reloaded.objective + # NonFixed ocp @def ocp begin v ∈ R, variable @@ -45,7 +51,7 @@ function test_solution() ∫(0.5u(t)^2) → min end - x = t -> t + x = t -> [t, t+1] u = t -> 2t obj = 1 v = 1 @@ -55,19 +61,9 @@ function test_solution() @test typeof(sol) == OptimalControlSolution @test_throws UndefKeywordError OptimalControlSolution(ocp; x, u, obj) - # test save / load solution in JLD2 format - @testset verbose = true showtiming = true ":save_load :JLD2" begin - export_ocp_solution(sol; filename_prefix = "solution_test") - sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test") - @test sol.objective == sol_reloaded.objective - end - - # test export / read solution in JSON format - @testset verbose = true showtiming = true ":export_read :JSON" begin - export_ocp_solution(sol; filename_prefix = "solution_test", format = :JSON) - sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test", format = :JSON) - @test sol.objective == sol_reloaded.objective - end + export_ocp_solution(sol; filename_prefix = "solution_test") + sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test") + @test sol.objective == sol_reloaded.objective end From 73e838f02f65d8edb0a0d36a13e0a31c899dcee4 Mon Sep 17 00:00:00 2001 From: Olivier Cots Date: Fri, 6 Dec 2024 15:46:05 +0100 Subject: [PATCH 3/3] foo --- src/CTBase.jl | 4 ++++ src/optimal_control_solution-type.jl | 9 +++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/CTBase.jl b/src/CTBase.jl index 88dea2b0..d9270be2 100644 --- a/src/CTBase.jl +++ b/src/CTBase.jl @@ -326,4 +326,8 @@ export @def export ct_repl, ct_repl_update_model isdefined(Base, :active_repl) && ct_repl() +# load and save solution +export export_ocp_solution +export import_ocp_solution + end diff --git a/src/optimal_control_solution-type.jl b/src/optimal_control_solution-type.jl index da75c718..635f8f82 100644 --- a/src/optimal_control_solution-type.jl +++ b/src/optimal_control_solution-type.jl @@ -80,13 +80,10 @@ $(TYPEDFIELDS) mult_control_box_upper::Union{Nothing, Function} = nothing end -export export_ocp_solution -export import_ocp_solution - # placeholders (see extension CTBaseLoadSave) -function export_ocp_solution(args...; kwargs...) - error("Requires JLD2 and JSON3 packages") +function export_ocp_solution(args...; kwargs...) + throw(ExtensionError(:JLD2, :JSON3)) end function import_ocp_solution(args...; kwargs...) - error("Requires JLD2 and JSON3 packages") + throw(ExtensionError(:JLD2, :JSON3)) end \ No newline at end of file