From 9b28112f8f35e537717d0fe8bce517311c0b1d98 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 27 Aug 2022 20:57:47 -0400 Subject: [PATCH 01/11] allow shared parameters, take III Co-authored-by: Brian Chen --- src/Optimisers.jl | 6 +-- src/adjust.jl | 4 +- src/interface.jl | 101 ++++++++++++++++++++++++++++++---------------- test/runtests.jl | 86 ++++++++++++++++++++++++++++++++++++--- 4 files changed, 152 insertions(+), 45 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 5b61dbf1..7c94e233 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -1,6 +1,6 @@ module Optimisers -using Functors: functor, fmap, isleaf +using Functors: functor, fmap, isleaf, @functor, fmapstructure, children using LinearAlgebra include("interface.jl") @@ -157,8 +157,8 @@ true julia> m # original should be discarded, may be mutated but no guarantee (x = Float32[0.6666666, 1.5333333], y = Float32[4.0, 5.0]) -julia> t # original state should likewise be discarded -(x = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.333333, 0.466667]), y = Leaf(Momentum{Float64}(0.0333333, 0.9), Float32[0.0, 0.0])) +julia> t == t2 # original state is in fact guaranteed to be mutated +true ``` """ update! diff --git a/src/adjust.jl b/src/adjust.jl index d6f3647d..3ec2f220 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -47,8 +47,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree) adjust(::Nothing, ::Real) = nothing adjust(::Nothing; kw...) = nothing -adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state) -adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state) +adjust(ℓ::Leaf, eta::Real) = ℓ.frozen ? ℓ : Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen) +adjust(ℓ::Leaf; kw...) = ℓ.frozen ? ℓ : Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen) """ diff --git a/src/interface.jl b/src/interface.jl index ae831906..236b1ed9 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,50 +1,83 @@ -using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero +using ChainRulesCore: canonicalize, backing, Tangent, AbstractZero, ZeroTangent base(dx::Tangent) = backing(canonicalize(dx)) base(dx) = dx const Zero = Union{Nothing, AbstractZero} # Union{Zygote, Diffractor} abstract type AbstractRule end -struct Leaf{R,S} +### +### setup +### + +mutable struct Leaf{R,S} rule::R state::S + frozen::Bool end -function setup(rule, x; seen = Base.IdSet()) - rule isa AbstractRule || Base.depwarn("In future, all optimisation rules should be <: AbstractRule", :setup) - if isnumeric(x) - x in seen && throw(ArgumentError("Optimisers.jl does not at present handle tied weights, sorry.")) - isbits(x) || push!(seen, x) - return Leaf(rule, init(rule, x)) - elseif isleaf(x) - return nothing - else - return map(xᵢ -> setup(rule, xᵢ; seen), _trainable(x)) +@functor Leaf + +Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b) + +function setup(rule::AbstractRule, model) + cnt = Ref(0) + # Rely on Functors to identify shared arrays, they will share a Leaf in this tree: + tree = fmapstructure(model, exclude = isnumeric) do x + cnt[] += 1 + Leaf(rule, init(rule, x), false) end + cnt[] == 0 && @warn "setup found no parameters in the given model" + tree end -subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) - -update!(::Nothing, x, ::Zero, ::Zero...) = nothing, x -update!(::Nothing, x, x̄s...) = nothing, x +function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! + ioc = IOContext(io, :compact => true) + print(ioc, "Leaf(", ℓ.rule, ", ") + show(ioc, ℓ.state) + print(ioc, ", ", ℓ.frozen, ")") +end -update!(ℓ::Leaf, x, ::Zero, ::Zero...) = ℓ, x -function update!(ℓ::Leaf, x, x̄s...) - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, base.(x̄s)...) - Leaf(ℓ.rule, s′), subtract!(x, x̄′) +### +### update +### + +function update!(tree, model, grad) + # First walk is to accumulate the gradient. This recursion visits every copy of + # shared leaves, but stops when branches are absent from the gradient: + dict = IdDict{Leaf, Any}() + grads!(dict, tree, model, grad) + # Second walk is to update the model, using same fmap walk as setup, thus each Leaf exactly once: + newmodel = fmap(model, tree; exclude = isnumeric) do x, ℓ + ℓ isa Leaf || error("this state does not match the model, expected a Leaf here") + ℓ.frozen && return x + haskey(dict, ℓ) || return x + s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) + ℓ.state = s′ # to get state out of here, rely on mutability of Leaf + subtract!(x, x̄′) + end + tree, newmodel # note that tree is guaranteed to be updated end -update!(tree, x, ::Zero, ::Zero...) = tree, x -function update!(tree, x, x̄s...) +subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) + +grads!(dict::IdDict, ℓ::Leaf, x, ::Zero) = nothing +function grads!(dict::IdDict, ℓ::Leaf, x, x̄) + x̄₀ = get(dict, ℓ, false) + dict[ℓ] = Broadcast.broadcasted(+, x̄, x̄₀) + nothing +end +grads!(dict::IdDict, t, x, ::Zero) = nothing +function grads!(dict::IdDict, tree, x, x̄s...) + # The only reason grads! takes model is that functor(typeof(x), base(x̄)) may differ from + # functor(typeof(tree), base(x̄)), for things like Transpose x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) - x′, re = functor(typeof(x), x) - xtree = map((stᵢ, xᵢ, x̄sᵢ...) -> update!(stᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) - map(first, xtree), re(map(last, xtree)) + x′, _ = functor(typeof(x), x) + foreach((tᵢ, xᵢ, x̄sᵢ...) -> grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) end function update(tree, x, x̄s...) - t′ = fmap(copy, tree; exclude = maywrite) + t′ = fmap(copy, tree; exclude = maywrite) # goes inside Leaf x′ = fmap(copy, x; exclude = maywrite) update!(t′, x′, x̄s...) end @@ -52,6 +85,10 @@ end # default all rules to first order calls apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx) +### +### sources of truth +### + """ isnumeric(x) -> Bool @@ -98,6 +135,10 @@ function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tu map(c -> c in tr ? c : nothing, ch) end +### +### rule definition helpers +### + """ @.. x = x + y @@ -135,11 +176,3 @@ Broadcast.materialize(x::Lazy) = Broadcast.instantiate(x.bc) onevalue(λ::T, x::AbstractArray{T}) where T = map(_ -> λ, x) onevalue(λ, x::AbstractArray{T}) where T = onevalue(convert(float(T), λ), x) - -function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! - ioc = IOContext(io, :compact => true) - print(ioc, "Leaf(", ℓ.rule, ", ") - show(ioc, ℓ.state) - print(io, ")") -end - diff --git a/test/runtests.jl b/test/runtests.jl index 59f2cb0a..7b54f24d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,6 +48,17 @@ end g4 = Tangent{typeof(m)}(g...) s4, m4 = Optimisers.update!(s, ([1.0, 2.0],), g4) @test m4[1] ≈ [1,2] .- 0.1 .* [25, 33] + + o5 = Momentum(0.1) + s5 = Optimisers.setup(o5, m) + + s6, m6 = Optimisers.update(s5, m, g) + @test s6[1].state ≈ [2.5, 3.3] + @test s5[1].state == [0, 0] # not mutated -- wrong on v0.2.9 + + s7, m7 = Optimisers.update!(s5, m, g) + @test s7[1].state === s5[1].state # same array + @test s7[1] === s5[1] # same Leaf end @testset "gradient clipping" begin @@ -225,12 +236,75 @@ end end @testset "tied weights" begin - ok = (1.0:3.0, sin, "abc", :abc) - m = (α = ok, β = rand(3), γ = ok) - m1 = (rand(3), m, rand(3)) - @test Optimisers.setup(AdamW(), m1) isa Tuple - m2 = (rand(3), m, rand(3), m, rand(3)) # illegal - @test_throws ArgumentError Optimisers.setup(AdamW(), m2) + @testset "tuples" begin + twice = [1,2.0] + mtup = (twice, (copy(twice), twice)) # (tied (not tied, tied)) + + # simplest rule for which opt(g1) + opt(g2) != opt(g1 + g2) + stup = Optimisers.setup(Momentum(0.1), mtup) + gtup = ([3,3], ([10,10], [7,7])) # (g1, (g1 + g2, g2)) + + snew, mnew = Optimisers.update(stup, mtup, gtup) + @test mnew[1] ≈ mnew[2][1] # gradient was accumulated + @test mnew[2][2] === mnew[1] # and tie is not broken + + st3, mt3 = Optimisers.update(stup, mtup, ([3,3], nothing)) + @test mt3[1] ≈ [1,2] - 0.1 * [3,3] + @test mt3[2][2] === mt3[1] + + st4, mt4 = Optimisers.update(stup, mtup, (nothing, ([5,5], [7,7]))) + @test mt4[1] ≈ [1,2] - 0.1 * [7,7] + end + + @testset "named" begin + thrice = [3f0] + model = (a = (x = thrice, y = Float32[4,5,6], z = true), b = ((m = (0, 1, thrice),),), c = (x = Float32[7,8], y = thrice)) + tree = Optimisers.setup(Momentum(0.1, 0.9), model) + @test model.a.x === model.b[1].m[3] == model.c.y + + loss(x::Array) = sum(abs2, x) + loss(x::Number) = x^3 + loss(m) = sum(2 * loss(x) for x in m) + gradient(loss, model) + _, m2 = Optimisers.update(tree, model, gradient(loss, model)...) + @test m2.a.x === m2.b[1].m[3] == m2.c.y + + loss3(m) = sum(x isa Tuple ? 0 : 2 * loss(x) for x in m) + gradient(loss3, model) # truncates the b limb + _, m3 = Optimisers.update(tree, model, gradient(loss3, model)...) + @test m3.a.x === m3.b[1].m[3] == m3.c.y + end + + @testset "transpose" begin + mat = [1 2 3; 4 5 6.0] + bidir = (m = mat, f = log, t = transpose(mat), v = [7, 8, 9.0]) + bigrad, _ = gradient((m, x) -> sum(abs2, m.m * (m.f).(m.t*x .+ m.v)), bidir, [1, 0.1]) + @test bigrad.t isa Matrix # not a Transpose, that's the point here + + state = Optimisers.setup(Descent(0.1), bidir) + @test state.t.parent === state.m # successfully tied + + s2, b2 = Optimisers.update(state, bidir, bigrad) + @test b2.t.parent === b2.m # tie restored + @test b2.m ≈ bidir.m - 0.1 * (bigrad.m + transpose(bigrad.t)) # grad accumulated + + state = Optimisers.setup(OptimiserChain(ClipGrad(10), Descent(0.1), ClipGrad(10)), bidir) + s2, b2 = Optimisers.update(state, bidir, bigrad) + @test b2.t.parent === b2.m + @test b2.m ≈ bidir.m - 0.1 * clamp.((bigrad.m + transpose(bigrad.t)), -10, 10) + + # Similar, but now "primary" field is the transposed one: + tri = (a = transpose(mat), b = mat, c = transpose(mat), d = 4.0) + trigrad = gradient(m -> sum(abs2, m.a * (m.b * (m.c * [0.1, 1] .+ m.d) .- m.d)), tri)[1] + stri = Optimisers.setup(Descent(0.1), tri) + s3, t3 = Optimisers.update(stri, tri, trigrad) + @test t3.a.parent === t3.b === t3.c.parent + @test t3.a ≈ tri.a - 0.1 * (trigrad.a + trigrad.b' + trigrad.c) + + g4 = (a = Broadcast.broadcasted(+, mat', 1), b = nothing, c = @thunk(mat' .+ 1), d = nothing) + # Error: no constructors for type Any + @test_broken s4, t4 = Optimisers.update(stri, tri, g4) + end end @testset "higher order interface" begin From 64d5d9f7dd08939bb05db19b79bf6ad044cf27f1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 28 Aug 2022 11:51:06 -0400 Subject: [PATCH 02/11] one more dict to allow artificial ties --- src/interface.jl | 19 +++++++++++++------ test/runtests.jl | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 236b1ed9..4cef995e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -45,16 +45,23 @@ end function update!(tree, model, grad) # First walk is to accumulate the gradient. This recursion visits every copy of # shared leaves, but stops when branches are absent from the gradient: - dict = IdDict{Leaf, Any}() - grads!(dict, tree, model, grad) - # Second walk is to update the model, using same fmap walk as setup, thus each Leaf exactly once: + gdict = IdDict{Leaf, Any}() + grads!(gdict, tree, model, grad) + # Second walk is to update the model, using same fmap walk as setup: + xdict = IdDict{Leaf, Any}() # (this exists to allow for shared ℓ without shared x) newmodel = fmap(model, tree; exclude = isnumeric) do x, ℓ ℓ isa Leaf || error("this state does not match the model, expected a Leaf here") ℓ.frozen && return x - haskey(dict, ℓ) || return x - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) + haskey(gdict, ℓ) || return x # no gradient seen, nothing to do + if haskey(xdict, ℓ) + # This means that shared ℓ encodes sharing not noted in x. Won't happen with setup above, no API yet. + x′ = xdict[ℓ] # ... and is why xdict exists. + size(x′) == size(x) || error("the same Leaf belongs to arrays of size $(size(x)) and $(size(x′))") + return x′ + end + s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, gdict[ℓ]) ℓ.state = s′ # to get state out of here, rely on mutability of Leaf - subtract!(x, x̄′) + xdict[ℓ] = subtract!(x, x̄′) end tree, newmodel # note that tree is guaranteed to be updated end diff --git a/test/runtests.jl b/test/runtests.jl index 7b54f24d..f152c76d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -305,6 +305,22 @@ end # Error: no constructors for type Any @test_broken s4, t4 = Optimisers.update(stri, tri, g4) end + + @testset "artificial" begin + # Interpret shared Leaf as implying shared parameters, even if this did not arise from shared arrays. + # No API for setting this at the moment, but can construct one by hand: + model = (a = [1,2.0], b = [1, 2.0], c = [1, 2.0], d = [1, 2.0]) + honest = Optimisers.setup(Momentum(0.1), model) + trick = (a = honest.a, b = honest.a, c = honest.c, d= honest.d) # makes a & b shared + + trick2, model2 = Optimisers.update(trick, model, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10])) + trick3, model3 = Optimisers.update(trick2, model2, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10])) + + @test model3.a == model3.b == model3.d # same as having the gradients added + @test !(model3.a ≈ model3.c) + @test trick3.a === trick3.b # leaves remain shared + model3.a === model3.b # in fact arrays end up shared, but this is not required + end end @testset "higher order interface" begin From 670e49a9bbe09a8a8b15873b430a23c1514bf88c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 28 Aug 2022 12:37:37 -0400 Subject: [PATCH 03/11] a tidier idea, just replace _default_walk --- src/interface.jl | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 4cef995e..f4936e38 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -10,10 +10,10 @@ abstract type AbstractRule end ### setup ### -mutable struct Leaf{R,S} +mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing rule::R state::S - frozen::Bool + frozen::Bool # mutability also allows this flag to be changed end @functor Leaf @@ -45,23 +45,15 @@ end function update!(tree, model, grad) # First walk is to accumulate the gradient. This recursion visits every copy of # shared leaves, but stops when branches are absent from the gradient: - gdict = IdDict{Leaf, Any}() - grads!(gdict, tree, model, grad) - # Second walk is to update the model, using same fmap walk as setup: - xdict = IdDict{Leaf, Any}() # (this exists to allow for shared ℓ without shared x) - newmodel = fmap(model, tree; exclude = isnumeric) do x, ℓ - ℓ isa Leaf || error("this state does not match the model, expected a Leaf here") + dict = IdDict{Leaf, Any}() + grads!(dict, tree, model, grad) + # Second walk is to update the model. The walk taken follows Leaf identity + newmodel = fmap(tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk) do ℓ, x ℓ.frozen && return x - haskey(gdict, ℓ) || return x # no gradient seen, nothing to do - if haskey(xdict, ℓ) - # This means that shared ℓ encodes sharing not noted in x. Won't happen with setup above, no API yet. - x′ = xdict[ℓ] # ... and is why xdict exists. - size(x′) == size(x) || error("the same Leaf belongs to arrays of size $(size(x)) and $(size(x′))") - return x′ - end - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, gdict[ℓ]) + haskey(dict, ℓ) || return x # no gradient seen, nothing to do + s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) ℓ.state = s′ # to get state out of here, rely on mutability of Leaf - xdict[ℓ] = subtract!(x, x̄′) + subtract!(x, x̄′) end tree, newmodel # note that tree is guaranteed to be updated end @@ -89,6 +81,13 @@ function update(tree, x, x̄s...) update!(t′, x′, x̄s...) end +# This differs from _default_walk(f,x,y) in taking re from 2nd argument, but cache will still operate on the first +function _second_walk(f, x, y) + x′, _ = functor(typeof(y), x) + y′, re = functor(y) + re(map(f, x′, y′)) +end + # default all rules to first order calls apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx) From 6db7a36fafa58d9e064f265b284d65c6081f74fc Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 28 Aug 2022 15:42:57 -0400 Subject: [PATCH 04/11] add a LeafCache type, to make fmap ignore () singleton --- src/interface.jl | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index f4936e38..a8ed3c66 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -48,7 +48,7 @@ function update!(tree, model, grad) dict = IdDict{Leaf, Any}() grads!(dict, tree, model, grad) # Second walk is to update the model. The walk taken follows Leaf identity - newmodel = fmap(tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk) do ℓ, x + newmodel = fmap(tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x ℓ.frozen && return x haskey(dict, ℓ) || return x # no gradient seen, nothing to do s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) @@ -88,6 +88,21 @@ function _second_walk(f, x, y) re(map(f, x′, y′)) end +# When fmap reconstructs for update!, it should not cache results with trivial nodes like () in the state. +# This cache type has just enough methods to work in Functors, which possibly should be upgraded to just work. +struct LeafCache <: AbstractDict{Leaf,Any} + dict::IdDict{Leaf,Any} +end +LeafCache() = LeafCache(IdDict{Leaf,Any}()) + +Base.setindex!(c::LeafCache, x, ℓ::Leaf) = setindex!(c.dict, x, ℓ) +Base.setindex!(c::LeafCache, x, _) = nothing +Base.in(k, c::LeafCache) = k in c.dict +Base.haskey(c::LeafCache, k) = haskey(c.dict, k) +Base.getindex(c::LeafCache, ℓ::Leaf) = getindex(c.dict, ℓ) +Base.iterate(c::LeafCache, i = 0) = iterate(c.dict, i) +Base.length(c::LeafCache) = length(c.dict) + # default all rules to first order calls apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx) From 5e5d5db33397bd6395e0539b12025bed92ae05e1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 28 Aug 2022 15:44:40 -0400 Subject: [PATCH 05/11] remove leaf.frozen field --- src/adjust.jl | 4 ++-- src/interface.jl | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/adjust.jl b/src/adjust.jl index 3ec2f220..d6f3647d 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -47,8 +47,8 @@ adjust(tree; kw...) = map(st -> adjust(st; kw...), tree) adjust(::Nothing, ::Real) = nothing adjust(::Nothing; kw...) = nothing -adjust(ℓ::Leaf, eta::Real) = ℓ.frozen ? ℓ : Leaf(adjust(ℓ.rule, eta), ℓ.state, ℓ.frozen) -adjust(ℓ::Leaf; kw...) = ℓ.frozen ? ℓ : Leaf(adjust(ℓ.rule; kw...), ℓ.state, ℓ.frozen) +adjust(ℓ::Leaf, eta::Real) = Leaf(adjust(ℓ.rule, eta), ℓ.state) +adjust(ℓ::Leaf; kw...) = Leaf(adjust(ℓ.rule; kw...), ℓ.state) """ diff --git a/src/interface.jl b/src/interface.jl index a8ed3c66..ea7fa954 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -13,7 +13,6 @@ abstract type AbstractRule end mutable struct Leaf{R,S} # mutable so that its identity encodes parameter sharing rule::R state::S - frozen::Bool # mutability also allows this flag to be changed end @functor Leaf @@ -25,7 +24,7 @@ function setup(rule::AbstractRule, model) # Rely on Functors to identify shared arrays, they will share a Leaf in this tree: tree = fmapstructure(model, exclude = isnumeric) do x cnt[] += 1 - Leaf(rule, init(rule, x), false) + Leaf(rule, init(rule, x)) end cnt[] == 0 && @warn "setup found no parameters in the given model" tree @@ -35,7 +34,7 @@ function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long ioc = IOContext(io, :compact => true) print(ioc, "Leaf(", ℓ.rule, ", ") show(ioc, ℓ.state) - print(ioc, ", ", ℓ.frozen, ")") + print(ioc, ")") end ### @@ -49,7 +48,6 @@ function update!(tree, model, grad) grads!(dict, tree, model, grad) # Second walk is to update the model. The walk taken follows Leaf identity newmodel = fmap(tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x - ℓ.frozen && return x haskey(dict, ℓ) || return x # no gradient seen, nothing to do s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) ℓ.state = s′ # to get state out of here, rely on mutability of Leaf From 522f66a38e601f96329f8b3ea45ea79cc8c10562 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 28 Aug 2022 15:56:19 -0400 Subject: [PATCH 06/11] eager accumulation --- src/interface.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index ea7fa954..cdd43e45 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -60,8 +60,8 @@ subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) grads!(dict::IdDict, ℓ::Leaf, x, ::Zero) = nothing function grads!(dict::IdDict, ℓ::Leaf, x, x̄) - x̄₀ = get(dict, ℓ, false) - dict[ℓ] = Broadcast.broadcasted(+, x̄, x̄₀) + x̄₀ = get(dict, ℓ, ZeroTangent()) + dict[ℓ] = x̄ + x̄₀ # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. nothing end grads!(dict::IdDict, t, x, ::Zero) = nothing From 3172f1349c19eed5dd1a7d930330dd00912b9b42 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 28 Aug 2022 21:37:42 -0400 Subject: [PATCH 07/11] give up on customising fmap & write the recursion, add evil tests --- src/interface.jl | 107 ++++++++++++++++++++++++----------------------- test/runtests.jl | 60 ++++++++++++++++++++++---- 2 files changed, 107 insertions(+), 60 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index cdd43e45..6ba3f739 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -20,16 +20,28 @@ end Base.:(==)(a::Leaf, b::Leaf) = children(a) == children(b) function setup(rule::AbstractRule, model) - cnt = Ref(0) - # Rely on Functors to identify shared arrays, they will share a Leaf in this tree: - tree = fmapstructure(model, exclude = isnumeric) do x - cnt[] += 1 - Leaf(rule, init(rule, x)) - end - cnt[] == 0 && @warn "setup found no parameters in the given model" + cache = IdDict() + tree = _setup(rule, model; cache) + isempty(cache) && @warn "setup found no trainable parameters in this model" tree end +# _setup is almost fmapstructure, but needs a _trainable_walk, and a cache which ignores numbers etc. +function _setup(rule, x; cache) + haskey(cache, x) && return cache[x] + if isnumeric(x) + ℓ = Leaf(rule, init(rule, x)) + if isbits(x) + cache[nothing] = nothing # just to disable the warning + ℓ + else + cache[x] = ℓ + end + else + map(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x)) + end +end + function Base.show(io::IO, ℓ::Leaf) # show method is mostly to hide its long type! ioc = IOContext(io, :compact => true) print(ioc, "Leaf(", ℓ.rule, ", ") @@ -41,65 +53,56 @@ end ### update ### -function update!(tree, model, grad) +function update(tree, model, grad, higher...) + t′ = fmap(copy, tree; exclude = maywrite) # walks inside Leaf + x′ = fmap(copy, model; exclude = maywrite) + update!(t′, x′, grad, higher...) +end + +function update!(tree, model, grad, higher...) # First walk is to accumulate the gradient. This recursion visits every copy of # shared leaves, but stops when branches are absent from the gradient: - dict = IdDict{Leaf, Any}() - grads!(dict, tree, model, grad) - # Second walk is to update the model. The walk taken follows Leaf identity - newmodel = fmap(tree, model; exclude = ℓ -> ℓ isa Leaf, walk = _second_walk, cache = LeafCache()) do ℓ, x - haskey(dict, ℓ) || return x # no gradient seen, nothing to do - s′, x̄′ = apply!(ℓ.rule, ℓ.state, x, dict[ℓ]) - ℓ.state = s′ # to get state out of here, rely on mutability of Leaf + grads = IdDict{Leaf, Any}() + _grads!(grads, tree, model, grad, higher...) + # Second walk is to update the model. The params cache indexed by (tree,x), + # so that identified Leafs can tie isbits parameters, but setup won't do that for you: + newmodel = _update!(tree, model; grads, params = IdDict()) + tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree. +end + +function _update!(tree, x; grads, params) + haskey(params, (tree,x)) && return params[(tree,x)] + isbits(tree) && return x # means () is not cached, and also (((),),) + x′, re = functor(x) + x′′ = map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′) + params[(tree,x)] = re(x′′) +end +function _update!(ℓ::Leaf, x; grads, params) + haskey(params, (ℓ,x)) && return params[(ℓ,x)] + params[(ℓ,x)] = if haskey(grads, ℓ) + ℓ.state, x̄′ = apply!(ℓ.rule, ℓ.state, x, grads[ℓ]...) subtract!(x, x̄′) + else + x # no gradient seen end - tree, newmodel # note that tree is guaranteed to be updated end subtract!(x, x̄) = maywrite(x) ? (x .= x .- x̄) : eltype(x).(x .- x̄) -grads!(dict::IdDict, ℓ::Leaf, x, ::Zero) = nothing -function grads!(dict::IdDict, ℓ::Leaf, x, x̄) - x̄₀ = get(dict, ℓ, ZeroTangent()) - dict[ℓ] = x̄ + x̄₀ # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. +_grads!(dict::IdDict, ℓ::Leaf, x, ::Zero...) = nothing +function _grads!(dict::IdDict, ℓ::Leaf, x, x̄s...) + x̄s₀ = get(dict, ℓ, map(_ -> ZeroTangent(), x̄s)) + dict[ℓ] = map(+, x̄s, x̄s₀) # adding Zero should be free. Lazy accumulation broadcasted(+, x̄, x̄₀) also possible. nothing end -grads!(dict::IdDict, t, x, ::Zero) = nothing -function grads!(dict::IdDict, tree, x, x̄s...) - # The only reason grads! takes model is that functor(typeof(x), base(x̄)) may differ from +_grads!(dict::IdDict, t, x, ::Zero...) = nothing +function _grads!(dict::IdDict, tree, x, x̄s...) + # The only reason _grads! takes model is that functor(typeof(x), base(x̄)) may differ from # functor(typeof(tree), base(x̄)), for things like Transpose x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s) x′, _ = functor(typeof(x), x) - foreach((tᵢ, xᵢ, x̄sᵢ...) -> grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) -end - -function update(tree, x, x̄s...) - t′ = fmap(copy, tree; exclude = maywrite) # goes inside Leaf - x′ = fmap(copy, x; exclude = maywrite) - update!(t′, x′, x̄s...) -end - -# This differs from _default_walk(f,x,y) in taking re from 2nd argument, but cache will still operate on the first -function _second_walk(f, x, y) - x′, _ = functor(typeof(y), x) - y′, re = functor(y) - re(map(f, x′, y′)) -end - -# When fmap reconstructs for update!, it should not cache results with trivial nodes like () in the state. -# This cache type has just enough methods to work in Functors, which possibly should be upgraded to just work. -struct LeafCache <: AbstractDict{Leaf,Any} - dict::IdDict{Leaf,Any} + foreach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...) end -LeafCache() = LeafCache(IdDict{Leaf,Any}()) - -Base.setindex!(c::LeafCache, x, ℓ::Leaf) = setindex!(c.dict, x, ℓ) -Base.setindex!(c::LeafCache, x, _) = nothing -Base.in(k, c::LeafCache) = k in c.dict -Base.haskey(c::LeafCache, k) = haskey(c.dict, k) -Base.getindex(c::LeafCache, ℓ::Leaf) = getindex(c.dict, ℓ) -Base.iterate(c::LeafCache, i = 0) = iterate(c.dict, i) -Base.length(c::LeafCache) = length(c.dict) # default all rules to first order calls apply!(o, state, x, dx, dx2, dxs...) = apply!(o, state, x, dx) diff --git a/test/runtests.jl b/test/runtests.jl index f152c76d..88eaf976 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,19 +13,26 @@ struct TwoThirds a; b; c; end Functors.@functor TwoThirds (a, c) Optimisers.trainable(x::TwoThirds) = (a = x.a,) -struct DummyHigherOrder <: AbstractRule end +mutable struct MutTwo; x; y; end +Functors.@functor MutTwo +struct DummyHigherOrder <: AbstractRule end Optimisers.init(::DummyHigherOrder, x::AbstractArray) = (ones(eltype(x), size(x)), zero(x)) - dummy_update_rule(st, p, dx, dx2) = @. p - (st[1] * dx + st[2] * dx2) function Optimisers.apply!(::DummyHigherOrder, state, x, dx, dx2) a, b = state @.. dx = a * dx + b * dx2 - return (a .+ 1, b .+ 1), dx end +struct BiRule <: Optimisers.AbstractRule end +Optimisers.init(o::BiRule, x::AbstractArray) = nothing +function Optimisers.apply!(o::BiRule, state, x, dx, dx2) + dx == dx2 || error("expected 1st & 2nd gradients to agree") + return state, dx +end + @testset verbose=true "Optimisers.jl" begin @testset verbose=true "Features" begin @@ -220,6 +227,23 @@ end @test_throws MethodError Optimisers.update(sm, m) end + @testset "2nd order gradient" begin + m = (α = ([1.0], sin), γ = Float32[4,3,2]) + + # Special rule which requires this: + s = Optimisers.setup(BiRule(), m) + g = (α = ([0.1], ZeroTangent()), γ = [1,10,100],) + s1, m1 = Optimisers.update(s, m, g, g) + @test m1.α[1] == [0.9] + @test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g)) + + # Ordinary rule which doesn't need it: + s2 = Optimisers.setup(Adam(), m) + s3, m3 = Optimisers.update(s2, m, g) + s4, m4 = Optimisers.update(s2, m, g, g) + @test m3.γ == m4.γ + end + @testset "broadcasting macros" begin x = [1.0, 2.0]; y = [3,4]; z = [5,6] @test (@lazy x + y * z) isa Broadcast.Broadcasted @@ -305,13 +329,15 @@ end # Error: no constructors for type Any @test_broken s4, t4 = Optimisers.update(stri, tri, g4) end - + @testset "artificial" begin # Interpret shared Leaf as implying shared parameters, even if this did not arise from shared arrays. # No API for setting this at the moment, but can construct one by hand: - model = (a = [1,2.0], b = [1, 2.0], c = [1, 2.0], d = [1, 2.0]) - honest = Optimisers.setup(Momentum(0.1), model) - trick = (a = honest.a, b = honest.a, c = honest.c, d= honest.d) # makes a & b shared + model = (a = SA[1,2.0], b = SA[1, 2.0], c = SA[1, 2.0], d = SA[1, 2.0]) + auto = Optimisers.setup(Momentum(0.1), model) + @test auto.a !== auto.b # not tied just by value + + trick = (a = auto.a, b = auto.a, c = auto.c, d= auto.d) # makes a & b tied trick2, model2 = Optimisers.update(trick, model, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10])) trick3, model3 = Optimisers.update(trick2, model2, (a=[3,3], b=[7,7], c=[3,3], d=[10, 10])) @@ -319,8 +345,26 @@ end @test model3.a == model3.b == model3.d # same as having the gradients added @test !(model3.a ≈ model3.c) @test trick3.a === trick3.b # leaves remain shared - model3.a === model3.b # in fact arrays end up shared, but this is not required end + + @testset "mutable containers" begin + tmp = MutTwo([1.0], [2.0]) + model = (a=tmp, b=tmp, c=MutTwo(tmp.x, tmp.y)) + state = Optimisers.setup(Momentum(), model) + + @test model.a === model.b + @test model.a !== model.c # fields are identified, but struct is not + + @test state.a.x === state.b.x + @test state.a === state.b + @test state.a === state.c # unavoidable, but means we can't use leaf ID alone + + mgrad = (a=(x=[1], y=[10]), b=(x=[100], y=[1000]), c=(x=[1/3], y=[1/30])) + state2, model2 = Optimisers.update(state, model, mgrad) + + @test model2.a === model2.b # tie of MutTwo structs is restored + @test model2.a !== model2.c # but a new tie is not created + end end @testset "higher order interface" begin From 37521c892b185bdd0b93777e46d06daf120b0ad6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 29 Aug 2022 21:27:43 -0400 Subject: [PATCH 08/11] add ismutable check --- src/interface.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 6ba3f739..1903d615 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -74,8 +74,12 @@ function _update!(tree, x; grads, params) haskey(params, (tree,x)) && return params[(tree,x)] isbits(tree) && return x # means () is not cached, and also (((),),) x′, re = functor(x) - x′′ = map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′) - params[(tree,x)] = re(x′′) + x′′ = re(map((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′)) + if ismutable(x′′) + params[(tree,x)] = x′′ + else # no ties to preserve between immutable structs, right? + x′′ + end end function _update!(ℓ::Leaf, x; grads, params) haskey(params, (ℓ,x)) && return params[(ℓ,x)] From 0d6619a0320203fc4356742154183bf1e8413182 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 12 Oct 2022 18:01:23 -0400 Subject: [PATCH 09/11] docs etc --- Project.toml | 4 +-- docs/src/index.md | 66 ++++++++++++++++++++++++++++++++++++++++++----- src/Optimisers.jl | 8 ++++++ src/interface.jl | 2 +- 4 files changed, 70 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index f1d8501e..bac796ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" authors = ["Mike J Innes "] -version = "0.2.9" +version = "0.2.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" -Functors = "0.2.8, 0.3" +Functors = "0.3" Zygote = "0.6.40" julia = "1.6" diff --git a/docs/src/index.md b/docs/src/index.md index aaee5d7f..65b441bb 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,6 +1,6 @@ # Optimisers.jl -## Defining an optimisation rule +## An optimisation rule A new optimiser must overload two functions, [`apply!`](@ref) and [`init`](@ref). These act on one array of parameters: @@ -60,6 +60,11 @@ Notice that a completely new instance of the model is returned. Internally, this is handled by [Functors.jl](https://fluxml.ai/Functors.jl), where we do a walk over the tree formed by the model and update the parameters using the gradients. +There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, +but is free to mutate arrays within the old one for efficiency. +The method of `apply!` for each rule is likewise free to mutate arrays within its state; +they are defensively copied when this rule is used with `update`. + Optimisers.jl does not depend on any one automatic differentiation package, but for now the most likely source of gradients is [Zygote.jl](https://fluxml.ai/Zygote.jl). Note that `update` always wants the gradient from Zygote's "explicit" mode, as shown above. @@ -67,11 +72,6 @@ This `∇model` is another tree structure, rather than the dictionary-like objec Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference. -There is also [`Optimisers.update!`](@ref) which similarly returns a new model and new state, -but is free to mutate arrays within the old one for efficiency. -The method of `apply!` you write is likewise free to mutate arrays within its state; -they are defensively copied when this rule is used with `update`. - ## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl) The main design difference of Lux is that the tree of parameters is separate from @@ -110,6 +110,57 @@ Besides the parameters stored in `params` and gradually optimised, any other mod is stored in `lux_state`. For simplicity this example does not show how to propagate the updated `lux_state` to the next iteration, see Lux's documentation. +## Non-`trainable` Parameters + +Optimisers.jl uses [Functors.jl](https://fluxml.ai/Functors.jl) to walk the `struct`s +making up the model, for which they must be annotated `@functor Type`. +By default optimisation will alter all [`isnumeric`](@ref) arrays. + +If some arrays of a particular layer should not be treated this way, +you can define a method for [`trainable`](@ref) + +```julia +struct Layer{T} + alpha::T + beta::T + length::Int +end +Layer(n::Int) = Layer(randn(n), zeros(n), n) + +Functors.@functor Layer + +# Both array fields will be, for example, moved to the GPU: +Functors.children(Layer(3)) # (alpha = [...], beta = [...], length) + +Optimisers.trainable(x::Layer) = (; alpha = x.alpha) # must be a subset of chidlren + +# Only the first field will be optimised: +st = Optimisers.setup(DecayDescent(0.1), Layer(3)) +``` + +## Tied Parameters + +If the same array appears twice (or more) in the model, [Functors.jl](https://fluxml.ai/Functors.jl) should recognise this. +Within Optimisers.jl, `setup` will initialise once, and use the same `Leaf` for both parameters. +Then `update` will accumulate the gradient from both, and the updated model returned will have the tie maintained. + +```julia +using Flux, Optimisers + +enc = Chain(Dense(40 => 20, tanh), Dense(20 => 10)); +dec = Chain(Dense(enc[1].weight', true, tanh), Dense(enc[2].weight', true, tanh)); +model = Chain(; enc, dec) + +st = Optimisers.setup(Optimisers.Adam(), model); + +st.layers.enc.layers[1].weight === st.layers.dec.layers[1].weight.parent # true +``` + +This identification relies on `===`, and will work for ordinary `Array`s and `CuArray`s. +It will not at present work for `reshape`d arrays, nor for immutable arrays such as those +from StaticArrays.jl. + + ## Obtaining a flat parameter vector Instead of a nested tree-like structure, sometimes is is convenient to have all the @@ -143,10 +194,11 @@ st, flat = Optimisers.update(st, flat, ∇flat) ``` Here `flat` contains only the 283 trainable parameters, while the non-trainable -ones are preserved inside `re`. +ones are preserved inside `re`, an object of type `Restructure`. When defining new layers, these can be specified if necessary by overloading [`trainable`](@ref). By default, all numeric arrays visible to [Functors.jl](https://github.com/FluxML/Functors.jl) are assumed to contain trainable parameters. +Tied parameters (arrays appearing in different layers) are included only once in `flat`. Lux stores only the trainable parameters in `params`. This can also be flattened to a plain `Vector` in the same way: diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 7c94e233..73c56305 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -16,6 +16,10 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, ClipGrad, ClipNorm, OptimiserChain +### +### one-array functions +### + """ Optimisers.apply!(rule::RuleType, state, parameters, gradient) -> (state, gradient) @@ -57,6 +61,10 @@ julia> Optimisers.init(Momentum(), [1.0, 2.0]) """ init +### +### whole-model functions +### + """ Optimisers.setup(rule, model) -> tree diff --git a/src/interface.jl b/src/interface.jl index 1903d615..79d03396 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -166,7 +166,7 @@ end ### """ - @.. x = x + y + @.. x = y + z Sometimes in-place broadcasting macro, for use in `apply!` rules. If `maywrite(x)` then it is just `@. x = rhs`, but if not, it becomes `x = @. rhs`. From d13e52abdf57b69ef7b713d3348d72aa181bc13c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 12 Oct 2022 18:24:06 -0400 Subject: [PATCH 10/11] fix doctests --- src/Optimisers.jl | 10 +++++----- src/adjust.jl | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 73c56305..8e8cb19f 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -77,7 +77,7 @@ or [`update!`](@ref). julia> m = (x = rand(3), y = (true, false), z = tanh); julia> Optimisers.setup(Momentum(), m) # same field names as m -(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = (nothing, nothing), z = nothing) +(x = Leaf(Momentum{Float32}(0.01, 0.9), [0.0, 0.0, 0.0]), y = ((), ()), z = ()) ``` The recursion into structures uses Functors.jl, and any new `struct`s containing parameters @@ -90,7 +90,7 @@ julia> struct Layer; mat; fun; end julia> model = (lay = Layer([1 2; 3 4f0], sin), vec = [5, 6f0]); julia> Optimisers.setup(Momentum(), model) # new struct is by default ignored -(lay = nothing, vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0])) +(lay = (), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0])) julia> destructure(model) (Float32[5.0, 6.0], Restructure(NamedTuple, ..., 2)) @@ -98,7 +98,7 @@ julia> destructure(model) julia> using Functors; @functor Layer # annotate this type as containing parameters julia> Optimisers.setup(Momentum(), model) -(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = nothing), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0])) +(lay = (mat = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0 0.0; 0.0 0.0]), fun = ()), vec = Leaf(Momentum{Float32}(0.01, 0.9), Float32[0.0, 0.0])) julia> destructure(model) (Float32[1.0, 3.0, 2.0, 4.0, 5.0, 6.0], Restructure(NamedTuple, ..., 6)) @@ -120,12 +120,12 @@ See also [`update!`](@ref), which will be faster for models of ordinary `Array`s julia> m = (x = Float32[1,2,3], y = tanh); julia> t = Optimisers.setup(Descent(0.1f0), m) -(x = Leaf(Descent{Float32}(0.1), nothing), y = nothing) +(x = Leaf(Descent{Float32}(0.1), nothing), y = ()) julia> g = (x = [1,1,1], y = nothing); # fake gradient julia> Optimisers.update(t, m, g) -((x = Leaf(Descent{Float32}(0.1), nothing), y = nothing), (x = Float32[0.9, 1.9, 2.9], y = tanh)) +((x = Leaf(Descent{Float32}(0.1), nothing), y = ()), (x = Float32[0.9, 1.9, 2.9], y = tanh)) ``` """ update diff --git a/src/adjust.jl b/src/adjust.jl index d6f3647d..78b3d452 100644 --- a/src/adjust.jl +++ b/src/adjust.jl @@ -13,15 +13,15 @@ To change just the learning rate, provide a number `η::Real`. julia> m = (vec = rand(Float32, 2), fun = sin); julia> st = Optimisers.setup(Nesterov(), m) # stored momentum is initialised to zero -(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = nothing) +(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[0.0, 0.0]), fun = ()) julia> st, m = Optimisers.update(st, m, (vec = [16, 88], fun = nothing)); # with fake gradient julia> st -(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = nothing) +(vec = Leaf(Nesterov{Float32}(0.001, 0.9), Float32[-0.016, -0.088]), fun = ()) julia> st = Optimisers.adjust(st, 0.123) # change learning rate, stored momentum untouched -(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = nothing) +(vec = Leaf(Nesterov{Float32}(0.123, 0.9), Float32[-0.016, -0.088]), fun = ()) ``` To change other parameters, `adjust` also accepts keyword arguments matching the field From 1577b8890f51bfaaad3ee32eb84848e02f690992 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 12 Oct 2022 18:30:07 -0400 Subject: [PATCH 11/11] group the tests --- test/runtests.jl | 94 +++++++++++++++++++++++++----------------------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 88eaf976..51e76053 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,8 @@ using Optimisers: @.., @lazy Random.seed!(1) +# Fake "models" for testing + struct Foo; x; y; end Functors.@functor Foo Optimisers.trainable(x::Foo) = (x.y, x.x) @@ -16,6 +18,8 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) mutable struct MutTwo; x; y; end Functors.@functor MutTwo +# Simple rules for testing + struct DummyHigherOrder <: AbstractRule end Optimisers.init(::DummyHigherOrder, x::AbstractArray) = (ones(eltype(x), size(x)), zero(x)) @@ -227,23 +231,6 @@ end @test_throws MethodError Optimisers.update(sm, m) end - @testset "2nd order gradient" begin - m = (α = ([1.0], sin), γ = Float32[4,3,2]) - - # Special rule which requires this: - s = Optimisers.setup(BiRule(), m) - g = (α = ([0.1], ZeroTangent()), γ = [1,10,100],) - s1, m1 = Optimisers.update(s, m, g, g) - @test m1.α[1] == [0.9] - @test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g)) - - # Ordinary rule which doesn't need it: - s2 = Optimisers.setup(Adam(), m) - s3, m3 = Optimisers.update(s2, m, g) - s4, m4 = Optimisers.update(s2, m, g, g) - @test m3.γ == m4.γ - end - @testset "broadcasting macros" begin x = [1.0, 2.0]; y = [3,4]; z = [5,6] @test (@lazy x + y * z) isa Broadcast.Broadcasted @@ -365,34 +352,53 @@ end @test model2.a === model2.b # tie of MutTwo structs is restored @test model2.a !== model2.c # but a new tie is not created end - end + end # tied weights + + @testset "2nd-order interface" begin + @testset "BiRule" begin + m = (α = ([1.0], sin), γ = Float32[4,3,2]) + + # Special rule which requires this: + s = Optimisers.setup(BiRule(), m) + g = (α = ([0.1], ZeroTangent()), γ = [1,10,100],) + s1, m1 = Optimisers.update(s, m, g, g) + @test m1.α[1] == [0.9] + @test_throws Exception Optimisers.update(s, m, g, map(x->2 .* x, g)) + + # Ordinary rule which doesn't need it: + s2 = Optimisers.setup(Adam(), m) + s3, m3 = Optimisers.update(s2, m, g) + s4, m4 = Optimisers.update(s2, m, g, g) + @test m3.γ == m4.γ + end - @testset "higher order interface" begin - w, b = rand(3, 4), rand(3) - - o = DummyHigherOrder() - psin = (w, b) - dxs = map(x -> rand(size(x)...), psin) - dx2s = map(x -> rand(size(x)...), psin) - stin = Optimisers.setup(o, psin) - stout, psout = Optimisers.update(stin, psin, dxs, dx2s) - - # hardcoded rule behavior for dummy rule - @test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1]) - @test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2]) - @test stout[1].state[1] == stin[1].state[1] .+ 1 - @test stout[2].state[2] == stin[2].state[2] .+ 1 - - # error if only given one derivative - @test_throws MethodError Optimisers.update(stin, psin, dxs) - - # first-order rules compose with second-order - ochain = OptimiserChain(Descent(0.1), o) - stin = Optimisers.setup(ochain, psin) - stout, psout = Optimisers.update(stin, psin, dxs, dx2s) - @test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1]) - @test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2]) - end + @testset "DummyHigherOrder" begin + w, b = rand(3, 4), rand(3) + + o = DummyHigherOrder() + psin = (w, b) + dxs = map(x -> rand(size(x)...), psin) + dx2s = map(x -> rand(size(x)...), psin) + stin = Optimisers.setup(o, psin) + stout, psout = Optimisers.update(stin, psin, dxs, dx2s) + + # hardcoded rule behavior for dummy rule + @test psout[1] == dummy_update_rule(stin[1].state, psin[1], dxs[1], dx2s[1]) + @test psout[2] == dummy_update_rule(stin[2].state, psin[2], dxs[2], dx2s[2]) + @test stout[1].state[1] == stin[1].state[1] .+ 1 + @test stout[2].state[2] == stin[2].state[2] .+ 1 + + # error if only given one derivative + @test_throws MethodError Optimisers.update(stin, psin, dxs) + + # first-order rules compose with second-order + ochain = OptimiserChain(Descent(0.1), o) + stin = Optimisers.setup(ochain, psin) + stout, psout = Optimisers.update(stin, psin, dxs, dx2s) + @test psout[1] == dummy_update_rule(stin[1].state[2], psin[1], 0.1 * dxs[1], dx2s[1]) + @test psout[2] == dummy_update_rule(stin[2].state[2], psin[2], 0.1 * dxs[2], dx2s[2]) + end + end # 2nd-order end @testset verbose=true "Destructure" begin