Skip to content

Commit

Permalink
docs etc
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 12, 2022
1 parent 37521c8 commit 0d6619a
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 10 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <[email protected]>"]
version = "0.2.9"
version = "0.2.10"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Functors = "0.2.8, 0.3"
Functors = "0.3"
Zygote = "0.6.40"
julia = "1.6"

Expand Down
66 changes: 59 additions & 7 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Optimisers.jl

## Defining an optimisation rule
## An optimisation rule

A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref).
These act on one array of parameters:
Expand Down Expand Up @@ -60,18 +60,18 @@ Notice that a completely new instance of the model is returned. Internally, this
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
tree formed by the model and update the parameters using the gradients.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.

Optimisers.jl does not depend on any one automatic differentiation package,
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above.
This `∇model` is another tree structure, rather than the dictionary-like object from
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` you write is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.

## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)

The main design difference of Lux is that the tree of parameters is separate from
Expand Down Expand Up @@ -110,6 +110,57 @@ Besides the parameters stored in `params` and gradually optimised, any other mod
is stored in `lux_state`. For simplicity this example does not show how to propagate the
updated `lux_state` to the next iteration, see Lux's documentation.

## Non-`trainable` Parameters

Optimisers.jl uses [Functors.jl](https://fluxml.ai/Functors.jl) to walk the `struct`s
making up the model, for which they must be annotated `@functor Type`.
By default optimisation will alter all [`isnumeric`](@ref) arrays.

If some arrays of a particular layer should not be treated this way,
you can define a method for [`trainable`](@ref)

```julia
struct Layer{T}
alpha::T
beta::T
length::Int
end
Layer(n::Int) = Layer(randn(n), zeros(n), n)

Functors.@functor Layer

# Both array fields will be, for example, moved to the GPU:
Functors.children(Layer(3)) # (alpha = [...], beta = [...], length)

Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chidlren

# Only the first field will be optimised:
st = Optimisers.setup(DecayDescent(0.1), Layer(3))
```

## Tied Parameters

If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this.
Within Optimisers.jl, `setup` will initialise once, and use the same `Leaf` for both parameters.
Then `update` will accumulate the gradient from both, and the updated model returned will have the tie maintained.

```julia
using Flux, Optimisers

enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10));
dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh));
model = Chain(; enc, dec)

st = Optimisers.setup(Optimisers.Adam(), model);

st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
```

This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s.
It will not at present work for `reshape`d arrays, nor for immutable arrays such as those
from StaticArrays.jl.


## Obtaining a flat parameter vector

Instead of a nested tree-like structure, sometimes is is convenient to have all the
Expand Down Expand Up @@ -143,10 +194,11 @@ st, flat = Optimisers.update(st, flat, ∇flat)
```

Here `flat` contains only the 283 trainable parameters, while the non-trainable
ones are preserved inside `re`.
ones are preserved inside `re`, an object of type `Restructure`.
When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref).
By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl)
are assumed to contain trainable parameters.
Tied parameters (arrays appearing in different layers) are included only once in `flat`.

Lux stores only the trainable parameters in `params`.
This can also be flattened to a plain `Vector` in the same way:
Expand Down
8 changes: 8 additions & 0 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, ClipGrad, ClipNorm, OptimiserChain

###
### one-array functions
###

"""
Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient)
Expand Down Expand Up @@ -57,6 +61,10 @@ julia> Optimisers.init(Momentum(), [1.0, 2.0])
"""
init

###
### whole-model functions
###

"""
Optimisers.setup(rule, model) -> tree
Expand Down
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ end
###

"""
@.. x = x + y
@.. x = y + z
Sometimes in-place broadcasting macro, for use in `apply!` rules.
If `maywrite(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
Expand Down

0 comments on commit 0d6619a

Please sign in to comment.