From 1e34fa2c99ca9e66d346c9e32d0414963357eeb5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 14 Feb 2022 17:40:23 -0500 Subject: [PATCH] Add `destructure`, take II (#54) * destructure, take II * add a test * tidy * replace append! with reduce(vcat, ...) * testset names * rename everything * tweak * two broken tests * make len positional, fix a bug * second derivatives * arrays of arrays * more... the dimensionmismatch bug is not here * warnings --- Project.toml | 2 +- docs/src/api.md | 7 ++ src/Optimisers.jl | 5 +- src/destructure.jl | 152 ++++++++++++++++++++++++++++++++++++++++ src/interface.jl | 5 +- test/destructure.jl | 166 ++++++++++++++++++++++++++++++++++++++++++++ test/rules.jl | 8 +-- test/runtests.jl | 5 +- 8 files changed, 341 insertions(+), 9 deletions(-) create mode 100644 src/destructure.jl create mode 100644 test/destructure.jl diff --git a/Project.toml b/Project.toml index d91c01d2..66c062ed 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" -Functors = "0.2.7" +Functors = "0.2.8" julia = "1.6" [extras] diff --git a/docs/src/api.md b/docs/src/api.md index 5671140b..edd8be32 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -42,6 +42,13 @@ optimiser to act on all suitable fields. To restrict this, define `trainable`: Optimisers.trainable ``` +Such restrictions are also obeyed by this function for flattening a model: + +```@docs +Optimisers.destructure +Optimisers.Restructure +``` + ## Rule Definition ```@docs diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 9f93e041..417b90d4 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -4,8 +4,11 @@ using Functors: functor, fmap, isleaf using LinearAlgebra include("interface.jl") -include("rules.jl") +include("destructure.jl") +export destructure, total, total2 + +include("rules.jl") export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, OADAM, AdaBelief, WeightDecay, ClipGrad, ClipNorm, OptimiserChain diff --git a/src/destructure.jl b/src/destructure.jl new file mode 100644 index 00000000..3ace52ec --- /dev/null +++ b/src/destructure.jl @@ -0,0 +1,152 @@ + +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk +const NoT = NoTangent() + +""" + destructure(model) -> vector, reconstructor + +Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model +to a vector, and returns also a function which reverses this transformation. +Differentiable. + +# Example +```jldoctest +julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im]))) +(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3)) + +julia> re([3, 5-im, 7+11im]) +(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im])) +``` +""" +function destructure(x) + flat, off, len = _flatten(x) + flat, Restructure(x, off, len) +end + +""" + Restructure(Model, ..., length) + +This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with +new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)`. + +# Example +```julia +julia> using Flux, Optimisers + +julia> _, re = destructure(Dense([1 2; 3 4], [0, 0], sigmoid)) +([1, 3, 2, 4, 0, 0], Restructure(Dense, ..., 6)) + +julia> m = re(-4:1) +Dense(2, 2, σ) # 6 parameters + +julia> m([0.2, 0.3]) ≈ re([0.2, 0.3], -4:1) +true +``` +""" +struct Restructure{T,S} + model::T + offsets::S + length::Int +end +(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length) +(re::Restructure)(x, flat::AbstractVector) = re(flat)(x) +Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") +Base.length(re::Restructure) = re.length + +# This flattens a model, and returns a web of offsets for later use: +function _flatten(x) + isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case + arrays = AbstractVector[] + len = Ref(0) + off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y + push!(arrays, _vec(y)) + o = len[] + len[] = o + length(y) + o + end + reduce(vcat, arrays), off, len[] +end + +_vec(x::Number) = LinRange(x,x,1) +_vec(x::AbstractArray) = vec(x) + +function ChainRulesCore.rrule(::typeof(_flatten), x) + flat, off, len = _flatten(x) + _maybewarn() + _flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT)) + (flat, off, len), _flatten_back +end + +# This reconstructs either a model like x, or a gradient for it: +function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trainable_biwalk, kw...) + len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))")) + fmap(x, off; exclude = isnumeric, walk, kw...) do y, o + _getat(y, o, flat) + end +end + +_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) +_getat(y::AbstractArray, o::Int, flat::AbstractVector) = + ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes + +function _trainable_biwalk(f, x, aux) + ch, re = functor(typeof(x), x) + au, _ = functor(typeof(x), aux) + _trainmap(f, ch, _trainable(x), au) |> re +end + +function _trainmap(f, ch, tr, aux) + map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c) + isnothing(t) ? c : f(t, a) + end +end + +function _Tangent_biwalk(f, x, aux) # use with prune = NoT + ch, re = functor(typeof(x), x) + au, _ = functor(typeof(x), aux) + y = _trainmap(f, ch, _trainable(x), au) + y isa Tuple{} && return NoT + p = ProjectTo(x) + if p isa ProjectTo # e.g. Array, NamedTuple + p(y) + else # p === identity for unknown structs + Tangent{typeof(x), typeof(y)}(y) + end +end + +function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...) + _rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT) + _rebuild(x, off, flat, len; kw...), _rebuild_back +end + +_zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad! +ChainRulesCore.@non_differentiable _zero(x) + +# This is the gradient of model reconstruction, accumulating duplicates: +function _grad!(x, dx, off, flat::AbstractVector) + x′, _ = functor(typeof(x), x) + dx′, _ = functor(typeof(x), base(dx)) + off′, _ = functor(typeof(x), off) + foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) + flat +end +function _grad!(x, dx, off::Integer, flat::AbstractVector) + @views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes + flat +end +_grad!(x, dx::Zero, off, flat::AbstractVector) = dx +_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity + +# These are only needed for 2nd derivatives: +function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) + @warn "second derivatives of Restructure may not work yet, sorry!" maxlog=3 + _grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT) + _grad!(x, dx, off, flat), _grad_back +end +base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure)) +base(dx::Tangent{Any, <:NamedTuple{(:backing,)}}) = base(backing(dx).backing) # Zygote version +_maybewarn() = nothing +function ChainRulesCore.rrule(::typeof(_maybewarn)) + @warn "second derivatives of destructure may not work yet, sorry!" maxlog=3 + nothing, _ -> (NoT,) +end diff --git a/src/interface.jl b/src/interface.jl index 80f87dcc..235c2e94 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -70,9 +70,10 @@ trainable(x) = functor(x)[1] _trainable(x) = _trainable(functor(x)[1], trainable(x)) _trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr) -_trainable(ch::Tuple, tr::Tuple) = tr +_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr +_trainable(ch::AbstractArray, tr::AbstractArray) = tr function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple - @warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" + @warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3 map(c -> c in tr ? c : nothing, ch) end diff --git a/test/destructure.jl b/test/destructure.jl new file mode 100644 index 00000000..40c4360c --- /dev/null +++ b/test/destructure.jl @@ -0,0 +1,166 @@ + +m1 = collect(1:3.0) +m2 = (collect(1:3.0), collect(4:6.0)) +m3 = (x = m1, y = sin, z = collect(4:6.0)) +m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied +m5 = (a = (m3, true), b = (m1, false), c = (m4, true)) +m6 = (a = m1, b = [4.0 + im], c = m1) +m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) +m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]] + +@testset "flatten & rebuild" begin + @test destructure(m1)[1] isa Vector{Float64} + @test destructure(m1)[1] == 1:3 + @test destructure(m2)[1] == 1:6 + @test destructure(m3)[1] == 1:6 + @test destructure(m4)[1] == 1:6 + @test destructure(m5)[1] == vcat(1:6, 4:6) + @test destructure(m6)[1] == vcat(1:3, 4 + im) + + @test destructure(m1)[2](7:9) == [7,8,9] + @test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9]) + @test destructure(m3)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9]) + m4′ = destructure(m4)[2](4:9) + @test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9]) + @test m4′.x === m4′.y + m5′ = destructure(m5)[2](reverse(1:9)) + @test m5′.a[1].x === m5′.b[1] + @test m5′.b[2] === false + m6′ = destructure(m6)[2]((4:7) .+ (1:4) .* im) + @test m6′.a isa Vector{Float64} + @test m6′.a == 4:6 + @test m6′.a === m6′.c + @test m6′.b == [7 + 4im] + + # struct, trainable + @test destructure(m7)[1] == 1:3 + m7′ = destructure(m7)[2]([10,20,30]) + @test m7′.a == (sin, [10,20,30]) + @test m7′.b == (cos, [4,5,6]) + @test m7′.c == (tan, [7,8,9]) + + @test destructure(m8)[1] == 1:5 + m8′ = destructure(m8)[2](1:5) + @test m8′[1].x === m8′[1].y + @test m8′[2].b.y === false + @test m8′[3][1] == [5.0] + + # errors + @test_throws Exception destructure(m7)[2]([10,20]) + @test_throws Exception destructure(m7)[2]([10,20,30,40]) +end + +@testset "gradient of flatten" begin + @test gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0] + @test gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0]) + @test gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing) + @test gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0]) + @test gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0]) + + g5 = gradient(m -> destructure(m)[1][3], m5)[1] + @test g5.a[1].x == [0,0,1] + @test g5.a[2] === nothing + + g6 = gradient(m -> imag(destructure(m)[1][4]), m6)[1] + @test g6.a == [0,0,0] + @test g6.a isa Vector{Float64} + @test g6.b == [0+im] + + g8 = gradient(m -> sum(abs2, destructure(m)[1]), m8)[1] + @test g8[1].x == [2,4,6] + @test g8[2].b.x == [8] + @test g8[3] == [[10.0]] + + @testset "second derivative" begin + @test gradient([1,2,3.0]) do v + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1]) + end[1] ≈ [8,16,24] + # With Diffractor, non-leaf _grad!(x, dx, off, flat::AbstractVector) gets double-wrapped dx: + # off = (0, 3), dx = Tangent{Tangent{Tuple{Vector{Float64}, Vector{Float64}}, ... + # until you add explicit double-unwrap: base(dx::Tangent{<:Tangent}) = backing(dx).backing + # With Zygote, instead: + # dx = Tangent{Any}(backing = Tangent{Any}([4.0, 8.0, 12.0], ZeroTangent()),) + + @test gradient([1,2,3.0]) do v + sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1]) + end[1] == [378, 378, 378] + + @test_broken gradient([1,2,3.0]) do v + sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1]) + end[1] ≈ [8,16,24] + # Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z) + # Diffractor error in perform_optic_transform + end +end + +@testset "gradient of rebuild" begin + re1 = destructure(m1)[2] + @test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0] + re2 = destructure(m2)[2] + @test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0] + re3 = destructure(m3)[2] + @test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0] + @test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0] + + re4 = destructure(m4)[2] + @test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0] + @test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0] + @test gradient(rand(6)) do x + m = re4(x) + m.x[1] + 2*m.y[2] + 3*m.z[3] + end[1] == [1,2,0, 0,0,3] + + re7 = destructure(m7)[2] + @test gradient(x -> re7(x).a[2][3], rand(3))[1] == [0,0,1] + @test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0] + @test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0] + + v8, re8 = destructure(m8) + @test gradient(x -> sum(abs2, re8(x)[1].y), v8)[1] == [2,4,6,0,0] + @test gradient(x -> only(sum(re8(x)[3]))^2, v8)[1] == [0,0,0,0,10] + + @testset "second derivative" begin + @test_broken gradient(collect(1:6.0)) do y + sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1]) + end[1] ≈ [8,16,24,0,0,0] + # ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}} + # with Zygote, which can be fixed by: + # Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,) + + @test_broken gradient(collect(1:6.0)) do y + sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1]) + end[1] ≈ [0,0,0,32,40,48] + # Not fixed by this: + # Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,) + end +end + +@testset "Flux issue 1826" begin + v, re = destructure((x=[1,2.0], y=[3,4,5.0])) + @test gradient(zero(v)) do w + m = re(w) + 5 * sum(m.x) + 7 * sum(m[2]) # uses both x and y + end == ([5.0, 5.0, 7.0, 7.0, 7.0],) + # This, using only x, was broken on Flux: + @test gradient(w -> sum(re(w).x), zero(v)) == ([1.0, 1.0, 0.0, 0.0, 0.0],) + + sh = [7,7.0]; + v, re = destructure((x=sh, y=[3.0,4.0], z=sh)) # shared array in the model + @test v == [7, 7, 3, 4] + @test re([1,10,100,1000]) == (x = [1, 10], y = [100, 1000], z = [1, 10]) + + @test gradient(zero(v)) do w + m = re(w) + 3 * sum(m.x) + 13 * sum(m.z) # no dependence on y, but two distinct gradient arrays + end == ([16, 16, 0, 0],) # Flux gave ([3.0, 3.0, 13.0, 13.0],) + + @test gradient(zero(v)) do w + m = re(w) + 4(sum(m.x) + sum(m.z)) # now two gradients are ===, so it eliminates one + end == ([8,8,0,0],) + + @test gradient(zero(v)) do w + m = re(w) + 4(sum(m.x) + sum(m.y)) + 13*sum(m.z) # again two gradients are ===, so it eliminates one + end == ([17,17,4,4],) # Flux gave ([4.0, 4.0, 13.0, 13.0],) +end diff --git a/test/rules.jl b/test/rules.jl index c8697683..ffb4ca65 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -44,7 +44,7 @@ end end end -@testset verbose=true "simple sum" begin +@testset "simple sum" begin empty!(LOG) @testset "$(name(o))" for o in RULES m = shuffle!(reshape(1:64, 8, 8) .+ 0.0) @@ -79,7 +79,7 @@ end end end -@testset verbose=true "StaticArrays" begin +@testset "StaticArrays" begin empty!(LOG) @testset "$(name(o))" for o in RULES W1 = @SMatrix randn(10, 10) @@ -157,7 +157,7 @@ end end end -@testset verbose=true "mutation check" begin +@testset "mutation check" begin # If @lazy captures a matrix which is later mutated, the results won't agree here: @testset "$(name(o))" for o in RULES model = Float64.(rand(Int8, 8)) @@ -174,7 +174,7 @@ end end end -@testset "with complex numebers: Flux#1776" begin +@testset "with complex numbers: Flux#1776" begin empty!(LOG) @testset "$(name(opt))" for opt in [ # The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too: diff --git a/test/runtests.jl b/test/runtests.jl index 825d977e..d47bce08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -80,6 +80,7 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) end @testset "trainable subset" begin + @info "ignore these warnings about trainable, testing the old path" # Foo has an old-style tuple trainable, both elements mf = Foo([1.0, 2.0], (a = sin, b = [3.0, 4.0], c = 5)) sf = Optimisers.setup(Descent(0.1), mf) @@ -164,7 +165,9 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,) @test_throws ArgumentError Optimisers.setup(ADAMW(), m2) end - @info "finished feature testing" + end + @testset verbose=true "Destructure" begin + include("destructure.jl") end @testset verbose=true "Optimisation Rules" begin include("rules.jl")