From cd06023fd2bd62187a1d809e9c8fce183550aebe Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 14 Feb 2022 17:54:42 -0600 Subject: [PATCH 01/19] Add initial implementation --- Project.toml | 2 +- src/Flux.jl | 2 ++ src/deprecations.jl | 2 ++ src/functor.jl | 8 -------- src/loading.jl | 46 +++++++++++++++++++++++++++++++++++++++++++++ test/utils.jl | 19 ++++++++++++++++++- 6 files changed, 69 insertions(+), 10 deletions(-) create mode 100644 src/loading.jl diff --git a/Project.toml b/Project.toml index bf2bc008de..adfa20e719 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ Adapt = "3.0" ArrayInterface = "3.1, 4, 5" CUDA = "3" ChainRulesCore = "1.12" -Functors = "0.2.1" +Functors = "0.2.8" MLUtils = "0.2" MacroTools = "0.5" NNlib = "0.8.2" diff --git a/src/Flux.jl b/src/Flux.jl index aa3f021595..0458dc6cea 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -46,6 +46,8 @@ include("layers/normalise.jl") include("layers/upsample.jl") include("layers/show.jl") +include("loading.jl") + include("outputsize.jl") include("data/Data.jl") diff --git a/src/deprecations.jl b/src/deprecations.jl index 597fc5a913..33a0fa3060 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -15,6 +15,8 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...) ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, use Base.ones to specify the element type")) zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type")) +@deprecate loadparams!(m, xs) loadmodel!(m, xs) + # v0.13 deprecations function Broadcast.broadcasted(f::Recur, args...) diff --git a/src/functor.jl b/src/functor.jl index ebd6dd1102..905a37b8d6 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -85,14 +85,6 @@ function params(m...) return ps end -function loadparams!(m, xs) - for (p, x) in zip(params(m), xs) - size(p) == size(x) || - error("Expected param size $(size(p)), got $(size(x))") - copyto!(p, x) - end -end - struct FluxCUDAAdaptor end adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) diff --git a/src/loading.jl b/src/loading.jl new file mode 100644 index 0000000000..6e2019566f --- /dev/null +++ b/src/loading.jl @@ -0,0 +1,46 @@ +_loadleaf(x) = isleaf(x) +for T in [:Dense, :Diagonal, :Bilinear, :Embedding, + :Conv, :ConvTranspose, :DepthwiseConv, :CrossCor] + @eval _loadleaf(::$T) = true +end + +loadto!(x, x̄) = x +loadto!(x::AbstractArray, x̄::AbstractArray) = copyto!(x, x̄) +for T in [:Dense, :Bilinear, :Conv, :ConvTranspose, :DepthwiseConv, :CrossCor] + @eval begin + function loadto!(m::$T, m̄::$T) + if (size(m.weight) != size(m̄.weight)) || (size(m.bias) != size(m̄.bias)) + throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")) + else + return fmap(loadto!, m, m̄) + end + end + loadto!(m::$T, m̄) = throw(ArgumentError("Tried to load $m̄ into $m.")) + end +end +function loadto!(m::Diagonal, m̄::Diagonal) + if (size(m.α) != size(m̄.α)) || (size(m.β) != size(m̄.β)) + throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")) + else + return fmap(loadto!, m, m̄) + end +end +loadto!(m::Diagonal, m̄) = throw(ArgumentError("Tried to load $m̄ into $m.")) +function loadto!(m::Embedding, m̄::Embedding) + if size(m.weight) != size(m̄.weight) + throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")) + else + return fmap(loadto!, m, m̄) + end +end +loadto!(m::Embedding, m̄) = throw(ArgumentError("Tried to load $m̄ into $m.")) + +function loadmodel!(m, xs::Params) + for (p, x) in zip(params(m), xs) + size(p) == size(x) || + error("Expected param size $(size(p)), got $(size(x))") + copyto!(p, x) + end +end +loadmodel!(m, xs::AbstractVector) = loadmodel!(m, params(xs)) +loadmodel!(m, m̄) = fmap(loadto!, m, m̄; exclude = _loadleaf) diff --git a/test/utils.jl b/test/utils.jl index 14b5ad9bbc..89d7f9f76c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -373,11 +373,28 @@ end weights(m) = mapreduce(l -> [l.weight], vcat, m) @testset "Bias type $bt" for bt in (Flux.zeros32, nobias) m = dm(bt) - loadparams!(m, params(m)) + loadmodel!(m, params(m)) testdense(m, bt) end end + @testset "loadmodel!(m, m̄)" begin + import Flux: loadmodel! + + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) + m2 = Chain(Dense(10, 5), Dense(5, 2)) + m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) + m4 = Chain(Dense(10, 6), Dense(6, 2)) + + loadmodel!(m1, m2) + @test m1[1].weight == m2[1].weight + @test m1[1].bias == m2[1].bias + @test m1[2].σ == relu + + @test_throws ArgumentError loadmodel!(m1, m3) + @test_throws DimensionMismatch loadmodel!(m1, m4) + end + @testset "destructure" begin import Flux: destructure @testset "Bias type $bt" for bt in (zeros, nobias) From 99a18ec1b69141508472926000e065c05ce07652 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 14 Feb 2022 18:12:08 -0600 Subject: [PATCH 02/19] Add more tests --- src/loading.jl | 1 + test/utils.jl | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/loading.jl b/src/loading.jl index 6e2019566f..975f25a908 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -5,6 +5,7 @@ for T in [:Dense, :Diagonal, :Bilinear, :Embedding, end loadto!(x, x̄) = x +loadto!(x::Zeros, x̄) = x loadto!(x::AbstractArray, x̄::AbstractArray) = copyto!(x, x̄) for T in [:Dense, :Bilinear, :Conv, :ConvTranspose, :DepthwiseConv, :CrossCor] @eval begin diff --git a/test/utils.jl b/test/utils.jl index 89d7f9f76c..df5f29a6cc 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -379,20 +379,27 @@ end end @testset "loadmodel!(m, m̄)" begin - import Flux: loadmodel! + import Flux: loadmodel!, Zeros m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) m2 = Chain(Dense(10, 5), Dense(5, 2)) m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) m4 = Chain(Dense(10, 6), Dense(6, 2)) + m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(5, 2), Zeros()), Dense(5, 2))) + m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) loadmodel!(m1, m2) @test m1[1].weight == m2[1].weight @test m1[1].bias == m2[1].bias @test m1[2].σ == relu + loadmodel!(m5, m6) + @test m5[1].weight == m6[1].weight + @test m5[2][1].weight == m6[2][1].weight + @test m5[2][1].bias == Zeros() @test_throws ArgumentError loadmodel!(m1, m3) @test_throws DimensionMismatch loadmodel!(m1, m4) + @test_throws ArgumentError loadmodel!(m1, m5) end @testset "destructure" begin From 492c34e408bf2ae62fdc4b64836cf00c2f939069 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 14 Feb 2022 18:34:40 -0600 Subject: [PATCH 03/19] Fix typo in tests --- test/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index df5f29a6cc..e11397562d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -385,7 +385,7 @@ end m2 = Chain(Dense(10, 5), Dense(5, 2)) m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) m4 = Chain(Dense(10, 6), Dense(6, 2)) - m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(5, 2), Zeros()), Dense(5, 2))) + m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), Zeros()), Dense(5, 2))) m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) loadmodel!(m1, m2) From dee58429c5ed24facecd6e7eb383eac000b0a320 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 15 Feb 2022 13:41:39 -0600 Subject: [PATCH 04/19] Refactor to allow better support for loading errors with custom models --- src/loading.jl | 55 +++++++++++++++++++------------------------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index 975f25a908..3f51c4a3ff 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -1,40 +1,27 @@ -_loadleaf(x) = isleaf(x) -for T in [:Dense, :Diagonal, :Bilinear, :Embedding, - :Conv, :ConvTranspose, :DepthwiseConv, :CrossCor] - @eval _loadleaf(::$T) = true -end +isloadleaf(x) = all(Functors.isleaf, Functors.children(x)) -loadto!(x, x̄) = x -loadto!(x::Zeros, x̄) = x -loadto!(x::AbstractArray, x̄::AbstractArray) = copyto!(x, x̄) -for T in [:Dense, :Bilinear, :Conv, :ConvTranspose, :DepthwiseConv, :CrossCor] - @eval begin - function loadto!(m::$T, m̄::$T) - if (size(m.weight) != size(m̄.weight)) || (size(m.bias) != size(m̄.bias)) - throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")) - else - return fmap(loadto!, m, m̄) - end - end - loadto!(m::$T, m̄) = throw(ArgumentError("Tried to load $m̄ into $m.")) - end +loadnumeric!(x, x̄, err) = x +loadnumeric!(x::Zeros, x̄, err) = x +function loadnumeric!(x::AbstractArray, x̄::AbstractArray, err) + (size(x) == size(x̄)) || throw(err) + copyto!(x, x̄) end -function loadto!(m::Diagonal, m̄::Diagonal) - if (size(m.α) != size(m̄.α)) || (size(m.β) != size(m̄.β)) - throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")) - else - return fmap(loadto!, m, m̄) - end + +function _loadto!(m, m̄) + ls, _ = functor(m) + l̄s, _ = functor(m̄) + (keys(ls) == keys(l̄s)) || + throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match.")) + + err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.") + foreach((l, l̄) -> loadnumeric!(l, l̄, err), ls, l̄s) + + return m end -loadto!(m::Diagonal, m̄) = throw(ArgumentError("Tried to load $m̄ into $m.")) -function loadto!(m::Embedding, m̄::Embedding) - if size(m.weight) != size(m̄.weight) - throw(DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.")) - else - return fmap(loadto!, m, m̄) - end +function loadto!(m::T, m̄::S) where {T, S} + (nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m.")) + _loadto!(m, m̄) end -loadto!(m::Embedding, m̄) = throw(ArgumentError("Tried to load $m̄ into $m.")) function loadmodel!(m, xs::Params) for (p, x) in zip(params(m), xs) @@ -44,4 +31,4 @@ function loadmodel!(m, xs::Params) end end loadmodel!(m, xs::AbstractVector) = loadmodel!(m, params(xs)) -loadmodel!(m, m̄) = fmap(loadto!, m, m̄; exclude = _loadleaf) +loadmodel!(m, m̄) = fmap(loadto!, m, m̄; exclude = isloadleaf) From b4fe66b569d39127f72d33efa22f7809a4ef558a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 5 Mar 2022 10:36:47 -0600 Subject: [PATCH 05/19] Add documentation for `loadmodel!` --- docs/src/saving.md | 43 +++++++++++++++++++++++-------------------- src/loading.jl | 46 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 24 deletions(-) diff --git a/docs/src/saving.md b/docs/src/saving.md index d9db750d1e..5d5c3432c4 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -2,7 +2,7 @@ You may wish to save models so that they can be loaded and run in a later session. The easiest way to do this is via -[BSON.jl](https://github.com/MikeInnes/BSON.jl). +[BSON.jl](https://github.com/JuliaIO/BSON.jl). Save a model: @@ -46,15 +46,17 @@ versions of Flux). !!! note - If a saved model's weights are stored on the GPU, the model will not load + If a saved model's parameters are stored on the GPU, the model will not load later on if there is no GPU support available. It's best to [move your model to the CPU](gpu.md) with `cpu(model)` before saving it. -## Saving Model Weights +!!! warning -In some cases it may be useful to save only the model parameters themselves, and -rebuild the model architecture in your code. You can use `params(model)` to get -model parameters. + Previous versions of Flux suggested saving only the model weights using + `@save "mymodel.bson" params(model)`. + This is no longer recommended and even strongly discouraged. + Saving models this way will only store the trainable parameters which + will result in incorrect behavior for layers like `BatchNorm`. ```Julia julia> using Flux @@ -64,28 +66,29 @@ Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax) julia> weights = Flux.params(model); -julia> using BSON: @save - -julia> @save "mymodel.bson" weights -``` - -You can easily load parameters back into a model with `Flux.loadparams!`. +Loading the model as shown above will return a new model with the stored parameters. +But sometimes you already have a model, and you want to load stored parameters into it. +This can be done as ```julia -julia> using Flux +using Flux: loadmodel! +using BSON: @load -julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax) -Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax) +# some predefined model +model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax) -julia> using BSON: @load +# load one model into another +model = loadmodel!(model, @load("mymodel.bson")) +``` -julia> @load "mymodel.bson" weights +This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory. -julia> Flux.loadparams!(model, weights) +```@docs +Flux.loadmodel! +Flux.isloadleaf +Flux.loadleaf! ``` -The new `model` we created will now be identical to the one we saved parameters for. - ## Checkpointing In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md). diff --git a/src/loading.jl b/src/loading.jl index 3f51c4a3ff..8646f63b00 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -1,8 +1,25 @@ +""" + isloadleaf(x) + +Return `true` whenever `x` should be treated as a "leaf node" +for the purposes of loading parameters. +By default, `isloadleaf` returns `true` if [`Functors.isleaf`](@ref) +is `true` for all [`Functors.children(x)`](@ref `Functors.children`). + +You can override this function for a specific type if needed. +""" isloadleaf(x) = all(Functors.isleaf, Functors.children(x)) -loadnumeric!(x, x̄, err) = x -loadnumeric!(x::Zeros, x̄, err) = x -function loadnumeric!(x::AbstractArray, x̄::AbstractArray, err) +""" + loadleaf!(x, x̄, err) + +Copy `x̄` to `x` or throw `err` when their sizes are mismatched. +By default, use `copyto!` when `x` and `x̄` are arrays. +Otherwise, just return `x`. +""" +loadleaf!(x, x̄, err) = x +loadleaf!(x::Zeros, x̄, err) = x +function loadleaf!(x::AbstractArray, x̄::AbstractArray, err) (size(x) == size(x̄)) || throw(err) copyto!(x, x̄) end @@ -14,7 +31,7 @@ function _loadto!(m, m̄) throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match.")) err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.") - foreach((l, l̄) -> loadnumeric!(l, l̄, err), ls, l̄s) + foreach((l, l̄) -> loadleaf!(l, l̄, err), ls, l̄s) return m end @@ -23,6 +40,27 @@ function loadto!(m::T, m̄::S) where {T, S} _loadto!(m, m̄) end +""" + loadmodel!(m, m̄) + +Copy all the parameters (trainable and non-trainable) from `m̄` to `m`. + +`loadmodel!` recursively walks `m` and `m̄` until it encounters +a subfield, `x`, (i.e. layer) where `isloadleaf(x)` is true. +The parameters of the matching subfield, `x̄`, are copied to `x`, +throwing an error whenever: +- `x` and `x̄` are not the same type (e.g. loading a `Conv` to a `Dense`) +- `x` and `x̄` do not share the same fields +- the parameter sizes are mismatched between `x` and `x̄` + +See [`loadleaf!`](@ref) for more details on the copy behavior. +See [`isloadleaf`](@ref) for more details on which layers are considered leaves. + +!!! warning + This function allows `m̄` to be a vector or `Params` for backwards-compatibility. + You should avoid using `loadmodel!` this way, because it skips most of the structural + checking used when `m̄` is also a struct. Silent errors may occur. +""" function loadmodel!(m, xs::Params) for (p, x) in zip(params(m), xs) size(p) == size(x) || From 0790f249b92388d7e1ea01aca22ad1ce0aea2963 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 5 Mar 2022 10:43:02 -0600 Subject: [PATCH 06/19] Spacing in docs --- docs/src/saving.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/saving.md b/docs/src/saving.md index 5d5c3432c4..e5008296ea 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -75,7 +75,7 @@ using Flux: loadmodel! using BSON: @load # some predefined model -model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax) +model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax) # load one model into another model = loadmodel!(model, @load("mymodel.bson")) From a155a446fc31fa266bce50934e2ad94f3725b23f Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 5 Mar 2022 10:54:49 -0600 Subject: [PATCH 07/19] Fix tests --- src/loading.jl | 27 +++++++++++++++------------ test/utils.jl | 16 ++++++++-------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index 8646f63b00..f133c064ca 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -18,26 +18,29 @@ By default, use `copyto!` when `x` and `x̄` are arrays. Otherwise, just return `x`. """ loadleaf!(x, x̄, err) = x -loadleaf!(x::Zeros, x̄, err) = x +function loadleaf!(x::AbstractArray, x̄, err) + x .= x̄ + return x +end function loadleaf!(x::AbstractArray, x̄::AbstractArray, err) - (size(x) == size(x̄)) || throw(err) - copyto!(x, x̄) + (size(x) == size(x̄)) || throw(err) + copyto!(x, x̄) end function _loadto!(m, m̄) - ls, _ = functor(m) - l̄s, _ = functor(m̄) - (keys(ls) == keys(l̄s)) || - throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match.")) + ls, _ = functor(m) + l̄s, _ = functor(m̄) + (keys(ls) == keys(l̄s)) || + throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match.")) - err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.") - foreach((l, l̄) -> loadleaf!(l, l̄, err), ls, l̄s) + err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.") + foreach((l, l̄) -> loadleaf!(l, l̄, err), ls, l̄s) - return m + return m end function loadto!(m::T, m̄::S) where {T, S} - (nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m.")) - _loadto!(m, m̄) + (nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m.")) + _loadto!(m, m̄) end """ diff --git a/test/utils.jl b/test/utils.jl index e11397562d..82b1bbbd0a 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -2,7 +2,7 @@ using Flux using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, sparse_init, identity_init, stack, unstack, batch, unbatch, - unsqueeze, params, loadparams! + unsqueeze, params, loadmodel! using StatsBase: var, std using Statistics, LinearAlgebra using Random @@ -366,14 +366,14 @@ end @test_skip typeof(l1.bias) === typeof(l2.bias) end - @testset "loadparams!" begin + @testset "loadmodel!" begin pars(w, b) = [w, b] pars(l) = pars(l.weight, l.bias) pararray(m) = mapreduce(pars, vcat, m) weights(m) = mapreduce(l -> [l.weight], vcat, m) @testset "Bias type $bt" for bt in (Flux.zeros32, nobias) m = dm(bt) - loadmodel!(m, params(m)) + Flux.loadmodel!(m, params(m)) testdense(m, bt) end end @@ -421,22 +421,22 @@ end end end -@testset "loadparams! & absent bias" begin +@testset "loadmodel! & absent bias" begin m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) - Flux.loadparams!(m1, Flux.params(m2)) + Flux.loadmodel!(m1, m2) @test m1[1].bias == 7:9 @test sum(m1[1].weight) == 21 # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it - @test_broken Flux.loadparams!(m1, Flux.params(m0)) - @test_broken iszero(m1[1].bias) + m1 = Flux.loadmodel!(m1, m0) + @test iszero(m1[1].bias) @test sum(m1[1].weight) == 6 # written before error # load into a model without bias -- should it ignore the parameter which has no home, or error? - @test_broken Flux.loadparams!(m0, Flux.params(m2)) + m0 = Flux.loadmodel!(m0, m2) @test iszero(m0[1].bias) # obviously unchanged @test sum(m0[1].weight) == 21 end From fbc9fafe5907b86c0e14a3670c6bbb2797535e01 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 5 Mar 2022 10:57:34 -0600 Subject: [PATCH 08/19] Add NEWS entry for `loadmodel!` --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 9c4907dc33..d97c94fea5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -12,6 +12,7 @@ been removed in favour of MLDatasets.jl. * The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`. * Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights. * The `Flux.Diagonal` layer is now called `Scale`, and accepts an activation function. +* `loadparams!` is replaced by [`loadmodel!`](https://github.com/FluxML/Flux.jl/pull/1875) which copies trainable + non-trainable parameters and performs more thorough structural checking ## v0.12.10 * `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838) From b2a2664c426deb308eff6be7d6e78058b260e485 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 5 Mar 2022 11:26:36 -0600 Subject: [PATCH 09/19] Better docs --- docs/src/saving.md | 13 ++++++++++++ src/loading.jl | 50 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/docs/src/saving.md b/docs/src/saving.md index e5008296ea..6418638bd4 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -85,10 +85,23 @@ This ensures that the model loaded from `"mymodel.bson"` matches the structure o ```@docs Flux.loadmodel! +Flux.loadto! Flux.isloadleaf Flux.loadleaf! ``` +### Customizing `loadmodel!` for a custom layer + +By default, [`loadmodel!`](@ref) will recursively walk a nested model (like a `Chain`) using [`Functors.fmap`](@ref) until it encounters a loading *leaf node*. A leaf node is defined as any node for which [`Flux.isloadleaf`](@ref) returns `true`. For example, consider the model + +```julia +model = Chain(Dense(10 => 5), Parallel(+, Dense(5 => 2), Dense(5 => 2))) +``` + +Here, the `Chain` and `Parallel` layers are not leaf nodes, but all the `Dense` layers are leaf nodes. This makes sense, because `Dense` layers are the ones with parameters that we need to copy. The default behavior for [`Flux.isloadleaf`](@ref) should work for most custom layers, but you can override this function for your type. + +Once a pair of leaf nodes is encountered, `loadmodel!` will call [`Flux.loadto!](@ref) on them. By default, this just copies the parameters from one leaf node to the other, but you can customize the behavior by overriding `loadto!` for your pair of types. + ## Checkpointing In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md). diff --git a/src/loading.jl b/src/loading.jl index f133c064ca..46fed73dc7 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -27,7 +27,17 @@ function loadleaf!(x::AbstractArray, x̄::AbstractArray, err) copyto!(x, x̄) end -function _loadto!(m, m̄) +""" + loadto!(m, m̄) + +Load a leaf node `m̄` into `m`. + +By default, call [`Flux.loadleaf!`](@ref) on each pair of children +in `zip(Functors.children(m), Functors.children(m̄))`. +""" +function loadto!(m::T, m̄::S) where {T, S} + (nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m.")) + ls, _ = functor(m) l̄s, _ = functor(m̄) (keys(ls) == keys(l̄s)) || @@ -38,10 +48,6 @@ function _loadto!(m, m̄) return m end -function loadto!(m::T, m̄::S) where {T, S} - (nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m.")) - _loadto!(m, m̄) -end """ loadmodel!(m, m̄) @@ -56,8 +62,38 @@ throwing an error whenever: - `x` and `x̄` do not share the same fields - the parameter sizes are mismatched between `x` and `x̄` -See [`loadleaf!`](@ref) for more details on the copy behavior. -See [`isloadleaf`](@ref) for more details on which layers are considered leaves. +```julia +julia> using Flux: loadmodel! + +julia> m = Chain(Dense(Flux.ones32(2, 5)), Dense(2 => 1)) +Chain( + Dense(5 => 2), # 12 parameters + Dense(2 => 1), # 3 parameters +) # Total: 4 arrays, 15 parameters, 316 bytes. + +julia> m̄ = Chain(Dense(5 => 2), Dense(2 => 1)); + +julia> all(isone, m[1].weight) +true + +julia> m = loadmodel!(m, m̄) +Chain( + Dense(5 => 2), # 12 parameters + Dense(2 => 1), # 3 parameters +) # Total: 4 arrays, 15 parameters, 316 bytes. + +julia> all(isone, m[1].weight) +false + +julia> m[1].weight == m̄[1].weight +true + +julia> m[2].bias == m̄[2].bias +true +``` + +See [`Flux.loadleaf!`](@ref) for more details on the copy behavior. +See [`Flux.isloadleaf`](@ref) for more details on which layers are considered leaves. !!! warning This function allows `m̄` to be a vector or `Params` for backwards-compatibility. From a6cdfdd7022e3646696cea09f84ffc991e7314d8 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 7 Mar 2022 11:15:19 -0600 Subject: [PATCH 10/19] Refactor `loadmodel!` to use a custom recursion instead of `fmap`. Add more tests. --- docs/src/saving.md | 14 ------ src/loading.jl | 107 +++++++++++++++++---------------------------- test/utils.jl | 100 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 138 insertions(+), 83 deletions(-) diff --git a/docs/src/saving.md b/docs/src/saving.md index 6418638bd4..20cd5eaf36 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -85,23 +85,9 @@ This ensures that the model loaded from `"mymodel.bson"` matches the structure o ```@docs Flux.loadmodel! -Flux.loadto! -Flux.isloadleaf Flux.loadleaf! ``` -### Customizing `loadmodel!` for a custom layer - -By default, [`loadmodel!`](@ref) will recursively walk a nested model (like a `Chain`) using [`Functors.fmap`](@ref) until it encounters a loading *leaf node*. A leaf node is defined as any node for which [`Flux.isloadleaf`](@ref) returns `true`. For example, consider the model - -```julia -model = Chain(Dense(10 => 5), Parallel(+, Dense(5 => 2), Dense(5 => 2))) -``` - -Here, the `Chain` and `Parallel` layers are not leaf nodes, but all the `Dense` layers are leaf nodes. This makes sense, because `Dense` layers are the ones with parameters that we need to copy. The default behavior for [`Flux.isloadleaf`](@ref) should work for most custom layers, but you can override this function for your type. - -Once a pair of leaf nodes is encountered, `loadmodel!` will call [`Flux.loadto!](@ref) on them. By default, this just copies the parameters from one leaf node to the other, but you can customize the behavior by overriding `loadto!` for your pair of types. - ## Checkpointing In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md). diff --git a/src/loading.jl b/src/loading.jl index 46fed73dc7..19bb3293c4 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -1,104 +1,68 @@ """ - isloadleaf(x) - -Return `true` whenever `x` should be treated as a "leaf node" -for the purposes of loading parameters. -By default, `isloadleaf` returns `true` if [`Functors.isleaf`](@ref) -is `true` for all [`Functors.children(x)`](@ref `Functors.children`). - -You can override this function for a specific type if needed. -""" -isloadleaf(x) = all(Functors.isleaf, Functors.children(x)) + loadleaf!(dst, src, err) +Copy `src` to `dst` or throw `err` when their sizes are mismatched. +By default, use `copyto!` when `dst` and `src` are arrays. +When only `dst` is an array, set every element to `src`. +Otherwise, just return `dst`. """ - loadleaf!(x, x̄, err) - -Copy `x̄` to `x` or throw `err` when their sizes are mismatched. -By default, use `copyto!` when `x` and `x̄` are arrays. -Otherwise, just return `x`. -""" -loadleaf!(x, x̄, err) = x -function loadleaf!(x::AbstractArray, x̄, err) - x .= x̄ - return x -end -function loadleaf!(x::AbstractArray, x̄::AbstractArray, err) - (size(x) == size(x̄)) || throw(err) - copyto!(x, x̄) +loadleaf!(dst, src, err) = dst +function loadleaf!(dst::AbstractArray, src, err) + dst .= src + return dst end - -""" - loadto!(m, m̄) - -Load a leaf node `m̄` into `m`. - -By default, call [`Flux.loadleaf!`](@ref) on each pair of children -in `zip(Functors.children(m), Functors.children(m̄))`. -""" -function loadto!(m::T, m̄::S) where {T, S} - (nameof(T) == nameof(S)) || throw(ArgumentError("Tried to load $m̄ into $m.")) - - ls, _ = functor(m) - l̄s, _ = functor(m̄) - (keys(ls) == keys(l̄s)) || - throw(ArgumentError("Tried to load $m̄ into $m but the structures do not match.")) - - err = DimensionMismatch("Tried to load $m̄ into $m but the parameter sizes do not match.") - foreach((l, l̄) -> loadleaf!(l, l̄, err), ls, l̄s) - - return m +function loadleaf!(dst::AbstractArray, src::AbstractArray, err) + (size(dst) == size(src)) || throw(err) + copyto!(dst, src) end """ - loadmodel!(m, m̄) + loadmodel!(dst, src) -Copy all the parameters (trainable and non-trainable) from `m̄` to `m`. +Copy all the parameters (trainable and non-trainable) from `src` to `dst`. -`loadmodel!` recursively walks `m` and `m̄` until it encounters -a subfield, `x`, (i.e. layer) where `isloadleaf(x)` is true. -The parameters of the matching subfield, `x̄`, are copied to `x`, -throwing an error whenever: -- `x` and `x̄` are not the same type (e.g. loading a `Conv` to a `Dense`) -- `x` and `x̄` do not share the same fields -- the parameter sizes are mismatched between `x` and `x̄` +`loadmodel!` recursively walks the [`Functors.children`](@ref) of `dst` and `src` +calling `loadleaf!` on any pair of children where [`Functors.isleaf`](@ref) is true. +It throws an error whenever: +- `dst` and `src` do not share the same fields (at any level) +- the sizes of leaf nodes are mismatched between `dst` and `src` ```julia julia> using Flux: loadmodel! -julia> m = Chain(Dense(Flux.ones32(2, 5)), Dense(2 => 1)) +julia> dst = Chain(Dense(Flux.ones32(2, 5)), Dense(2 => 1)) Chain( Dense(5 => 2), # 12 parameters Dense(2 => 1), # 3 parameters ) # Total: 4 arrays, 15 parameters, 316 bytes. -julia> m̄ = Chain(Dense(5 => 2), Dense(2 => 1)); +julia> src = Chain(Dense(5 => 2), Dense(2 => 1)); -julia> all(isone, m[1].weight) +julia> all(isone, dst[1].weight) true -julia> m = loadmodel!(m, m̄) +julia> dst = loadmodel!(dst, src) Chain( Dense(5 => 2), # 12 parameters Dense(2 => 1), # 3 parameters ) # Total: 4 arrays, 15 parameters, 316 bytes. -julia> all(isone, m[1].weight) +julia> all(isone, dst[1].weight) false -julia> m[1].weight == m̄[1].weight +julia> dst[1].weight == src[1].weight true -julia> m[2].bias == m̄[2].bias +julia> dst[2].bias == src[2].bias true ``` See [`Flux.loadleaf!`](@ref) for more details on the copy behavior. -See [`Flux.isloadleaf`](@ref) for more details on which layers are considered leaves. !!! warning - This function allows `m̄` to be a vector or `Params` for backwards-compatibility. + This function allows `src` to be a `Params` for backwards-compatibility. You should avoid using `loadmodel!` this way, because it skips most of the structural - checking used when `m̄` is also a struct. Silent errors may occur. + checking used when `src` is also a nested structure. Silent errors may occur. """ function loadmodel!(m, xs::Params) for (p, x) in zip(params(m), xs) @@ -107,5 +71,16 @@ function loadmodel!(m, xs::Params) copyto!(p, x) end end -loadmodel!(m, xs::AbstractVector) = loadmodel!(m, params(xs)) -loadmodel!(m, m̄) = fmap(loadto!, m, m̄; exclude = isloadleaf) +function loadmodel!(dst, src) + ldsts, _ = functor(dst) + lsrcs, _ = functor(src) + (keys(ldsts) == keys(lsrcs)) || + throw(ArgumentError("Tried to load $src into $dst but the structures do not match.")) + + err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.") + foreach(ldsts, lsrcs) do ldst, lsrc + Functors.isleaf(ldst) ? loadleaf!(ldst, lsrc, err) : loadmodel!(ldst, lsrc) + end + + return dst +end diff --git a/test/utils.jl b/test/utils.jl index 82b1bbbd0a..0136ba49ff 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -378,7 +378,7 @@ end end end - @testset "loadmodel!(m, m̄)" begin + @testset "loadmodel!(dst, src)" begin import Flux: loadmodel!, Zeros m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) @@ -389,17 +389,111 @@ end m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) loadmodel!(m1, m2) + # trainable parameters copy over @test m1[1].weight == m2[1].weight @test m1[1].bias == m2[1].bias + # non-array leaves are untouched @test m1[2].σ == relu + loadmodel!(m5, m6) + # more complex nested structures also work @test m5[1].weight == m6[1].weight @test m5[2][1].weight == m6[2][1].weight - @test m5[2][1].bias == Zeros() + # false bias is not overwritten + @test m5[2][1].bias == false + # mismatched nodes throw an error @test_throws ArgumentError loadmodel!(m1, m3) - @test_throws DimensionMismatch loadmodel!(m1, m4) @test_throws ArgumentError loadmodel!(m1, m5) + # size mismatches throw an error + @test_throws DimensionMismatch loadmodel!(m1, m4) + + m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2)) + m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1)) + m2[2].μ .= rand(Float32, size(m2[2].μ)...) + loadmodel!(m1, m2) + # non-trainable parameters are copied as well + @test m1[2].μ == m2[2].μ + # functions are not copied + @test m1[3] == Flux.flatten + # dropout rate is not copied + @test m1[4].p == 0.2 + + # from LegolasFlux (https://github.com/beacon-biosignals/LegolasFlux.jl/blob/80569ab63a8248a8a063c76e0bbf701f4ada9bd4/examples/digits.jl#L33) + # tests Chain(...) vs Chain([...]) + # tests MaxPool + # tests testmode!/trainmode! is not copied + # tests Dense, Conv, BatchNorm, Dropout (like above) but in a bigger model + chain1 = Chain(Dropout(0.2), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((2, 2)), + Dropout(0.2), + Conv((3, 3), 32 => 16, relu), + Dropout(0.2), + MaxPool((2, 2)), + Dropout(0.2), + Conv((3, 3), 16 => 10, relu), + Dropout(0.2), + x -> reshape(x, :, size(x, 4)), + Dropout(0.2), + Dense(90, 10), + softmax) + chain2 = Chain([Dropout(0.1), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((3, 3)), + Dropout(0.1), + Conv((3, 3), 32 => 16, relu), + Dropout(0.1), + MaxPool((3, 3)), + Dropout(0.1), + Conv((3, 3), 16 => 10, relu), + Dropout(0.1), + x -> reshape(x, :, size(x, 4)), + Dropout(0.1), + Dense(90, 10), + softmax]) + chain2[3].μ .= 5f0 + chain2[3].σ² .= 2f0 + testmode!(chain2) + loadmodel!(chain1, chain2) + for (dst, src) in zip(chain1, chain2) + if dst isa Dropout + @test dst.p == 0.2 + elseif dst isa Union{Conv, Dense} + @test dst.weight == src.weight + @test dst.bias == src.bias + elseif dst isa MaxPool + @test dst.k == (2, 2) + elseif dst isa BatchNorm + @test dst.μ == src.μ + @test dst.σ² == src.σ² + @test isnothing(dst.active) + end + end + + # copy only a subset of the model + chain1[end - 1].weight .= 1f0 + chain1[3].μ .= 3f0 + chain1[2].bias .= 5f0 + loadmodel!(chain2[end - 1], chain1[end - 1]) + loadmodel!(chain2[3], chain1[3]) + @test chain2[end - 1].weight == chain1[end - 1].weight + @test chain2[3].μ == chain1[3].μ + @test chain2[2].bias != chain1[2].bias + + # test shared weights + m1 = Chain(Dense(10 => 5), Dense(5 => 2)) + m2 = Chain(Dense(transpose(m1[2].weight)), Dense(permutedims(m1[1].weight))) + m3 = Chain(Dense(m1[1].weight), Dense(m1[2].weight)) + m2[2].weight .= 1f0 + loadmodel!(m1, m3) + @test m1[2].weight === parent(m2[1].weight) + @test m1[2].weight == transpose(m2[1].weight) + @test m1[1].weight === m3[1].weight + @test m2[2].weight != transpose(m1[1].weight) + @test m3[2].weight == transpose(m2[1].weight) end @testset "destructure" begin From 29662b26a4147f3f8017811a47f3e08a77a22496 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 2 Apr 2022 16:07:21 -0500 Subject: [PATCH 11/19] Add better support for `loadmodel!` w/ tied parameters and address some other review comments --- src/deprecations.jl | 12 +++++++-- src/loading.jl | 64 ++++++++++++++++++++++++++------------------- test/utils.jl | 47 +++++++++++++++++++++------------ 3 files changed, 78 insertions(+), 45 deletions(-) diff --git a/src/deprecations.jl b/src/deprecations.jl index 33a0fa3060..118e25dbb5 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -15,8 +15,6 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...) ones32(::Type, dims...) = throw(ArgumentError("Flux.ones32 is always Float32, use Base.ones to specify the element type")) zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, use Base.zeros to specify the element type")) -@deprecate loadparams!(m, xs) loadmodel!(m, xs) - # v0.13 deprecations function Broadcast.broadcasted(f::Recur, args...) @@ -50,6 +48,16 @@ function Diagonal(size::Tuple; kw...) Scale(size...; kw...) end +# Deprecate this eventually once saving models w/o structure is no more +function loadparams!(m, xs) + Base.depwarn("loadparams! will be deprecated eventually. Use loadmodel! instead.", :loadparams!) + for (p, x) in zip(params(m), xs) + size(p) == size(x) || + error("Expected param size $(size(p)), got $(size(x))") + copyto!(p, x) + end +end + # Channel notation: Changed to match Conv, but very softly deprecated! # Perhaps change to @deprecate for v0.14, but there is no plan to remove these. Dense(in::Integer, out::Integer, σ = identity; kw...) = diff --git a/src/loading.jl b/src/loading.jl index 19bb3293c4..867f213594 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -1,31 +1,42 @@ -""" - loadleaf!(dst, src, err) - -Copy `src` to `dst` or throw `err` when their sizes are mismatched. -By default, use `copyto!` when `dst` and `src` are arrays. -When only `dst` is an array, set every element to `src`. -Otherwise, just return `dst`. -""" loadleaf!(dst, src, err) = dst -function loadleaf!(dst::AbstractArray, src, err) - dst .= src +function loadleaf!(dst::AbstractArray, src::Bool, err) + if iszero(src) + dst .= src + else + error("Cannot copy boolean parameter == true to non-zero parameter.") + end return dst end +loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst : + error("Cannot copy non-zero parameter to boolean parameter == true.") function loadleaf!(dst::AbstractArray, src::AbstractArray, err) (size(dst) == size(src)) || throw(err) copyto!(dst, src) end +_parent(x) = x +_parent(x::AbstractArray) = parent(x) + +_tie_check(dst::AbstractArray, src::AbstractArray) = dst == src +_tie_check(dst, src) = true + +_bool_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) +_bool_tie_check(dst::AbstractArray, src::Bool) = iszero(dst) && iszero(src) +_bool_tie_check(dst, src) = true + """ loadmodel!(dst, src) Copy all the parameters (trainable and non-trainable) from `src` to `dst`. `loadmodel!` recursively walks the [`Functors.children`](@ref) of `dst` and `src` -calling `loadleaf!` on any pair of children where [`Functors.isleaf`](@ref) is true. +calling `copyto!` on any pair of children where [`Functors.isleaf`](@ref) is true. +It also handles "absent" parameters such as `bias == false`. It throws an error whenever: - `dst` and `src` do not share the same fields (at any level) - the sizes of leaf nodes are mismatched between `dst` and `src` +- `dst` is a "tied" parameter (e.g. `transpose` of another parameter) and + loaded into multiple times with mismatched source values ```julia julia> using Flux: loadmodel! @@ -56,22 +67,8 @@ true julia> dst[2].bias == src[2].bias true ``` - -See [`Flux.loadleaf!`](@ref) for more details on the copy behavior. - -!!! warning - This function allows `src` to be a `Params` for backwards-compatibility. - You should avoid using `loadmodel!` this way, because it skips most of the structural - checking used when `src` is also a nested structure. Silent errors may occur. """ -function loadmodel!(m, xs::Params) - for (p, x) in zip(params(m), xs) - size(p) == size(x) || - error("Expected param size $(size(p)), got $(size(x))") - copyto!(p, x) - end -end -function loadmodel!(dst, src) +function loadmodel!(dst, src; cache = Base.IdSet()) ldsts, _ = functor(dst) lsrcs, _ = functor(src) (keys(ldsts) == keys(lsrcs)) || @@ -79,7 +76,20 @@ function loadmodel!(dst, src) err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.") foreach(ldsts, lsrcs) do ldst, lsrc - Functors.isleaf(ldst) ? loadleaf!(ldst, lsrc, err) : loadmodel!(ldst, lsrc) + if _parent(ldst) in cache # we already loaded this parameter before + if !_bool_tie_check(ldst, lsrc) # special case to handle tied + absent parameters + error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") + elseif _tie_check(ldst, lsrc) # the arrays match and we already loaded (or these are not arrays) + return ldst + else # tied dst but mismatched src case + error("Encountered tied destination parameters with untied and mismatched sources.") + end + elseif Functors.isleaf(ldst) # our first time loading this leaf + push!(cache, ldst) + loadleaf!(ldst, lsrc, err) + else # this isn't a leaf + loadmodel!(ldst, lsrc; cache = cache) + end end return dst diff --git a/test/utils.jl b/test/utils.jl index 0136ba49ff..e4637b21c1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -2,7 +2,7 @@ using Flux using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, sparse_init, identity_init, stack, unstack, batch, unbatch, - unsqueeze, params, loadmodel! + unsqueeze, params, loadparams!, loadmodel! using StatsBase: var, std using Statistics, LinearAlgebra using Random @@ -366,26 +366,24 @@ end @test_skip typeof(l1.bias) === typeof(l2.bias) end - @testset "loadmodel!" begin + @testset "loadparams!" begin pars(w, b) = [w, b] pars(l) = pars(l.weight, l.bias) pararray(m) = mapreduce(pars, vcat, m) weights(m) = mapreduce(l -> [l.weight], vcat, m) @testset "Bias type $bt" for bt in (Flux.zeros32, nobias) m = dm(bt) - Flux.loadmodel!(m, params(m)) + Flux.loadparams!(m, params(m)) testdense(m, bt) end end @testset "loadmodel!(dst, src)" begin - import Flux: loadmodel!, Zeros - m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) m2 = Chain(Dense(10, 5), Dense(5, 2)) m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) m4 = Chain(Dense(10, 6), Dense(6, 2)) - m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), Zeros()), Dense(5, 2))) + m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2))) m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) loadmodel!(m1, m2) @@ -408,6 +406,7 @@ end # size mismatches throw an error @test_throws DimensionMismatch loadmodel!(m1, m4) + # tests for BatchNorm and Dropout m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2)) m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1)) m2[2].μ .= rand(Float32, size(m2[2].μ)...) @@ -484,16 +483,32 @@ end @test chain2[2].bias != chain1[2].bias # test shared weights - m1 = Chain(Dense(10 => 5), Dense(5 => 2)) - m2 = Chain(Dense(transpose(m1[2].weight)), Dense(permutedims(m1[1].weight))) - m3 = Chain(Dense(m1[1].weight), Dense(m1[2].weight)) - m2[2].weight .= 1f0 - loadmodel!(m1, m3) - @test m1[2].weight === parent(m2[1].weight) - @test m1[2].weight == transpose(m2[1].weight) - @test m1[1].weight === m3[1].weight - @test m2[2].weight != transpose(m1[1].weight) - @test m3[2].weight == transpose(m2[1].weight) + encoder_dst = Chain(Dense(10 => 5), Dense(5 => 2)) + decoder_dst = Chain(Dense(transpose(encoder_dst[2].weight)), + Dense(permutedims(encoder_dst[1].weight))) + encoder_src = Chain(Dense(10 => 5), Dense(5 => 2)) + decoder_src = Chain(Dense(transpose(encoder_src[2].weight)), + Dense(5 => 10)) + # matched weights are okay + m1 = Chain(encoder_dst, decoder_dst) + m2 = Chain(encoder_src, decoder_src) + loadmodel!(m1, m2) + @test m1[1][2].weight === parent(m1[2][1].weight) + @test m1[1][1].weight == m2[1][1].weight + @test m1[1][1].weight != permutedims(m1[2][2].weight) + # mismatched weights are an error + m2 = Chain(Chain(Dense(10 => 5), Dense(5 => 2)), Chain(Dense(2 => 5), Dense(5 => 10))) + @test_throws ErrorException loadmodel!(m1, m2) + # loading into tied weights with absent parameter is okay when the dst == zero + b = Flux.zeros32(5) + m1 = Chain(Dense(10 => 5; bias = b), Dense(5 => 5; bias = b)) + m2 = Chain(Dense(10 => 5; bias = Flux.zeros32(5)), Dense(5 => 5; bias = false)) + loadmodel!(m1, m2) + @test m1[1].bias === m1[2].bias + @test iszero(m1[1].bias) + # loading into tied weights with absent parameter is bad when the dst != zero + m2[1].bias .= 1 + @test_throws ErrorException loadmodel!(m1, m2) end @testset "destructure" begin From c831955fd678431fb373e30bad141c0a8cf80d4a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 3 Apr 2022 14:05:27 -0500 Subject: [PATCH 12/19] Combine `_bool_tie_check` and `_tie_check`. --- src/loading.jl | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index 867f213594..8d277693a1 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -17,11 +17,14 @@ end _parent(x) = x _parent(x::AbstractArray) = parent(x) -_tie_check(dst::AbstractArray, src::AbstractArray) = dst == src +_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) || + error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") +_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) || + error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") +_tie_check(dst::AbstractArray, src::AbstractArray) = (dst == src) || + error("Encountered tied destination parameters with untied and mismatched sources.") _tie_check(dst, src) = true -_bool_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) -_bool_tie_check(dst::AbstractArray, src::Bool) = iszero(dst) && iszero(src) _bool_tie_check(dst, src) = true """ @@ -77,13 +80,7 @@ function loadmodel!(dst, src; cache = Base.IdSet()) err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.") foreach(ldsts, lsrcs) do ldst, lsrc if _parent(ldst) in cache # we already loaded this parameter before - if !_bool_tie_check(ldst, lsrc) # special case to handle tied + absent parameters - error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") - elseif _tie_check(ldst, lsrc) # the arrays match and we already loaded (or these are not arrays) - return ldst - else # tied dst but mismatched src case - error("Encountered tied destination parameters with untied and mismatched sources.") - end + _tie_check(ldst, lsrc) && return ldst elseif Functors.isleaf(ldst) # our first time loading this leaf push!(cache, ldst) loadleaf!(ldst, lsrc, err) From 0d55c00d0d012ae829eafcc0ba3296d841bec138 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 3 Apr 2022 14:30:13 -0500 Subject: [PATCH 13/19] Remove `_parent` --- src/loading.jl | 5 +---- test/utils.jl | 19 +++++++------------ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index 8d277693a1..fe8cc03596 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -14,9 +14,6 @@ function loadleaf!(dst::AbstractArray, src::AbstractArray, err) copyto!(dst, src) end -_parent(x) = x -_parent(x::AbstractArray) = parent(x) - _tie_check(dst::Bool, src::AbstractArray) = iszero(dst) || error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") _tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) || @@ -79,7 +76,7 @@ function loadmodel!(dst, src; cache = Base.IdSet()) err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.") foreach(ldsts, lsrcs) do ldst, lsrc - if _parent(ldst) in cache # we already loaded this parameter before + if ldst in cache # we already loaded this parameter before _tie_check(ldst, lsrc) && return ldst elseif Functors.isleaf(ldst) # our first time loading this leaf push!(cache, ldst) diff --git a/test/utils.jl b/test/utils.jl index e4637b21c1..054589755b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -483,21 +483,16 @@ end @test chain2[2].bias != chain1[2].bias # test shared weights - encoder_dst = Chain(Dense(10 => 5), Dense(5 => 2)) - decoder_dst = Chain(Dense(transpose(encoder_dst[2].weight)), - Dense(permutedims(encoder_dst[1].weight))) - encoder_src = Chain(Dense(10 => 5), Dense(5 => 2)) - decoder_src = Chain(Dense(transpose(encoder_src[2].weight)), - Dense(5 => 10)) + shared_dst = Dense(10 => 10) + shared_src = Dense(10 => 10) # matched weights are okay - m1 = Chain(encoder_dst, decoder_dst) - m2 = Chain(encoder_src, decoder_src) + m1 = Chain(shared_dst, Dense(shared_dst.weight)) + m2 = Chain(shared_src, Dense(shared_src.weight)) loadmodel!(m1, m2) - @test m1[1][2].weight === parent(m1[2][1].weight) - @test m1[1][1].weight == m2[1][1].weight - @test m1[1][1].weight != permutedims(m1[2][2].weight) + @test m1[1].weight === m1[2].weight + @test m1[1].weight == m2[2].weight # mismatched weights are an error - m2 = Chain(Chain(Dense(10 => 5), Dense(5 => 2)), Chain(Dense(2 => 5), Dense(5 => 10))) + m2 = Chain(Dense(10 => 10), Dense(10 => 10)) @test_throws ErrorException loadmodel!(m1, m2) # loading into tied weights with absent parameter is okay when the dst == zero b = Flux.zeros32(5) From 3c824715a3b8a43db82b6ac20a895b898b19b977 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 4 Apr 2022 08:05:33 -0500 Subject: [PATCH 14/19] Apply suggestions from code review Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/loading.jl | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index fe8cc03596..5aabbed9c1 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -27,12 +27,14 @@ _bool_tie_check(dst, src) = true """ loadmodel!(dst, src) -Copy all the parameters (trainable and non-trainable) from `src` to `dst`. +Copy all the parameters (trainable and non-trainable) from `src` into `dst`. -`loadmodel!` recursively walks the [`Functors.children`](@ref) of `dst` and `src` -calling `copyto!` on any pair of children where [`Functors.isleaf`](@ref) is true. -It also handles "absent" parameters such as `bias == false`. -It throws an error whenever: +Recursively walks `dst` and `src` together using [`Functors.children`](@ref), +and calling `copyto!` on parameter arrays. +Non-array elements (such as activation functions) need not match. +An all-zero bias array can be copied to or from absent bias, encoded `bias = false`. + +Throws an error when: - `dst` and `src` do not share the same fields (at any level) - the sizes of leaf nodes are mismatched between `dst` and `src` - `dst` is a "tied" parameter (e.g. `transpose` of another parameter) and @@ -52,13 +54,9 @@ julia> src = Chain(Dense(5 => 2), Dense(2 => 1)); julia> all(isone, dst[1].weight) true -julia> dst = loadmodel!(dst, src) -Chain( - Dense(5 => 2), # 12 parameters - Dense(2 => 1), # 3 parameters -) # Total: 4 arrays, 15 parameters, 316 bytes. +julia> loadmodel!(dst, src); -julia> all(isone, dst[1].weight) +julia> dst[1].weight ≈ ones(2, 5) false julia> dst[1].weight == src[1].weight From 9b067307bd71f6c692eb822d7a83e6712fa9b3df Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 4 Apr 2022 08:17:16 -0500 Subject: [PATCH 15/19] Clarify docstrings as per review --- src/loading.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index 5aabbed9c1..f985087836 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -31,13 +31,16 @@ Copy all the parameters (trainable and non-trainable) from `src` into `dst`. Recursively walks `dst` and `src` together using [`Functors.children`](@ref), and calling `copyto!` on parameter arrays. -Non-array elements (such as activation functions) need not match. -An all-zero bias array can be copied to or from absent bias, encoded `bias = false`. +Non-array elements (such as activation functions) are not copied +and do not need to match between `dst` and `src`. +Inactive parameters, encoded by `false` in place on an array, +can be copied to and from all-zero arrays. +Attempting to copy a non-zero array to/from an inactive parameter will throw an error. Throws an error when: - `dst` and `src` do not share the same fields (at any level) - the sizes of leaf nodes are mismatched between `dst` and `src` -- `dst` is a "tied" parameter (e.g. `transpose` of another parameter) and +- `dst` is a "tied" parameter (i.e. refers to another parameter) and loaded into multiple times with mismatched source values ```julia From cba299bb6189d039883968df54c504856ecc0001 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 4 Apr 2022 08:39:19 -0500 Subject: [PATCH 16/19] More clarification --- src/loading.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index f985087836..01090d95b0 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -33,9 +33,10 @@ Recursively walks `dst` and `src` together using [`Functors.children`](@ref), and calling `copyto!` on parameter arrays. Non-array elements (such as activation functions) are not copied and do not need to match between `dst` and `src`. -Inactive parameters, encoded by `false` in place on an array, -can be copied to and from all-zero arrays. -Attempting to copy a non-zero array to/from an inactive parameter will throw an error. +Inactive parameters can be encoded by using the boolean value `false` instead of an array. +If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); +however, attempting to copy a non-zero array to an inactive parameter will throw an error. +Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error. Throws an error when: - `dst` and `src` do not share the same fields (at any level) From a59f68874f03874f9a3f86c06ab37a2a7c0db6e9 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 4 Apr 2022 10:32:40 -0500 Subject: [PATCH 17/19] Use extended help for `loadmodel!` docstring --- src/loading.jl | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index 01090d95b0..a4f33a3304 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -1,4 +1,8 @@ loadleaf!(dst, src, err) = dst +loadleaf!(dst::AbstractArray, src, err) = + error("Tried to copy $src into an array destination; this is not allowed.") +loadleaf!(dst, src::AbstractArray, err) = + error("Tried to copy an array to $dst; this is not allowed.") function loadleaf!(dst::AbstractArray, src::Bool, err) if iszero(src) dst .= src @@ -30,20 +34,12 @@ _bool_tie_check(dst, src) = true Copy all the parameters (trainable and non-trainable) from `src` into `dst`. Recursively walks `dst` and `src` together using [`Functors.children`](@ref), -and calling `copyto!` on parameter arrays. -Non-array elements (such as activation functions) are not copied -and do not need to match between `dst` and `src`. -Inactive parameters can be encoded by using the boolean value `false` instead of an array. -If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); -however, attempting to copy a non-zero array to an inactive parameter will throw an error. -Likewise, copying `src == false` to any `dst` array is valid, but copying `src == true` will error. - -Throws an error when: -- `dst` and `src` do not share the same fields (at any level) -- the sizes of leaf nodes are mismatched between `dst` and `src` -- `dst` is a "tied" parameter (i.e. refers to another parameter) and - loaded into multiple times with mismatched source values +and calling `copyto!` on parameter arrays or throwing an error when there is a mismatch. +Non-array elements (such as activation functions) are not copied and need not match. +Zero bias vectors and `bias=false` are considered equivalent +(see extended help for more details). +# Examples ```julia julia> using Flux: loadmodel! @@ -69,6 +65,22 @@ true julia> dst[2].bias == src[2].bias true ``` + +# Extended help + +Throws an error when: +- `dst` and `src` do not share the same fields (at any level) +- the sizes of leaf nodes are mismatched between `dst` and `src` +- copying non-array values to/from an array parameter + (except inactive parameters described below) +- `dst` is a "tied" parameter (i.e. refers to another parameter) and + loaded into multiple times with mismatched source values + +Inactive parameters can be encoded by using the boolean value `false` instead of an array. +If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); +however, attempting to copy a non-zero array to an inactive parameter will throw an error. +Likewise, copying a `src` value of `false` to any `dst` array is valid, +but copying a `src` value of `true` will error. """ function loadmodel!(dst, src; cache = Base.IdSet()) ldsts, _ = functor(dst) From d08072a4da5aa5c5002c6e0b43cb6cc67ff53d43 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 4 Apr 2022 12:57:39 -0500 Subject: [PATCH 18/19] Updated docstring examples for `loadmodel!` to cover more cases --- src/loading.jl | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index a4f33a3304..ece58d991c 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -41,28 +41,23 @@ Zero bias vectors and `bias=false` are considered equivalent # Examples ```julia -julia> using Flux: loadmodel! - -julia> dst = Chain(Dense(Flux.ones32(2, 5)), Dense(2 => 1)) +julia> dst = Chain(Dense(Flux.ones32(2, 5, tanh)), Dense(2 => 1; bias = [1f0])) Chain( - Dense(5 => 2), # 12 parameters + Dense(5 => 2, tanh), # 12 parameters Dense(2 => 1), # 3 parameters ) # Total: 4 arrays, 15 parameters, 316 bytes. -julia> src = Chain(Dense(5 => 2), Dense(2 => 1)); - -julia> all(isone, dst[1].weight) +julia> dst[1].weight ≈ ones(2, 5) # by construction true -julia> loadmodel!(dst, src); +julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false)); -julia> dst[1].weight ≈ ones(2, 5) -false +julia> Flux.loadmodel!(dst, src); -julia> dst[1].weight == src[1].weight -true +julia> dst[1].weight ≈ ones(2, 5) # values changed +false -julia> dst[2].bias == src[2].bias +julia> iszero(dst[2].bias) true ``` From 6b533b896e077d9c0e2951b2dc099a459bb4d461 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 4 Apr 2022 16:29:53 -0500 Subject: [PATCH 19/19] Don't do `loadleaf!` docstring --- docs/src/saving.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/src/saving.md b/docs/src/saving.md index 20cd5eaf36..6cfe6648a8 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -36,7 +36,6 @@ Chain( Dense(5 => 2), # 12 parameters NNlib.softmax, ) # Total: 4 arrays, 67 parameters, 524 bytes. - ``` Models are just normal Julia structs, so it's fine to use any Julia storage @@ -85,7 +84,6 @@ This ensures that the model loaded from `"mymodel.bson"` matches the structure o ```@docs Flux.loadmodel! -Flux.loadleaf! ``` ## Checkpointing