diff --git a/Project.toml b/Project.toml index f1d8501e..bac796ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" authors = ["Mike J Innes "] -version = "0.2.9" +version = "0.2.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" -Functors = "0.2.8, 0.3" +Functors = "0.3" Zygote = "0.6.40" julia = "1.6" diff --git a/docs/src/index.md b/docs/src/index.md index aaee5d7f..65b441bb 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,6 +1,6 @@ # Optimisers.jl -## Defining an optimisation rule +## An optimisation rule A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref). These act on one array of parameters: @@ -60,6 +60,11 @@ Notice that a completely new instance of the model is returned. Internally, this is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the tree formed by the model and update the parameters using the gradients. +There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, +but is free to mutate arrays within the old one for efficiency. +The method of `apply!` for each rule is likewise free to mutate arrays within its state; +they are defensively copied when this rule is used with `update`. + Optimisers.jl does not depend on any one automatic differentiation package, but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl). Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above. @@ -67,11 +72,6 @@ This `∇model` is another tree structure, rather than the dictionary-like objec Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference. -There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, -but is free to mutate arrays within the old one for efficiency. -The method of `apply!` you write is likewise free to mutate arrays within its state; -they are defensively copied when this rule is used with `update`. - ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) The main design difference of Lux is that the tree of parameters is separate from @@ -110,6 +110,57 @@ Besides the parameters stored in `params` and gradually optimised, any other mod is stored in `lux_state`. For simplicity this example does not show how to propagate the updated `lux_state` to the next iteration, see Lux's documentation. +## Non-`trainable` Parameters + +Optimisers.jl uses [Functors.jl](https://fluxml.ai/Functors.jl) to walk the `struct`s +making up the model, for which they must be annotated `@functor Type`. +By default optimisation will alter all [`isnumeric`](@ref) arrays. + +If some arrays of a particular layer should not be treated this way, +you can define a method for [`trainable`](@ref) + +```julia +struct Layer{T} + alpha::T + beta::T + length::Int +end +Layer(n::Int) = Layer(randn(n), zeros(n), n) + +Functors.@functor Layer + +# Both array fields will be, for example, moved to the GPU: +Functors.children(Layer(3)) # (alpha = [...], beta = [...], length) + +Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chidlren + +# Only the first field will be optimised: +st = Optimisers.setup(DecayDescent(0.1), Layer(3)) +``` + +## Tied Parameters + +If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this. +Within Optimisers.jl, `setup` will initialise once, and use the same `Leaf` for both parameters. +Then `update` will accumulate the gradient from both, and the updated model returned will have the tie maintained. + +```julia +using Flux, Optimisers + +enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10)); +dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh)); +model = Chain(; enc, dec) + +st = Optimisers.setup(Optimisers.Adam(), model); + +st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true +``` + +This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s. +It will not at present work for `reshape`d arrays, nor for immutable arrays such as those +from StaticArrays.jl. + + ## Obtaining a flat parameter vector Instead of a nested tree-like structure, sometimes is is convenient to have all the @@ -143,10 +194,11 @@ st, flat = Optimisers.update(st, flat, ∇flat) ``` Here `flat` contains only the 283 trainable parameters, while the non-trainable -ones are preserved inside `re`. +ones are preserved inside `re`, an object of type `Restructure`. When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref). By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl) are assumed to contain trainable parameters. +Tied parameters (arrays appearing in different layers) are included only once in `flat`. Lux stores only the trainable parameters in `params`. This can also be flattened to a plain `Vector` in the same way: diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 5b61dbf1..8e8cb19f 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -1,6 +1,6 @@ module Optimisers -using Functors: functor, fmap, isleaf +using Functors: functor, fmap, isleaf, @functor, fmapstructure, children using LinearAlgebra include("interface.jl") @@ -16,6 +16,10 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, ClipGrad, ClipNorm, OptimiserChain +### +### one-array functions +### + """ Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient) @@ -57,6 +61,10 @@ julia> Optimisers.init(Momentum(), [1.0, 2.0]) """ init +### +### whole-model functions +### + """ Optimisers.setup(rule, model) -> tree @@ -69,7 +77,7 @@ or [`update!`](@ref). julia> m = (x = rand(3), y = (true, false), z = tanh); julia> Optimisers.setup(Momentum(), m) # same field names as m -(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = (nothing, nothing), z = nothing) +(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ()) ``` The recursion into structures uses Functors.jl, and any new `struct`s containing parameters @@ -82,7 +90,7 @@ julia> struct Layer; mat; fun; end julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]); julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored -(lay = nothing, vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0])) +(lay = (), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0])) julia> destructure(model) (Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2)) @@ -90,7 +98,7 @@ julia> destructure(model) julia> using Functors; @functor Layer # annotate this type as containing parameters julia> Optimisers.setup(Momentum(), model) -(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = nothing), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0])) +(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0])) julia> destructure(model) (Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6)) @@ -112,12 +120,12 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s julia> m = (x = Float32[1,2,3], y = tanh); julia> t = Optimisers.setup(Descent(0.1f0), m) -(x = Leaf(Descent{Float32}(0.1), nothing), y = nothing) +(x = Leaf(Descent{Float32}(0.1), nothing), y = ()) julia> g = (x = [1,1,1], y = nothing); # fake gradient julia> Optimisers.update(t, m, g) -((x = Leaf(Descent{Float32}(0.1), nothing), y = nothing), (x = Float32[0.9, 1.9, 2.9], y = tanh)) +((x = Leaf(Descent{Float32}(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh)) ``` """ update @@ -157,8 +165,8 @@ true julia> m # original should be discarded, may be mutated but no guarantee (x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0]) -julia> t # original state should likewise be discarded -(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.333333, 0.466667]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0])) +julia> t == t2 # original state is in fact guaranteed to be mutated +true ``` """ update! diff --git a/src/adjust.jl b/src/adjust.jl index d6f3647d..78b3d452 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -13,15 +13,15 @@ To change just the learning rate, provide a number `η::Real`. julia> m = (vec = rand(Float32, 2), fun = sin); julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero -(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = nothing) +(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = ()) julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient julia> st -(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing) +(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ()) julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentum untouched -(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = nothing) +(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ()) ``` To change other parameters, `adjust` also accepts keyword arguments matching the field diff --git a/src/interface.jl b/src/interface.jl index ae831906..79d03396 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,57 +1,120 @@ -using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero +using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent base(dx::Tangent) = backing(canonicalize(dx)) base(dx) = dx const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor} abstract type AbstractRule end -struct Leaf{R,S} +### +### setup +### + +mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing rule::R state::S end -function setup(rule, x; seen = Base.IdSet()) - rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup) +@functor Leaf + +Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b) + +function setup(rule::AbstractRule, model) + cache = IdDict() + tree = _setup(rule, model; cache) + isempty(cache) && @warn "setup found no trainable parameters in this model" + tree +end + +# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc. +function _setup(rule, x; cache) + haskey(cache, x) && return cache[x] if isnumeric(x) - x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry.")) - isbits(x) || push!(seen, x) - return Leaf(rule, init(rule, x)) - elseif isleaf(x) - return nothing + ℓ = Leaf(rule, init(rule, x)) + if isbits(x) + cache[nothing] = nothing # just to disable the warning + ℓ + else + cache[x] = ℓ + end else - return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x)) + map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x)) end end -subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) +function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! + ioc = IOContext(io, :compact => true) + print(ioc, "Leaf(", ℓ.rule, ", ") + show(ioc, ℓ.state) + print(ioc, ")") +end -update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x -update!(::Nothing, x, x̄s...) = nothing, x +### +### update +### -update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x -function update!(ℓ::Leaf, x, x̄s...) - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...) - Leaf(ℓ.rule, s′), subtract!(x, x̄′) +function update(tree, model, grad, higher...) + t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf + x′ = fmap(copy, model; exclude = maywrite) + update!(t′, x′, grad, higher...) end -update!(tree, x, ::Zero, ::Zero...) = tree, x -function update!(tree, x, x̄s...) - x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) - x′, re = functor(typeof(x), x) - xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) - map(first, xtree), re(map(last, xtree)) +function update!(tree, model, grad, higher...) + # First walk is to accumulate the gradient. This recursion visits every copy of + # shared leaves, but stops when branches are absent from the gradient: + grads = IdDict{Leaf, Any}() + _grads!(grads, tree, model, grad, higher...) + # Second walk is to update the model. The params cache indexed by (tree,x), + # so that identified Leafs can tie isbits parameters, but setup won't do that for you: + newmodel = _update!(tree, model; grads, params = IdDict()) + tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree. +end + +function _update!(tree, x; grads, params) + haskey(params, (tree,x)) && return params[(tree,x)] + isbits(tree) && return x # means () is not cached, and also (((),),) + x′, re = functor(x) + x′′ = re(map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′)) + if ismutable(x′′) + params[(tree,x)] = x′′ + else # no ties to preserve between immutable structs, right? + x′′ + end end +function _update!(ℓ::Leaf, x; grads, params) + haskey(params, (ℓ,x)) && return params[(ℓ,x)] + params[(ℓ,x)] = if haskey(grads, ℓ) + ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...) + subtract!(x, x̄′) + else + x # no gradient seen + end +end + +subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) -function update(tree, x, x̄s...) - t′ = fmap(copy, tree; exclude = maywrite) - x′ = fmap(copy, x; exclude = maywrite) - update!(t′, x′, x̄s...) +_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing +function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...) + x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s)) + dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. + nothing +end +_grads!(dict::IdDict, t, x, ::Zero...) = nothing +function _grads!(dict::IdDict, tree, x, x̄s...) + # The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from + # functor(typeof(tree), base(x̄)), for things like Transpose + x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) + x′, _ = functor(typeof(x), x) + foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) end # default all rules to first order calls apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx) +### +### sources of truth +### + """ isnumeric(x) -> Bool @@ -98,8 +161,12 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu map(c -> c in tr ? c : nothing, ch) end +### +### rule definition helpers +### + """ - @.. x = x + y + @.. x = y + z Sometimes in-place broadcasting macro, for use in `apply!` rules. If `maywrite(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`. @@ -135,11 +202,3 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc) onevalue(λ::T, x::AbstractArray{T}) where T = map(_ -> λ, x) onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x) - -function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! - ioc = IOContext(io, :compact => true) - print(ioc, "Leaf(", ℓ.rule, ", ") - show(ioc, ℓ.state) - print(io, ")") -end - diff --git a/test/runtests.jl b/test/runtests.jl index 59f2cb0a..51e76053 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,8 @@ using Optimisers: @.., @lazy Random.seed!(1) +# Fake "models" for testing + struct Foo; x; y; end Functors.@functor Foo Optimisers.trainable(x::Foo) = (x.y, x.x) @@ -13,19 +15,28 @@ struct TwoThirds a; b; c; end Functors.@functor TwoThirds (a, c) Optimisers.trainable(x::TwoThirds) = (a = x.a,) -struct DummyHigherOrder <: AbstractRule end +mutable struct MutTwo; x; y; end +Functors.@functor MutTwo + +# Simple rules for testing +struct DummyHigherOrder <: AbstractRule end Optimisers.init(::DummyHigherOrder, x::AbstractArray) = (ones(eltype(x), size(x)), zero(x)) - dummy_update_rule(st, p, dx, dx2) = @. p - (st[1] * dx + st[2] * dx2) function Optimisers.apply!(::DummyHigherOrder, state, x, dx, dx2) a, b = state @.. dx = a * dx + b * dx2 - return (a .+ 1, b .+ 1), dx end +struct BiRule <: Optimisers.AbstractRule end +Optimisers.init(o::BiRule, x::AbstractArray) = nothing +function Optimisers.apply!(o::BiRule, state, x, dx, dx2) + dx == dx2 || error("expected 1st & 2nd gradients to agree") + return state, dx +end + @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin @@ -48,6 +59,17 @@ end g4 = Tangent{typeof(m)}(g...) s4, m4 = Optimisers.update!(s, ([1.0, 2.0],), g4) @test m4[1] ≈ [1,2] .- 0.1 .* [25, 33] + + o5 = Momentum(0.1) + s5 = Optimisers.setup(o5, m) + + s6, m6 = Optimisers.update(s5, m, g) + @test s6[1].state ≈ [2.5, 3.3] + @test s5[1].state == [0, 0] # not mutated -- wrong on v0.2.9 + + s7, m7 = Optimisers.update!(s5, m, g) + @test s7[1].state === s5[1].state # same array + @test s7[1] === s5[1] # same Leaf end @testset "gradient clipping" begin @@ -225,40 +247,158 @@ end end @testset "tied weights" begin - ok = (1.0:3.0, sin, "abc", :abc) - m = (α = ok, β = rand(3), γ = ok) - m1 = (rand(3), m, rand(3)) - @test Optimisers.setup(AdamW(), m1) isa Tuple - m2 = (rand(3), m, rand(3), m, rand(3)) # illegal - @test_throws ArgumentError Optimisers.setup(AdamW(), m2) - end - - @testset "higher order interface" begin - w, b = rand(3, 4), rand(3) - - o = DummyHigherOrder() - psin = (w, b) - dxs = map(x -> rand(size(x)...), psin) - dx2s = map(x -> rand(size(x)...), psin) - stin = Optimisers.setup(o, psin) - stout, psout = Optimisers.update(stin, psin, dxs, dx2s) - - # hardcoded rule behavior for dummy rule - @test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1]) - @test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2]) - @test stout[1].state[1] == stin[1].state[1] .+ 1 - @test stout[2].state[2] == stin[2].state[2] .+ 1 - - # error if only given one derivative - @test_throws MethodError Optimisers.update(stin, psin, dxs) - - # first-order rules compose with second-order - ochain = OptimiserChain(Descent(0.1), o) - stin = Optimisers.setup(ochain, psin) - stout, psout = Optimisers.update(stin, psin, dxs, dx2s) - @test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1]) - @test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2]) - end + @testset "tuples" begin + twice = [1,2.0] + mtup = (twice, (copy(twice), twice)) # (tied (not tied, tied)) + + # simplest rule for which opt(g1) + opt(g2) != opt(g1 + g2) + stup = Optimisers.setup(Momentum(0.1), mtup) + gtup = ([3,3], ([10,10], [7,7])) # (g1, (g1 + g2, g2)) + + snew, mnew = Optimisers.update(stup, mtup, gtup) + @test mnew[1] ≈ mnew[2][1] # gradient was accumulated + @test mnew[2][2] === mnew[1] # and tie is not broken + + st3, mt3 = Optimisers.update(stup, mtup, ([3,3], nothing)) + @test mt3[1] ≈ [1,2] - 0.1 * [3,3] + @test mt3[2][2] === mt3[1] + + st4, mt4 = Optimisers.update(stup, mtup, (nothing, ([5,5], [7,7]))) + @test mt4[1] ≈ [1,2] - 0.1 * [7,7] + end + + @testset "named" begin + thrice = [3f0] + model = (a = (x = thrice, y = Float32[4,5,6], z = true), b = ((m = (0, 1, thrice),),), c = (x = Float32[7,8], y = thrice)) + tree = Optimisers.setup(Momentum(0.1, 0.9), model) + @test model.a.x === model.b[1].m[3] == model.c.y + + loss(x::Array) = sum(abs2, x) + loss(x::Number) = x^3 + loss(m) = sum(2 * loss(x) for x in m) + gradient(loss, model) + _, m2 = Optimisers.update(tree, model, gradient(loss, model)...) + @test m2.a.x === m2.b[1].m[3] == m2.c.y + + loss3(m) = sum(x isa Tuple ? 0 : 2 * loss(x) for x in m) + gradient(loss3, model) # truncates the b limb + _, m3 = Optimisers.update(tree, model, gradient(loss3, model)...) + @test m3.a.x === m3.b[1].m[3] == m3.c.y + end + + @testset "transpose" begin + mat = [1 2 3; 4 5 6.0] + bidir = (m = mat, f = log, t = transpose(mat), v = [7, 8, 9.0]) + bigrad, _ = gradient((m, x) -> sum(abs2, m.m * (m.f).(m.t*x .+ m.v)), bidir, [1, 0.1]) + @test bigrad.t isa Matrix # not a Transpose, that's the point here + + state = Optimisers.setup(Descent(0.1), bidir) + @test state.t.parent === state.m # successfully tied + + s2, b2 = Optimisers.update(state, bidir, bigrad) + @test b2.t.parent === b2.m # tie restored + @test b2.m ≈ bidir.m - 0.1 * (bigrad.m + transpose(bigrad.t)) # grad accumulated + + state = Optimisers.setup(OptimiserChain(ClipGrad(10), Descent(0.1), ClipGrad(10)), bidir) + s2, b2 = Optimisers.update(state, bidir, bigrad) + @test b2.t.parent === b2.m + @test b2.m ≈ bidir.m - 0.1 * clamp.((bigrad.m + transpose(bigrad.t)), -10, 10) + + # Similar, but now "primary" field is the transposed one: + tri = (a = transpose(mat), b = mat, c = transpose(mat), d = 4.0) + trigrad = gradient(m -> sum(abs2, m.a * (m.b * (m.c * [0.1, 1] .+ m.d) .- m.d)), tri)[1] + stri = Optimisers.setup(Descent(0.1), tri) + s3, t3 = Optimisers.update(stri, tri, trigrad) + @test t3.a.parent === t3.b === t3.c.parent + @test t3.a ≈ tri.a - 0.1 * (trigrad.a + trigrad.b' + trigrad.c) + + g4 = (a = Broadcast.broadcasted(+, mat', 1), b = nothing, c = @thunk(mat' .+ 1), d = nothing) + # Error: no constructors for type Any + @test_broken s4, t4 = Optimisers.update(stri, tri, g4) + end + + @testset "artificial" begin + # Interpret shared Leaf as implying shared parameters, even if this did not arise from shared arrays. + # No API for setting this at the moment, but can construct one by hand: + model = (a = SA[1,2.0], b = SA[1, 2.0], c = SA[1, 2.0], d = SA[1, 2.0]) + auto = Optimisers.setup(Momentum(0.1), model) + @test auto.a !== auto.b # not tied just by value + + trick = (a = auto.a, b = auto.a, c = auto.c, d= auto.d) # makes a & b tied + + trick2, model2 = Optimisers.update(trick, model, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10])) + trick3, model3 = Optimisers.update(trick2, model2, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10])) + + @test model3.a == model3.b == model3.d # same as having the gradients added + @test !(model3.a ≈ model3.c) + @test trick3.a === trick3.b # leaves remain shared + end + + @testset "mutable containers" begin + tmp = MutTwo([1.0], [2.0]) + model = (a=tmp, b=tmp, c=MutTwo(tmp.x, tmp.y)) + state = Optimisers.setup(Momentum(), model) + + @test model.a === model.b + @test model.a !== model.c # fields are identified, but struct is not + + @test state.a.x === state.b.x + @test state.a === state.b + @test state.a === state.c # unavoidable, but means we can't use leaf ID alone + + mgrad = (a=(x=[1], y=[10]), b=(x=[100], y=[1000]), c=(x=[1/3], y=[1/30])) + state2, model2 = Optimisers.update(state, model, mgrad) + + @test model2.a === model2.b # tie of MutTwo structs is restored + @test model2.a !== model2.c # but a new tie is not created + end + end # tied weights + + @testset "2nd-order interface" begin + @testset "BiRule" begin + m = (α = ([1.0], sin), γ = Float32[4,3,2]) + + # Special rule which requires this: + s = Optimisers.setup(BiRule(), m) + g = (α = ([0.1], ZeroTangent()), γ = [1,10,100],) + s1, m1 = Optimisers.update(s, m, g, g) + @test m1.α[1] == [0.9] + @test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g)) + + # Ordinary rule which doesn't need it: + s2 = Optimisers.setup(Adam(), m) + s3, m3 = Optimisers.update(s2, m, g) + s4, m4 = Optimisers.update(s2, m, g, g) + @test m3.γ == m4.γ + end + + @testset "DummyHigherOrder" begin + w, b = rand(3, 4), rand(3) + + o = DummyHigherOrder() + psin = (w, b) + dxs = map(x -> rand(size(x)...), psin) + dx2s = map(x -> rand(size(x)...), psin) + stin = Optimisers.setup(o, psin) + stout, psout = Optimisers.update(stin, psin, dxs, dx2s) + + # hardcoded rule behavior for dummy rule + @test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1]) + @test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2]) + @test stout[1].state[1] == stin[1].state[1] .+ 1 + @test stout[2].state[2] == stin[2].state[2] .+ 1 + + # error if only given one derivative + @test_throws MethodError Optimisers.update(stin, psin, dxs) + + # first-order rules compose with second-order + ochain = OptimiserChain(Descent(0.1), o) + stin = Optimisers.setup(ochain, psin) + stout, psout = Optimisers.update(stin, psin, dxs, dx2s) + @test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1]) + @test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2]) + end + end # 2nd-order end @testset verbose=true "Destructure" begin