diff --git a/Project.toml b/Project.toml index 38155434..71059320 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CTBase" uuid = "54762871-cc72-4466-b8e8-f6c8b58076cd" authors = ["Olivier Cots ", "Jean-Baptiste Caillau "] -version = "0.14.0" +version = "0.14.1" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -22,9 +22,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [weakdeps] +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" [extensions] +CTBaseLoadSave = ["JLD2", "JSON3"] CTBasePlots = "Plots" [compat] @@ -33,6 +36,8 @@ DifferentiationInterface = "0.5" DocStringExtensions = "0.9" ForwardDiff = "0.10" Interpolations = "0.15" +JLD2 = "0.5" +JSON3 = "1" MLStyle = "0.4" MacroTools = "0.5" Parameters = "0.12" diff --git a/ext/CTBaseLoadSave.jl b/ext/CTBaseLoadSave.jl new file mode 100644 index 00000000..01fe0192 --- /dev/null +++ b/ext/CTBaseLoadSave.jl @@ -0,0 +1,64 @@ +module CTBaseLoadSave + +using CTBase +using DocStringExtensions + +using JLD2 +using JSON3 + +""" +$(TYPEDSIGNATURES) + +Export OCP solution in JLD / JSON format +""" +function CTBase.export_ocp_solution(sol::OptimalControlSolution; filename_prefix = "solution", format = :JLD) + if format == :JLD + save_object(filename_prefix * ".jld2", sol) + elseif format == :JSON + blob = Dict( + "objective" => sol.objective, + "time_grid" => sol.time_grid, + "state" => state_discretized(sol), + "control" => control_discretized(sol), + "costate" => costate_discretized(sol)[1:(end - 1), :], + "variable" => sol.variable, + ) + open(filename_prefix * ".json", "w") do io + JSON3.pretty(io, blob) + end + else + error("Export_ocp_solution: unknow format (should be :JLD or :JSON): ", format) + end + return nothing +end + +""" +$(TYPEDSIGNATURES) + +Read OCP solution in JLD / JSON format +""" +function CTBase.import_ocp_solution(ocp::OptimalControlModel; filename_prefix = "solution", format = :JLD) + + if format == :JLD + return load_object(filename_prefix * ".jld2") + elseif format == :JSON + json_string = read(filename_prefix * ".json", String) + blob = JSON3.read(json_string) + + # NB. convert vect{vect} to matrix + return OptimalControlSolution( + ocp, + blob.time_grid, + stack(blob.state, dims = 1), + stack(blob.control, dims = 1), + blob.variable, + stack(blob.costate, dims = 1); + objective = blob.objective, + ) + else + error("Export_ocp_solution: unknow format (should be :JLD or :JSON): ", format) + end +end + + +end diff --git a/src/CTBase.jl b/src/CTBase.jl index 88dea2b0..d9270be2 100644 --- a/src/CTBase.jl +++ b/src/CTBase.jl @@ -326,4 +326,8 @@ export @def export ct_repl, ct_repl_update_model isdefined(Base, :active_repl) && ct_repl() +# load and save solution +export export_ocp_solution +export import_ocp_solution + end diff --git a/src/optimal_control_solution-setters.jl b/src/optimal_control_solution-setters.jl index e6f44fc7..7cd6f6c6 100644 --- a/src/optimal_control_solution-setters.jl +++ b/src/optimal_control_solution-setters.jl @@ -224,6 +224,142 @@ function OptimalControlSolution( ) end + +""" +$(TYPEDSIGNATURES) + +Build OCP functional solution from discrete solution (given as raw variables and multipliers plus some optional infos) +""" +function OptimalControlSolution( + ocp::OptimalControlModel, + T, + X, + U, + v, + P; + objective = 0, + iterations = 0, + constraints_violation = 0, + message = "No msg", + stopping = nothing, + success = nothing, + constraints_types = (nothing, nothing, nothing, nothing, nothing), + constraints_mult = (nothing, nothing, nothing, nothing, nothing), + box_multipliers = (nothing, nothing, nothing, nothing, nothing, nothing), +) + dim_x = state_dimension(ocp) + dim_u = control_dimension(ocp) + dim_v = variable_dimension(ocp) + + # check that time grid is strictly increasing + # if not proceed with list of indexes as time grid + if !issorted(T, lt = <=) + println( + "WARNING: time grid at solution is not strictly increasing, replacing with list of indices...", + ) + println(T) + dim_NLP_steps = length(T) - 1 + T = LinRange(0, dim_NLP_steps, dim_NLP_steps + 1) + end + + # variables: remove additional state for lagrange cost + x = ctinterpolate(T, matrix2vec(X[:, 1:dim_x], 1)) + p = ctinterpolate(T[1:(end - 1)], matrix2vec(P[:, 1:dim_x], 1)) + u = ctinterpolate(T, matrix2vec(U[:, 1:dim_u], 1)) + + # force scalar output when dimension is 1 + fx = (dim_x == 1) ? deepcopy(t -> x(t)[1]) : deepcopy(t -> x(t)) + fu = (dim_u == 1) ? deepcopy(t -> u(t)[1]) : deepcopy(t -> u(t)) + fp = (dim_x == 1) ? deepcopy(t -> p(t)[1]) : deepcopy(t -> p(t)) + var = (dim_v == 1) ? v[1] : v + + # misc infos + infos = Dict{Symbol, Any}() + infos[:constraints_violation] = constraints_violation + + # nonlinear constraints and multipliers + control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[1], 1))(t) + mult_control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[1], 1))(t) + state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[2], 1))(t) + mult_state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[2], 1))(t) + mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[3], 1))(t) + mult_mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[3], 1))(t) + + # boundary and variable constraints + boundary_constraints = constraints_types[4] + mult_boundary_constraints = constraints_mult[4] + variable_constraints = constraints_types[5] + mult_variable_constraints = constraints_mult[5] + + # box constraints multipliers + mult_state_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[1][:, 1:dim_x], 1))(t) + mult_state_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[2][:, 1:dim_x], 1)) + mult_control_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[3][:, 1:dim_u], 1))(t) + mult_control_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[4][:, 1:dim_u], 1)) + mult_variable_box_lower, mult_variable_box_upper = box_multipliers[5], box_multipliers[6] + + # build and return solution + if is_variable_dependent(ocp) + return OptimalControlSolution( + ocp; + state = fx, + control = fu, + objective = objective, + costate = fp, + time_grid = T, + variable = var, + iterations = iterations, + stopping = stopping, + message = message, + success = success, + infos = infos, + control_constraints = control_constraints, + state_constraints = state_constraints, + mixed_constraints = mixed_constraints, + boundary_constraints = boundary_constraints, + variable_constraints = variable_constraints, + mult_control_constraints = mult_control_constraints, + mult_state_constraints = mult_state_constraints, + mult_mixed_constraints = mult_mixed_constraints, + mult_boundary_constraints = mult_boundary_constraints, + mult_variable_constraints = mult_variable_constraints, + mult_state_box_lower = mult_state_box_lower, + mult_state_box_upper = mult_state_box_upper, + mult_control_box_lower = mult_control_box_lower, + mult_control_box_upper = mult_control_box_upper, + mult_variable_box_lower = mult_variable_box_lower, + mult_variable_box_upper = mult_variable_box_upper, + ) + else + return OptimalControlSolution( + ocp; + state = fx, + control = fu, + objective = objective, + costate = fp, + time_grid = T, + iterations = iterations, + stopping = stopping, + message = message, + success = success, + infos = infos, + control_constraints = control_constraints, + state_constraints = state_constraints, + mixed_constraints = mixed_constraints, + boundary_constraints = boundary_constraints, + mult_control_constraints = mult_control_constraints, + mult_state_constraints = mult_state_constraints, + mult_mixed_constraints = mult_mixed_constraints, + mult_boundary_constraints = mult_boundary_constraints, + mult_state_box_lower = mult_state_box_lower, + mult_state_box_upper = mult_state_box_upper, + mult_control_box_lower = mult_control_box_lower, + mult_control_box_upper = mult_control_box_upper, + ) + end +end + + # setters #state!(sol::OptimalControlSolution, state::Function) = (sol.state = state; nothing) #control!(sol::OptimalControlSolution, control::Function) = (sol.control = control; nothing) diff --git a/src/optimal_control_solution-type.jl b/src/optimal_control_solution-type.jl index 8710802c..635f8f82 100644 --- a/src/optimal_control_solution-type.jl +++ b/src/optimal_control_solution-type.jl @@ -79,3 +79,11 @@ $(TYPEDFIELDS) mult_control_box_lower::Union{Nothing, Function} = nothing mult_control_box_upper::Union{Nothing, Function} = nothing end + +# placeholders (see extension CTBaseLoadSave) +function export_ocp_solution(args...; kwargs...) + throw(ExtensionError(:JLD2, :JSON3)) +end +function import_ocp_solution(args...; kwargs...) + throw(ExtensionError(:JLD2, :JSON3)) +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 5bb19cbe..8d8b4e57 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,8 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 28f34bfd..dc140f93 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Aqua using CTBase using DifferentiationInterface: AutoForwardDiff using Plots +using JLD2, JSON3 using Test # functions and types that are not exported diff --git a/test/solution_test.jld2 b/test/solution_test.jld2 new file mode 100644 index 00000000..0c2791d4 Binary files /dev/null and b/test/solution_test.jld2 differ diff --git a/test/solution_test.json b/test/solution_test.json new file mode 100644 index 00000000..bf6c704a --- /dev/null +++ b/test/solution_test.json @@ -0,0 +1,108 @@ +{ + "time_grid": [ + 0, + 0.1111111111111111, + 0.2222222222222222, + 0.3333333333333333, + 0.4444444444444444, + 0.5555555555555556, + 0.6666666666666666, + 0.7777777777777778, + 0.8888888888888888, + 1 + ], + "objective": 1, + "control": [ + 0, + 0.2222222222222222, + 0.4444444444444444, + 0.6666666666666666, + 0.8888888888888888, + 1.1111111111111112, + 1.3333333333333333, + 1.5555555555555556, + 1.7777777777777777, + 2 + ], + "costate": [ + [ + 0, + -1 + ], + [ + 0.1111111111111111, + -0.8888888888888888 + ], + [ + 0.2222222222222222, + -0.7777777777777778 + ], + [ + 0.3333333333333333, + -0.6666666666666667 + ], + [ + 0.4444444444444444, + -0.5555555555555556 + ], + [ + 0.5555555555555556, + -0.4444444444444444 + ], + [ + 0.6666666666666666, + -0.33333333333333337 + ], + [ + 0.7777777777777778, + -0.2222222222222222 + ], + [ + 0.8888888888888888, + -0.11111111111111116 + ] + ], + "variable": null, + "state": [ + [ + 0, + 1 + ], + [ + 0.1111111111111111, + 1.1111111111111112 + ], + [ + 0.2222222222222222, + 1.2222222222222223 + ], + [ + 0.3333333333333333, + 1.3333333333333333 + ], + [ + 0.4444444444444444, + 1.4444444444444444 + ], + [ + 0.5555555555555556, + 1.5555555555555556 + ], + [ + 0.6666666666666666, + 1.6666666666666665 + ], + [ + 0.7777777777777778, + 1.7777777777777777 + ], + [ + 0.8888888888888888, + 1.8888888888888888 + ], + [ + 1, + 2 + ] + ] +} \ No newline at end of file diff --git a/test/test_solution.jl b/test/test_solution.jl index 0c8b9e80..0993ed27 100644 --- a/test/test_solution.jl +++ b/test/test_solution.jl @@ -12,9 +12,9 @@ function test_solution() end times = range(0, 1, 10) - x = t -> t + x = t -> [t, t+1] u = t -> 2t - p = t -> t + p = t -> [t, t-1] obj = 1 sol = OptimalControlSolution( ocp; @@ -33,6 +33,12 @@ function test_solution() @test all(control_discretized(sol) .== u.(times)) @test all(costate_discretized(sol) .== p.(times)) + # test export / read solution in JSON format (NB. requires time grid in solution !) + println(sol.time_grid) + export_ocp_solution(sol; filename_prefix = "solution_test", format = :JSON) + sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test", format = :JSON) + @test sol.objective == sol_reloaded.objective + # NonFixed ocp @def ocp begin v ∈ R, variable @@ -45,7 +51,7 @@ function test_solution() ∫(0.5u(t)^2) → min end - x = t -> t + x = t -> [t, t+1] u = t -> 2t obj = 1 v = 1 @@ -54,4 +60,10 @@ function test_solution() @test variable(sol) == v @test typeof(sol) == OptimalControlSolution @test_throws UndefKeywordError OptimalControlSolution(ocp; x, u, obj) + + # test save / load solution in JLD2 format + export_ocp_solution(sol; filename_prefix = "solution_test") + sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test") + @test sol.objective == sol_reloaded.objective + end