From aad1fbcf52198f7a26a4dc0f1535352248eee0ba Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 5 Mar 2022 17:17:27 -0500 Subject: [PATCH 1/8] rm destructure --- src/Flux.jl | 2 +- src/utils.jl | 53 ---------------------------------------------------- 2 files changed, 1 insertion(+), 54 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index c6b7fbe975..aa3f021595 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ using MacroTools: @forward @reexport using NNlib using MLUtils -import Optimisers: trainable # before v0.13, Flux owned this function +import Optimisers: trainable, destructure # before v0.13, Flux owned these functions using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback, @nograd diff --git a/src/utils.jl b/src/utils.jl index e93b83a89b..c198960927 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -475,59 +475,6 @@ function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer. bias end -# Flattening models to weight vectors, and back - -function _restructure(m, xs) - i = 0 - m̄ = fmap(m) do x - x isa AbstractArray || return x - x = reshape(xs[i.+(1:length(x))], size(x)) - i += length(x) - return x - end - length(xs) == i || @warn "Expected $(i) params, got $(length(xs))" - return m̄ -end - -@adjoint function _restructure(m, xs) # TODO ChainRulesCore.rrule - m̄, numel = _restructure(m, xs), length(xs) - function _restructure_pullback(dm) - xs′ = destructure(dm)[1] - numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))" - return (nothing, xs′) - end - return m̄, _restructure_pullback -end - -""" - destructure(m) - -Flatten a model's parameters into a single weight vector. - - julia> m = Chain(Dense(10, 5, std), Dense(5, 2), softmax) - Chain(Dense(10, 5, std), Dense(5, 2), softmax) - - julia> θ, re = destructure(m); - - julia> θ - 67-element Vector{Float32}: - -0.1407104 - ... - -The second return value `re` allows you to reconstruct the original network after making -modifications to the weight vector (for example, with a hypernetwork). - - julia> re(θ .* 2) - Chain(Dense(10, 5, σ), Dense(5, 2), softmax) -""" -function destructure(m) - xs = Zygote.Buffer([]) - fmap(m) do x - x isa AbstractArray && push!(xs, x) - return x - end - return vcat(vec.(copy(xs))...), p -> _restructure(m, p) -end # Other From dc487c69c2c1807ffdc569d3e98f6595569bf14f Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 5 Mar 2022 22:23:31 -0500 Subject: [PATCH 2/8] try to fix Downstream.yml by copying NNlib --- .github/workflows/Downstream.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 009252c23c..e59af1c40c 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -28,7 +28,7 @@ jobs: - {user: SciML, repo: DiffEqFlux.jl, group: Layers} - {user: SciML, repo: NeuralPDE.jl, group: NNPDE} - {user: SciML, repo: OperatorLearning.jl, group: All} - if: contains(github.event.pull_request.labels.*.name, 'run downstream test') + # if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 From 26e8474b6a04c30016df9213ad6df7f8e5780ce2 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 5 Mar 2022 22:23:49 -0500 Subject: [PATCH 3/8] Optimisers 0.2.1 --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 370ad66b11..653e881e09 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.0-DEV" +version = "0.12.99" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -34,7 +34,7 @@ MLUtils = "0.2" MacroTools = "0.5" NNlib = "0.8.2" NNlibCUDA = "0.2" -Optimisers = "0.2" +Optimisers = "0.2.1" ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" From 14302bed8a0192ecf3c0cfc23ab8974515825f75 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 7 Mar 2022 19:27:53 -0500 Subject: [PATCH 4/8] rm trainable fallback defn --- src/functor.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index b056ff9574..71c1ec0500 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -4,8 +4,6 @@ using Zygote: IdSet import Functors: Functors, @functor, functor, fmap, isleaf using SparseArrays: AbstractSparseArray -trainable(m) = functor(m)[1] - """ testmode!(m, mode = true) From 2ff00baf2e6977fe808dd48155f363f738909c7a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 7 Mar 2022 23:43:09 -0500 Subject: [PATCH 5/8] more tests --- test/utils.jl | 92 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/test/utils.jl b/test/utils.jl index d04fd85e06..14ea558b06 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -538,3 +538,95 @@ end @test n_iter == 3 end end + +@testset "Various destructure bugs" begin + + @testset "issue 1601" begin + struct TwoDenses + dense::Dense + dense2::Dense + end + Flux.@functor TwoDenses + + function (m::TwoDenses)(x) + out = m.dense(x) + end + + model = TwoDenses( + Dense(3,1), + Dense(3,2) + ) + p, re = Flux.destructure(model) + + x = [1., 2., 3.] + y, back = Flux.Zygote.pullback((x, p) -> re(p)(x), x, p) + + dy = [4.] + dx, dp = back(dy) + @test length(p) == length(dp) + end + + @testset "issue 1727" begin + p, re = Flux.destructure(BatchNorm(3)) # 6 parameters, plus 6 non-trainable + @test length(p) == 6 + + x = rand(Float32, 3, 4) + y, back = Flux.pullback(x, p) do x, p + vec(re(p)(x)) + end + @test_nowarn back(y) + b = back(y) + + @test size(b[1]) == size(x) + @test size(b[2]) == size(p) + end + + @testset "issue 1767" begin + struct Model{A} + a::A + b::A + end + Flux.@functor Model + (m::Model)(x) = m.a(x) .+ m.b(x) + + d = Dense(1, 1) + x = rand(Float32, 1, 1) + + # Sharing the parameters + model = Model(d, d) + + # Works + g1 = Flux.gradient(() -> sum(model(x)), Flux.params(model)) + + p, re = Flux.destructure(model) + # Fails + g2 = Flux.gradient(p -> sum(re(p)(x)), p) + + @test g2[1] ≈ vcat(g1[d.weight], g1[d.bias]) + end + + @testset "issue 1826" begin + struct Split{T} # taken from: https://fluxml.ai/Flux.jl/stable/models/advanced/#Multiple-outputs:-a-custom-Split-layer + paths::T + end + Split(paths...) = Split(paths) + Flux.@functor Split + (m::Split)(x::AbstractArray) = map(f -> f(x), m.paths) + + n_input, n_batch, n_shared = 5, 13, 11 + n_outputs = [3, 7] + + data = rand(Float32, n_input, n_batch) + model = Chain( + Dense(n_input, n_shared), + Split(Dense(n_shared, n_outputs[1]), Dense(n_shared, n_outputs[2])) + ) + + pvec, re = Flux.destructure(model) + loss(x, idx, pv) = sum(abs2, re(pv)(x)[idx]) # loss wrt `idx`th output term + + g = Flux.Zygote.ForwardDiff.gradient(pv -> loss(data, 1, pv), pvec) + @test g ≈ Flux.Zygote.gradient(pv -> loss(data, 1, pv), pvec)[1] + end + +end From 2d510d780f22c338c5b2058e9e6fa64b48c4a618 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 8 Mar 2022 00:18:32 -0500 Subject: [PATCH 6/8] test no longer broken --- test/utils.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index 14ea558b06..c5de12a39e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -390,11 +390,7 @@ end ∇m = gradient(m -> sum(m(x)), m)[1] p, re = destructure(m) ∇p = gradient(θ -> sum(re(θ)(x)), p)[1] - if VERSION >= v"1.7" - @test_broken ∇p ≈ destructure(∇m)[1] - else - @test ∇p ≈ destructure(∇m)[1] - end + @test ∇p ≈ destructure(∇m)[1] end end end From b90973a4531abd72c619b11a9fb59b9d19a5e182 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 11 Mar 2022 15:21:40 -0500 Subject: [PATCH 7/8] enlarge downstream for now --- .github/workflows/Downstream.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index e59af1c40c..3a74979e91 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -25,8 +25,8 @@ jobs: - {user: FluxML, repo: Torch.jl, group: All} - {user: FluxML, repo: Metalhead.jl, group: All} - {user: Chemellia, repo: AtomicGraphNets.jl, group: All} - - {user: SciML, repo: DiffEqFlux.jl, group: Layers} - - {user: SciML, repo: NeuralPDE.jl, group: NNPDE} + - {user: SciML, repo: DiffEqFlux.jl, group: All} # Layers} + - {user: SciML, repo: NeuralPDE.jl, group: All} # NNPDE} - {user: SciML, repo: OperatorLearning.jl, group: All} # if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: From 91d88fe306f78d1db5e95e50cd2674d0d19cd982 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 20 Mar 2022 22:44:09 -0400 Subject: [PATCH 8/8] revert steps for downstream testing --- .github/workflows/Downstream.yml | 6 +++--- Project.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 3a74979e91..009252c23c 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -25,10 +25,10 @@ jobs: - {user: FluxML, repo: Torch.jl, group: All} - {user: FluxML, repo: Metalhead.jl, group: All} - {user: Chemellia, repo: AtomicGraphNets.jl, group: All} - - {user: SciML, repo: DiffEqFlux.jl, group: All} # Layers} - - {user: SciML, repo: NeuralPDE.jl, group: All} # NNPDE} + - {user: SciML, repo: DiffEqFlux.jl, group: Layers} + - {user: SciML, repo: NeuralPDE.jl, group: NNPDE} - {user: SciML, repo: OperatorLearning.jl, group: All} - # if: contains(github.event.pull_request.labels.*.name, 'run downstream test') + if: contains(github.event.pull_request.labels.*.name, 'run downstream test') steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 diff --git a/Project.toml b/Project.toml index 653e881e09..43c48109e6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.12.99" +version = "0.13.0-DEV" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"