Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a structural loadparams! #1875

Merged
merged 19 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ been removed in favour of MLDatasets.jl.
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
* Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights.
* The `Flux.Diagonal` layer is now called `Scale`, and accepts an activation function.
* `loadparams!` is replaced by [`loadmodel!`](https://github.com/FluxML/Flux.jl/pull/1875) which copies trainable + non-trainable parameters and performs more thorough structural checking

## v0.12.10
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Adapt = "3.0"
ArrayInterface = "3.1, 4, 5"
CUDA = "3"
ChainRulesCore = "1.12"
Functors = "0.2.1"
Functors = "0.2.8"
MLUtils = "0.2"
MacroTools = "0.5"
NNlib = "0.8.2"
Expand Down
42 changes: 21 additions & 21 deletions docs/src/saving.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

You may wish to save models so that they can be loaded and run in a later
session. The easiest way to do this is via
[BSON.jl](https://github.com/MikeInnes/BSON.jl).
[BSON.jl](https://github.com/JuliaIO/BSON.jl).

Save a model:

Expand Down Expand Up @@ -36,7 +36,6 @@ Chain(
Dense(5 => 2), # 12 parameters
NNlib.softmax,
) # Total: 4 arrays, 67 parameters, 524 bytes.

```

Models are just normal Julia structs, so it's fine to use any Julia storage
Expand All @@ -46,15 +45,17 @@ versions of Flux).

!!! note

If a saved model's weights are stored on the GPU, the model will not load
If a saved model's parameters are stored on the GPU, the model will not load
later on if there is no GPU support available. It's best to [move your model
to the CPU](gpu.md) with `cpu(model)` before saving it.

## Saving Model Weights
!!! warning

In some cases it may be useful to save only the model parameters themselves, and
rebuild the model architecture in your code. You can use `params(model)` to get
model parameters.
Previous versions of Flux suggested saving only the model weights using
`@save "mymodel.bson" params(model)`.
This is no longer recommended and even strongly discouraged.
Saving models this way will only store the trainable parameters which
will result in incorrect behavior for layers like `BatchNorm`.

```Julia
julia> using Flux
Expand All @@ -64,28 +65,27 @@ Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)

julia> weights = Flux.params(model);

julia> using BSON: @save

julia> @save "mymodel.bson" weights
```

You can easily load parameters back into a model with `Flux.loadparams!`.
Loading the model as shown above will return a new model with the stored parameters.
But sometimes you already have a model, and you want to load stored parameters into it.
This can be done as

```julia
julia> using Flux
using Flux: loadmodel!
using BSON: @load

julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
# some predefined model
model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)

julia> using BSON: @load
# load one model into another
model = loadmodel!(model, @load("mymodel.bson"))
```

julia> @load "mymodel.bson" weights
This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory.

julia> Flux.loadparams!(model, weights)
```@docs
Flux.loadmodel!
```

The new `model` we created will now be identical to the one we saved parameters for.

## Checkpointing

In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md).
Expand Down
2 changes: 2 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/show.jl")

include("loading.jl")

include("outputsize.jl")

include("data/Data.jl")
Expand Down
10 changes: 10 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ function Diagonal(size::Tuple; kw...)
Scale(size...; kw...)
end

# Deprecate this eventually once saving models w/o structure is no more
function loadparams!(m, xs)
Base.depwarn("loadparams! will be deprecated eventually. Use loadmodel! instead.", :loadparams!)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end

# Channel notation: Changed to match Conv, but very softly deprecated!
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
Dense(in::Integer, out::Integer, σ = identity; kw...) =
Expand Down
8 changes: 0 additions & 8 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,6 @@ function params(m...)
return ps
end

function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end

struct FluxCUDAAdaptor end
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
Expand Down
99 changes: 99 additions & 0 deletions src/loading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
loadleaf!(dst, src, err) = dst
Copy link
Member

@mcabbott mcabbott Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if there should be more errors here:

Suggested change
loadleaf!(dst, src, err) = dst
loadleaf!(dst, src, err) = dst
loadleaf!(dst::AbstractArray, src, err) = error()
loadleaf!(dst, src::AbstractArray, err) = error()

I can imagine that allowing src to have nothing means "don't change the existing weight". Which is what #1875 (comment) would generate. But it may also make truncations of branches not just leaves, which aren't allowed right now, but would I think be easy:

loadleaf!(dst, src::Nothing, err) = dst
loadleaf!(dst:: AbstractArray, src::Nothing, err) = dst

loadmodel!(dst, src::Nothing; cache = Base.IdSet()) = dst

loadleaf!(dst::AbstractArray, src, err) =
error("Tried to copy $src into an array destination; this is not allowed.")
loadleaf!(dst, src::AbstractArray, err) =
error("Tried to copy an array to $dst; this is not allowed.")
function loadleaf!(dst::AbstractArray, src::Bool, err)
if iszero(src)
dst .= src
else
error("Cannot copy boolean parameter == true to non-zero parameter.")
end
return dst
end
loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst :
error("Cannot copy non-zero parameter to boolean parameter == true.")
function loadleaf!(dst::AbstractArray, src::AbstractArray, err)
(size(dst) == size(src)) || throw(err)
copyto!(dst, src)
end

_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) ||
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) ||
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
_tie_check(dst::AbstractArray, src::AbstractArray) = (dst == src) ||
error("Encountered tied destination parameters with untied and mismatched sources.")
_tie_check(dst, src) = true

_bool_tie_check(dst, src) = true

"""
loadmodel!(dst, src)

Copy all the parameters (trainable and non-trainable) from `src` into `dst`.

Recursively walks `dst` and `src` together using [`Functors.children`](@ref),
and calling `copyto!` on parameter arrays or throwing an error when there is a mismatch.
Non-array elements (such as activation functions) are not copied and need not match.
Zero bias vectors and `bias=false` are considered equivalent
(see extended help for more details).

# Examples
```julia
julia> dst = Chain(Dense(Flux.ones32(2, 5, tanh)), Dense(2 => 1; bias = [1f0]))
Chain(
Dense(5 => 2, tanh), # 12 parameters
Dense(2 => 1), # 3 parameters
) # Total: 4 arrays, 15 parameters, 316 bytes.

julia> dst[1].weight ≈ ones(2, 5) # by construction
true

julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false));

julia> Flux.loadmodel!(dst, src);

julia> dst[1].weight ≈ ones(2, 5) # values changed
false

julia> iszero(dst[2].bias)
true
```

# Extended help

Throws an error when:
- `dst` and `src` do not share the same fields (at any level)
- the sizes of leaf nodes are mismatched between `dst` and `src`
- copying non-array values to/from an array parameter
(except inactive parameters described below)
- `dst` is a "tied" parameter (i.e. refers to another parameter) and
loaded into multiple times with mismatched source values

Inactive parameters can be encoded by using the boolean value `false` instead of an array.
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied);
however, attempting to copy a non-zero array to an inactive parameter will throw an error.
Likewise, copying a `src` value of `false` to any `dst` array is valid,
but copying a `src` value of `true` will error.
"""
function loadmodel!(dst, src; cache = Base.IdSet())
ldsts, _ = functor(dst)
lsrcs, _ = functor(src)
(keys(ldsts) == keys(lsrcs)) ||
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))

err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
foreach(ldsts, lsrcs) do ldst, lsrc
if ldst in cache # we already loaded this parameter before
_tie_check(ldst, lsrc) && return ldst
elseif Functors.isleaf(ldst) # our first time loading this leaf
push!(cache, ldst)
loadleaf!(ldst, lsrc, err)
else # this isn't a leaf
loadmodel!(ldst, lsrc; cache = cache)
end
end

return dst
end
Loading