Skip to content

Commit

Permalink
Merge pull request #312 from control-toolbox/move_load_save_to_base
Browse files Browse the repository at this point in the history
Load / save OCP solution moved from CTDirect to CTBase
  • Loading branch information
ocots authored Dec 6, 2024
2 parents 14b6599 + 73e838f commit 9d7ec29
Show file tree
Hide file tree
Showing 10 changed files with 344 additions and 4 deletions.
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CTBase"
uuid = "54762871-cc72-4466-b8e8-f6c8b58076cd"
authors = ["Olivier Cots <[email protected]>", "Jean-Baptiste Caillau <[email protected]>"]
version = "0.14.0"
version = "0.14.1"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -22,9 +22,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[weakdeps]
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[extensions]
CTBaseLoadSave = ["JLD2", "JSON3"]
CTBasePlots = "Plots"

[compat]
Expand All @@ -33,6 +36,8 @@ DifferentiationInterface = "0.5"
DocStringExtensions = "0.9"
ForwardDiff = "0.10"
Interpolations = "0.15"
JLD2 = "0.5"
JSON3 = "1"
MLStyle = "0.4"
MacroTools = "0.5"
Parameters = "0.12"
Expand Down
64 changes: 64 additions & 0 deletions ext/CTBaseLoadSave.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
module CTBaseLoadSave

using CTBase
using DocStringExtensions

using JLD2
using JSON3

"""
$(TYPEDSIGNATURES)
Export OCP solution in JLD / JSON format
"""
function CTBase.export_ocp_solution(sol::OptimalControlSolution; filename_prefix = "solution", format = :JLD)
if format == :JLD
save_object(filename_prefix * ".jld2", sol)
elseif format == :JSON
blob = Dict(
"objective" => sol.objective,
"time_grid" => sol.time_grid,
"state" => state_discretized(sol),
"control" => control_discretized(sol),
"costate" => costate_discretized(sol)[1:(end - 1), :],
"variable" => sol.variable,
)
open(filename_prefix * ".json", "w") do io
JSON3.pretty(io, blob)
end
else
error("Export_ocp_solution: unknow format (should be :JLD or :JSON): ", format)
end
return nothing
end

"""
$(TYPEDSIGNATURES)
Read OCP solution in JLD / JSON format
"""
function CTBase.import_ocp_solution(ocp::OptimalControlModel; filename_prefix = "solution", format = :JLD)

if format == :JLD
return load_object(filename_prefix * ".jld2")
elseif format == :JSON
json_string = read(filename_prefix * ".json", String)
blob = JSON3.read(json_string)

# NB. convert vect{vect} to matrix
return OptimalControlSolution(
ocp,
blob.time_grid,
stack(blob.state, dims = 1),
stack(blob.control, dims = 1),
blob.variable,
stack(blob.costate, dims = 1);
objective = blob.objective,
)
else
error("Export_ocp_solution: unknow format (should be :JLD or :JSON): ", format)
end
end


end
4 changes: 4 additions & 0 deletions src/CTBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,8 @@ export @def
export ct_repl, ct_repl_update_model
isdefined(Base, :active_repl) && ct_repl()

# load and save solution
export export_ocp_solution
export import_ocp_solution

end
136 changes: 136 additions & 0 deletions src/optimal_control_solution-setters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,142 @@ function OptimalControlSolution(
)
end


"""
$(TYPEDSIGNATURES)
Build OCP functional solution from discrete solution (given as raw variables and multipliers plus some optional infos)
"""
function OptimalControlSolution(
ocp::OptimalControlModel,
T,
X,
U,
v,
P;
objective = 0,
iterations = 0,
constraints_violation = 0,
message = "No msg",
stopping = nothing,
success = nothing,
constraints_types = (nothing, nothing, nothing, nothing, nothing),
constraints_mult = (nothing, nothing, nothing, nothing, nothing),
box_multipliers = (nothing, nothing, nothing, nothing, nothing, nothing),
)
dim_x = state_dimension(ocp)
dim_u = control_dimension(ocp)
dim_v = variable_dimension(ocp)

# check that time grid is strictly increasing
# if not proceed with list of indexes as time grid
if !issorted(T, lt = <=)
println(
"WARNING: time grid at solution is not strictly increasing, replacing with list of indices...",
)
println(T)
dim_NLP_steps = length(T) - 1
T = LinRange(0, dim_NLP_steps, dim_NLP_steps + 1)
end

# variables: remove additional state for lagrange cost
x = ctinterpolate(T, matrix2vec(X[:, 1:dim_x], 1))
p = ctinterpolate(T[1:(end - 1)], matrix2vec(P[:, 1:dim_x], 1))
u = ctinterpolate(T, matrix2vec(U[:, 1:dim_u], 1))

# force scalar output when dimension is 1
fx = (dim_x == 1) ? deepcopy(t -> x(t)[1]) : deepcopy(t -> x(t))
fu = (dim_u == 1) ? deepcopy(t -> u(t)[1]) : deepcopy(t -> u(t))
fp = (dim_x == 1) ? deepcopy(t -> p(t)[1]) : deepcopy(t -> p(t))
var = (dim_v == 1) ? v[1] : v

# misc infos
infos = Dict{Symbol, Any}()
infos[:constraints_violation] = constraints_violation

# nonlinear constraints and multipliers
control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[1], 1))(t)
mult_control_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[1], 1))(t)
state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[2], 1))(t)
mult_state_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[2], 1))(t)
mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_types[3], 1))(t)
mult_mixed_constraints = t -> ctinterpolate(T, matrix2vec(constraints_mult[3], 1))(t)

# boundary and variable constraints
boundary_constraints = constraints_types[4]
mult_boundary_constraints = constraints_mult[4]
variable_constraints = constraints_types[5]
mult_variable_constraints = constraints_mult[5]

# box constraints multipliers
mult_state_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[1][:, 1:dim_x], 1))(t)
mult_state_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[2][:, 1:dim_x], 1))
mult_control_box_lower = t -> ctinterpolate(T, matrix2vec(box_multipliers[3][:, 1:dim_u], 1))(t)
mult_control_box_upper = t -> ctinterpolate(T, matrix2vec(box_multipliers[4][:, 1:dim_u], 1))
mult_variable_box_lower, mult_variable_box_upper = box_multipliers[5], box_multipliers[6]

# build and return solution
if is_variable_dependent(ocp)
return OptimalControlSolution(
ocp;
state = fx,
control = fu,
objective = objective,
costate = fp,
time_grid = T,
variable = var,
iterations = iterations,
stopping = stopping,
message = message,
success = success,
infos = infos,
control_constraints = control_constraints,
state_constraints = state_constraints,
mixed_constraints = mixed_constraints,
boundary_constraints = boundary_constraints,
variable_constraints = variable_constraints,
mult_control_constraints = mult_control_constraints,
mult_state_constraints = mult_state_constraints,
mult_mixed_constraints = mult_mixed_constraints,
mult_boundary_constraints = mult_boundary_constraints,
mult_variable_constraints = mult_variable_constraints,
mult_state_box_lower = mult_state_box_lower,
mult_state_box_upper = mult_state_box_upper,
mult_control_box_lower = mult_control_box_lower,
mult_control_box_upper = mult_control_box_upper,
mult_variable_box_lower = mult_variable_box_lower,
mult_variable_box_upper = mult_variable_box_upper,
)
else
return OptimalControlSolution(
ocp;
state = fx,
control = fu,
objective = objective,
costate = fp,
time_grid = T,
iterations = iterations,
stopping = stopping,
message = message,
success = success,
infos = infos,
control_constraints = control_constraints,
state_constraints = state_constraints,
mixed_constraints = mixed_constraints,
boundary_constraints = boundary_constraints,
mult_control_constraints = mult_control_constraints,
mult_state_constraints = mult_state_constraints,
mult_mixed_constraints = mult_mixed_constraints,
mult_boundary_constraints = mult_boundary_constraints,
mult_state_box_lower = mult_state_box_lower,
mult_state_box_upper = mult_state_box_upper,
mult_control_box_lower = mult_control_box_lower,
mult_control_box_upper = mult_control_box_upper,
)
end
end


# setters
#state!(sol::OptimalControlSolution, state::Function) = (sol.state = state; nothing)
#control!(sol::OptimalControlSolution, control::Function) = (sol.control = control; nothing)
Expand Down
8 changes: 8 additions & 0 deletions src/optimal_control_solution-type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,11 @@ $(TYPEDFIELDS)
mult_control_box_lower::Union{Nothing, Function} = nothing
mult_control_box_upper::Union{Nothing, Function} = nothing
end

# placeholders (see extension CTBaseLoadSave)
function export_ocp_solution(args...; kwargs...)
throw(ExtensionError(:JLD2, :JSON3))
end
function import_ocp_solution(args...; kwargs...)
throw(ExtensionError(:JLD2, :JSON3))
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Aqua
using CTBase
using DifferentiationInterface: AutoForwardDiff
using Plots
using JLD2, JSON3
using Test

# functions and types that are not exported
Expand Down
Binary file added test/solution_test.jld2
Binary file not shown.
108 changes: 108 additions & 0 deletions test/solution_test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"time_grid": [
0,
0.1111111111111111,
0.2222222222222222,
0.3333333333333333,
0.4444444444444444,
0.5555555555555556,
0.6666666666666666,
0.7777777777777778,
0.8888888888888888,
1
],
"objective": 1,
"control": [
0,
0.2222222222222222,
0.4444444444444444,
0.6666666666666666,
0.8888888888888888,
1.1111111111111112,
1.3333333333333333,
1.5555555555555556,
1.7777777777777777,
2
],
"costate": [
[
0,
-1
],
[
0.1111111111111111,
-0.8888888888888888
],
[
0.2222222222222222,
-0.7777777777777778
],
[
0.3333333333333333,
-0.6666666666666667
],
[
0.4444444444444444,
-0.5555555555555556
],
[
0.5555555555555556,
-0.4444444444444444
],
[
0.6666666666666666,
-0.33333333333333337
],
[
0.7777777777777778,
-0.2222222222222222
],
[
0.8888888888888888,
-0.11111111111111116
]
],
"variable": null,
"state": [
[
0,
1
],
[
0.1111111111111111,
1.1111111111111112
],
[
0.2222222222222222,
1.2222222222222223
],
[
0.3333333333333333,
1.3333333333333333
],
[
0.4444444444444444,
1.4444444444444444
],
[
0.5555555555555556,
1.5555555555555556
],
[
0.6666666666666666,
1.6666666666666665
],
[
0.7777777777777778,
1.7777777777777777
],
[
0.8888888888888888,
1.8888888888888888
],
[
1,
2
]
]
}
Loading

0 comments on commit 9d7ec29

Please sign in to comment.