Skip to content

Commit

Permalink
update freezing docs
Browse files Browse the repository at this point in the history
  • Loading branch information
isentropic committed Jun 28, 2024
1 parent 4331dc8 commit 693604a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ makedocs(
"OneHotArrays.jl" => "reference/data/onehot.md",
"Low-level Operations -- NNlib.jl" => "reference/models/nnlib.md",
"Nested Structures -- Functors.jl" => "reference/models/functors.md",
"Advanced" => "reference/misc-model-tweaking.md"
],
"Tutorials" => [
# These walk you through various tasks. It's fine if they overlap quite a lot.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ subsections should make it clear which one suits your needs the best.

## On-the-fly freezing per model instance
Perhaps you'd like to freeze some of the weights of the model (even at
mid-training), and Flux accomplishes this through [`freeze!`](@ref Flux.freeze!) and `thaw!`.
mid-training), and Flux accomplishes this through [`freeze!`](@ref Flux.freeze!)
and `thaw!`.

```julia
m = Chain(
Expand Down Expand Up @@ -45,12 +46,13 @@ end
## Static freezing per model definition
Sometimes some parts of the model ([`Flux.@layer`](@ref)) needn't to be trained at all but these params
still need to reside on the GPU (these params are still needed in the forward
and/or backward pass).
and/or backward pass). This is somewhat similar to `pytorch` [register_buffer](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer).



```julia
struct MaskedLayer{T}
chain::Chain
mask::T
end
struct MaskedLayer chain; mask; end

Flux.@layer MyLayer trainable=(chain,)
# mask field will not be updated in the training loop

Expand All @@ -70,11 +72,7 @@ transfer to the GPU (or be part of the functor). All scalar fields are like this
by default, so things like learning rate multipliers are not trainable nor
transferred to the GPU by default.
```julia
struct CustomLayer{T, F}
chain::Chain
activation_results::Vector{F}
lr_multiplier::Float32
end
struct CustomLayer chain; activation_results; end
Flux.@functor CustomLayer (chain, ) # Explicitly leaving out `activation_results`

function (m::CustomLayer)(x)
Expand All @@ -90,15 +88,16 @@ See more about this in [`Flux.@functor`](@ref)


## Freezing Layer Parameters (deprecated)
When it is desired to not include all the model parameters (for e.g. transfer
learning), we can simply not pass in those layers into our call to `params`.

When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`.

!!! compat "Flux ≤ 0.14"
The mechanism described here is for Flux's old "implicit" training style.
When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.
!!! compat "Flux ≤ 0.14" The mechanism described here is for Flux's old
"implicit" training style. When upgrading for Flux 0.15, it should be
replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`.

Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain
this using the slicing features `Chain` provides:
Consider a simple multi-layer perceptron model where we want to avoid optimising
the first two `Dense` layers. We can obtain this using the slicing features
`Chain` provides:

```julia
m = Chain(
Expand All @@ -110,19 +109,24 @@ m = Chain(
ps = Flux.params(m[3:end])
```

The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it.
The `Zygote.Params` object `ps` now holds a reference to only the parameters of
the layers passed to it.

During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed.
During training, the gradients will only be computed for (and applied to) the
last `Dense` layer, therefore only that would have its parameters changed.

`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this:
`Flux.params` also takes multiple inputs to make it easy to collect parameters
from heterogenous models with a single call. A simple demonstration would be if
we wanted to omit optimising the second `Dense` layer in the previous example.
It would look something like this:

```julia
Flux.params(m[1], m[3:end])
```

Sometimes, a more fine-tuned control is needed.
We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`,
by simply deleting it from `ps`:
Sometimes, a more fine-tuned control is needed. We can freeze a specific
parameter of a specific layer which already entered a `Params` object `ps`, by
simply deleting it from `ps`:

```julia
ps = Flux.params(m)
Expand Down

0 comments on commit 693604a

Please sign in to comment.