Skip to content

Commit

Permalink
restructre
Browse files Browse the repository at this point in the history
  • Loading branch information
isentropic committed Mar 8, 2024
1 parent 7f234d6 commit 5514952
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ makedocs(
=#
# Not really sure where this belongs... some in Fluxperimental, aim to delete?
"Custom Layers" => "models/advanced.md",
"Freezing model params" => "models/freezing-params.md",
"Advanced tweaking of models" => "tutorials/misc-model-tweaking.md",
],
],
format = Documenter.HTML(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Freezing model weights
# Choosing differentiable/gpu parts of the model
!!! note
This tutorial features somewhat disconnected topics about customizing your
models even further. It is advised to be familiar with
[`Flux.@layer`](@ref), [`Flux.@functor`](@ref), [`freeze!`](@ref
Flux.freeze!) and other basics of Flux.

Flux provides several ways of freezing, excluding from backprop entirely and
marking custom struct fields not to be moved to the GPU
([Functors.@functor](@ref)) hence excluded from being trained. The following
Expand Down Expand Up @@ -37,21 +43,19 @@ end
```

## Static freezing per model definition
Sometimes some parts of the model ([`Flux.@functor`](@ref)) needn't to be trained at all but these params
Sometimes some parts of the model ([`Flux.@layer`](@ref)) needn't to be trained at all but these params
still need to reside on the GPU (these params are still needed in the forward
and/or backward pass).
```julia
struct MaskedLayer{T}
chain::Chain
mask::T
end
Flux.@functor MaskedLayer

# mark the trainable part
Flux.trainable(a::MaskedLayer)=(;a.chain)
# a.mask will not be updated in the training loop
Flux.@layer MyLayer trainable=(chain,)
# mask field will not be updated in the training loop

function (m::MaskedLayer)(x)
# mask field will still move to to gpu for efficient operations:
return m.chain(x) + x + m.mask
end

Expand All @@ -61,7 +65,7 @@ Note how this method permanently sets some model fields to be excluded from
training without on-the-fly changing.

## Excluding from model definition
Sometimes some parameters are just "not trainable" but they shouldn't even
Sometimes some parameters aren't just "not trainable" but they shouldn't even
transfer to the GPU. All scalar fields are like this by default, so things like
learning rate multipliers are not trainable nor transferred to the GPU by
default.
Expand All @@ -82,7 +86,7 @@ function (m::CustomLayer)(x)
return result
end
```
See more about this in [`Flux.@functor`](@ref) and
See more about this in [`Flux.@functor`](@ref)


## Freezing Layer Parameters (deprecated)
Expand Down

0 comments on commit 5514952

Please sign in to comment.