Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow shared parameters, take III #106

Merged
merged 11 commits into from
Oct 13, 2022
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <[email protected]>"]
version = "0.2.9"
version = "0.2.10"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Functors = "0.2.8, 0.3"
Functors = "0.3"
Zygote = "0.6.40"
julia = "1.6"

Expand Down
66 changes: 59 additions & 7 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Optimisers.jl

## Defining an optimisation rule
## An optimisation rule

A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref).
These act on one array of parameters:
Expand Down Expand Up @@ -60,18 +60,18 @@ Notice that a completely new instance of the model is returned. Internally, this
is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the
tree formed by the model and update the parameters using the gradients.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` for each rule is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.

Optimisers.jl does not depend on any one automatic differentiation package,
but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl).
Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above.
This `∇model` is another tree structure, rather than the dictionary-like object from
Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.

There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state,
but is free to mutate arrays within the old one for efficiency.
The method of `apply!` you write is likewise free to mutate arrays within its state;
they are defensively copied when this rule is used with `update`.

## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)

The main design difference of Lux is that the tree of parameters is separate from
Expand Down Expand Up @@ -110,6 +110,57 @@ Besides the parameters stored in `params` and gradually optimised, any other mod
is stored in `lux_state`. For simplicity this example does not show how to propagate the
updated `lux_state` to the next iteration, see Lux's documentation.

## Non-`trainable` Parameters

Optimisers.jl uses [Functors.jl](https://fluxml.ai/Functors.jl) to walk the `struct`s
making up the model, for which they must be annotated `@functor Type`.
By default optimisation will alter all [`isnumeric`](@ref) arrays.

If some arrays of a particular layer should not be treated this way,
you can define a method for [`trainable`](@ref)

```julia
struct Layer{T}
alpha::T
beta::T
length::Int
end
Layer(n::Int) = Layer(randn(n), zeros(n), n)

Functors.@functor Layer

# Both array fields will be, for example, moved to the GPU:
Functors.children(Layer(3)) # (alpha = [...], beta = [...], length)

Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chidlren

# Only the first field will be optimised:
st = Optimisers.setup(DecayDescent(0.1), Layer(3))
```

## Tied Parameters

If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this.
Within Optimisers.jl, `setup` will initialise once, and use the same `Leaf` for both parameters.
Then `update` will accumulate the gradient from both, and the updated model returned will have the tie maintained.

```julia
using Flux, Optimisers

enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10));
dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh));
model = Chain(; enc, dec)

st = Optimisers.setup(Optimisers.Adam(), model);

st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true
```

This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s.
It will not at present work for `reshape`d arrays, nor for immutable arrays such as those
from StaticArrays.jl.


## Obtaining a flat parameter vector

Instead of a nested tree-like structure, sometimes is is convenient to have all the
Expand Down Expand Up @@ -143,10 +194,11 @@ st, flat = Optimisers.update(st, flat, ∇flat)
```

Here `flat` contains only the 283 trainable parameters, while the non-trainable
ones are preserved inside `re`.
ones are preserved inside `re`, an object of type `Restructure`.
When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref).
By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl)
are assumed to contain trainable parameters.
Tied parameters (arrays appearing in different layers) are included only once in `flat`.

Lux stores only the trainable parameters in `params`.
This can also be flattened to a plain `Vector` in the same way:
Expand Down
24 changes: 16 additions & 8 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Optimisers

using Functors: functor, fmap, isleaf
using Functors: functor, fmap, isleaf, @functor, fmapstructure, children
using LinearAlgebra

include("interface.jl")
Expand All @@ -16,6 +16,10 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, ClipGrad, ClipNorm, OptimiserChain

###
### one-array functions
###

"""
Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient)

Expand Down Expand Up @@ -57,6 +61,10 @@ julia> Optimisers.init(Momentum(), [1.0, 2.0])
"""
init

###
### whole-model functions
###

"""
Optimisers.setup(rule, model) -> tree

Expand All @@ -69,7 +77,7 @@ or [`update!`](@ref).
julia> m = (x = rand(3), y = (true, false), z = tanh);

julia> Optimisers.setup(Momentum(), m) # same field names as m
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = (nothing, nothing), z = nothing)
(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ())
```

The recursion into structures uses Functors.jl, and any new `struct`s containing parameters
Expand All @@ -82,15 +90,15 @@ julia> struct Layer; mat; fun; end
julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]);

julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored
(lay = nothing, vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))

julia> destructure(model)
(Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2))

julia> using Functors; @functor Layer # annotate this type as containing parameters

julia> Optimisers.setup(Momentum(), model)
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = nothing), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))
(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0]))

julia> destructure(model)
(Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6))
Expand All @@ -112,12 +120,12 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s
julia> m = (x = Float32[1,2,3], y = tanh);

julia> t = Optimisers.setup(Descent(0.1f0), m)
(x = Leaf(Descent{Float32}(0.1), nothing), y = nothing)
(x = Leaf(Descent{Float32}(0.1), nothing), y = ())

julia> g = (x = [1,1,1], y = nothing); # fake gradient

julia> Optimisers.update(t, m, g)
((x = Leaf(Descent{Float32}(0.1), nothing), y = nothing), (x = Float32[0.9, 1.9, 2.9], y = tanh))
((x = Leaf(Descent{Float32}(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh))
```
"""
update
Expand Down Expand Up @@ -157,8 +165,8 @@ true
julia> m # original should be discarded, may be mutated but no guarantee
(x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0])

julia> t # original state should likewise be discarded
(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.333333, 0.466667]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0]))
julia> t == t2 # original state is in fact guaranteed to be mutated
true
```
"""
update!
Expand Down
6 changes: 3 additions & 3 deletions src/adjust.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ To change just the learning rate, provide a number `η::Real`.
julia> m = (vec = rand(Float32, 2), fun = sin);

julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = nothing)
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = ())

julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient

julia> st
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing)
(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ())

julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentum untouched
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = nothing)
(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ())
```

To change other parameters, `adjust` also accepts keyword arguments matching the field
Expand Down
131 changes: 95 additions & 36 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,120 @@

using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero
using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent
base(dx::Tangent) = backing(canonicalize(dx))
base(dx) = dx
const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor}

abstract type AbstractRule end

struct Leaf{R,S}
###
### setup
###

mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing
rule::R
state::S
end

function setup(rule, x; seen = Base.IdSet())
rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup)
@functor Leaf

Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b)

function setup(rule::AbstractRule, model)
cache = IdDict()
tree = _setup(rule, model; cache)
isempty(cache) && @warn "setup found no trainable parameters in this model"
tree
end

# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc.
function _setup(rule, x; cache)
haskey(cache, x) && return cache[x]
if isnumeric(x)
x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry."))
isbits(x) || push!(seen, x)
return Leaf(rule, init(rule, x))
elseif isleaf(x)
return nothing
ℓ = Leaf(rule, init(rule, x))
if isbits(x)
cache[nothing] = nothing # just to disable the warning
else
cache[x] = ℓ
end
else
return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x))
map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
end
end

subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)
function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
ioc = IOContext(io, :compact => true)
print(ioc, "Leaf(", ℓ.rule, ", ")
show(ioc, ℓ.state)
print(ioc, ")")
end

update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x
update!(::Nothing, x, x̄s...) = nothing, x
###
### update
###

update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x
function update!(ℓ::Leaf, x, x̄s...)
s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...)
Leaf(ℓ.rule, s′), subtract!(x, x̄′)
function update(tree, model, grad, higher...)
t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf
x′ = fmap(copy, model; exclude = maywrite)
update!(t′, x′, grad, higher...)
end

update!(tree, x, ::Zero, ::Zero...) = tree, x
function update!(tree, x, x̄s...)
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, re = functor(typeof(x), x)
xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
map(first, xtree), re(map(last, xtree))
function update!(tree, model, grad, higher...)
# First walk is to accumulate the gradient. This recursion visits every copy of
# shared leaves, but stops when branches are absent from the gradient:
grads = IdDict{Leaf, Any}()
_grads!(grads, tree, model, grad, higher...)
# Second walk is to update the model. The params cache indexed by (tree,x),
# so that identified Leafs can tie isbits parameters, but setup won't do that for you:
newmodel = _update!(tree, model; grads, params = IdDict())
tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree.
end

function _update!(tree, x; grads, params)
haskey(params, (tree,x)) && return params[(tree,x)]
isbits(tree) && return x # means () is not cached, and also (((),),)
Copy link
Member

@ToucheSir ToucheSir Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does imply we will be caching almost every level of an average Flux model (since BitsType{NotBits, BitsTypes...} is not a bitstype). objectid being not the fastest function in the world, perhaps both cache lookup and insertion should be additionally guarded by ismutable(x).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered this too. For large ImmutableArrays this may eventually need something fancier. But for now I think every fmap walk does the same thing.

Copy link
Member

@ToucheSir ToucheSir Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I wasn't even thinking about those, but cases like JuliaLang/julia#43542. We're unlikely to see any truly pathological behaviour, but I have to imagine the single comparison ismutable makes is more efficient than the recursive hash function objectid uses.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I guess ismutable really is right here. For parameter arrays IIRC there was a concern that it tells you e.g. that PermutedDimsArray is immutable. But for known non-leaf types, maybe it's always right?

Copy link
Member

@ToucheSir ToucheSir Aug 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. PermutedDimsArray at least does implement functor, but you can always find an array wrapper which hasn't. Perhaps then the check should be isleaf instead? The isbits check is still useful either way.

Edit: I suppose isnumeric makes more sense since it forwards to isleaf already and setup guarantees only unfamiliar immutable wrappers of immutable arrays will get their own Leaf. Moving the isbits check up front also seems safe and could save a couple cycles on dict lookups.

function _update!(tree, x; grads, params)
  isbits(tree) && return x  # means () is not cached, and also (((),),)
  isnum = isnumeric(x)  
  isnum && haskey(params, (tree,x)) && return params[(tree,x)]
  children, re = functor(x)
  children′ = map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, children)
  x′ = re(children′)
  isnum ? (params[(tree,x)] = x′) : x′
end

It's likely this can be simplified, but I wanted to get something on the page first in case there are any unforeseen edge cases present in this formulation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think anything isnumeric should have a corresponding Leaf and hit the _update!(::Leaf, x; ...) method.

This one wants only to deal with mutable non-leaf things, like my mutable struct MutTwo example. Which makes me think that ismutable is fine -- we have Foo(MutTwo(Bar(Transpose(Array, then the Array is leaf, and the only level at which it's worthwhile for this method to cache anything is the MutTwo one. If this whole stack appears twice, a fresh new struct Foo cannot be distinguished from the old one.

x′, re = functor(x)
x′′ = re(map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
if ismutable(x′′)
params[(tree,x)] = x′′
else # no ties to preserve between immutable structs, right?
x′′
end
end
function _update!(ℓ::Leaf, x; grads, params)
haskey(params, (ℓ,x)) && return params[(ℓ,x)]
params[(ℓ,x)] = if haskey(grads, ℓ)
ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...)
subtract!(x, x̄′)
else
x # no gradient seen
end
end

subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄)

function update(tree, x, x̄s...)
t′ = fmap(copy, tree; exclude = maywrite)
x′ = fmap(copy, x; exclude = maywrite)
update!(t′, x′, x̄s...)
_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing
function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...)
x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s))
dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible.
ToucheSir marked this conversation as resolved.
Show resolved Hide resolved
nothing
end
_grads!(dict::IdDict, t, x, ::Zero...) = nothing
function _grads!(dict::IdDict, tree, x, x̄s...)
# The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from
# functor(typeof(tree), base(x̄)), for things like Transpose
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, _ = functor(typeof(x), x)
foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
end

# default all rules to first order calls
apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx)

###
### sources of truth
###

"""
isnumeric(x) -> Bool

Expand Down Expand Up @@ -98,8 +161,12 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu
map(c -> c in tr ? c : nothing, ch)
end

###
### rule definition helpers
###

"""
@.. x = x + y
@.. x = y + z

Sometimes in-place broadcasting macro, for use in `apply!` rules.
If `maywrite(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`.
Expand Down Expand Up @@ -135,11 +202,3 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc)

onevalue(λ::T, x::AbstractArray{T}) where T = map(_ -> λ, x)
onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x)

function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type!
ioc = IOContext(io, :compact => true)
print(ioc, "Leaf(", ℓ.rule, ", ")
show(ioc, ℓ.state)
print(io, ")")
end

Loading