Skip to content

Commit

Permalink
update minibatch and sophia docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 19, 2024
1 parent 1947896 commit f29d5d8
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 39 deletions.
23 changes: 2 additions & 21 deletions docs/src/optimization_packages/optimisers.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,7 @@ Pkg.add("OptimizationOptimisers");
In addition to the optimisation algorithms provided by the Optimisers.jl package this subpackage
also provides the Sophia optimisation algorithm.

## Local Unconstrained Optimizers

- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information
in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.

+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`

+ `η` is the learning rate
+ `βs` are the decay of momentums
+ `ϵ` is the epsilon value
+ `λ` is the weight decay parameter
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
+ `ρ` is the momentum
+ Defaults:

* `η = 0.001`
* `βs = (0.9, 0.999)`
* `ϵ = 1e-8`
* `λ = 0.1`
* `k = 10`
* `ρ = 0.04`
## List of optimizers

- [`Optimisers.Descent`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Descent): **Classic gradient descent optimizer with learning rate**

Expand All @@ -42,6 +22,7 @@ also provides the Sophia optimisation algorithm.
+ Defaults:

* `η = 0.1`

- [`Optimisers.Momentum`](https://fluxml.ai/Optimisers.jl/dev/api/#Optimisers.Momentum): **Classic gradient descent optimizer with learning rate and momentum**

+ `solve(problem, Momentum(η, ρ))`
Expand Down
55 changes: 54 additions & 1 deletion docs/src/optimization_packages/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,35 @@ There are some solvers that are available in the Optimization.jl package directl

## Methods

`LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.
- `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.

This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl.

- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.

+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`

+ `η` is the learning rate
+ `βs` are the decay of momentums
+ `ϵ` is the epsilon value
+ `λ` is the weight decay parameter
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
+ `ρ` is the momentum
+ Defaults:

* `η = 0.001`
* `βs = (0.9, 0.999)`
* `ϵ = 1e-8`
* `λ = 0.1`
* `k = 10`
* `ρ = 0.04`

## Examples

### Unconstrained rosenbrock problem

```@example L-BFGS
using Optimization, Zygote
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
Expand All @@ -27,6 +47,7 @@ sol = solve(prob, Optimization.LBFGS())
### With nonlinear and bounds constraints

```@example L-BFGS
function con2_c(res, x, p)
res .= [x[1]^2 + x[2]^2, (x[2] * sin(x[1]) + x[1]) - 5]
end
Expand All @@ -37,3 +58,35 @@ prob = OptimizationProblem(optf, x0, p, lcons = [1.0, -Inf],
ub = [1.0, 1.0])
res = solve(prob, Optimization.LBFGS(), maxiters = 100)
```

### Train NN with Sophia

```@example Sophia
using Optimization, Lux, Zygote, MLUtils, Statistics, Plots
x = rand(10000)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 100)
# Define the neural network
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)
function callback(state, l)
state.iter % 25 == 1 && @show "Iteration: %5d, Loss: %.6e\n" state.iter l
return l < 1e-1 ## Terminate if loss is small
end
function loss(ps, data)
ypred = [smodel([data[1][i]], ps)[1] for i in eachindex(data[1])]
return sum(abs2, ypred .- data[2])
end
optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)
res = Optimization.solve(prob, Optimization.Sophia(), callback = callback)
```
25 changes: 14 additions & 11 deletions docs/src/tutorials/minibatch.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Data Iterators and Minibatching

It is possible to solve an optimization problem with batches using a `Flux.Data.DataLoader`, which is passed to `Optimization.solve` with `ncycles`. All data for the batches need to be passed as a tuple of vectors.
It is possible to solve an optimization problem with batches using a `MLUtils.DataLoader`, which is passed to `Optimization.solve` with `ncycles`. All data for the batches need to be passed as a tuple of vectors.

!!! note

This example uses the OptimizationOptimisers.jl package. See the
[Optimisers.jl page](@ref optimisers) for details on the installation and usage.

```@example
using Flux, Optimization, OptimizationOptimisers, OrdinaryDiffEq, SciMLSensitivity
```@example minibatch
using Lux, Optimization, OptimizationOptimisers, OrdinaryDiffEq, SciMLSensitivity, MLUtils
function newtons_cooling(du, u, p, t)
temp = u[1]
Expand All @@ -21,14 +22,16 @@ function true_sol(du, u, p, t)
newtons_cooling(du, u, true_p, t)
end
ann = Chain(Dense(1, 8, tanh), Dense(8, 1, tanh))
pp, re = Flux.destructure(ann)
model = Chain(Dense(1, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)
function dudt_(u, p, t)
re(p)(u) .* u
smodel(u, p) .* u
end
callback = function (state, l, pred; doplot = false) #callback function to observe training
function callback(state, l, pred; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
if doplot
Expand All @@ -53,21 +56,21 @@ function predict_adjoint(fullp, time_batch)
Array(solve(prob, Tsit5(), p = fullp, saveat = time_batch))
end
function loss_adjoint(fullp, batch, time_batch)
function loss_adjoint(fullp, data)
batch, time_batch = data
pred = predict_adjoint(fullp, time_batch)
sum(abs2, batch .- pred), pred
end
k = 10
# Pass the data for the batches as separate vectors wrapped in a tuple
train_loader = Flux.Data.DataLoader((ode_data, t), batchsize = k)
train_loader = MLUtils.DataLoader((ode_data, t), batchsize = k)
numEpochs = 300
l1 = loss_adjoint(pp, train_loader.data[1], train_loader.data[2])[1]
optfun = OptimizationFunction(
(θ, p, batch, time_batch) -> loss_adjoint(θ, batch,
time_batch),
loss_adjoint,
Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, pp)
using IterTools: ncycle
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
cache.solver_args.epochs
end

maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
maxiters = Optimization._check_and_convert_maxiters(maxiters)
if maxiters === nothing
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
end
Expand Down
9 changes: 4 additions & 5 deletions test/minibatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ optfun = OptimizationFunction(loss_adjoint,
Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, pp, train_loader)

# res1 = Optimization.solve(optprob,
# Optimization.Sophia(; η = 0.5,
# λ = 0.0), callback = callback,
# maxiters = 1000)
# @test 10res1.objective < l1
res1 = Optimization.solve(optprob,
Optimization.Sophia(), callback = callback,
maxiters = 1000)
@test 10res1.objective < l1

optfun = OptimizationFunction(loss_adjoint,
Optimization.AutoForwardDiff())
Expand Down

0 comments on commit f29d5d8

Please sign in to comment.