Skip to content

Commit

Permalink
docs: improve freezing docs
Browse files Browse the repository at this point in the history
  • Loading branch information
isentropic committed Feb 29, 2024
1 parent 48a43db commit 383fb84
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 41 deletions.
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ makedocs(
"Deep Convolutional GAN" => "tutorials/2021-10-08-dcgan-mnist.md",
=#
# Not really sure where this belongs... some in Fluxperimental, aim to delete?
"Custom Layers" => "models/advanced.md", # TODO move freezing to Training
"Custom Layers" => "models/advanced.md",
"Freezing model params" => "models/freezing-params.md",
],
],
format = Documenter.HTML(
Expand Down
40 changes: 0 additions & 40 deletions docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,46 +69,6 @@ Params([])

It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired).

## Freezing Layer Parameters

When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.

!!! compat "Flux ≤ 0.14"
The mechanism described here is for Flux's old "implicit" training style.
When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.

Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
this using the slicing features `Chain` provides:

```julia
m = Chain(
Dense(784 => 64, relu),
Dense(64 => 64, relu),
Dense(32 => 10)
);

ps = Flux.params(m[3:end])
```

The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.

During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.

`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:

```julia
Flux.params(m[1], m[3:end])
```

Sometimes, a more fine-tuned control is needed.
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
by simply deleting it from `ps`:

```julia
ps = Flux.params(m)
delete!(ps, m[2].bias)
```

## Custom multiple input or output layer

Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf).
Expand Down
127 changes: 127 additions & 0 deletions docs/src/models/freezing-params.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Freezing model weights
Flux provides several ways of freezing, excluding from backprop entirely and
marking custom struct fields not to be moved to the GPU
([Functors.@functor](@ref)) hence excluded from being trained. The following
subsections should make it clear which one suits your needs the best.

## On-the-fly freezing per model instance
Perhaps you'd like to freeze some of the weights of the model (even at
mid-training), and Flux accomplishes this through [`freeze!`](@ref Flux.freeze!) and `thaw!`.

```julia
m = Chain(
Dense(784 => 64, relu), # freeze this one
Dense(64 => 64, relu),
Dense(32 => 10)
)
opt_state = Flux.setup(Momentum(), m);

# Freeze some layers right away
Flux.freeze!(opt_state.layers[1])

for data in train_set
input, label = data

# Some params could be frozen during the training:
Flux.freeze!(opt_state.layers[2])

grads = Flux.gradient(m) do m
result = m(input)
loss(result, label)
end
Flux.update!(opt_state, m, grads[1])

# Optionally unfreeze the params later
Flux.thaw!(opt_state.layers[1])
end
```

## Static freezing per model definition
Sometimes some parts of the model ([`Flux.@functor`](@ref)) needn't to be trained at all but these params
still need to reside on the GPU (these params are still needed in the forward
and/or backward pass).
```julia
struct MaskedLayer{T}
chain::Chain
mask::T
end
Flux.@functor MaskedLayer

# mark the trainable part
Flux.trainable(a::MaskedLayer)=(;a.chain)
# a.mask will not be updated in the training loop

function (m::MaskedLayer)(x)
return m.chain(x) + x + m.mask
end

model = MaskedLayer(...) # this model will not have the `mask` field trained
```
Note how this method permanently sets some model fields to be excluded from
training without on-the-fly changing.

## Excluding from model definition
Sometimes some parameters are just "not trainable" but they shouldn't even
transfer to the GPU. All scalar fields are like this by default, so things like
learning rate multipliers are not trainable nor transferred to the GPU by
default.
```julia
struct CustomLayer{T, F}
chain::Chain
activation_results::Vector{F}
lr_multiplier::Float32
end
Flux.@functor CustomLayer (chain, ) # Explicitly leaving out `activation_results`

function (m::CustomLayer)(x)
result = m.chain(x) + x

# `activation_results` are not part of the GPU loop, hence we could do
# things like `push!`
push!(m.activation_results, mean(result))
return result
end
```
See more about this in [`Flux.@functor`](@ref) and


## Freezing Layer Parameters (deprecated)

When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.

!!! compat "Flux ≤ 0.14"
The mechanism described here is for Flux's old "implicit" training style.
When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.

Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
this using the slicing features `Chain` provides:

```julia
m = Chain(
Dense(784 => 64, relu),
Dense(64 => 64, relu),
Dense(32 => 10)
);

ps = Flux.params(m[3:end])
```

The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.

During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.

`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:

```julia
Flux.params(m[1], m[3:end])
```

Sometimes, a more fine-tuned control is needed.
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
by simply deleting it from `ps`:

```julia
ps = Flux.params(m)
delete!(ps, m[2].bias)
```

0 comments on commit 383fb84

Please sign in to comment.