From 383fb8403110a0a441485ee1f4b908e6573961e2 Mon Sep 17 00:00:00 2001 From: Zhanibek Date: Thu, 29 Feb 2024 19:39:33 +0900 Subject: [PATCH 01/11] docs: improve freezing docs --- docs/make.jl | 3 +- docs/src/models/advanced.md | 40 --------- docs/src/models/freezing-params.md | 127 +++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 41 deletions(-) create mode 100644 docs/src/models/freezing-params.md diff --git a/docs/make.jl b/docs/make.jl index a1b588d618..2ec6843126 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -54,7 +54,8 @@ makedocs( "Deep Convolutional GAN" => "tutorials/2021-10-08-dcgan-mnist.md", =# # Not really sure where this belongs... some in Fluxperimental, aim to delete? - "Custom Layers" => "models/advanced.md", # TODO move freezing to Training + "Custom Layers" => "models/advanced.md", + "Freezing model params" => "models/freezing-params.md", ], ], format = Documenter.HTML( diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index b7161b8c59..be145858fe 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -69,46 +69,6 @@ Params([]) It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired). -## Freezing Layer Parameters - -When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`. - -!!! compat "Flux ≤ 0.14" - The mechanism described here is for Flux's old "implicit" training style. - When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`. - -Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain -this using the slicing features `Chain` provides: - -```julia -m = Chain( - Dense(784 => 64, relu), - Dense(64 => 64, relu), - Dense(32 => 10) - ); - -ps = Flux.params(m[3:end]) -``` - -The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it. - -During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed. - -`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this: - -```julia -Flux.params(m[1], m[3:end]) -``` - -Sometimes, a more fine-tuned control is needed. -We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`, -by simply deleting it from `ps`: - -```julia -ps = Flux.params(m) -delete!(ps, m[2].bias) -``` - ## Custom multiple input or output layer Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf). diff --git a/docs/src/models/freezing-params.md b/docs/src/models/freezing-params.md new file mode 100644 index 0000000000..b1f5f26369 --- /dev/null +++ b/docs/src/models/freezing-params.md @@ -0,0 +1,127 @@ +# Freezing model weights +Flux provides several ways of freezing, excluding from backprop entirely and +marking custom struct fields not to be moved to the GPU +([Functors.@functor](@ref)) hence excluded from being trained. The following +subsections should make it clear which one suits your needs the best. + +## On-the-fly freezing per model instance +Perhaps you'd like to freeze some of the weights of the model (even at +mid-training), and Flux accomplishes this through [`freeze!`](@ref Flux.freeze!) and `thaw!`. + +```julia +m = Chain( + Dense(784 => 64, relu), # freeze this one + Dense(64 => 64, relu), + Dense(32 => 10) + ) +opt_state = Flux.setup(Momentum(), m); + +# Freeze some layers right away +Flux.freeze!(opt_state.layers[1]) + +for data in train_set + input, label = data + + # Some params could be frozen during the training: + Flux.freeze!(opt_state.layers[2]) + + grads = Flux.gradient(m) do m + result = m(input) + loss(result, label) + end + Flux.update!(opt_state, m, grads[1]) + + # Optionally unfreeze the params later + Flux.thaw!(opt_state.layers[1]) +end +``` + +## Static freezing per model definition +Sometimes some parts of the model ([`Flux.@functor`](@ref)) needn't to be trained at all but these params +still need to reside on the GPU (these params are still needed in the forward +and/or backward pass). +```julia +struct MaskedLayer{T} + chain::Chain + mask::T +end +Flux.@functor MaskedLayer + +# mark the trainable part +Flux.trainable(a::MaskedLayer)=(;a.chain) +# a.mask will not be updated in the training loop + +function (m::MaskedLayer)(x) + return m.chain(x) + x + m.mask +end + +model = MaskedLayer(...) # this model will not have the `mask` field trained +``` +Note how this method permanently sets some model fields to be excluded from +training without on-the-fly changing. + +## Excluding from model definition +Sometimes some parameters are just "not trainable" but they shouldn't even +transfer to the GPU. All scalar fields are like this by default, so things like +learning rate multipliers are not trainable nor transferred to the GPU by +default. +```julia +struct CustomLayer{T, F} + chain::Chain + activation_results::Vector{F} + lr_multiplier::Float32 +end +Flux.@functor CustomLayer (chain, ) # Explicitly leaving out `activation_results` + +function (m::CustomLayer)(x) + result = m.chain(x) + x + + # `activation_results` are not part of the GPU loop, hence we could do + # things like `push!` + push!(m.activation_results, mean(result)) + return result +end +``` +See more about this in [`Flux.@functor`](@ref) and + + +## Freezing Layer Parameters (deprecated) + +When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`. + +!!! compat "Flux ≤ 0.14" + The mechanism described here is for Flux's old "implicit" training style. + When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`. + +Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain +this using the slicing features `Chain` provides: + +```julia +m = Chain( + Dense(784 => 64, relu), + Dense(64 => 64, relu), + Dense(32 => 10) + ); + +ps = Flux.params(m[3:end]) +``` + +The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it. + +During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed. + +`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this: + +```julia +Flux.params(m[1], m[3:end]) +``` + +Sometimes, a more fine-tuned control is needed. +We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`, +by simply deleting it from `ps`: + +```julia +ps = Flux.params(m) +delete!(ps, m[2].bias) +``` + From 7f234d6267997ef8bc9ef67fd7a10e4a9ad10c07 Mon Sep 17 00:00:00 2001 From: Zhanibek Date: Mon, 4 Mar 2024 17:08:12 +0900 Subject: [PATCH 02/11] fix broken link --- docs/src/training/optimisers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 25d817454e..04e5c03ef3 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -76,7 +76,7 @@ Flux.Optimise.Optimiser ## Scheduling Optimisers -In practice, it is fairly common to schedule the learning rate of an optimiser to obtain faster convergence. There are a variety of popular scheduling policies, and you can find implementations of them in [ParameterSchedulers.jl](http://fluxml.ai/ParameterSchedulers.jl/dev/README.html). The documentation for ParameterSchedulers.jl provides a more detailed overview of the different scheduling policies, and how to use them with Flux optimisers. Below, we provide a brief snippet illustrating a [cosine annealing](https://arxiv.org/pdf/1608.03983.pdf) schedule with a momentum optimiser. +In practice, it is fairly common to schedule the learning rate of an optimiser to obtain faster convergence. There are a variety of popular scheduling policies, and you can find implementations of them in [ParameterSchedulers.jl](http://fluxml.ai/ParameterSchedulers.jl/dev). The documentation for ParameterSchedulers.jl provides a more detailed overview of the different scheduling policies, and how to use them with Flux optimisers. Below, we provide a brief snippet illustrating a [cosine annealing](https://arxiv.org/pdf/1608.03983.pdf) schedule with a momentum optimiser. First, we import ParameterSchedulers.jl and initialize a cosine annealing schedule to vary the learning rate between `1e-4` and `1e-2` every 10 steps. We also create a new [`Momentum`](@ref) optimiser. ```julia From 5514952317c36d2dcb26289bd7a4282f2befd253 Mon Sep 17 00:00:00 2001 From: Zhanibek Date: Fri, 8 Mar 2024 17:48:38 +0900 Subject: [PATCH 03/11] restructre --- docs/make.jl | 2 +- .../misc-model-tweaking.md} | 22 +++++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) rename docs/src/{models/freezing-params.md => tutorials/misc-model-tweaking.md} (84%) diff --git a/docs/make.jl b/docs/make.jl index 2ec6843126..251a9bb78c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -55,7 +55,7 @@ makedocs( =# # Not really sure where this belongs... some in Fluxperimental, aim to delete? "Custom Layers" => "models/advanced.md", - "Freezing model params" => "models/freezing-params.md", + "Advanced tweaking of models" => "tutorials/misc-model-tweaking.md", ], ], format = Documenter.HTML( diff --git a/docs/src/models/freezing-params.md b/docs/src/tutorials/misc-model-tweaking.md similarity index 84% rename from docs/src/models/freezing-params.md rename to docs/src/tutorials/misc-model-tweaking.md index b1f5f26369..a563106585 100644 --- a/docs/src/models/freezing-params.md +++ b/docs/src/tutorials/misc-model-tweaking.md @@ -1,4 +1,10 @@ -# Freezing model weights +# Choosing differentiable/gpu parts of the model +!!! note + This tutorial features somewhat disconnected topics about customizing your + models even further. It is advised to be familiar with + [`Flux.@layer`](@ref), [`Flux.@functor`](@ref), [`freeze!`](@ref + Flux.freeze!) and other basics of Flux. + Flux provides several ways of freezing, excluding from backprop entirely and marking custom struct fields not to be moved to the GPU ([Functors.@functor](@ref)) hence excluded from being trained. The following @@ -37,7 +43,7 @@ end ``` ## Static freezing per model definition -Sometimes some parts of the model ([`Flux.@functor`](@ref)) needn't to be trained at all but these params +Sometimes some parts of the model ([`Flux.@layer`](@ref)) needn't to be trained at all but these params still need to reside on the GPU (these params are still needed in the forward and/or backward pass). ```julia @@ -45,13 +51,11 @@ struct MaskedLayer{T} chain::Chain mask::T end -Flux.@functor MaskedLayer - -# mark the trainable part -Flux.trainable(a::MaskedLayer)=(;a.chain) -# a.mask will not be updated in the training loop +Flux.@layer MyLayer trainable=(chain,) +# mask field will not be updated in the training loop function (m::MaskedLayer)(x) + # mask field will still move to to gpu for efficient operations: return m.chain(x) + x + m.mask end @@ -61,7 +65,7 @@ Note how this method permanently sets some model fields to be excluded from training without on-the-fly changing. ## Excluding from model definition -Sometimes some parameters are just "not trainable" but they shouldn't even +Sometimes some parameters aren't just "not trainable" but they shouldn't even transfer to the GPU. All scalar fields are like this by default, so things like learning rate multipliers are not trainable nor transferred to the GPU by default. @@ -82,7 +86,7 @@ function (m::CustomLayer)(x) return result end ``` -See more about this in [`Flux.@functor`](@ref) and +See more about this in [`Flux.@functor`](@ref) ## Freezing Layer Parameters (deprecated) From 10e9efd8d27a8b3c46fd1a42c377ebe7b3b9f825 Mon Sep 17 00:00:00 2001 From: Diego Javier Zea Date: Wed, 28 Feb 2024 09:54:28 +0100 Subject: [PATCH 04/11] Fix https://github.com/FluxML/Flux.jl/issues/2380 (#2384) --- docs/src/tutorials/logistic_regression.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/tutorials/logistic_regression.md b/docs/src/tutorials/logistic_regression.md index 858e907cd5..b6775e9f89 100644 --- a/docs/src/tutorials/logistic_regression.md +++ b/docs/src/tutorials/logistic_regression.md @@ -35,7 +35,7 @@ julia> x |> summary The `y` values here corresponds to a type of iris plant, with a total of 150 data points. The `x` values depict the sepal length, sepal width, petal length, and petal width (all in `cm`) of 150 iris plant (hence the matrix size `4×150`). Different type of iris plants have different lengths and widths of sepals and petals associated with them, and there is a definitive pattern for this in nature. We can leverage this to train a simple classifier that outputs the type of iris plant using the length and width of sepals and petals as inputs. -Our next step would be to convert this data into a form that can be fed to a machine learning model. The `x` values are arranged in a matrix and should ideally be converted to `Float32` type (see [Performance tips](@ref man-performance-tips)), but the labels must be one hot encoded. [Here](https://discourse.julialang.org/t/all-the-ways-to-do-one-hot-encoding/64807) is a great discourse thread on different techniques that can be used to one hot encode data with or without using any external Julia package. +Our next step would be to convert this data into a form that can be fed to a machine learning model. The `x` values are arranged in a matrix and should ideally be converted to `Float32` type (see [Performance tips](@ref id-man-performance-tips)), but the labels must be one hot encoded. [Here](https://discourse.julialang.org/t/all-the-ways-to-do-one-hot-encoding/64807) is a great discourse thread on different techniques that can be used to one hot encode data with or without using any external Julia package. ```jldoctest logistic_regression julia> x = Float32.(x); From bd234f197bd99ac67e062a50804ba6d29cb3cd5c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 28 Feb 2024 09:55:11 +0100 Subject: [PATCH 05/11] Bump codecov/codecov-action from 3 to 4 (#2376) Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 3 to 4. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v3...v4) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fbbe6a8b32..84863d36ad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,7 +56,7 @@ jobs: if: contains(fromJson('["1", "1.9"]'), matrix.version) && matrix.os == 'ubuntu-latest' - uses: julia-actions/julia-processcoverage@v1 if: contains(fromJson('["1", "1.9"]'), matrix.version) && matrix.os == 'ubuntu-latest' - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 if: contains(fromJson('["1", "1.9"]'), matrix.version) && matrix.os == 'ubuntu-latest' with: file: lcov.info From 06654aa91d6d9840ba9aa7d97acae83325ca2fa6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 4 Mar 2024 22:16:41 +0000 Subject: [PATCH 06/11] Bump dorny/paths-filter from 3.0.1 to 3.0.2 Bumps [dorny/paths-filter](https://github.com/dorny/paths-filter) from 3.0.1 to 3.0.2. - [Release notes](https://github.com/dorny/paths-filter/releases) - [Changelog](https://github.com/dorny/paths-filter/blob/master/CHANGELOG.md) - [Commits](https://github.com/dorny/paths-filter/compare/v3.0.1...v3.0.2) --- updated-dependencies: - dependency-name: dorny/paths-filter dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- .github/workflows/JuliaFormatter.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/JuliaFormatter.yml b/.github/workflows/JuliaFormatter.yml index f3bd1fe96d..a929470b0c 100644 --- a/.github/workflows/JuliaFormatter.yml +++ b/.github/workflows/JuliaFormatter.yml @@ -17,7 +17,7 @@ jobs: steps: - uses: actions/checkout@v4.1.1 - - uses: dorny/paths-filter@v3.0.1 + - uses: dorny/paths-filter@v3.0.2 id: filter with: filters: | From 960f57346e6b1ef8e3e60136e02b127850aca127 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 5 Mar 2024 14:38:07 -0600 Subject: [PATCH 07/11] Allow `cpu(::DataLoader)` (#2388) --- src/functor.jl | 15 ++++++++++++++- test/data.jl | 6 ++++++ test/ext_cuda/cuda.jl | 5 ++++- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 2c8f3360db..8215b92863 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -403,8 +403,9 @@ function _metal end """ gpu(data::DataLoader) + cpu(data::DataLoader) -Transforms a given `DataLoader` to apply `gpu` to each batch of data, +Transforms a given `DataLoader` to apply `gpu` or `cpu` to each batch of data, when iterated over. (If no GPU is available, this does nothing.) # Example @@ -456,6 +457,18 @@ function gpu(d::MLUtils.DataLoader) ) end +function cpu(d::MLUtils.DataLoader) + MLUtils.DataLoader(MLUtils.mapobs(cpu, d.data), + d.batchsize, + d.buffer, + d.partial, + d.shuffle, + d.parallel, + d.collate, + d.rng, + ) +end + # Defining device interfaces. """ Flux.AbstractDevice <: Function diff --git a/test/data.jl b/test/data.jl index 4e4c485064..b97c4dae80 100644 --- a/test/data.jl +++ b/test/data.jl @@ -1,3 +1,4 @@ +using Flux: DataLoader using Random @testset "DataLoader" begin @@ -14,6 +15,11 @@ using Random @test batches[2] == X[:,3:4] @test batches[3] == X[:,5:5] + d_cpu = d |> cpu # does nothing but shouldn't error + @test d_cpu isa DataLoader + @test first(d_cpu) == X[:,1:2] + @test length(d_cpu) == 3 + d = DataLoader(X, batchsize=2, partial=false) # @inferred first(d) batches = collect(d) diff --git a/test/ext_cuda/cuda.jl b/test/ext_cuda/cuda.jl index b52fa6c296..bbfd2854ba 100644 --- a/test/ext_cuda/cuda.jl +++ b/test/ext_cuda/cuda.jl @@ -182,11 +182,14 @@ end X = randn(Float64, 3, 33) pre1 = Flux.DataLoader(X |> gpu; batchsize=13, shuffle=false) post1 = Flux.DataLoader(X; batchsize=13, shuffle=false) |> gpu + rev1 = pre1 |> cpu # inverse operation for epoch in 1:2 - for (p, q) in zip(pre1, post1) + for (p, q, a) in zip(pre1, post1, rev1) @test p isa CuArray{Float32} @test q isa CuArray{Float32} @test p ≈ q + @test a isa Array{Float32} + @test a ≈ Array(p) end end From 97fdcd14abfeca5420630b09ac0054dcdc08d0e6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 5 Mar 2024 14:38:27 -0600 Subject: [PATCH 08/11] Add a macro to opt-in to fancy printing, and to everything else (#1932) --- NEWS.md | 8 ++ docs/src/models/advanced.md | 25 +++--- docs/src/models/basics.md | 9 ++- src/Flux.jl | 5 +- src/functor.jl | 1 + src/layers/attention.jl | 44 +++++++++- src/layers/basic.jl | 18 ++--- src/layers/conv.jl | 6 +- src/layers/macro.jl | 156 ++++++++++++++++++++++++++++++++++++ src/layers/normalise.jl | 19 ++--- src/layers/recurrent.jl | 11 ++- src/layers/show.jl | 69 ++++++++++------ test/layers/macro.jl | 47 +++++++++++ test/runtests.jl | 1 + test/utils.jl | 2 +- 15 files changed, 350 insertions(+), 71 deletions(-) create mode 100644 src/layers/macro.jl create mode 100644 test/layers/macro.jl diff --git a/NEWS.md b/NEWS.md index ac8883a091..68d36fdc34 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,13 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. +## v0.14.13 +* New macro `Flux.@layer` which should be used in place of `@functor`. + This also adds `show` methods for pretty printing. + +## v0.14.12 +* New `SignDecay` optimiser, like `` WeightNorm` but for L1 norm. + ## v0.14.0 (July 2023) * Flux now requires julia v1.9 or later. * CUDA.jl is not a hard dependency anymore. Support is now provided through the extension mechanism, by loading `using Flux, CUDA`. @@ -51,6 +58,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl ## v0.13.6 * Use the package [OneHotArrays.jl](https://github.com/FluxML/OneHotArrays.jl) instead of having the same code here. +* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078) ## v0.13.4 * Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983) diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index be145858fe..46931b6547 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -18,8 +18,8 @@ function (m::CustomModel)(x) return m.chain(x) + x end -# Call @functor to allow for training. Described below in more detail. -Flux.@functor CustomModel +# Call @layer to allow for training. Described below in more detail. +Flux.@layer CustomModel ``` You can then use the model like: @@ -39,7 +39,7 @@ Taking reference from our example `Affine` layer from the [basics](@ref man-basi By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, the way to mark some fields of our layer as trainable is through overloading the `trainable` function: ```julia-repl -julia> Flux.@functor Affine +julia> @layer Affine julia> a = Affine(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]) Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]) @@ -47,7 +47,7 @@ Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]) julia> Flux.params(a) # default behavior Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]]) -julia> Flux.trainable(a::Affine) = (; a.W) # returns a NamedTuple using the field's name +julia> Flux.trainable(a::Affine) = (; W = a.W) # returns a NamedTuple using the field's name julia> Flux.params(a) Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0]]) @@ -67,7 +67,14 @@ julia> Flux.params(Affine(true, [10, 11, 12.0])) Params([]) ``` -It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired). +The exact same method of `trainable` can also be defined using the macro, for convenience: + +```julia +Flux.@layer Affine trainable=(W,) +``` + +There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument. + ## Custom multiple input or output layer @@ -95,9 +102,9 @@ Join(combine, paths...) = Join(combine, paths) ``` Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field. -The next step is to use [`Functors.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path. +The next step is to use [`Functors.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path. ```julia -Flux.@functor Join +Flux.@layer Join ``` Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results. @@ -154,7 +161,7 @@ model(xs) Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs. -We start by following the same steps as the `Join` layer: define a struct, use [`Functors.@functor`](@ref), and define the forward pass. +We start by following the same steps as the `Join` layer: define a struct, use [`@layer`](@ref), and define the forward pass. ```julia using Flux using CUDA @@ -166,7 +173,7 @@ end Split(paths...) = Split(paths) -Flux.@functor Split +Flux.@layer Split (m::Split)(x::AbstractArray) = map(f -> f(x), m.paths) ``` diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index ca95dc747d..fb0f2d5488 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -257,8 +257,8 @@ m(5) # => 26 There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@functor`](@ref Functors.@functor) macro: -``` -Flux.@functor Affine +```julia +Flux.@layer Affine ``` Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the `Affine` layer as follows, using the helper function [`create_bias`](@ref Flux.create_bias): @@ -272,3 +272,8 @@ end Affine(3 => 1, bias=false, init=ones) |> gpu ``` + +```@docs +Flux.@layer +Flux.create_bias +``` diff --git a/src/Flux.jl b/src/Flux.jl index d3ca611dbd..5675f7c10f 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,6 +9,7 @@ using MacroTools: @forward @reexport using NNlib using MLUtils +const stack = MLUtils.stack # now exported by Base import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions using Optimisers: freeze!, thaw!, adjust! using Random: default_rng @@ -69,6 +70,9 @@ include("functor.jl") # Pirate error to catch a common mistake. Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.") +include("layers/show.jl") +include("layers/macro.jl") + include("layers/stateless.jl") include("layers/basic.jl") include("layers/conv.jl") @@ -76,7 +80,6 @@ include("layers/recurrent.jl") include("layers/normalise.jl") include("layers/upsample.jl") include("layers/attention.jl") -include("layers/show.jl") include("loading.jl") diff --git a/src/functor.jl b/src/functor.jl index 8215b92863..34fe52db35 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -81,6 +81,7 @@ function params!(p::Params, x, seen = IdSet()) elseif x in seen nothing else + _check_new_macro(x) # complains if you used @functor not @layer push!(seen, x) for child in trainable(x) params!(p, child, seen) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 3701be2bb0..d4a33283d9 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2} out_proj::P2 end -@functor MultiHeadAttention +@layer MultiHeadAttention function MultiHeadAttention(dims; nheads::Int = 8, @@ -83,8 +83,8 @@ function MultiHeadAttention(dims; dropout_prob = 0.0) dims = normalize_mha_dims(dims) - @assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads" - @assert dims.v % nheads == 0 "v_dim should be divisible by nheads" + dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)")) + dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)")) q_proj = Dense(dims.q_in => dims.qk; bias, init) k_proj = Dense(dims.k_in => dims.qk; bias, init) v_proj = Dense(dims.v_in => dims.v; bias, init) @@ -131,3 +131,41 @@ function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, # [α] = [kv_len, q_len, nheads, batch_size] return x, α end + +function Base.show(io::IO, mha::MultiHeadAttention) + qk, q_in = size(mha.q_proj.weight) + qk, k_in = size(mha.k_proj.weight) + v, v_in = size(mha.v_proj.weight) + out, v = size(mha.out_proj.weight) + # @show q_in, k_in, v_in, qk, v, out + print(io, "MultiHeadAttention(") + if q_in == k_in == v_in == qk == v == out + print(io, q_in) + elseif q_in == k_in == v_in && qk == v + print(io, q_in, " => ", qk, " => ", out) + elseif q_in == k_in == v_in + print(io, q_in, " => (", qk, ", ", v,") => ", out) + else + print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out) + end + print(io, "; nheads=", mha.nheads) + if mha.q_proj.bias !== false + print(io, ", bias=true") + end + if mha.attn_drop.p != 0 + print(io, ", dropout_prob=", mha.attn_drop.p) # can't we rename this? + end + print(io, ")") +end + + +#= + +# Test cases for printing: + +MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1) +MultiHeadAttention(3 => (6, 7) => 8; nheads=1) +MultiHeadAttention(3 => 6 => 8; nheads=1) +MultiHeadAttention(8; bias=true) + +=# diff --git a/src/layers/basic.jl b/src/layers/basic.jl index b7027f5007..018b19b31d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -46,7 +46,7 @@ end @forward Chain.layers Base.getindex, Base.length, Base.first, Base.last, Base.iterate, Base.lastindex, Base.keys, Base.firstindex -@functor Chain +@layer :expand Chain # the + opts-in to container-style pretty-printing (c::Chain)(x) = _applychain(c.layers, x) @@ -165,7 +165,7 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity; Dense(init(out, in), bias, σ) end -@functor Dense +@layer Dense function (a::Dense)(x::AbstractVecOrMat) _size_check(a, x, 1 => size(a.weight, 2)) @@ -251,7 +251,7 @@ end Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act) Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end]) -@functor Scale +@layer Scale function (a::Scale)(x::AbstractArray) σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc @@ -306,7 +306,7 @@ end Maxout(layers...) = Maxout(layers) Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...) -@functor Maxout +@layer :expand Maxout function (mo::Maxout)(input::AbstractArray) # Perhaps surprisingly, pairwise max broadcast is often faster, @@ -353,7 +353,7 @@ struct SkipConnection{T,F} connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b end -@functor SkipConnection +@layer :expand SkipConnection function (skip::SkipConnection)(input) skip.connection(skip.layers(input), input) @@ -423,7 +423,7 @@ struct Bilinear{F,A,B} end end -@functor Bilinear +@layer Bilinear function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity; bias = true, init = glorot_uniform) @@ -522,7 +522,7 @@ function Parallel(connection; kw...) Parallel(connection, layers) end -@functor Parallel +@layer :expand Parallel (m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) (m::Parallel)(xs::Tuple) = m(xs...) @@ -643,7 +643,7 @@ end end applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x) -@functor PairwiseFusion +@layer :expand PairwiseFusion Base.getindex(m::PairwiseFusion, i) = m.layers[i] Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i]) @@ -701,7 +701,7 @@ struct Embedding{W<:AbstractMatrix} weight::W end -@functor Embedding +@layer Embedding Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in)) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ca275d4a16..4e6044dcfb 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -187,7 +187,7 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}; init(filter..., cin÷groups, cout) end -@functor Conv +@layer Conv conv_dims(c::Conv, x::AbstractArray) = DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups) @@ -309,7 +309,7 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = ConvTranspose(weight, bias, σ; stride, pad, dilation, groups) end -@functor ConvTranspose +@layer ConvTranspose function conv_transpose_dims(c::ConvTranspose, x::AbstractArray) # Calculate size of "input", from ∇conv_data()'s perspective... @@ -460,7 +460,7 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden return CrossCor(weight, bias, σ; stride, pad, dilation) end -@functor CrossCor +@layer CrossCor function crosscor(x, w, ddims::DenseConvDims) ddims = DenseConvDims(ddims, F=true) diff --git a/src/layers/macro.jl b/src/layers/macro.jl new file mode 100644 index 0000000000..9e770add87 --- /dev/null +++ b/src/layers/macro.jl @@ -0,0 +1,156 @@ + +""" + @layer Dense + @layer :expand Chain + @layer BatchNorm trainable=(β,γ) + +This macro replaces most uses of `@functor`. Its basic purpose is the same: +When you define a new layer, this tells Flux to explore inside it +to see the parameters it trains, and also to move them to the GPU, change precision, etc. +Like `@functor`, this assumes your struct has the default constructor, to enable re-building. + +The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`. +Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. +* If some fields look like parameters but should not be trained, + then `trainable` lets you specify which fields to include, while the rest are ignored. + +The macro also handles overloads of `show` for pretty printing. +* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`. +* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents. +* To disable all `show` overloads, there is an `:ignore` option too. + +(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.) + +Note that re-running the macro with different options may not overwrite all methods, you will need to restart. + +# Example +```jldoctest +julia> struct Trio; a; b; c end + +julia> tri = Trio(Dense([1.1 2.2], [0.0], tanh), Dense(hcat(3.3), false), Dropout(0.4)) +Trio(Dense(2 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4)) + +julia> Flux.destructure(tri) # parameters are not yet visible to Flux +(Bool[], Restructure(Trio, ..., 0)) + +julia> Flux.@layer :expand Trio + +julia> Flux.destructure(tri) # now gpu, params, train!, etc will see inside too +([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4)) + +julia> tri # and layer is printed like Chain +Trio( + Dense(2 => 1, tanh), # 3 parameters + Dense(1 => 1; bias=false), # 1 parameters + Dropout(0.4), +) # Total: 3 arrays, 4 parameters, 224 bytes. +``` + +""" +macro layer(exs...) + out = quote end + + # These functions are defined in show.jl, and each return an expression overloading Base.show + type, rest... = if exs[1] == QuoteNode(:expand) + push!(out.args, _macro_big_show(esc(exs[2]))) + exs[2:end] + elseif exs[1] == QuoteNode(:ignore) + exs[2:end] + elseif exs[1] isa QuoteNode + error("`@layer` accepts only two options before the layer type, `:expand` and `:ignore` (to control `show`)") + else + push!(out.args, _macro_layer_show(esc(exs[1]))) + exs + end + + # This function exists only for depwarns when you use @functor directly + push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) + + push!(out.args, _macro_functor(esc(type))) + + for j in 1:length(rest) + ex = rest[j] + Meta.isexpr(ex, :(=)) || error("The macro `@layer` expects here `keyword = (fields...,)`, got $ex") + + name = if ex.args[1] == :trainable + :(Optimisers.trainable) + else + error("`@layer` cannot define a method for `$(ex.args[1])` at the moment, sorry.") + # @warn "Trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1 + # esc(ex.args[1]) + end + push!(out.args, _macro_trainable(esc(type), name, ex.args[2])) + end + + out +end + +# Temporary depwarn function, called within `params`, is also called by `show`. + +function _check_new_macro(x::T) where T + Functors.isleaf(x) && return + Base.depwarn("This type should probably now use `Flux.@layer` instead of `@functor`: $T", Symbol("@functor")) +end +_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users +_check_new_macro(::NamedTuple) = nothing +_check_new_macro(::AbstractArray) = nothing +_check_new_macro(::Ref) = nothing + +# @layer's code for Functors & Adapt +# Unlike @functor, _default_functor doesn't need to eval anything + +function _macro_functor(type) + quote + Functors.functor(::Type{T}, x) where {T<:$type} = $_default_functor(T, x) + Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer) + end +end + +function _macro_functor(type, fields) + Meta.isexpr(fields, :tuple) || error("expected a tuple of field names") + symbols = Tuple(map(_noquotenode, fields.args)) + quote + Functors.functor(::Type{T}, x) where {T<:$type} = $_custom_functor(T, x, Val($symbols)) + Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer) + end +end +_macro_functor(type, field::Union{Symbol,QuoteNode}) = _macro_functor(type, :(($field,))) # lets you forget a comma + +function _default_functor(::Type{T}, x) where {T} + if @generated + F = fieldnames(T) + args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F) + C = Base.typename(T).wrapper # constructor + # recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C)) + recon = :(Base.splat($C)) + :((NamedTuple{$F}(($(args...),)), $recon)) + else + # Getting this parameterless type takes about 2μs, every time: + # spl = VERSION > v"1.9-" ? Splat : Base.splat + spl = Base.splat + namedtuple(x), spl(Base.typename(T).wrapper) + end +end + +function namedtuple(x::T) where T + F = fieldnames(T) + NamedTuple{F}(map(sy -> getfield(x, sy), F)) +end + +# @layer's code for Optimisers.trainable, and perhaps anything else, +# with the pattern that keywords mean function names & what fields they pick. + +function _macro_trainable(type, fun, fields) + Meta.isexpr(fields, :tuple) || error("expected a tuple of field names") + symbols = Tuple(map(_noquotenode, fields.args)) + quoted = map(QuoteNode, symbols) + gets = [:(getfield(x, $f)) for f in quoted] + quote + $fun(x::$type) = NamedTuple{$symbols}(($(gets...),)) + end +end +_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma + +_noquotenode(s::Symbol) = s +_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y) +_noquotenode(ex) = error("expected a symbol here, as a field name, but got $ex") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 1c8fbff5a1..c0a86c8796 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -78,8 +78,7 @@ function Dropout(p::Real; dims=:, active::Union{Bool,Nothing} = nothing, rng = d Dropout(p, dims, active, rng) end -@functor Dropout -trainable(a::Dropout) = (;) +@layer Dropout trainable=() (a::Dropout)(x) = dropout(a.rng, x, a.p * _isactive(a, x); dims=a.dims) @@ -131,8 +130,7 @@ function AlphaDropout(p; rng = default_rng(), active::Union{Bool,Nothing} = noth AlphaDropout(p, active, rng) end -@functor AlphaDropout -trainable(a::AlphaDropout) = (;) +@layer AlphaDropout trainable=() function (a::AlphaDropout)(x::AbstractArray{T}) where T _isactive(a, x) || return x @@ -151,6 +149,8 @@ end testmode!(m::AlphaDropout, mode=true) = (m.active = isnothing(_tidy_active(mode)) ? nothing : !mode; m) +Base.show(io::IO, d::AlphaDropout) = print(io, "AlphaDropout(", d.p, ")") + """ LayerNorm(size..., λ=identity; affine=true, eps=1f-5) @@ -199,7 +199,7 @@ end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) -@functor LayerNorm +@layer LayerNorm function (a::LayerNorm)(x::AbstractArray) ChainRulesCore.@ignore_derivatives if a.diag isa Scale @@ -343,8 +343,7 @@ function BatchNorm(chs::Int, λ=identity; active, chs) end -@functor BatchNorm -trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) +@layer BatchNorm trainable=(β,γ) function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N} _size_check(BN, x, N-1 => BN.chs) @@ -437,8 +436,7 @@ function InstanceNorm(chs::Int, λ=identity; active, chs) end -@functor InstanceNorm -trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;) +@layer InstanceNorm trainable=(β,γ) function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N} _size_check(l, x, N-1 => l.chs) @@ -517,8 +515,7 @@ mutable struct GroupNorm{F,V,N} chs::Int # number of channels end -@functor GroupNorm -trainable(gn::GroupNorm) = hasaffine(gn) ? (β = gn.β, γ = gn.γ) : (;) +@layer GroupNorm trainable=(β,γ) function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 375ff43d52..f55ebb1741 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -135,8 +135,7 @@ function (m::Recur)(x) return y end -@functor Recur -trainable(a::Recur) = (; cell = a.cell) +@layer :expand Recur trainable=(cell,) Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") @@ -209,7 +208,7 @@ function (m::RNNCell{F,I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where return h, reshape_cell_output(h, x) end -@functor RNNCell +@layer RNNCell # state0 is trainable, see issue 807 about this. function Base.show(io::IO, l::RNNCell) print(io, "RNNCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)) @@ -318,7 +317,7 @@ function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::AbstractV return (h′, c′), reshape_cell_output(h′, x) end -@functor LSTMCell +@layer LSTMCell Base.show(io::IO, l::LSTMCell) = print(io, "LSTMCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷4, ")") @@ -391,7 +390,7 @@ function (m::GRUCell{I,H,V,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) where { return h′, reshape_cell_output(h′, x) end -@functor GRUCell +@layer GRUCell Base.show(io::IO, l::GRUCell) = print(io, "GRUCell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") @@ -461,7 +460,7 @@ function (m::GRUv3Cell{I,H,V,HH,<:AbstractMatrix{T}})(h, x::AbstractVecOrMat) wh return h′, reshape_cell_output(h′, x) end -@functor GRUv3Cell +@layer GRUv3Cell Base.show(io::IO, l::GRUv3Cell) = print(io, "GRUv3Cell(", size(l.Wi, 2), " => ", size(l.Wi, 1)÷3, ")") diff --git a/src/layers/show.jl b/src/layers/show.jl index 0ae14dd9ee..a03ddf3754 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -1,15 +1,21 @@ +@nospecialize # just for this file, for startup time -for T in [ - :Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL - _big_show(io, x) - elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix - _layer_show(io, x) - else - show(io, x) +# This is called by @layer :expand, on layers which should be treated like Chain, and returns an expression: +function _macro_big_show(ex) + quote + # Entry point: + function Base.show(io::IO, m::MIME"text/plain", x::$ex) + if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL + _big_show(io, x) + elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix + _layer_show(io, x) + else + show(io, x) + end end + + # Don't show Chain(Tuple(...)), always splat that. And ignore Recur's non-trainable state: + Flux._show_children(x::$ex) = _flat_children(trainable(x)) end end @@ -17,6 +23,8 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing) pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")") children = _show_children(obj) if all(_show_leaflike, children) + # This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids, + # but once all layers use @layer, they stop the recursion by defining a method for _big_show. _layer_show(io, obj, indent, name) else println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre) @@ -49,25 +57,32 @@ _show_leaflike(x) = isleaf(x) # mostly follow Functors, except for: # note the covariance of tuple, using <:T causes warning or error _show_leaflike(::Tuple{Vararg{Number}}) = true # e.g. stride of Conv _show_leaflike(::Tuple{Vararg{AbstractArray}}) = true # e.g. parameters of LSTMcell -_show_leaflike(::Scale) = true # appears inside LayerNorm _show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays -_show_children(x) = trainable(x) # except for layers which hide their Tuple: -_show_children(c::Chain) = c.layers -_show_children(m::Maxout) = m.layers -_show_children(p::Parallel) = (p.connection, p.layers...) -_show_children(f::PairwiseFusion) = (f.connection, f.layers...) - -for T in [ - :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag, - :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, - ] - @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) - if !get(io, :compact, false) - _layer_show(io, x) - else - show(io, x) +_show_children(x) = trainable(x) +# This used to have methods for Chain, Maxout, Parallel, PairwiseFusion. Now @layer instead +# writes a method to use this function. It flattens the Tuple within Chain etc. +# (The remaining special cases are for printing of layer names when a NamedTuple, above.) +function _flat_children(x) + alpha = map(f -> getfield(x, f), fieldnames(typeof(x))) + beta = map(y -> y isa Union{Tuple, NamedTuple} ? y : (y,), alpha) + gamma = ((beta...)...,) +end + +# This is called by @layer, on layers which should be treated like Dense, and returns an expression: +function _macro_layer_show(ex) + quote + # Entry point: + function Base.show(io::IO, m::MIME"text/plain", x::$ex) + if !get(io, :compact, false) + _layer_show(io, x) + else + show(io, x) + end end + + # Exit from _big_show recursion: + Flux._big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name) end end @@ -126,6 +141,8 @@ function _nan_show(io::IO, x) end end +@specialize # un-does @nospecialze at the top of this file + _any(f, xs::AbstractArray{<:Number}) = any(f, xs) # _any(f, xs::Union{Tuple,NamedTuple,Zygote.Params}) = any(x -> _any(f, x), xs) _any(f, xs) = any(x -> _any(f, x), xs) diff --git a/test/layers/macro.jl b/test/layers/macro.jl new file mode 100644 index 0000000000..e41d5a2240 --- /dev/null +++ b/test/layers/macro.jl @@ -0,0 +1,47 @@ +using Flux, Functors, Optimisers + +module MacroTest + using Flux: @layer + + struct Duo{T,S}; x::T; y::S; end + @layer :expand Duo + + struct Trio; a; b; c end + # @layer Trio trainable=(a,b) test=(c) # should be (c,) but it lets you forget + @layer Trio trainable=(a,b) # defining a method for test is made an error, for now + + struct TwoThirds; a; b; c; end +end + +@testset "@layer macro" begin + @test !isdefined(MacroTest, :Flux) # That's why the module, to check scope + + m2 = MacroTest.Duo(Dense(2=>2), Chain(Flux.Scale(2), Dropout(0.2))) + + @test Functors.children(m2) isa NamedTuple{(:x, :y)} + @test length(Optimisers.destructure(m2)[1]) == 10 + + m3 = MacroTest.Trio([1.0], [2.0], [3.0]) + + @test Functors.children(m3) isa NamedTuple{(:a, :b, :c)} + @test fmap(zero, m3) isa MacroTest.Trio + + @test Optimisers.trainable(m3) isa NamedTuple{(:a, :b)} + @test Optimisers.destructure(m3)[1] == [1, 2] + + # @test MacroTest.test(m3) == (c = [3.0],) # removed, for now + + m23 = MacroTest.TwoThirds([1 2], [3 4], [5 6]) + # Check that we can use the macro with a qualified type name, outside the defining module: + Flux.@layer :expand MacroTest.TwoThirds trainable=(:a) # documented as (a,c) but allow quotes + + m23re = Functors.functor(m23)[2]((a = [10 20], b = [3 4], c = [50 60])) + @test m23re isa MacroTest.TwoThirds + @test Flux.namedtuple(m23re) == (a = [10 20], b = [3 4], c = [50 60]) + + @test Optimisers.trainable(m23) == (a = [1 2],) + + @test_throws LoadError @eval Flux.@layer :zzz MacroTest.TwoThirds + @test_throws LoadError @eval Flux.@layer MacroTest.TwoThirds chidren=(a, b) +end + diff --git a/test/runtests.jl b/test/runtests.jl index 94e0c466e6..8dca6becdd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -48,6 +48,7 @@ Random.seed!(0) include("layers/conv.jl") include("layers/upsample.jl") include("layers/show.jl") + include("layers/macro.jl") end @testset "outputsize" begin diff --git a/test/utils.jl b/test/utils.jl index 620a4d40b4..e175eb1f5b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -616,7 +616,7 @@ end a::A b::A end - Flux.@functor Model + Flux.@layer Model (m::Model)(x) = m.a(x) .+ m.b(x) d = Dense(1, 1) From 56180591bb0ed5f9d68a6926e8946b24df582bc8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 5 Mar 2024 14:39:03 -0600 Subject: [PATCH 09/11] Small upgrades to training docs (#2331) --- docs/src/training/reference.md | 12 +++++++----- docs/src/training/training.md | 8 +++++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index 77dc0f81d0..1bf0cfd1bf 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -10,10 +10,6 @@ Because of this: * Flux defines its own version of `setup` which checks this assumption. (Using instead `Optimisers.setup` will also work, they return the same thing.) -The new implementation of rules such as Adam in the Optimisers is quite different from the old one in `Flux.Optimise`. In Flux 0.14, `Flux.Adam()` returns the old one, with supertype `Flux.Optimise.AbstractOptimiser`, but `setup` will silently translate it to its new counterpart. -The available rules are listed the [optimisation rules](@ref man-optimisers) page here; -see the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the new rules work. - ```@docs Flux.Train.setup Flux.Train.train!(loss, model, data, state; cb) @@ -47,10 +43,16 @@ Flux 0.13 and 0.14 are the transitional versions which support both; Flux 0.15 w The blue-green boxes in the [training section](@ref man-training) describe the changes needed to upgrade old code. +The available rules are listed the [optimisation rules](@ref man-optimisers) page here. + +!!! compat "Old & new rules" + The new implementation of rules such as Adam in the Optimisers is quite different from the old one in `Flux.Optimise`. In Flux 0.14, `Flux.Adam()` still returns the old one, with supertype `Flux.Optimise.AbstractOptimiser`, but `setup` will silently translate it to its new counterpart. + For full details on the interface for implicit-style optimisers, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). +See the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the new rules work. !!! compat "Flux ≤ 0.12" - Earlier versions of Flux exported `params`, thus allowing unqualified `params(model)` + Much earlier versions of Flux exported `params`, thus allowing unqualified `params(model)` after `using Flux`. This conflicted with too many other packages, and was removed in Flux 0.13. If you get an error `UndefVarError: params not defined`, this probably means that you are following code for Flux 0.12 or earlier on a more recent version. diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 623b4788fc..6dd80897b5 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -225,6 +225,9 @@ callback API. Here is an example, in which it may be helpful to note: returns the value of the function, for logging or diagnostic use. * Logging or printing is best done outside of the `gradient` call, as there is no need to differentiate these commands. +* To use `result` for logging purposes, you could change the `do` block to end with + `return my_loss(result, label), result`, i.e. make the function passed to `withgradient` + return a tuple. The first element is always the loss. * Julia's `break` and `continue` keywords let you exit from parts of the loop. ```julia @@ -319,9 +322,12 @@ The first, [`WeightDecay`](@ref Flux.WeightDecay) adds `0.42` times original par matching the gradient of the penalty above (with the same, unrealistically large, constant). After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. +The same trick works for *L₁ regularisation* (also called Lasso), where the penalty is +`pen_l1(x::AbstractArray) = sum(abs, x)` instead. This is implemented by `SignDecay(0.42)`. + The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref Flux.Optimise.ClipValue) or [`ClipNorm`](@ref Flux.Optimise.ClipNorm). -Besides L2 / weight decay, another common and quite different kind of regularisation is +Besides L1 / L2 / weight decay, another common and quite different kind of regularisation is provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the previous layer during training. It should switch automatically, but see [`trainmode!`](@ref Flux.trainmode!) / [`testmode!`](@ref Flux.testmode!) to manually enable or disable this layer. From 4ce7033ecdd81ece96da80a1671bca98c33bee8a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 6 Mar 2024 18:46:54 -0600 Subject: [PATCH 10/11] doc changes re at-functor and at-layer (#2390) * doc changes re at-functor and at-layer * fix a doctest * more fixes * public at-layer * add a sentence comparing to freeze/thaw * Apply suggestions from code review Co-authored-by: Kyle Daruwalla * two fixes re SignDecay --------- Co-authored-by: Kyle Daruwalla --- NEWS.md | 2 +- Project.toml | 2 +- docs/src/models/advanced.md | 2 +- docs/src/models/basics.md | 2 +- docs/src/models/functors.md | 6 +++++- docs/src/models/layers.md | 2 +- docs/src/saving.md | 4 ++-- docs/src/training/optimisers.md | 1 + docs/src/training/training.md | 3 +++ src/Flux.jl | 4 ++-- src/functor.jl | 8 ++++---- src/layers/macro.jl | 6 +++--- 12 files changed, 25 insertions(+), 17 deletions(-) diff --git a/NEWS.md b/NEWS.md index 68d36fdc34..87333f8717 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,7 +7,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl This also adds `show` methods for pretty printing. ## v0.14.12 -* New `SignDecay` optimiser, like `` WeightNorm` but for L1 norm. +* New `SignDecay` optimiser, like `WeightDecay` but for L1 norm. ## v0.14.0 (July 2023) * Flux now requires julia v1.9 or later. diff --git a/Project.toml b/Project.toml index 5ec702c6b1..bc31cd5d3f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.12" +version = "0.14.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 46931b6547..255a7d68e3 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -102,7 +102,7 @@ Join(combine, paths...) = Join(combine, paths) ``` Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field. -The next step is to use [`Functors.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path. +The next step is to use [`Flux.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `Flux.setup` on a `Join` maps over the underlying trainable arrays on each path. ```julia Flux.@layer Join ``` diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index fb0f2d5488..cf83764349 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -255,7 +255,7 @@ m(5) # => 26 ## Layer Helpers -There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@functor`](@ref Functors.@functor) macro: +There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@layer`](@ref Flux.@layer) macro: ```julia Flux.@layer Affine diff --git a/docs/src/models/functors.md b/docs/src/models/functors.md index ab0883c95e..861528cda9 100644 --- a/docs/src/models/functors.md +++ b/docs/src/models/functors.md @@ -2,7 +2,11 @@ Flux models are deeply nested structures, and [Functors.jl](https://github.com/FluxML/Functors.jl) provides tools needed to explore such objects, apply functions to the parameters they contain, and re-build them. -New layers should be annotated using the `Functors.@functor` macro. This will enable [`params`](@ref Flux.params) to see the parameters inside, and [`gpu`](@ref) to move them to the GPU. +!!! compat "Flux ≤ 0.14" + All layers were previously defined with the `Functors.@functor` macro. + This still works, but it is recommended that you use the new [`Flux.@layer`](@ref Flux.@layer) macro instead. + Both allow [`Flux.setup`](@ref Flux.setup) to see the parameters inside, and [`gpu`](@ref) to move them to the GPU, but [`Flux.@layer`](@ref Flux.@layer) also overloads printing, + and offers a way to define `trainable` at the same time. `Functors.jl` has its own [notes on basic usage](https://fluxml.ai/Functors.jl/stable/#Basic-Usage-and-Implementation) for more details. Additionally, the [Advanced Model Building and Customisation](@ref man-advanced) page covers the use cases of `Functors` in greater details. diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 31db9cd204..177a3eca94 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -12,7 +12,7 @@ The `Dense` exemplifies several features: * The bias vector is always initialised [`Flux.zeros32`](@ref). The keyword `bias=false` will turn this off, i.e. keeping the bias permanently zero. -* It is annotated with [`@functor`](@ref Functors.@functor), which means that [`params`](@ref Flux.params) will see the contents, and [`gpu`](@ref Flux.gpu) will move their arrays to the GPU. +* It is annotated with [`@layer`](@ref Flux.@layer), which means that [`Flux.setup`](@ref Flux.setup) will see the contents, and [`gpu`](@ref Flux.gpu) will move their arrays to the GPU. By contrast, `Chain` itself contains no parameters, but connects other layers together. The section on [dataflow layers](@ref man-dataflow-layers) introduces others like this. diff --git a/docs/src/saving.md b/docs/src/saving.md index 16f944ef08..37c0470704 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -16,12 +16,12 @@ julia> struct MyModel net end -julia> Flux.@functor MyModel +julia> Flux.@layer MyModel julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2))); julia> model = MyModel() -MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) +MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) # 67 parameters julia> model_state = Flux.state(model); diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 04e5c03ef3..263ebc32db 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -112,6 +112,7 @@ Similar to optimisers, Flux also defines some simple decays that can be used in ExpDecay InvDecay WeightDecay +SignDecay ``` ## Gradient Clipping diff --git a/docs/src/training/training.md b/docs/src/training/training.md index 6dd80897b5..0370c86a3d 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -384,6 +384,9 @@ Flux.thaw!(opt_state) The earlier "implicit" equivalent was to pass to `gradient` an object referencing only part of the model, such as `Flux.params(bimodel.layers.enc)`. +While `adjust!` and `freeze!`/`thaw!` make temporary modifications to the optimiser state, +permanently removing some fields of a new layer type from training is usually done +when defining the layer, by calling for example [`@layer`](@ref Flux.@layer)` NewLayer trainable=(weight,)`. ## Implicit or Explicit? diff --git a/src/Flux.jl b/src/Flux.jl index 5675f7c10f..a8720b7905 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -34,11 +34,11 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion @compat(public, ( # mark unexported symbols as API, on Julia 1.11 # modules - Losses, + Losses, Train, # layers Bilinear, Scale, dropout, # utils - outputsize, state, + outputsize, state, create_bias, @layer, )) include("optimise/Optimise.jl") diff --git a/src/functor.jl b/src/functor.jl index 34fe52db35..f09ac6ae93 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -286,7 +286,7 @@ _paramtype(::Type{T}, x::AbstractArray{<:Complex{<:AbstractFloat}}) where {T<:Ab f32(m) Converts the `eltype` of model's *floating point* parameters to `Float32` (which is Flux's default). -Recurses into structs marked with [`@functor`](@ref). +Recurses into structs marked with [`@layer`](@ref Flux.@layer). See also [`f64`](@ref) and [`f16`](@ref). """ @@ -296,7 +296,7 @@ f32(m) = _paramtype(Float32, m) f64(m) Converts the `eltype` of model's *floating point* parameters to `Float64`. -Recurses into structs marked with [`@functor`](@ref). +Recurses into structs marked with [`@layer`](@ref Flux.@layer). See also [`f32`](@ref) and [`f16`](@ref). """ @@ -306,7 +306,7 @@ f64(m) = _paramtype(Float64, m) f16(m) Converts the `eltype` of model's *floating point* parameters to `Float16`. -Recurses into structs marked with [`@functor`](@ref). +Recurses into structs marked with [`@layer`](@ref Flux.@layer). Support for `Float16` is limited on many CPUs. Julia may convert to `Float32` for each operation, which is slow. @@ -330,7 +330,7 @@ Chain( """ f16(m) = _paramtype(Float16, m) -# Functors for certain Julia data structures +# Functors for certain Julia data structures -- PIRACY, should move to Functors.jl @functor Cholesky trainable(c::Cholesky) = () diff --git a/src/layers/macro.jl b/src/layers/macro.jl index 9e770add87..2fb6db0faf 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -7,12 +7,12 @@ This macro replaces most uses of `@functor`. Its basic purpose is the same: When you define a new layer, this tells Flux to explore inside it to see the parameters it trains, and also to move them to the GPU, change precision, etc. + Like `@functor`, this assumes your struct has the default constructor, to enable re-building. +If you define an inner constructor (i.e. a function within the `struct` block) things may break. The keyword `trainable` allows you to limit this exploration, instead of visiting all `fieldnames(T)`. Note that it is never necessary to tell Flux to ignore non-array objects such as functions or sizes. -* If some fields look like parameters but should not be trained, - then `trainable` lets you specify which fields to include, while the rest are ignored. The macro also handles overloads of `show` for pretty printing. * By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`. @@ -21,7 +21,7 @@ The macro also handles overloads of `show` for pretty printing. (You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.) -Note that re-running the macro with different options may not overwrite all methods, you will need to restart. +Note that re-running the macro with different options may not remove all methods, you will need to restart. # Example ```jldoctest From 8f3fe209e692fa880f1e0e124f33dd6dad05e965 Mon Sep 17 00:00:00 2001 From: Zhanibek Date: Thu, 29 Feb 2024 19:39:33 +0900 Subject: [PATCH 11/11] docs: improve freezing docs --- docs/src/models/freezing-params.md | 127 +++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 docs/src/models/freezing-params.md diff --git a/docs/src/models/freezing-params.md b/docs/src/models/freezing-params.md new file mode 100644 index 0000000000..b1f5f26369 --- /dev/null +++ b/docs/src/models/freezing-params.md @@ -0,0 +1,127 @@ +# Freezing model weights +Flux provides several ways of freezing, excluding from backprop entirely and +marking custom struct fields not to be moved to the GPU +([Functors.@functor](@ref)) hence excluded from being trained. The following +subsections should make it clear which one suits your needs the best. + +## On-the-fly freezing per model instance +Perhaps you'd like to freeze some of the weights of the model (even at +mid-training), and Flux accomplishes this through [`freeze!`](@ref Flux.freeze!) and `thaw!`. + +```julia +m = Chain( + Dense(784 => 64, relu), # freeze this one + Dense(64 => 64, relu), + Dense(32 => 10) + ) +opt_state = Flux.setup(Momentum(), m); + +# Freeze some layers right away +Flux.freeze!(opt_state.layers[1]) + +for data in train_set + input, label = data + + # Some params could be frozen during the training: + Flux.freeze!(opt_state.layers[2]) + + grads = Flux.gradient(m) do m + result = m(input) + loss(result, label) + end + Flux.update!(opt_state, m, grads[1]) + + # Optionally unfreeze the params later + Flux.thaw!(opt_state.layers[1]) +end +``` + +## Static freezing per model definition +Sometimes some parts of the model ([`Flux.@functor`](@ref)) needn't to be trained at all but these params +still need to reside on the GPU (these params are still needed in the forward +and/or backward pass). +```julia +struct MaskedLayer{T} + chain::Chain + mask::T +end +Flux.@functor MaskedLayer + +# mark the trainable part +Flux.trainable(a::MaskedLayer)=(;a.chain) +# a.mask will not be updated in the training loop + +function (m::MaskedLayer)(x) + return m.chain(x) + x + m.mask +end + +model = MaskedLayer(...) # this model will not have the `mask` field trained +``` +Note how this method permanently sets some model fields to be excluded from +training without on-the-fly changing. + +## Excluding from model definition +Sometimes some parameters are just "not trainable" but they shouldn't even +transfer to the GPU. All scalar fields are like this by default, so things like +learning rate multipliers are not trainable nor transferred to the GPU by +default. +```julia +struct CustomLayer{T, F} + chain::Chain + activation_results::Vector{F} + lr_multiplier::Float32 +end +Flux.@functor CustomLayer (chain, ) # Explicitly leaving out `activation_results` + +function (m::CustomLayer)(x) + result = m.chain(x) + x + + # `activation_results` are not part of the GPU loop, hence we could do + # things like `push!` + push!(m.activation_results, mean(result)) + return result +end +``` +See more about this in [`Flux.@functor`](@ref) and + + +## Freezing Layer Parameters (deprecated) + +When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`. + +!!! compat "Flux ≤ 0.14" + The mechanism described here is for Flux's old "implicit" training style. + When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`. + +Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain +this using the slicing features `Chain` provides: + +```julia +m = Chain( + Dense(784 => 64, relu), + Dense(64 => 64, relu), + Dense(32 => 10) + ); + +ps = Flux.params(m[3:end]) +``` + +The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it. + +During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed. + +`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this: + +```julia +Flux.params(m[1], m[3:end]) +``` + +Sometimes, a more fine-tuned control is needed. +We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`, +by simply deleting it from `ps`: + +```julia +ps = Flux.params(m) +delete!(ps, m[2].bias) +``` +