Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialising weights outside of layer declarations #1879

Open
theabhirath opened this issue Feb 19, 2022 · 18 comments
Open

Initialising weights outside of layer declarations #1879

theabhirath opened this issue Feb 19, 2022 · 18 comments

Comments

@theabhirath
Copy link
Member

Following the discussions in FluxML/Metalhead.jl#119, I realised that currently there is no way for the user to programmatically pass in weight initialisation strategies for layers in a Chain-like structure based on the type of the layer (after the layer has been declared already, that is). This would be quite the useful feature to have given that many recent models use specific weight initialisations for some types of layers.

An initial idea that I had was to add a mutating version of the existing initialisation functions. Then we could have a wrapper function that mutated the weights of the already existing layer instead of having to copy over an entirely new layer just to change the initial weights. I'm unsure if this clashes with something (and I also don't really have ideas on if there are efficient ways to do this already via existing functionalities), so opening this up for discussion in case there's some conflict before I sit down to write it up.

\cc @darsnack

@darsnack
Copy link
Member

Mutating variants sound like a good idea. For applying them to initialized models, we could make some wrapper interface like init!(method, model), but that would require defining a init!(..., ::Dense), init!(..., ::Chain), etc. I was thinking instead, we could fmap the mutating variant on a given model's trainable parameters. The main issue here is that usually one wants to initialize a weight parameter but not the bias.

@mcabbott
Copy link
Member

One way with fmap is:

julia> m = Chain(Dense(2,3));

julia> fmap(m; exclude = x -> hasproperty(x, :weight)) do x
         x.weight .= (1:3)
         x
       end
Chain(
  Dense(2, 3),                          # 9 parameters
)

julia> m.layers[1].weight
3×2 Matrix{Float32}:
 1.0  1.0
 2.0  2.0
 3.0  3.0

julia> m.layers[1].bias
3-element Vector{Float32}:
 0.0
 0.0
 0.0

Is going by field name good enough? It might be. Could be wrapped up something like weight!(init!, m, fields = :weight), perhaps, in case you want to specify others?

It may also not be worth the hassle making this mutate, since it will only run once. Maybe the fmap needed to reconstruct the layers is slightly more tricky, but should be doable, and could use all existing init functions.

@darsnack
Copy link
Member

Yeah my concern was relying on a particular field. We could always make reinit!(method, model) default to fmap on all trainables, then allow custom overrides. Similar approach to #1875.

@ToucheSir
Copy link
Member

ToucheSir commented Feb 19, 2022

My fear is that creating a single function for re-init would be too niche for the reasons discussed already (e.g. different parameters in the same layer wanting different init functions). Mutating variants of init functions makes sense to me, however. They'll at least allow users to do things manually until we can think of good higher-level APIs.

@mcabbott
Copy link
Member

I'd like to avoid adding another special function you have to remember to overload for any new layer, so that other people can re-weight it. My sketch above is much too rough, but can there be some nice API a bit like that?

Most layers call bias bias, so filtering based on that might not be terrible. Maybe the reweight! function takes field names to ignore (default (:b, :bias)) and to act on (default every other trainable array); then you can target your custom layer?

If it targets trainable, should it live in Optimisers.jl?

@darsnack
Copy link
Member

Most layers call bias bias, so filtering based on that might not be terrible. Maybe the reweight! function takes field names to ignore (default (:b, :bias)) and to act on (default every other trainable array); then you can target your custom layer?

I like this. Let's write it so that the keys to ignore are a nested NamedTuple of the same structure as the model (with nothing for "don't descend this branch"). It's easy enough to go from a plain vector ignores = [:b, :bias] to this nested structure (i.e. fmapstructure(x -> ignores, model)). But the core being nested means we allow branches to have separate, overlapping ignore patterns.

If it targets trainable, should it live in Optimisers.jl?

My thought it no, Flux will depend on Optimisers, so it can still live here. Initialization is specific to neural network models and not optimization.

@mcabbott
Copy link
Member

So long as we are only doing mutable models, the easy way to apply this only to some branch is probably something like reweight!(m.layers.enc, glorot_normal; skip = (:b, :bias)). That is, instead of building an API for specifying what branch of the model to act on, just pass that branch in.

@darsnack
Copy link
Member

True, that's better!

@darsnack
Copy link
Member

darsnack commented Feb 19, 2022

Small sidetone: I would make the initialization method the first arg to support the do syntax in case anyone needs it.

@ToucheSir
Copy link
Member

reweight! would have to return a changed model in order to handle bias=false, or do we not care about those?

@darsnack
Copy link
Member

The semantic definition of bias=false means that trying to load a numeric value into it is ignored. I think that extends to reweight! too.

@mcabbott
Copy link
Member

Indeed [re argument order]. I guess the next question is what gets passed to that function. Should this work, or should it get the size?

reinit!(model) do x
  x isa AbstractVector && return x
  randn(Float32, size(x))
end

Is what it returns (like here) always copied back into the old array, or only if you do it? I presume it should return a re-built model alla fmap, but does it guarantee that the old one matches?

@theabhirath
Copy link
Member Author

theabhirath commented Apr 23, 2022

Is there a way to get Functors to only "see" down to a certain level? If fmap can somehow be overloaded to stop at the Flux layer level (for custom layers, I reckon then it means stopping when a struct is found? Not sure how Flux recognises those), then instead of passing a skip-list for params, we could just leave it to the user to define behaviour for parameters they want to re-initialise (somewhat like PyTorch, whose behaviour I found quite intuitive in this case). First define an _init_weights! function that takes care of the necessary behaviour:

function _init_weights!(m)
    if m isa Conv
        m.weight .*= 2
        m.bias .+= 5
    end
    return m
end

Now all that is required is a recursive function (fmap-like, or on the torch side of things, like apply) that can walk through the model and apply this function. I was trying to get this to happen but I couldn't figure out how to get Functors to stop at the Flux layer level - is there a simple way to make this happen?

@ToucheSir
Copy link
Member

The exclude kwarg of fmap can be used to stop traversing at any point in the tree. It's set to Functors.isleaf by default, but it's relatively straightforward to write a custom callback:

is_layer_or_leaf(m) = Functors.isleaf(m)
is_layer_or_leaf(::Conv) = true

fmap(_init_weights!, m; exclude=is_layer_or_leaf)

_init_weights! could likewise be written in a dispatch-oriented style.

@theabhirath
Copy link
Member Author

That's great! I tried something that's a pretty typical usecase and it worked quite well:

julia> is_layer_or_leaf(m) = Functors.isleaf(m)
is_layer_or_leaf (generic function with 1 method)

julia> is_layer_or_leaf(::Conv) = true
is_layer_or_leaf (generic function with 2 methods)

julia> is_layer_or_leaf(::Dense) = true
is_layer_or_leaf (generic function with 3 methods)

julia> l = Chain(Dense(3, 3), Conv((3, 3), 3 => 10))
Chain(
  Dense(3 => 3),                        # 12 parameters
  Conv((3, 3), 3 => 10),                # 280 parameters
)                   # Total: 4 arrays, 292 parameters, 1.617 KiB.

julia> function _init_weights!(m::Conv)
           m.weight .*= 2
           m.bias .+= 5
           return m
       end
_init_weights! (generic function with 1 method)

julia> function _init_weights!(m::Dense)
           m.weight .*= 3
           m.bias .+= 4
           return m
       end
_init_weights! (generic function with 2 methods)

julia> fmap(_init_weights!, l; exclude = is_layer_or_leaf)
Chain(
  Dense(3 => 3),                        # 12 parameters
  Conv((3, 3), 3 => 10),                # 280 parameters
)                   # Total: 4 arrays, 292 parameters, 1.617 KiB.

julia> l[1].bias
3-element Vector{Float32}:
 4.0
 4.0
 4.0

julia> l[2].bias
10-element Vector{Float32}:
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0

If this approach has no problems, then it seems pretty straightforward to define a reinit function that has exclude=is_layer_or_leaf passed to fmap by default. The only problem I can imagine happening will be for layers like LayerNorm, which itself has Flux.Scale as one of its components. Some people may want to consider LayerNorm a leaf and only reinit the explicit Flux.Scale layers, while others may want to reinit all Flux.Scale layers irrespective of whether they're within a LayerNorm or not.

@ToucheSir
Copy link
Member

The only problem I can imagine happening will be for layers like LayerNorm, which itself has Flux.Scale as one of its components. Some people may want to consider LayerNorm a leaf and only reinit the explicit Flux.Scale layers, while others may want to reinit all Flux.Scale layers irrespective of whether they're within a LayerNorm or not.

This ambiguity is part of why we don't already have a built-in reinit function, IMO. If we had an option for pre-order traversal like https://chengchingwen.github.io/StructWalk.jl/dev/#StructWalk.prewalk-Tuple{Any,%20Any} or https://fluxml.ai/MacroTools.jl/stable/pattern-matching/#Expression-Walking-1, a user could easily choose whether they want to handle LayerNorm.affine separately or not.

@theabhirath
Copy link
Member Author

I was trying to give this another go, but I noticed the above example (from here) doesn't work with DenseNet. The error was quite cryptic:

julia> model = DenseNet();

julia> fmap(_init_weights!, model; exclude = is_layer_or_leaf)
ERROR: MethodError: no method matching copyto!(::Bool, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Tuple{}, typeof(+), Tuple{Bool, Int64}})
Closest candidates are:
  copyto!(::Zygote.Buffer, ::Any) at ~/.julia/packages/Zygote/DkIUK/src/tools/buffer.jl:54
  copyto!(::Any, ::Base.Broadcast.Broadcasted{<:StaticArrays.StaticArrayStyle}) at ~/.julia/packages/StaticArrays/G7IlJ/src/broadcast.jl:68
  copyto!(::AbstractArray, ::Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}) at broadcast.jl:929
  ...
Stacktrace:
  [1] broadcasted
    @ ./broadcast.jl:1319 [inlined]
  [2] broadcasted
    @ ./broadcast.jl:1317 [inlined]
  [3] _init_weights!(m::Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})
    @ Main ./REPL[9]:3
  [4] #fmap#17
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:50 [inlined]
  [5] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
  [6] iterate
    @ ./generator.jl:47 [inlined]
  [7] _collect(c::Vector{Any}, itr::Base.Generator{Vector{Any}, Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:804
  [8] collect_similar
    @ ./array.jl:713 [inlined]
  [9] map
    @ ./abstractarray.jl:2976 [inlined]
 [10] _default_walk
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:43 [inlined]
 [11] fmap(f::typeof(_init_weights!), x::Vector{Any}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [12] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Vector{Any})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [13] map
    @ ./tuple.jl:273 [inlined]
 [14] map(::Function, ::NamedTuple{(:layers,), Tuple{Vector{Any}}})
    @ Base ./namedtuple.jl:218
 [15] _default_walk(f::Function, x::Chain{Vector{Any}})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
 [16] fmap(f::typeof(_init_weights!), x::Chain{Vector{Any}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [17] #18
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:50 [inlined]
 [18] map
    @ ./tuple.jl:274 [inlined]
 [19] _default_walk
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:43 [inlined]
 [20] fmap(f::typeof(_init_weights!), x::Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [21] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [22] map
    @ ./tuple.jl:273 [inlined]
 [23] map(::Function, ::NamedTuple{(:layers,), Tuple{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}})
    @ Base ./namedtuple.jl:218
 [24] _default_walk(f::Function, x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
 [25] fmap(f::typeof(_init_weights!), x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [26] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [27] map
    @ ./tuple.jl:273 [inlined]
 [28] map(::Function, ::NamedTuple{(:layers,), Tuple{Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}}})
    @ Base ./namedtuple.jl:218
 [29] _default_walk(f::Function, x::DenseNet)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
 [30] fmap(f::typeof(_init_weights!), x::DenseNet; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [31] top-level scope
    @ REPL[16]:1
 [32] top-level scope
    @ ~/.julia/packages/CUDA/GGwVa/src/initialization.jl:52

Am I missing something here? Why isn't this working the way it's supposed to?

@ToucheSir
Copy link
Member

Most likely you are trying to accumulate into a bias=false field. Bool (if not all non-array) params are probably safe to ignore when re-initializing, but at some point (probably out of scope for now) we'd want to consider immutable arrays like SArray as well. Those would require returning an updated layer from _init_weights! much like Optimisers.update! does now with state + gradients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants