diff --git a/Project.toml b/Project.toml index 370ad66b11..43c48109e6 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ MLUtils = "0.2" MacroTools = "0.5" NNlib = "0.8.2" NNlibCUDA = "0.2" -Optimisers = "0.2" +Optimisers = "0.2.1" ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" diff --git a/src/Flux.jl b/src/Flux.jl index c6b7fbe975..aa3f021595 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ using MacroTools: @forward @reexport using NNlib using MLUtils -import Optimisers: trainable # before v0.13, Flux owned this function +import Optimisers: trainable, destructure # before v0.13, Flux owned these functions using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback, @nograd diff --git a/src/functor.jl b/src/functor.jl index b056ff9574..71c1ec0500 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -4,8 +4,6 @@ using Zygote: IdSet import Functors: Functors, @functor, functor, fmap, isleaf using SparseArrays: AbstractSparseArray -trainable(m) = functor(m)[1] - """ testmode!(m, mode = true) diff --git a/src/utils.jl b/src/utils.jl index e93b83a89b..c198960927 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -475,59 +475,6 @@ function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer. bias end -# Flattening models to weight vectors, and back - -function _restructure(m, xs) - i = 0 - m̄ = fmap(m) do x - x isa AbstractArray || return x - x = reshape(xs[i.+(1:length(x))], size(x)) - i += length(x) - return x - end - length(xs) == i || @warn "Expected $(i) params, got $(length(xs))" - return m̄ -end - -@adjoint function _restructure(m, xs) # TODO ChainRulesCore.rrule - m̄, numel = _restructure(m, xs), length(xs) - function _restructure_pullback(dm) - xs′ = destructure(dm)[1] - numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))" - return (nothing, xs′) - end - return m̄, _restructure_pullback -end - -""" - destructure(m) - -Flatten a model's parameters into a single weight vector. - - julia> m = Chain(Dense(10, 5, std), Dense(5, 2), softmax) - Chain(Dense(10, 5, std), Dense(5, 2), softmax) - - julia> θ, re = destructure(m); - - julia> θ - 67-element Vector{Float32}: - -0.1407104 - ... - -The second return value `re` allows you to reconstruct the original network after making -modifications to the weight vector (for example, with a hypernetwork). - - julia> re(θ .* 2) - Chain(Dense(10, 5, σ), Dense(5, 2), softmax) -""" -function destructure(m) - xs = Zygote.Buffer([]) - fmap(m) do x - x isa AbstractArray && push!(xs, x) - return x - end - return vcat(vec.(copy(xs))...), p -> _restructure(m, p) -end # Other diff --git a/test/utils.jl b/test/utils.jl index d04fd85e06..c5de12a39e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -390,11 +390,7 @@ end ∇m = gradient(m -> sum(m(x)), m)[1] p, re = destructure(m) ∇p = gradient(θ -> sum(re(θ)(x)), p)[1] - if VERSION >= v"1.7" - @test_broken ∇p ≈ destructure(∇m)[1] - else - @test ∇p ≈ destructure(∇m)[1] - end + @test ∇p ≈ destructure(∇m)[1] end end end @@ -538,3 +534,95 @@ end @test n_iter == 3 end end + +@testset "Various destructure bugs" begin + + @testset "issue 1601" begin + struct TwoDenses + dense::Dense + dense2::Dense + end + Flux.@functor TwoDenses + + function (m::TwoDenses)(x) + out = m.dense(x) + end + + model = TwoDenses( + Dense(3,1), + Dense(3,2) + ) + p, re = Flux.destructure(model) + + x = [1., 2., 3.] + y, back = Flux.Zygote.pullback((x, p) -> re(p)(x), x, p) + + dy = [4.] + dx, dp = back(dy) + @test length(p) == length(dp) + end + + @testset "issue 1727" begin + p, re = Flux.destructure(BatchNorm(3)) # 6 parameters, plus 6 non-trainable + @test length(p) == 6 + + x = rand(Float32, 3, 4) + y, back = Flux.pullback(x, p) do x, p + vec(re(p)(x)) + end + @test_nowarn back(y) + b = back(y) + + @test size(b[1]) == size(x) + @test size(b[2]) == size(p) + end + + @testset "issue 1767" begin + struct Model{A} + a::A + b::A + end + Flux.@functor Model + (m::Model)(x) = m.a(x) .+ m.b(x) + + d = Dense(1, 1) + x = rand(Float32, 1, 1) + + # Sharing the parameters + model = Model(d, d) + + # Works + g1 = Flux.gradient(() -> sum(model(x)), Flux.params(model)) + + p, re = Flux.destructure(model) + # Fails + g2 = Flux.gradient(p -> sum(re(p)(x)), p) + + @test g2[1] ≈ vcat(g1[d.weight], g1[d.bias]) + end + + @testset "issue 1826" begin + struct Split{T} # taken from: https://fluxml.ai/Flux.jl/stable/models/advanced/#Multiple-outputs:-a-custom-Split-layer + paths::T + end + Split(paths...) = Split(paths) + Flux.@functor Split + (m::Split)(x::AbstractArray) = map(f -> f(x), m.paths) + + n_input, n_batch, n_shared = 5, 13, 11 + n_outputs = [3, 7] + + data = rand(Float32, n_input, n_batch) + model = Chain( + Dense(n_input, n_shared), + Split(Dense(n_shared, n_outputs[1]), Dense(n_shared, n_outputs[2])) + ) + + pvec, re = Flux.destructure(model) + loss(x, idx, pv) = sum(abs2, re(pv)(x)[idx]) # loss wrt `idx`th output term + + g = Flux.Zygote.ForwardDiff.gradient(pv -> loss(data, 1, pv), pvec) + @test g ≈ Flux.Zygote.gradient(pv -> loss(data, 1, pv), pvec)[1] + end + +end