Skip to content

Commit

Permalink
Merge branch 'master' into dg/bs_adjoint
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi authored Nov 11, 2024
2 parents e8ef956 + 59c8068 commit bb5be4e
Show file tree
Hide file tree
Showing 24 changed files with 258 additions and 81 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLSensitivity"
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
authors = ["Christopher Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
version = "7.69.0"
version = "7.71.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -64,8 +64,8 @@ FiniteDiff = "2"
ForwardDiff = "0.10"
FunctionProperties = "0.1"
FunctionWrappersWrappers = "0.1"
Functors = "0.4"
GPUArraysCore = "0.1"
Functors = "0.4, 0.5"
GPUArraysCore = "0.1, 0.2"
LinearAlgebra = "1.10"
LinearSolve = "2"
Lux = "1"
Expand All @@ -81,7 +81,7 @@ PreallocationTools = "0.4.4"
QuadGK = "2.9.1"
Random = "1.10"
RandomNumbers = "1.5.3"
RecursiveArrayTools = "3.18.1"
RecursiveArrayTools = "3.27.2"
Reexport = "1.0"
ReverseDiff = "1.15.1"
SafeTestsets = "0.1.0"
Expand Down
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Expand All @@ -23,6 +24,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -45,6 +47,7 @@ Enzyme = "0.12, 0.13"
Flux = "0.14"
ForwardDiff = "0.10"
IterTools = "1"
MLUtils = "0.4"
Lux = "1"
LuxCUDA = "0.3"
Optimization = "3.9, 4"
Expand All @@ -56,6 +59,7 @@ Plots = "1.36"
QuadGK = "2.6"
RecursiveArrayTools = "2.32, 3"
ReverseDiff = "1.14"
SciMLBase = "2.58"
SciMLSensitivity = "7.11"
SimpleChains = "0.4"
StaticArrays = "1"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/examples/dde/delay_diffeq.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
loss_dde(p) = sum(abs2, x - 1 for x in predict_dde(p))
using Plots
callback = function (state, l...; doplot = false)
callback = function (state, l; doplot = false)
display(loss_dde(state.u))
doplot &&
display(plot(
Expand All @@ -60,7 +60,7 @@ We define a callback to display the solution at the current parameters for each

```@example dde
using Plots
callback = function (state, l...; doplot = false)
callback = function (state, l; doplot = false)
display(loss_dde(state.u))
doplot &&
display(plot(
Expand Down
8 changes: 4 additions & 4 deletions docs/src/examples/neural_ode/minibatch.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

```@example
using SciMLSensitivity
using DifferentialEquations, Flux, Random, Plots
using DifferentialEquations, Flux, Random, Plots, MLUtils
using IterTools: ncycle
rng = Random.default_rng()
Expand Down Expand Up @@ -46,7 +46,7 @@ ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
prob = ODEProblem{false}(dudt_, u0, tspan, θ)
k = 10
train_loader = Flux.Data.DataLoader((ode_data, t), batchsize = k)
train_loader = DataLoader((ode_data, t), batchsize = k)
for (x, y) in train_loader
@show x
Expand Down Expand Up @@ -96,7 +96,7 @@ When training a neural network, we need to find the gradient with respect to our
For this example, we will use a very simple ordinary differential equation, newtons law of cooling. We can represent this in Julia like so.

```@example minibatch
using SciMLSensitivity
using SciMLSensitivity, MLUtils
using DifferentialEquations, Flux, Random, Plots
using IterTools: ncycle
Expand Down Expand Up @@ -152,7 +152,7 @@ ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
prob = ODEProblem{false}(dudt_, u0, tspan, θ)
k = 10
train_loader = Flux.Data.DataLoader((ode_data, t), batchsize = k)
train_loader = DataLoader((ode_data, t), batchsize = k)
for (x, y) in train_loader
@show x
Expand Down
9 changes: 5 additions & 4 deletions docs/src/examples/neural_ode/neural_ode_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,23 @@ end
function loss_n_ode(θ)
pred = predict_n_ode(θ)
loss = sum(abs2, ode_data .- pred)
loss, pred
return loss
end
loss_n_ode(θ)
callback = function (θ, l, pred; doplot = false) #callback function to observe training
callback = function (state, l; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
pred = predict_n_ode(state.u)
pl = scatter(t, ode_data[1, :], label = "data")
scatter!(pl, t, pred[1, :], label = "prediction")
display(plot(pl))
return false
end
# Display the ODE with the initial parameter values.
callback(θ, loss_n_ode(θ)...)
callback((; u = θ), loss_n_ode(θ)...)
# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()
Expand All @@ -143,7 +144,7 @@ result_neuralode = Optimization.solve(optprob,
maxiters = 300)
```

Notice that the advantage of this format is that we can use Optim's optimizers, like
Notice that the advantage of this format is that we can use other optimizers, like
`LBFGS` with a full `Chain` object, for all of Flux's neural networks, like
convolutional neural networks.

Expand Down
7 changes: 4 additions & 3 deletions docs/src/examples/neural_ode/simplechains.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,20 @@ end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, data .- pred)
return loss, pred
return loss
end
```

## Training

The next step is to minimize the loss, so that the NeuralODE gets trained. But in order to be able to do that, we have to be able to backpropagate through the NeuralODE model. Here the backpropagation through the neural network is the easy part, and we get that out of the box with any deep learning package(although not as fast as SimpleChains for the small nn case here). But we have to find a way to first propagate the sensitivities of the loss back, first through the ODE solver and then to the neural network.

The adjoint of a neural ODE can be calculated through the various AD algorithms available in SciMLSensitivity.jl. But working with [StaticArrays](https://docs.sciml.ai/StaticArrays/stable/) in SimpleChains.jl requires a special adjoint method as StaticArrays do not allow any mutation. All the adjoint methods make heavy use of in-place mutation to be performant with the heap allocated normal arrays. For our statically sized, stack allocated StaticArrays, in order to be able to compute the ODE adjoint we need to do everything out of place. Hence, we have specifically used `QuadratureAdjoint(autojacvec=ZygoteVJP())` adjoint algorithm in the solve call inside `predict_neuralode(p)` which computes everything out-of-place when u0 is a StaticArray. Hence, we can move forward with the training of the NeuralODE
The adjoint of a neural ODE can be calculated through the various AD algorithms available in SciMLSensitivity.jl. But working with [StaticArrays](https://juliaarrays.github.io/StaticArrays.jl/stable/) in SimpleChains.jl requires a special adjoint method as StaticArrays do not allow any mutation. All the adjoint methods make heavy use of in-place mutation to be performant with the heap allocated normal arrays. For our statically sized, stack allocated StaticArrays, in order to be able to compute the ODE adjoint we need to do everything out of place. Hence, we have specifically used `QuadratureAdjoint(autojacvec=ZygoteVJP())` adjoint algorithm in the solve call inside `predict_neuralode(p)` which computes everything out-of-place when u0 is a StaticArray. Hence, we can move forward with the training of the NeuralODE

```@example sc_neuralode
callback = function (state, l, pred; doplot = true)
callback = function (state, l; doplot = true)
display(l)
pred = predict_neuralode(state.u)
plt = scatter(tsteps, data[1, :], label = "data")
scatter!(plt, tsteps, pred[1, :], label = "prediction")
if doplot
Expand Down
5 changes: 3 additions & 2 deletions docs/src/examples/ode/second_order_adjoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
return loss
end
# Callback function to observe training
list_plots = []
iter = 0
callback = function (state, l, pred; doplot = false)
callback = function (state, l; doplot = false)
global list_plots, iter
if iter == 0
Expand All @@ -66,6 +66,7 @@ callback = function (state, l, pred; doplot = false)
display(l)
# plot current prediction against data
pred = predict_neuralode(state.u)
plt = scatter(tsteps, ode_data[1, :], label = "data")
scatter!(plt, tsteps, pred[1, :], label = "prediction")
push!(list_plots, plt)
Expand Down
6 changes: 3 additions & 3 deletions docs/src/examples/ode/second_order_neural.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ t = range(tspan[1], tspan[2], length = 20)
model = Chain(Dense(2, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(Random.default_rng(), model)
ps = ComponentArray(ps)
model = StatefulLuxLayer{true}(model, ps, st)
model = Lux.StatefulLuxLayer{true}(model, ps, st)
ff(du, u, p, t) = model(u, p)
prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, ps)
Expand All @@ -46,12 +46,12 @@ correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:
function loss_n_ode(p)
pred = predict(p)
sum(abs2, correct_pos .- pred[1:2, :]), pred
sum(abs2, correct_pos .- pred[1:2, :])
end
l1 = loss_n_ode(ps)
callback = function (state, l, pred)
callback = function (state, l)
println(l)
l < 0.01
end
Expand Down
14 changes: 10 additions & 4 deletions docs/src/examples/optimal_control/feedback_control.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ l = loss_univ(θ)
```@example udeneuralcontrol
list_plots = []
iter = 0
cb = function (state, l)
cb = function (state, l; makeplot = false)
global list_plots, iter
if iter == 0
Expand All @@ -71,9 +71,11 @@ cb = function (state, l)
println(l)
plt = plot(predict_univ(state.u)', ylim = (0, 6))
push!(list_plots, plt)
display(plt)
if makeplot
plt = plot(predict_univ(state.u)', ylim = (0, 6))
push!(list_plots, plt)
display(plt)
end
return false
end
```
Expand All @@ -84,3 +86,7 @@ optf = Optimization.OptimizationFunction((x, p) -> loss_univ(x), adtype)
optprob = Optimization.OptimizationProblem(optf, θ)
result_univ = Optimization.solve(optprob, PolyOpt(), callback = cb)
```

```@example udeneuralcontrol
cb(result_univ, result_univ.minimum; makeplot = true)
```
5 changes: 3 additions & 2 deletions docs/src/examples/optimal_control/optimal_control.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ of a local minimum. This looks like:

```@example neuraloptimalcontrol
using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random
OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random,
ForwardDiff
rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)
Expand Down Expand Up @@ -89,7 +90,7 @@ end
# Setup and run the optimization
loss1 = loss_adjoint(θ)
adtype = Optimization.AutoZygote()
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)
optprob = Optimization.OptimizationProblem(optf, θ)
Expand Down
26 changes: 14 additions & 12 deletions docs/src/examples/pde/pde_constrained.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,26 @@ end
## Defining Loss function
function loss(θ)
pred = predict(θ)
return sum(abs2.(predict(θ) .- arr_sol)), pred # Mean squared error
return sum(abs2.(predict(θ) .- arr_sol)) # Mean squared error
end
l, pred = loss(ps)
size(pred), size(sol), size(t) # Checking sizes
l = loss(ps)
size(sol), size(t) # Checking sizes
LOSS = [] # Loss accumulator
PRED = [] # prediction accumulator
PARS = [] # parameters accumulator
cb = function (θ, l, pred) #callback function to observe training
cb = function (st, l) #callback function to observe training
display(l)
pred = predict(st.u)
append!(PRED, [pred])
append!(LOSS, l)
append!(PARS, [θ])
append!(PARS, [st.u])
false
end
cb(ps, loss(ps)...) # Testing callback function
cb((; u = ps), loss(ps)) # Testing callback function
# Let see prediction vs. Truth
scatter(sol[:, end], label = "Truth", size = (800, 500))
Expand Down Expand Up @@ -228,11 +229,11 @@ use the **mean squared error**.
## Defining Loss function
function loss(θ)
pred = predict(θ)
return sum(abs2.(predict(θ) .- arr_sol)), pred # Mean squared error
return sum(abs2.(predict(θ) .- arr_sol)) # Mean squared error
end
l, pred = loss(ps)
size(pred), size(sol), size(t) # Checking sizes
l = loss(ps)
size(sol), size(t) # Checking sizes
```

#### Optimizer
Expand All @@ -251,15 +252,16 @@ LOSS = [] # Loss accumulator
PRED = [] # prediction accumulator
PARS = [] # parameters accumulator
cb = function (θ, l, pred) #callback function to observe training
cb = function (st, l) #callback function to observe training
display(l)
pred = predict(st.u)
append!(PRED, [pred])
append!(LOSS, l)
append!(PARS, [θ])
append!(PARS, [st.u])
false
end
cb(ps, loss(ps)...) # Testing callback function
cb((; u = ps), loss(ps)) # Testing callback function
```

### Plotting Prediction vs Ground Truth
Expand Down
11 changes: 5 additions & 6 deletions docs/src/examples/sde/SDE_control.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,18 @@ function loss(p_nn; alg = EM(), sensealg = BacksolveAdjoint(autojacvec = Reverse
W = sqrt(myparameters.dt) * randn(typeof(myparameters.dt), size(myparameters.ts)) #for 1 trajectory
W1 = cumsum([zero(myparameters.dt); W[1:(end - 1)]], dims = 1)
NG = CreateGrid(myparameters.ts, W1)
remake(prob,
p = pars,
u0 = u0tmp,
callback = callback,
noise = NG)
end
_prob = remake(prob, p = pars)
ensembleprob = EnsembleProblem(prob,
ensembleprob = EnsembleProblem(_prob,
prob_func = prob_func,
safetycopy = true)
_sol = solve(ensembleprob, alg, EnsembleThreads(),
_sol = solve(ensembleprob, alg, EnsembleSerial(),
sensealg = sensealg,
saveat = myparameters.tinterval,
dt = myparameters.dt,
Expand Down Expand Up @@ -293,7 +292,7 @@ visualization_callback((; u = p_nn), l; doplot = true)
# optimize the parameters for a few epochs with Adam on time span
# Setup and run the optimization
adtype = Optimization.AutoZygote()
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p_nn)
Expand Down Expand Up @@ -655,7 +654,7 @@ is computed under the hood in the SciMLSensitivity package.
```@example sdecontrol
# optimize the parameters for a few epochs with Adam on time span
# Setup and run the optimization
adtype = Optimization.AutoZygote()
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p_nn)
Expand Down
Loading

0 comments on commit bb5be4e

Please sign in to comment.