diff --git a/NEWS.md b/NEWS.md index 9c4907dc33..d97c94fea5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -12,6 +12,7 @@ been removed in favour of MLDatasets.jl. * The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`. * Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights. * The `Flux.Diagonal` layer is now called `Scale`, and accepts an activation function. +* `loadparams!` is replaced by [`loadmodel!`](https://github.com/FluxML/Flux.jl/pull/1875) which copies trainable + non-trainable parameters and performs more thorough structural checking ## v0.12.10 * `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838) diff --git a/Project.toml b/Project.toml index bf2bc008de..adfa20e719 100644 --- a/Project.toml +++ b/Project.toml @@ -29,7 +29,7 @@ Adapt = "3.0" ArrayInterface = "3.1, 4, 5" CUDA = "3" ChainRulesCore = "1.12" -Functors = "0.2.1" +Functors = "0.2.8" MLUtils = "0.2" MacroTools = "0.5" NNlib = "0.8.2" diff --git a/docs/src/saving.md b/docs/src/saving.md index d9db750d1e..6cfe6648a8 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -2,7 +2,7 @@ You may wish to save models so that they can be loaded and run in a later session. The easiest way to do this is via -[BSON.jl](https://github.com/MikeInnes/BSON.jl). +[BSON.jl](https://github.com/JuliaIO/BSON.jl). Save a model: @@ -36,7 +36,6 @@ Chain( Dense(5 => 2), # 12 parameters NNlib.softmax, ) # Total: 4 arrays, 67 parameters, 524 bytes. - ``` Models are just normal Julia structs, so it's fine to use any Julia storage @@ -46,15 +45,17 @@ versions of Flux). !!! note - If a saved model's weights are stored on the GPU, the model will not load + If a saved model's parameters are stored on the GPU, the model will not load later on if there is no GPU support available. It's best to [move your model to the CPU](gpu.md) with `cpu(model)` before saving it. -## Saving Model Weights +!!! warning -In some cases it may be useful to save only the model parameters themselves, and -rebuild the model architecture in your code. You can use `params(model)` to get -model parameters. + Previous versions of Flux suggested saving only the model weights using + `@save "mymodel.bson" params(model)`. + This is no longer recommended and even strongly discouraged. + Saving models this way will only store the trainable parameters which + will result in incorrect behavior for layers like `BatchNorm`. ```Julia julia> using Flux @@ -64,28 +65,27 @@ Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax) julia> weights = Flux.params(model); -julia> using BSON: @save - -julia> @save "mymodel.bson" weights -``` - -You can easily load parameters back into a model with `Flux.loadparams!`. +Loading the model as shown above will return a new model with the stored parameters. +But sometimes you already have a model, and you want to load stored parameters into it. +This can be done as ```julia -julia> using Flux +using Flux: loadmodel! +using BSON: @load -julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax) -Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax) +# some predefined model +model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax) -julia> using BSON: @load +# load one model into another +model = loadmodel!(model, @load("mymodel.bson")) +``` -julia> @load "mymodel.bson" weights +This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory. -julia> Flux.loadparams!(model, weights) +```@docs +Flux.loadmodel! ``` -The new `model` we created will now be identical to the one we saved parameters for. - ## Checkpointing In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md). diff --git a/src/Flux.jl b/src/Flux.jl index aa3f021595..0458dc6cea 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -46,6 +46,8 @@ include("layers/normalise.jl") include("layers/upsample.jl") include("layers/show.jl") +include("loading.jl") + include("outputsize.jl") include("data/Data.jl") diff --git a/src/deprecations.jl b/src/deprecations.jl index 597fc5a913..118e25dbb5 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -48,6 +48,16 @@ function Diagonal(size::Tuple; kw...) Scale(size...; kw...) end +# Deprecate this eventually once saving models w/o structure is no more +function loadparams!(m, xs) + Base.depwarn("loadparams! will be deprecated eventually. Use loadmodel! instead.", :loadparams!) + for (p, x) in zip(params(m), xs) + size(p) == size(x) || + error("Expected param size $(size(p)), got $(size(x))") + copyto!(p, x) + end +end + # Channel notation: Changed to match Conv, but very softly deprecated! # Perhaps change to @deprecate for v0.14, but there is no plan to remove these. Dense(in::Integer, out::Integer, σ = identity; kw...) = diff --git a/src/functor.jl b/src/functor.jl index ebd6dd1102..905a37b8d6 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -85,14 +85,6 @@ function params(m...) return ps end -function loadparams!(m, xs) - for (p, x) in zip(params(m), xs) - size(p) == size(x) || - error("Expected param size $(size(p)), got $(size(x))") - copyto!(p, x) - end -end - struct FluxCUDAAdaptor end adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) diff --git a/src/loading.jl b/src/loading.jl new file mode 100644 index 0000000000..ece58d991c --- /dev/null +++ b/src/loading.jl @@ -0,0 +1,99 @@ +loadleaf!(dst, src, err) = dst +loadleaf!(dst::AbstractArray, src, err) = + error("Tried to copy $src into an array destination; this is not allowed.") +loadleaf!(dst, src::AbstractArray, err) = + error("Tried to copy an array to $dst; this is not allowed.") +function loadleaf!(dst::AbstractArray, src::Bool, err) + if iszero(src) + dst .= src + else + error("Cannot copy boolean parameter == true to non-zero parameter.") + end + return dst +end +loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst : + error("Cannot copy non-zero parameter to boolean parameter == true.") +function loadleaf!(dst::AbstractArray, src::AbstractArray, err) + (size(dst) == size(src)) || throw(err) + copyto!(dst, src) +end + +_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) || + error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") +_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) || + error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.") +_tie_check(dst::AbstractArray, src::AbstractArray) = (dst == src) || + error("Encountered tied destination parameters with untied and mismatched sources.") +_tie_check(dst, src) = true + +_bool_tie_check(dst, src) = true + +""" + loadmodel!(dst, src) + +Copy all the parameters (trainable and non-trainable) from `src` into `dst`. + +Recursively walks `dst` and `src` together using [`Functors.children`](@ref), +and calling `copyto!` on parameter arrays or throwing an error when there is a mismatch. +Non-array elements (such as activation functions) are not copied and need not match. +Zero bias vectors and `bias=false` are considered equivalent +(see extended help for more details). + +# Examples +```julia +julia> dst = Chain(Dense(Flux.ones32(2, 5, tanh)), Dense(2 => 1; bias = [1f0])) +Chain( + Dense(5 => 2, tanh), # 12 parameters + Dense(2 => 1), # 3 parameters +) # Total: 4 arrays, 15 parameters, 316 bytes. + +julia> dst[1].weight ≈ ones(2, 5) # by construction +true + +julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false)); + +julia> Flux.loadmodel!(dst, src); + +julia> dst[1].weight ≈ ones(2, 5) # values changed +false + +julia> iszero(dst[2].bias) +true +``` + +# Extended help + +Throws an error when: +- `dst` and `src` do not share the same fields (at any level) +- the sizes of leaf nodes are mismatched between `dst` and `src` +- copying non-array values to/from an array parameter + (except inactive parameters described below) +- `dst` is a "tied" parameter (i.e. refers to another parameter) and + loaded into multiple times with mismatched source values + +Inactive parameters can be encoded by using the boolean value `false` instead of an array. +If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied); +however, attempting to copy a non-zero array to an inactive parameter will throw an error. +Likewise, copying a `src` value of `false` to any `dst` array is valid, +but copying a `src` value of `true` will error. +""" +function loadmodel!(dst, src; cache = Base.IdSet()) + ldsts, _ = functor(dst) + lsrcs, _ = functor(src) + (keys(ldsts) == keys(lsrcs)) || + throw(ArgumentError("Tried to load $src into $dst but the structures do not match.")) + + err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.") + foreach(ldsts, lsrcs) do ldst, lsrc + if ldst in cache # we already loaded this parameter before + _tie_check(ldst, lsrc) && return ldst + elseif Functors.isleaf(ldst) # our first time loading this leaf + push!(cache, ldst) + loadleaf!(ldst, lsrc, err) + else # this isn't a leaf + loadmodel!(ldst, lsrc; cache = cache) + end + end + + return dst +end diff --git a/test/utils.jl b/test/utils.jl index 14b5ad9bbc..054589755b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -2,7 +2,7 @@ using Flux using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, sparse_init, identity_init, stack, unstack, batch, unbatch, - unsqueeze, params, loadparams! + unsqueeze, params, loadparams!, loadmodel! using StatsBase: var, std using Statistics, LinearAlgebra using Random @@ -373,11 +373,139 @@ end weights(m) = mapreduce(l -> [l.weight], vcat, m) @testset "Bias type $bt" for bt in (Flux.zeros32, nobias) m = dm(bt) - loadparams!(m, params(m)) + Flux.loadparams!(m, params(m)) testdense(m, bt) end end + @testset "loadmodel!(dst, src)" begin + m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) + m2 = Chain(Dense(10, 5), Dense(5, 2)) + m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) + m4 = Chain(Dense(10, 6), Dense(6, 2)) + m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2))) + m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) + + loadmodel!(m1, m2) + # trainable parameters copy over + @test m1[1].weight == m2[1].weight + @test m1[1].bias == m2[1].bias + # non-array leaves are untouched + @test m1[2].σ == relu + + loadmodel!(m5, m6) + # more complex nested structures also work + @test m5[1].weight == m6[1].weight + @test m5[2][1].weight == m6[2][1].weight + # false bias is not overwritten + @test m5[2][1].bias == false + + # mismatched nodes throw an error + @test_throws ArgumentError loadmodel!(m1, m3) + @test_throws ArgumentError loadmodel!(m1, m5) + # size mismatches throw an error + @test_throws DimensionMismatch loadmodel!(m1, m4) + + # tests for BatchNorm and Dropout + m1 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), Flux.flatten, Dropout(0.2)) + m2 = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), x -> reshape(x, :, size(x)[end]), Dropout(0.1)) + m2[2].μ .= rand(Float32, size(m2[2].μ)...) + loadmodel!(m1, m2) + # non-trainable parameters are copied as well + @test m1[2].μ == m2[2].μ + # functions are not copied + @test m1[3] == Flux.flatten + # dropout rate is not copied + @test m1[4].p == 0.2 + + # from LegolasFlux (https://github.com/beacon-biosignals/LegolasFlux.jl/blob/80569ab63a8248a8a063c76e0bbf701f4ada9bd4/examples/digits.jl#L33) + # tests Chain(...) vs Chain([...]) + # tests MaxPool + # tests testmode!/trainmode! is not copied + # tests Dense, Conv, BatchNorm, Dropout (like above) but in a bigger model + chain1 = Chain(Dropout(0.2), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((2, 2)), + Dropout(0.2), + Conv((3, 3), 32 => 16, relu), + Dropout(0.2), + MaxPool((2, 2)), + Dropout(0.2), + Conv((3, 3), 16 => 10, relu), + Dropout(0.2), + x -> reshape(x, :, size(x, 4)), + Dropout(0.2), + Dense(90, 10), + softmax) + chain2 = Chain([Dropout(0.1), + Conv((3, 3), 1 => 32, relu), + BatchNorm(32, relu), + MaxPool((3, 3)), + Dropout(0.1), + Conv((3, 3), 32 => 16, relu), + Dropout(0.1), + MaxPool((3, 3)), + Dropout(0.1), + Conv((3, 3), 16 => 10, relu), + Dropout(0.1), + x -> reshape(x, :, size(x, 4)), + Dropout(0.1), + Dense(90, 10), + softmax]) + chain2[3].μ .= 5f0 + chain2[3].σ² .= 2f0 + testmode!(chain2) + loadmodel!(chain1, chain2) + for (dst, src) in zip(chain1, chain2) + if dst isa Dropout + @test dst.p == 0.2 + elseif dst isa Union{Conv, Dense} + @test dst.weight == src.weight + @test dst.bias == src.bias + elseif dst isa MaxPool + @test dst.k == (2, 2) + elseif dst isa BatchNorm + @test dst.μ == src.μ + @test dst.σ² == src.σ² + @test isnothing(dst.active) + end + end + + # copy only a subset of the model + chain1[end - 1].weight .= 1f0 + chain1[3].μ .= 3f0 + chain1[2].bias .= 5f0 + loadmodel!(chain2[end - 1], chain1[end - 1]) + loadmodel!(chain2[3], chain1[3]) + @test chain2[end - 1].weight == chain1[end - 1].weight + @test chain2[3].μ == chain1[3].μ + @test chain2[2].bias != chain1[2].bias + + # test shared weights + shared_dst = Dense(10 => 10) + shared_src = Dense(10 => 10) + # matched weights are okay + m1 = Chain(shared_dst, Dense(shared_dst.weight)) + m2 = Chain(shared_src, Dense(shared_src.weight)) + loadmodel!(m1, m2) + @test m1[1].weight === m1[2].weight + @test m1[1].weight == m2[2].weight + # mismatched weights are an error + m2 = Chain(Dense(10 => 10), Dense(10 => 10)) + @test_throws ErrorException loadmodel!(m1, m2) + # loading into tied weights with absent parameter is okay when the dst == zero + b = Flux.zeros32(5) + m1 = Chain(Dense(10 => 5; bias = b), Dense(5 => 5; bias = b)) + m2 = Chain(Dense(10 => 5; bias = Flux.zeros32(5)), Dense(5 => 5; bias = false)) + loadmodel!(m1, m2) + @test m1[1].bias === m1[2].bias + @test iszero(m1[1].bias) + # loading into tied weights with absent parameter is bad when the dst != zero + m2[1].bias .= 1 + @test_throws ErrorException loadmodel!(m1, m2) + end + @testset "destructure" begin import Flux: destructure @testset "Bias type $bt" for bt in (zeros, nobias) @@ -397,22 +525,22 @@ end end end -@testset "loadparams! & absent bias" begin +@testset "loadmodel! & absent bias" begin m0 = Chain(Dense(2 => 3; bias=false, init = Flux.ones32), Dense(3 => 1)) m1 = Chain(Dense(2 => 3; bias = Flux.randn32(3)), Dense(3 => 1)) m2 = Chain(Dense(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]), Dense(3 => 1)) - Flux.loadparams!(m1, Flux.params(m2)) + Flux.loadmodel!(m1, m2) @test m1[1].bias == 7:9 @test sum(m1[1].weight) == 21 # load from a model without bias -- should ideally recognise the `false` but `Params` doesn't store it - @test_broken Flux.loadparams!(m1, Flux.params(m0)) - @test_broken iszero(m1[1].bias) + m1 = Flux.loadmodel!(m1, m0) + @test iszero(m1[1].bias) @test sum(m1[1].weight) == 6 # written before error # load into a model without bias -- should it ignore the parameter which has no home, or error? - @test_broken Flux.loadparams!(m0, Flux.params(m2)) + m0 = Flux.loadmodel!(m0, m2) @test iszero(m0[1].bias) # obviously unchanged @test sum(m0[1].weight) == 21 end