Skip to content

Commit

Permalink
Use destructure from Optimisers.jl (#1901)
Browse files Browse the repository at this point in the history
* rm destructure

* try to fix Downstream.yml by copying NNlib

* Optimisers 0.2.1

* rm trainable fallback defn

* more tests

* test no longer broken

* enlarge downstream for now

* revert steps for downstream testing
  • Loading branch information
mcabbott authored Mar 21, 2022
1 parent ed78e8a commit 81eea84
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 62 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ MLUtils = "0.2"
MacroTools = "0.5"
NNlib = "0.8.2"
NNlibCUDA = "0.2"
Optimisers = "0.2"
Optimisers = "0.2.1"
ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using MacroTools: @forward

@reexport using NNlib
using MLUtils
import Optimisers: trainable # before v0.13, Flux owned this function
import Optimisers: trainable, destructure # before v0.13, Flux owned these functions

using Zygote, ChainRulesCore
using Zygote: Params, @adjoint, gradient, pullback, @nograd
Expand Down
2 changes: 0 additions & 2 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using Zygote: IdSet
import Functors: Functors, @functor, functor, fmap, isleaf
using SparseArrays: AbstractSparseArray

trainable(m) = functor(m)[1]

"""
testmode!(m, mode = true)
Expand Down
53 changes: 0 additions & 53 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,59 +475,6 @@ function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer.
bias
end

# Flattening models to weight vectors, and back

function _restructure(m, xs)
i = 0
= fmap(m) do x
x isa AbstractArray || return x
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
return x
end
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
return
end

@adjoint function _restructure(m, xs) # TODO ChainRulesCore.rrule
m̄, numel = _restructure(m, xs), length(xs)
function _restructure_pullback(dm)
xs′ = destructure(dm)[1]
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
return (nothing, xs′)
end
return m̄, _restructure_pullback
end

"""
destructure(m)
Flatten a model's parameters into a single weight vector.
julia> m = Chain(Dense(10, 5, std), Dense(5, 2), softmax)
Chain(Dense(10, 5, std), Dense(5, 2), softmax)
julia> θ, re = destructure(m);
julia> θ
67-element Vector{Float32}:
-0.1407104
...
The second return value `re` allows you to reconstruct the original network after making
modifications to the weight vector (for example, with a hypernetwork).
julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
"""
function destructure(m)
xs = Zygote.Buffer([])
fmap(m) do x
x isa AbstractArray && push!(xs, x)
return x
end
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
end

# Other

Expand Down
98 changes: 93 additions & 5 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,7 @@ end
∇m = gradient(m -> sum(m(x)), m)[1]
p, re = destructure(m)
∇p = gradient-> sum(re(θ)(x)), p)[1]
if VERSION >= v"1.7"
@test_broken ∇p destructure(∇m)[1]
else
@test ∇p destructure(∇m)[1]
end
@test ∇p destructure(∇m)[1]
end
end
end
Expand Down Expand Up @@ -538,3 +534,95 @@ end
@test n_iter == 3
end
end

@testset "Various destructure bugs" begin

@testset "issue 1601" begin
struct TwoDenses
dense::Dense
dense2::Dense
end
Flux.@functor TwoDenses

function (m::TwoDenses)(x)
out = m.dense(x)
end

model = TwoDenses(
Dense(3,1),
Dense(3,2)
)
p, re = Flux.destructure(model)

x = [1., 2., 3.]
y, back = Flux.Zygote.pullback((x, p) -> re(p)(x), x, p)

dy = [4.]
dx, dp = back(dy)
@test length(p) == length(dp)
end

@testset "issue 1727" begin
p, re = Flux.destructure(BatchNorm(3)) # 6 parameters, plus 6 non-trainable
@test length(p) == 6

x = rand(Float32, 3, 4)
y, back = Flux.pullback(x, p) do x, p
vec(re(p)(x))
end
@test_nowarn back(y)
b = back(y)

@test size(b[1]) == size(x)
@test size(b[2]) == size(p)
end

@testset "issue 1767" begin
struct Model{A}
a::A
b::A
end
Flux.@functor Model
(m::Model)(x) = m.a(x) .+ m.b(x)

d = Dense(1, 1)
x = rand(Float32, 1, 1)

# Sharing the parameters
model = Model(d, d)

# Works
g1 = Flux.gradient(() -> sum(model(x)), Flux.params(model))

p, re = Flux.destructure(model)
# Fails
g2 = Flux.gradient(p -> sum(re(p)(x)), p)

@test g2[1] vcat(g1[d.weight], g1[d.bias])
end

@testset "issue 1826" begin
struct Split{T} # taken from: https://fluxml.ai/Flux.jl/stable/models/advanced/#Multiple-outputs:-a-custom-Split-layer
paths::T
end
Split(paths...) = Split(paths)
Flux.@functor Split
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)

n_input, n_batch, n_shared = 5, 13, 11
n_outputs = [3, 7]

data = rand(Float32, n_input, n_batch)
model = Chain(
Dense(n_input, n_shared),
Split(Dense(n_shared, n_outputs[1]), Dense(n_shared, n_outputs[2]))
)

pvec, re = Flux.destructure(model)
loss(x, idx, pv) = sum(abs2, re(pv)(x)[idx]) # loss wrt `idx`th output term

g = Flux.Zygote.ForwardDiff.gradient(pv -> loss(data, 1, pv), pvec)
@test g Flux.Zygote.gradient(pv -> loss(data, 1, pv), pvec)[1]
end

end

0 comments on commit 81eea84

Please sign in to comment.