Skip to content

Commit

Permalink
try adding optimal control back?
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Nov 5, 2024
1 parent a906903 commit eebb963
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pages = ["index.md",
"examples/hybrid_jump/bouncing_ball.md"],
"Bayesian Estimation" => Any["examples/bayesian/turing_bayesian.md"],
"Optimal and Model Predictive Control" => Any[
"examples/optimal_control/optimal_control.md",
"examples/optimal_control/feedback_control.md"]],
"Manual and APIs" => Any[
"manual/differential_equation_sensitivities.md",
Expand Down
128 changes: 128 additions & 0 deletions docs/src/examples/optimal_control/optimal_control.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# [Solving Optimal Control Problems with Universal Differential Equations](@id optcontrol)

Here we will solve a classic optimal control problem with a universal differential
equation. Let

```math
x^{′′} = u^3(t)
```

where we want to optimize our controller `u(t)` such that the following is
minimized:

```math
L(\theta) = \sum_i \Vert 4 - x(t_i) \Vert + 2 \Vert x^\prime(t_i) \Vert + \Vert u(t_i) \Vert
```

where ``i`` is measured on (0,8) at 0.01 intervals. To do this, we rewrite the
ODE in first order form:

```math
\begin{aligned}
x^\prime &= v \\
v^′ &= u^3(t) \\
\end{aligned}
```

and thus

```math
L(\theta) = \sum_i \Vert 4 - x(t_i) \Vert + 2 \Vert v(t_i) \Vert + \Vert u(t_i) \Vert
```

is our loss function on the first order system. We thus choose a neural network
form for ``u`` and optimize the equation with respect to this loss. Note that we
will first reduce control cost (the last term) by 10x in order to bump the network out
of a local minimum. This looks like:

```@example neuraloptimalcontrol
using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random
rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)
ann = Chain(Dense(1, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(rng, ann)
p = ComponentArray(ps)
θ, _ax = getdata(p), getaxes(p)
const ax = _ax
function dxdt_(dx, x, p, t)
ps = ComponentArray(p, ax)
x1, x2 = x
dx[1] = x[2]
dx[2] = first(ann([t], ps, st))[1]^3
end
x0 = [-4.0f0, 0.0f0]
ts = Float32.(collect(0.0:0.01:tspan[2]))
prob = ODEProblem(dxdt_, x0, tspan, θ)
solve(prob, Vern9(), abstol = 1e-10, reltol = 1e-10)
function predict_adjoint(θ)
Array(solve(prob, Vern9(), p = θ, saveat = ts))
end
function loss_adjoint(θ)
x = predict_adjoint(θ)
ps = ComponentArray(θ, ax)
mean(abs2, 4.0f0 .- x[1, :]) + 2mean(abs2, x[2, :]) +
mean(abs2, [first(first(ann([t], ps, st))) for t in ts]) / 10
end
l = loss_adjoint(θ)
cb = function (state, l; doplot = true)
println(l)
ps = ComponentArray(state.u, ax)
if doplot
p = plot(solve(remake(prob, p = state.u), Tsit5(), saveat = 0.01),
ylim = (-6, 6), lw = 3)
plot!(p, ts, [first(first(ann([t], ps, st))) for t in ts], label = "u(t)", lw = 3)
display(p)
end
return false
end
# Setup and run the optimization
loss1 = loss_adjoint(θ)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
optprob = Optimization.OptimizationProblem(optf, θ)
res1 = Optimization.solve(
optprob, OptimizationOptimisers.Adam(0.01), callback = cb, maxiters = 100)
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(
optprob2, OptimizationOptimJL.BFGS(), callback = cb, maxiters = 100)
```

Now that the system is in a better behaved part of parameter space, we return to
the original loss function to finish the optimization:

```@example neuraloptimalcontrol
function loss_adjoint(θ)
x = predict_adjoint(θ)
ps = ComponentArray(θ, ax)
mean(abs2, 4.0 .- x[1, :]) + 2mean(abs2, x[2, :]) +
mean(abs2, [first(first(ann([t], ps, st))) for t in ts])
end
optf3 = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
optprob3 = Optimization.OptimizationProblem(optf3, res2.u)
res3 = Optimization.solve(optprob3, OptimizationOptimJL.BFGS(), maxiters = 100)
```

Now let's see what we received:

```@example neuraloptimalcontrol
l = loss_adjoint(res3.u)
cb(res3, l)
p = plot(solve(remake(prob, p = res3.u), Tsit5(), saveat = 0.01), ylim = (-6, 6), lw = 3)
plot!(p, ts, [first(first(ann([t], ComponentArray(res3.u, ax), st))) for t in ts],
label = "u(t)", lw = 3)
```

0 comments on commit eebb963

Please sign in to comment.