Skip to content

Commit

Permalink
Bump compats and update tutorials for Optimization v4
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 24, 2024
1 parent f1843f3 commit 394806c
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 89 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ LuxLib = "1.2"
NNlib = "0.9.22"
OneHotArrays = "0.2.5"
Optimisers = "0.3"
Optimization = "3.25.0"
OptimizationOptimJL = "0.3.0"
OptimizationOptimisers = "0.2.1"
Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEq = "6.76.0"
Printf = "1.10"
Random = "1.10"
Expand Down
8 changes: 4 additions & 4 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ MLUtils = "0.4"
NNlib = "0.9"
OneHotArrays = "0.2"
Optimisers = "0.3"
Optimization = "3.9"
OptimizationOptimJL = "0.2, 0.3"
OptimizationOptimisers = "0.2"
OptimizationPolyalgorithms = "0.2"
Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationOptimisers = "0.3"
OptimizationPolyalgorithms = "0.3"
OrdinaryDiffEq = "6.31"
Plots = "1.36"
Printf = "1"
Expand Down
34 changes: 17 additions & 17 deletions docs/src/examples/augmented_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ function plot_contour(model, ps, st, npoints = 300)
return contour(x, y, sol; fill = true, linewidth = 0.0)
end
loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2)
loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2)
dataloader = concentric_sphere(
2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; batch_size = 256)
iter = 0
cb = function (ps, l)
cb = function (state, l)
global iter
iter += 1
if iter % 10 == 0
Expand All @@ -87,15 +87,15 @@ end
model, ps, st = construct_model(1, 2, 64, 0)
opt = OptimizationOptimisers.Adam(0.005)
loss_node(model, dataloader.data[1], dataloader.data[2], ps, st)
loss_node(model, (dataloader.data[1], dataloader.data[2]), ps, st)
println("Training Neural ODE")
optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)
plt_node = plot_contour(model, res.u, st)
Expand All @@ -106,10 +106,10 @@ println()
println("Training Augmented Neural ODE")
optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)
plot_contour(model, res.u, st)
```
Expand Down Expand Up @@ -229,7 +229,7 @@ We use the L2 distance between the model prediction `model(x)` and the actual pr
optimization objective.

```@example augneuralode
loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2)
loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2)
```

#### Dataset
Expand All @@ -248,7 +248,7 @@ Additionally, we define a callback function which displays the total loss at spe

```@example augneuralode
iter = 0
cb = function (ps, l)
cb = function (state, l)
global iter
iter += 1
if iter % 10 == 0
Expand Down Expand Up @@ -276,10 +276,10 @@ for `20` epochs.
model, ps, st = construct_model(1, 2, 64, 0)
optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)
plot_contour(model, res.u, st)
```
Expand All @@ -297,10 +297,10 @@ a function which can be expressed by the neural ode. For more details and proofs
model, ps, st = construct_model(1, 2, 64, 1)
optfunc = OptimizationFunction(
(x, p, data, target) -> loss_node(model, data, target, x, st),
(x, data) -> loss_node(model, data, x, st),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev)
res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb)
optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader)
res = solve(optprob, opt; callback = cb, epochs = 1000)
plot_contour(model, res.u, st)
```
Expand Down
25 changes: 11 additions & 14 deletions docs/src/examples/hamiltonian_nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we

```@example hamiltonian
using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
ComponentArrays, Optimization, OptimizationOptimisers, IterTools
ComponentArrays, Optimization, OptimizationOptimisers, MLUtils
t = range(0.0f0, 1.0f0; length = 1024)
π_32 = Float32(π)
Expand All @@ -87,12 +87,8 @@ dpdt = -2π_32 .* q_t
data = cat(q_t, p_t; dims = 1)
target = cat(dqdt, dpdt; dims = 1)
B = 256
NEPOCHS = 100
dataloader = ncycle(
((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
for i in 1:(size(data, 2) ÷ B)),
NEPOCHS)
NEPOCHS = 1000
dataloader = DataLoader((data, target); batchsize = B)
```

### Training the HamiltonianNN
Expand All @@ -103,24 +99,25 @@ We parameterize the with a small MultiLayered Perceptron. HNNs are trained by o
hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote())
ps, st = Lux.setup(Xoshiro(0), hnn)
ps_c = ps |> ComponentArray
hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st)
opt = OptimizationOptimisers.Adam(0.01f0)
function loss_function(ps, data, target)
pred, st_ = hnn(data, ps, st)
function loss_function(ps, databatch)
(data, target) = databatch
pred = hnn_stateful(data, ps)
return mean(abs2, pred .- target), pred
end
function callback(ps, loss, pred)
function callback(state, loss)
println("[Hamiltonian NN] Loss: ", loss)
return false
end
opt_func = OptimizationFunction(
(ps, _, data, target) -> loss_function(ps, data, target), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_c)
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)
res = solve(opt_prob, opt, dataloader; callback)
res = solve(opt_prob, opt; callback, epochs = NEPOCHS)
ps_trained = res.u
```
Expand Down
14 changes: 7 additions & 7 deletions docs/src/examples/mnist_conv_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,20 @@ end
# burn in accuracy
accuracy(m, ((img, lab),), ps, st)
function loss_function(ps, x, y)
function loss_function(ps, data)
(x, y) = data
pred, _ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
return logitcrossentropy(pred, y)
end
# burn in loss
loss_function(ps, img, lab)
loss_function(ps, (img, lab))
opt = OptimizationOptimisers.Adam(0.005)
iter = 0
opt_func = OptimizationFunction(
(ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps);
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader);
function callback(ps, l, pred)
global iter += 1
Expand All @@ -112,7 +112,7 @@ function callback(ps, l, pred)
end
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; maxiters = 5, callback)
res = Optimization.solve(opt_prob, opt; maxiters = 5, callback)
acc = accuracy(m, dataloader, res.u, st)
acc # hide
```
34 changes: 17 additions & 17 deletions docs/src/examples/mnist_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,29 +81,29 @@ end
accuracy(m, ((x_train1, y_train1),), ps, st) # burn in accuracy
function loss_function(ps, x, y)
function loss_function(ps, data)
(x, y) = data
pred, st_ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
return logitcrossentropy(pred, y)
end
loss_function(ps, x_train1, y_train1) # burn in loss
opt = OptimizationOptimisers.Adam(0.05)
iter = 0
opt_func = OptimizationFunction(
(ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps)
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader)
function callback(ps, l, pred)
function callback(state, l)
global iter += 1
iter % 10 == 0 &&
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
return false
end
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
res = Optimization.solve(opt_prob, opt; callback, maxiters = 5)
accuracy(m, dataloader, res.u, st)
```

Expand Down Expand Up @@ -285,12 +285,13 @@ final output of our model. `logitcrossentropy` takes in the prediction from our
model `model(x)` and compares it to actual output `y`:

```@example mnist
function loss_function(ps, x, y)
function loss_function(ps, data)
(x, y) = data
pred, st_ = m(x, ps, st)
return logitcrossentropy(pred, y), pred
return logitcrossentropy(pred, y)
end
loss_function(ps, x_train1, y_train1) # burn in loss
loss_function(ps, (x_train1, y_train1)) # burn in loss
```

#### Optimizer
Expand All @@ -309,14 +310,13 @@ This callback function is used to print both the training and testing accuracy a
```@example mnist
iter = 0
opt_func = OptimizationFunction(
(ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps)
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader)
function callback(ps, l, pred)
function callback(state, l)
global iter += 1
iter % 10 == 0 &&
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
@info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
return false
end
```
Expand All @@ -329,6 +329,6 @@ for Neural ODE is given by `nn_ode.p`:

```@example mnist
# Train the NN-ODE and monitor the loss and weights.
res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5)
res = Optimization.solve(opt_prob, opt; callback, maxiters = 5)
accuracy(m, dataloader, res.u, st)
```
19 changes: 12 additions & 7 deletions docs/src/examples/multiple_shooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))
nn = Chain(x -> x .^ 3, Dense(2, 16, tanh), Dense(16, 2))
p_init, st = Lux.setup(rng, nn)
ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)
neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps)
prob_node = ODEProblem((u, p, t) -> nn(u, p, st)[1], u0, tspan, ComponentArray(p_init))
Expand All @@ -62,14 +65,13 @@ end
anim = Plots.Animation()
iter = 0
callback = function (p, l, preds; doplot = true)
function callback(state, l; doplot = true, prob_node = prob_node)
display(l)
global iter
iter += 1
if doplot && iter % 1 == 0
# plot the original data
plt = scatter(tsteps, ode_data[1, :]; label = "Data")
# plot the different predictions for individual shoot
plot_multiple_shoot(plt, preds, group_size)
Expand All @@ -83,23 +85,26 @@ end
group_size = 3
continuity_term = 200
l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
function loss_function(data, pred)
return sum(abs2, data - pred)
end
ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)
function loss_multiple_shooting(p)
ps = ComponentArray(p, pax)
return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
global preds = currpred
return loss
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback)
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000)
gif(anim, "multiple_shooting.gif"; fps = 15)
```

Expand Down
Loading

0 comments on commit 394806c

Please sign in to comment.