Skip to content

Commit

Permalink
Merge pull request #1875 from darsnack/load-structured
Browse files Browse the repository at this point in the history
Add a structural `loadparams!`
  • Loading branch information
darsnack authored Apr 5, 2022
2 parents 5f17f1c + 6b533b8 commit 674527e
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 37 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ been removed in favour of MLDatasets.jl.
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
* Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights.
* The `Flux.Diagonal` layer is now called `Scale`, and accepts an activation function.
* `loadparams!` is replaced by [`loadmodel!`](https://github.com/FluxML/Flux.jl/pull/1875) which copies trainable + non-trainable parameters and performs more thorough structural checking

## v0.12.10
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Adapt = "3.0"
ArrayInterface = "3.1, 4, 5"
CUDA = "3"
ChainRulesCore = "1.12"
Functors = "0.2.1"
Functors = "0.2.8"
MLUtils = "0.2"
MacroTools = "0.5"
NNlib = "0.8.2"
Expand Down
42 changes: 21 additions & 21 deletions docs/src/saving.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

You may wish to save models so that they can be loaded and run in a later
session. The easiest way to do this is via
[BSON.jl](https://github.com/MikeInnes/BSON.jl).
[BSON.jl](https://github.com/JuliaIO/BSON.jl).

Save a model:

Expand Down Expand Up @@ -36,7 +36,6 @@ Chain(
Dense(5 => 2), # 12 parameters
NNlib.softmax,
) # Total: 4 arrays, 67 parameters, 524 bytes.

```

Models are just normal Julia structs, so it's fine to use any Julia storage
Expand All @@ -46,15 +45,17 @@ versions of Flux).

!!! note

If a saved model's weights are stored on the GPU, the model will not load
If a saved model's parameters are stored on the GPU, the model will not load
later on if there is no GPU support available. It's best to [move your model
to the CPU](gpu.md) with `cpu(model)` before saving it.

## Saving Model Weights
!!! warning

In some cases it may be useful to save only the model parameters themselves, and
rebuild the model architecture in your code. You can use `params(model)` to get
model parameters.
Previous versions of Flux suggested saving only the model weights using
`@save "mymodel.bson" params(model)`.
This is no longer recommended and even strongly discouraged.
Saving models this way will only store the trainable parameters which
will result in incorrect behavior for layers like `BatchNorm`.

```Julia
julia> using Flux
Expand All @@ -64,28 +65,27 @@ Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)

julia> weights = Flux.params(model);

julia> using BSON: @save

julia> @save "mymodel.bson" weights
```

You can easily load parameters back into a model with `Flux.loadparams!`.
Loading the model as shown above will return a new model with the stored parameters.
But sometimes you already have a model, and you want to load stored parameters into it.
This can be done as

```julia
julia> using Flux
using Flux: loadmodel!
using BSON: @load
julia> model = Chain(Dense(10 => 5,relu),Dense(5 => 2),softmax)
Chain(Dense(10, 5, NNlib.relu), Dense(5, 2), NNlib.softmax)
# some predefined model
model = Chain(Dense(10 => 5, relu), Dense(5 => 2), softmax)
julia> using BSON: @load
# load one model into another
model = loadmodel!(model, @load("mymodel.bson"))
```

julia> @load "mymodel.bson" weights
This ensures that the model loaded from `"mymodel.bson"` matches the structure of `model`. [`Flux.loadmodel!`](@ref) is also convenient for copying parameters between models in memory.

julia> Flux.loadparams!(model, weights)
```@docs
Flux.loadmodel!
```

The new `model` we created will now be identical to the one we saved parameters for.

## Checkpointing

In longer training runs it's a good idea to periodically save your model, so that you can resume if training is interrupted (for example, if there's a power cut). You can do this by saving the model in the [callback provided to `train!`](training/training.md).
Expand Down
2 changes: 2 additions & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/show.jl")

include("loading.jl")

include("outputsize.jl")

include("data/Data.jl")
Expand Down
10 changes: 10 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ function Diagonal(size::Tuple; kw...)
Scale(size...; kw...)
end

# Deprecate this eventually once saving models w/o structure is no more
function loadparams!(m, xs)
Base.depwarn("loadparams! will be deprecated eventually. Use loadmodel! instead.", :loadparams!)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end

# Channel notation: Changed to match Conv, but very softly deprecated!
# Perhaps change to @deprecate for v0.14, but there is no plan to remove these.
Dense(in::Integer, out::Integer, σ = identity; kw...) =
Expand Down
8 changes: 0 additions & 8 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,6 @@ function params(m...)
return ps
end

function loadparams!(m, xs)
for (p, x) in zip(params(m), xs)
size(p) == size(x) ||
error("Expected param size $(size(p)), got $(size(x))")
copyto!(p, x)
end
end

struct FluxCUDAAdaptor end
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
Expand Down
99 changes: 99 additions & 0 deletions src/loading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
loadleaf!(dst, src, err) = dst
loadleaf!(dst::AbstractArray, src, err) =
error("Tried to copy $src into an array destination; this is not allowed.")
loadleaf!(dst, src::AbstractArray, err) =
error("Tried to copy an array to $dst; this is not allowed.")
function loadleaf!(dst::AbstractArray, src::Bool, err)
if iszero(src)
dst .= src
else
error("Cannot copy boolean parameter == true to non-zero parameter.")
end
return dst
end
loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst :
error("Cannot copy non-zero parameter to boolean parameter == true.")
function loadleaf!(dst::AbstractArray, src::AbstractArray, err)
(size(dst) == size(src)) || throw(err)
copyto!(dst, src)
end

_tie_check(dst::Bool, src::AbstractArray) = iszero(dst) ||
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
_tie_check(dst::AbstractArray, src::Bool) = (iszero(dst) && iszero(src)) ||
error("Encountered tied parameter with boolean source at some nodes and non-boolean sources at others.")
_tie_check(dst::AbstractArray, src::AbstractArray) = (dst == src) ||
error("Encountered tied destination parameters with untied and mismatched sources.")
_tie_check(dst, src) = true

_bool_tie_check(dst, src) = true

"""
loadmodel!(dst, src)
Copy all the parameters (trainable and non-trainable) from `src` into `dst`.
Recursively walks `dst` and `src` together using [`Functors.children`](@ref),
and calling `copyto!` on parameter arrays or throwing an error when there is a mismatch.
Non-array elements (such as activation functions) are not copied and need not match.
Zero bias vectors and `bias=false` are considered equivalent
(see extended help for more details).
# Examples
```julia
julia> dst = Chain(Dense(Flux.ones32(2, 5, tanh)), Dense(2 => 1; bias = [1f0]))
Chain(
Dense(5 => 2, tanh), # 12 parameters
Dense(2 => 1), # 3 parameters
) # Total: 4 arrays, 15 parameters, 316 bytes.
julia> dst[1].weight ≈ ones(2, 5) # by construction
true
julia> src = Chain(Dense(5 => 2, relu), Dense(2 => 1, bias=false));
julia> Flux.loadmodel!(dst, src);
julia> dst[1].weight ≈ ones(2, 5) # values changed
false
julia> iszero(dst[2].bias)
true
```
# Extended help
Throws an error when:
- `dst` and `src` do not share the same fields (at any level)
- the sizes of leaf nodes are mismatched between `dst` and `src`
- copying non-array values to/from an array parameter
(except inactive parameters described below)
- `dst` is a "tied" parameter (i.e. refers to another parameter) and
loaded into multiple times with mismatched source values
Inactive parameters can be encoded by using the boolean value `false` instead of an array.
If `dst == false` and `src` is an all-zero array, no error will be raised (and no values copied);
however, attempting to copy a non-zero array to an inactive parameter will throw an error.
Likewise, copying a `src` value of `false` to any `dst` array is valid,
but copying a `src` value of `true` will error.
"""
function loadmodel!(dst, src; cache = Base.IdSet())
ldsts, _ = functor(dst)
lsrcs, _ = functor(src)
(keys(ldsts) == keys(lsrcs)) ||
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))

err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
foreach(ldsts, lsrcs) do ldst, lsrc
if ldst in cache # we already loaded this parameter before
_tie_check(ldst, lsrc) && return ldst
elseif Functors.isleaf(ldst) # our first time loading this leaf
push!(cache, ldst)
loadleaf!(ldst, lsrc, err)
else # this isn't a leaf
loadmodel!(ldst, lsrc; cache = cache)
end
end

return dst
end
Loading

0 comments on commit 674527e

Please sign in to comment.