diff --git a/Project.toml b/Project.toml index f1d8501e..bac796ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" authors = ["Mike J Innes "] -version = "0.2.9" +version = "0.2.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" -Functors = "0.2.8, 0.3" +Functors = "0.3" Zygote = "0.6.40" julia = "1.6" diff --git a/docs/src/index.md b/docs/src/index.md index aaee5d7f..65b441bb 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,6 +1,6 @@ # Optimisers.jl -## Defining an optimisation rule +## An optimisation rule A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref). These act on one array of parameters: @@ -60,6 +60,11 @@ Notice that a completely new instance of the model is returned. Internally, this is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the tree formed by the model and update the parameters using the gradients. +There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, +but is free to mutate arrays within the old one for efficiency. +The method of `apply!` for each rule is likewise free to mutate arrays within its state; +they are defensively copied when this rule is used with `update`. + Optimisers.jl does not depend on any one automatic differentiation package, but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl). Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above. @@ -67,11 +72,6 @@ This `∇model` is another tree structure, rather than the dictionary-like objec Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference. -There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, -but is free to mutate arrays within the old one for efficiency. -The method of `apply!` you write is likewise free to mutate arrays within its state; -they are defensively copied when this rule is used with `update`. - ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) The main design difference of Lux is that the tree of parameters is separate from @@ -110,6 +110,57 @@ Besides the parameters stored in `params` and gradually optimised, any other mod is stored in `lux_state`. For simplicity this example does not show how to propagate the updated `lux_state` to the next iteration, see Lux's documentation. +## Non-`trainable` Parameters + +Optimisers.jl uses [Functors.jl](https://fluxml.ai/Functors.jl) to walk the `struct`s +making up the model, for which they must be annotated `@functor Type`. +By default optimisation will alter all [`isnumeric`](@ref) arrays. + +If some arrays of a particular layer should not be treated this way, +you can define a method for [`trainable`](@ref) + +```julia +struct Layer{T} + alpha::T + beta::T + length::Int +end +Layer(n::Int) = Layer(randn(n), zeros(n), n) + +Functors.@functor Layer + +# Both array fields will be, for example, moved to the GPU: +Functors.children(Layer(3)) # (alpha = [...], beta = [...], length) + +Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chidlren + +# Only the first field will be optimised: +st = Optimisers.setup(DecayDescent(0.1), Layer(3)) +``` + +## Tied Parameters + +If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this. +Within Optimisers.jl, `setup` will initialise once, and use the same `Leaf` for both parameters. +Then `update` will accumulate the gradient from both, and the updated model returned will have the tie maintained. + +```julia +using Flux, Optimisers + +enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10)); +dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh)); +model = Chain(; enc, dec) + +st = Optimisers.setup(Optimisers.Adam(), model); + +st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true +``` + +This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s. +It will not at present work for `reshape`d arrays, nor for immutable arrays such as those +from StaticArrays.jl. + + ## Obtaining a flat parameter vector Instead of a nested tree-like structure, sometimes is is convenient to have all the @@ -143,10 +194,11 @@ st, flat = Optimisers.update(st, flat, ∇flat) ``` Here `flat` contains only the 283 trainable parameters, while the non-trainable -ones are preserved inside `re`. +ones are preserved inside `re`, an object of type `Restructure`. When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref). By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl) are assumed to contain trainable parameters. +Tied parameters (arrays appearing in different layers) are included only once in `flat`. Lux stores only the trainable parameters in `params`. This can also be flattened to a plain `Vector` in the same way: diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 7c94e233..73c56305 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -16,6 +16,10 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, ClipGrad, ClipNorm, OptimiserChain +### +### one-array functions +### + """ Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient) @@ -57,6 +61,10 @@ julia> Optimisers.init(Momentum(), [1.0, 2.0]) """ init +### +### whole-model functions +### + """ Optimisers.setup(rule, model) -> tree diff --git a/src/interface.jl b/src/interface.jl index 1903d615..79d03396 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -166,7 +166,7 @@ end ### """ - @.. x = x + y + @.. x = y + z Sometimes in-place broadcasting macro, for use in `apply!` rules. If `maywrite(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.