From 1012cbf26f01ea4ff8082a3277ccc5bf24b801d4 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 Apr 2024 15:44:06 +0200 Subject: [PATCH 1/6] update docs --- docs/src/destructure.md | 19 +++-- docs/src/models/advanced.md | 76 +++++------------- docs/src/models/basics.md | 76 ++++++------------ docs/src/models/quickstart.md | 32 +++----- docs/src/models/recurrence.md | 7 +- docs/src/saving.md | 11 +-- docs/src/training/optimisers.md | 77 ++++++++----------- docs/src/training/reference.md | 60 +-------------- docs/src/training/training.md | 61 ++++----------- docs/src/training/zygote.md | 16 ---- .../2020-09-15-deep-learning-flux.md | 77 +++++++------------ docs/src/tutorials/2021-01-26-mlp.md | 4 +- docs/src/tutorials/2021-10-08-dcgan-mnist.md | 2 +- docs/src/tutorials/2021-10-14-vanilla-gan.md | 14 ++-- 14 files changed, 157 insertions(+), 375 deletions(-) diff --git a/docs/src/destructure.md b/docs/src/destructure.md index 1cdcad5ce7..16089380c4 100644 --- a/docs/src/destructure.md +++ b/docs/src/destructure.md @@ -49,20 +49,27 @@ julia> Flux.destructure(grad) # acts on non-models, too (Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5)) ``` -!!! compat "Flux ≤ 0.12" - Old versions of Flux had an entirely different implementation of `destructure`, which - had many bugs (and almost no tests). Many comments online still refer to that now-deleted - function, or to memories of it. +In order to collect all parameters of a model into a list instead, you can use the `trainables` function: +```julia +julia> Flux.trainables(model) +5-element Vector{AbstractArray}: + [0.863101 1.2454957] + [0.0] + [1.290355429422727;;] + [0.0] +``` +Any mutation of the elements of the resulting list will affect the model's parameters. ### All Parameters -The function `destructure` now lives in [`Optimisers.jl`](https://github.com/FluxML/Optimisers.jl). -(Be warned this package is unrelated to the `Flux.Optimisers` sub-module! The confusion is temporary.) +The functions `destructure` and `trainables` live in [`Optimisers.jl`](https://github.com/FluxML/Optimisers.jl). + ```@docs Optimisers.destructure Optimisers.trainable +Optimisers.trainables Optimisers.isnumeric ``` diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index cf2d1fedb3..9569944b2e 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -26,7 +26,7 @@ Notice that we parameterized the type of the `chain` field. This is necessary fo You can then use the model like: ```julia -chain = Chain(Dense(10, 10)) +chain = Chain(Dense(10 => 10)) model = CustomModel(chain) model(rand(10)) ``` @@ -40,33 +40,37 @@ Taking reference from our example `Affine` layer from the [basics](@ref man-basi By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, the way to mark some fields of our layer as trainable is through overloading the `trainable` function: ```julia-repl -julia> @layer Affine +julia> struct Affine + W + b + end + +julia> Affine(in::Int, out::Int) = Affine(randn(out, in), randn(out)); + +julia> (m::Affine)(x) = m.W * x .+ m.b; + +julia> Flux.@layer Affine julia> a = Affine(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]) Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]) -julia> Flux.params(a) # default behavior -Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]]) +julia> Flux.trainable(a) # default behavior +(W = Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], b = Float32[7.0, 8.0, 9.0]) julia> Flux.trainable(a::Affine) = (; W = a.W) # returns a NamedTuple using the field's name -julia> Flux.params(a) -Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0]]) +julia> Flux.trainable(a) +(W = Float32[1.0 2.0; 3.0 4.0; 5.0 6.0],) ``` -Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`, and only these fields will be seen by `Flux.setup` and `Flux.update!` for training. But all fields wil be seen by `gpu` and similar functions, for example: +Only the fields returned by `trainable` will be seen by `Flux.setup` and `Flux.update!` for training. But all fields wil be seen by `gpu` and similar functions, for example: ```julia-repl julia> a |> f16 Affine(Float16[1.0 2.0; 3.0 4.0; 5.0 6.0], Float16[7.0, 8.0, 9.0]) ``` -Note that there is no need to overload `trainable` to hide fields which do not contain trainable parameters. (For example, activation functions, or Boolean flags.) These are always ignored by `params` and by training: - -```julia-repl -julia> Flux.params(Affine(true, [10, 11, 12.0])) -Params([]) -``` +Note that there is no need to overload `trainable` to hide fields which do not contain numerical array (for example, activation functions, or Boolean flags). These are always ignored by training. The exact same method of `trainable` can also be defined using the macro, for convenience: @@ -76,52 +80,14 @@ Flux.@layer Affine trainable=(W,) There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument. - -## Freezing Layer Parameters - -When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`. - -!!! compat "Flux ≤ 0.14" - The mechanism described here is for Flux's old "implicit" training style. - When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`. - -Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain -this using the slicing features `Chain` provides: - -```julia -m = Chain( - Dense(784 => 64, relu), - Dense(64 => 64, relu), - Dense(32 => 10) - ); - -ps = Flux.params(m[3:end]) -``` - -The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it. - -During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed. - -`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this: - -```julia -Flux.params(m[1], m[3:end]) -``` - -Sometimes, a more fine-tuned control is needed. -We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`, -by simply deleting it from `ps`: - -```julia -ps = Flux.params(m) -delete!(ps, m[2].bias) -``` - ## Custom multiple input or output layer Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf). -Naively, we could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. But that would mean a new struct any time the operations along each path changes. Instead, this guide will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path. +We could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. That would mean a new struct for each different block, +e.g. one would have a `TransformerBlock` struct for a transformer block, and a `ResNetBlock` struct for a ResNet block, each block being composed by smaller sub-blocks. This is often the simplest and cleanest way to implement complex models. + +This guide instead will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path. ### Multiple inputs: a custom `Join` layer diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index cf83764349..4334f47d33 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -74,50 +74,24 @@ julia> Flux.withgradient(g, nt) (val = 1, grad = ((a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing),)) ``` -!!! note "Implicit gradients" - Flux used to handle many parameters in a different way, using the [`params`](@ref Flux.params) function. - This uses a method of `gradient` which takes a zero-argument function, and returns a dictionary - through which the resulting gradients can be looked up: - - ```jldoctest basics - julia> x = [2, 1]; - - julia> y = [2, 0]; - - julia> gs = gradient(Flux.params(x, y)) do - f(x, y) - end - Grads(...) - - julia> gs[x] - 2-element Vector{Float64}: - 0.0 - 2.0 - - julia> gs[y] - 2-element Vector{Float64}: - -0.0 - -2.0 - ``` - - ## Building Simple Models Consider a simple linear regression, which tries to predict an output array `y` from an input `x`. ```julia -W = rand(2, 5) -b = rand(2) -predict(x) = W*x .+ b +predict(W, b, x) = W*x .+ b -function loss(x, y) - ŷ = predict(x) +function loss(W, b, x, y) + ŷ = predict(W, b, x) sum((y .- ŷ).^2) end x, y = rand(5), rand(2) # Dummy data -loss(x, y) # ~ 3 +W = rand(2, 5) +b = rand(2) + +loss(W, b, x, y) # ~ 3 ``` To improve the prediction we can take the gradients of the loss with respect to `W` and `b` and perform gradient descent. @@ -125,17 +99,15 @@ To improve the prediction we can take the gradients of the loss with respect to ```julia using Flux -gs = gradient(() -> loss(x, y), Flux.params(W, b)) +dW, db = gradient((W, b) -> loss(W, b, x, y), W, b) ``` Now that we have gradients, we can pull them out and update `W` to train the model. ```julia -W̄ = gs[W] +W .-= 0.1 .* dW -W .-= 0.1 .* W̄ - -loss(x, y) # ~ 2.5 +loss(W, b, x, y) # ~ 2.5 ``` The loss has decreased a little, meaning that our prediction `x` is closer to the target `y`. If we have some data we can already try [training the model](../training/training.md). @@ -144,7 +116,7 @@ All deep learning in Flux, however complex, is a simple generalisation of this e ## Building Layers -It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as: +It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) in between them. We could write this as: ```julia using Flux @@ -157,7 +129,7 @@ W2 = rand(2, 3) b2 = rand(2) layer2(x) = W2 * x .+ b2 -model(x) = layer2(σ.(layer1(x))) +model(x) = layer2(sigmoid.(layer1(x))) model(rand(5)) # => 2-element vector ``` @@ -174,7 +146,7 @@ end linear1 = linear(5, 3) # we can access linear1.W etc linear2 = linear(3, 2) -model(x) = linear2(σ.(linear1(x))) +model(x) = linear2(sigmoid.(linear1(x))) model(rand(5)) # => 2-element vector ``` @@ -188,7 +160,7 @@ struct Affine end Affine(in::Integer, out::Integer) = - Affine(randn(out, in), randn(out)) + Affine(randn(out, in), zeros(out)) # Overload call, so the object can be used as a function (m::Affine)(x) = m.W * x .+ m.b @@ -198,16 +170,16 @@ a = Affine(10, 5) a(rand(10)) # => 5-element vector ``` -Congratulations! You just built the `Dense` layer that comes with Flux. Flux has many interesting layers available, but they're all things you could have built yourself very easily. +Congratulations! You just built the [`Dense`](@ref) layer that comes with Flux. Flux has many interesting layers available, but they're all things you could have built yourself very easily. -(There is one small difference with `Dense` – for convenience it also takes an activation function, like `Dense(10 => 5, σ)`.) +(There is one small difference with `Dense` – for convenience it also takes an activation function, like `Dense(10 => 5, sigmoid)`.) ## Stacking It Up It's pretty common to write models that look something like: ```julia -layer1 = Dense(10 => 5, σ) +layer1 = Dense(10 => 5, relu) # ... model(x) = layer3(layer2(layer1(x))) ``` @@ -217,7 +189,7 @@ For long chains, it might be a bit more intuitive to have a list of layers, like ```julia using Flux -layers = [Dense(10 => 5, σ), Dense(5 => 2), softmax] +layers = [Dense(10 => 5, relu), Dense(5 => 2), softmax] model(x) = foldl((x, m) -> m(x), layers, init = x) @@ -228,7 +200,7 @@ Handily, this is also provided for in Flux: ```julia model2 = Chain( - Dense(10 => 5, σ), + Dense(10 => 5, relu), Dense(5 => 2), softmax) @@ -255,7 +227,7 @@ m(5) # => 26 ## Layer Helpers -There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@layer`](@ref Flux.@layer) macro: +There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref Flux.train!) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@layer`](@ref Flux.@layer) macro: ```julia Flux.@layer Affine @@ -263,14 +235,14 @@ Flux.@layer Affine Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the `Affine` layer as follows, using the helper function [`create_bias`](@ref Flux.create_bias): -``` -function Affine((in, out)::Pair; bias=true, init=Flux.randn32) +```julia +function Affine((in, out)::Pair; bias=true, init=glorot_uniform) W = init(out, in) b = Flux.create_bias(W, bias, out) - Affine(W, b) + return Affine(W, b) end -Affine(3 => 1, bias=false, init=ones) |> gpu +Affine(3 => 1, bias=false) |> gpu ``` ```@docs diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index dfef1f0c04..e7e379f17f 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -16,11 +16,11 @@ truth = [xor(col[1]>0.5, col[2]>0.5) for col in eachcol(noisy)] # 1000-element model = Chain( Dense(2 => 3, tanh), # activation function inside layer BatchNorm(3), - Dense(3 => 2), - softmax) |> gpu # move model to GPU, if available + Dense(3 => 2)) |> gpu # move model to GPU, if available # The model encapsulates parameters, randomly initialised. Its initial output is: out1 = model(noisy |> gpu) |> cpu # 2×1000 Matrix{Float32} +probs1 = softmax(out1) # normalise to get probabilities # To train the model, we use batches of 64 samples, and one-hot encoding: target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix @@ -36,7 +36,7 @@ losses = [] loss, grads = Flux.withgradient(model) do m # Evaluate model and loss inside gradient context: y_hat = m(x) - Flux.crossentropy(y_hat, y) + Flux.logitcrossentropy(y_hat, y) end Flux.update!(optim, model, grads[1]) push!(losses, loss) # logging, outside gradient context @@ -45,8 +45,8 @@ end optim # parameters, momenta and output have all changed out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false) - -mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far! +probs2 = softmax(out2) # normalise to get probabilities +mean((probs2[1,:] .> 0.5) .== truth) # accuracy 94% so far! ``` ![](../assets/quickstart/oneminute.png) @@ -55,8 +55,8 @@ mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far! using Plots # to draw the above figure p_true = scatter(noisy[1,:], noisy[2,:], zcolor=truth, title="True classification", legend=false) -p_raw = scatter(noisy[1,:], noisy[2,:], zcolor=out1[1,:], title="Untrained network", label="", clims=(0,1)) -p_done = scatter(noisy[1,:], noisy[2,:], zcolor=out2[1,:], title="Trained network", legend=false) +p_raw = scatter(noisy[1,:], noisy[2,:], zcolor=probs1[1,:], title="Untrained network", label="", clims=(0,1)) +p_done = scatter(noisy[1,:], noisy[2,:], zcolor=probs2[1,:], title="Trained network", legend=false) plot(p_true, p_raw, p_done, layout=(1,3), size=(1000,330)) ``` @@ -87,7 +87,7 @@ Some things to notice in this example are: * The `model` can be called like a function, `y = model(x)`. Each layer like [`Dense`](@ref Flux.Dense) is an ordinary `struct`, which encapsulates some arrays of parameters (and possibly other state, as for [`BatchNorm`](@ref Flux.BatchNorm)). -* But the model does not contain the loss function, nor the optimisation rule. The momenta needed by [`Adam`](@ref Flux.Adam) are stored in the object returned by [setup](@ref Flux.Train.setup). And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function. +* But the model does not contain the loss function, nor the optimisation rule. The momenta needed by [`Adam`](@ref Flux.Adam) are stored in the object returned by [setup](@ref Flux.Train.setup). And [`Flux.logitcrossentropy`](@ref Flux.Losses.logitcrossentropy) is an ordinary function that combines the [`softmax`](@ref Flux.softmax) and [`crossentropy`](@ref Flux.crossentropy) functions. * The `do` block creates an anonymous function, as the first argument of `gradient`. Anything executed within this is differentiated. @@ -97,21 +97,7 @@ Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux. for epoch in 1:1_000 Flux.train!(model, loader, optim) do m, x, y y_hat = m(x) - Flux.crossentropy(y_hat, y) + Flux.logitcrossentropy(y_hat, y) end end ``` - -!!! compat "Implicit-style training, Flux ≤ 0.14" - Until recently Flux's training worked a bit differently. - Any code which looks like - ``` - gradient(() -> loss(model, x, y), Flux.params(model)) - ``` - (gradient of a zero-argument function) or - ``` - train!((x,y) -> loss(model, x, y), Flux.params(model), loader, opt) - ``` - (with `Flux.params`) is in the old "implicit" style. - This still works on Flux 0.14, but will be removed from Flux 0.15. - See the [training section](@ref man-training) for more details. diff --git a/docs/src/models/recurrence.md b/docs/src/models/recurrence.md index dab24edff6..87cd944f4f 100644 --- a/docs/src/models/recurrence.md +++ b/docs/src/models/recurrence.md @@ -154,7 +154,7 @@ In such a model, only the last two outputs are used to compute the loss, hence t Alternatively, if one wants to perform some warmup of the sequence, it could be performed once, followed with a regular training where all the steps of the sequence would be considered for the gradient update: ```julia -function loss(x, y) +function loss(m, x, y) sum(mse(m(xi), yi) for (xi, yi) in zip(x, y)) end @@ -172,9 +172,8 @@ data = zip(X,Y) Flux.reset!(m) [m(x) for x in seq_init] -ps = Flux.params(m) -opt= Adam(1e-3) -Flux.train!(loss, ps, data, opt) +opt = Flux.setup(Adam(1e-3), m) +Flux.train!(loss, m, data, opt) ``` In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss. diff --git a/docs/src/saving.md b/docs/src/saving.md index 066795bfc5..0b1e4fc91b 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -18,7 +18,7 @@ julia> struct MyModel julia> Flux.@layer MyModel -julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2))); +julia> MyModel() = MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))); julia> model = MyModel() MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) # 67 parameters @@ -113,7 +113,7 @@ Save a model: ```jldoctest saving julia> using Flux -julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2)); +julia> model = Chain(Dense(10 => 5, NNlib.relu), Dense(5 => 2)); julia> using BSON: @save @@ -138,10 +138,3 @@ Chain( and across Flux versions if some of the Flux layers' internals are changed. It is therefore not recommended for long term storage, use [`Flux.state`](@ref) instead. -!!! warning - - Previous versions of Flux suggested saving only the model weights using - `@save "mymodel.bson" params(model)`. - This is no longer recommended and even strongly discouraged. - Saving models this way will only store the trainable parameters which - will result in incorrect behavior for layers like `BatchNorm`. diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index fc4e38eebe..bc6dc0628f 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -4,76 +4,63 @@ CurrentModule = Flux # [Optimisation Rules](@id man-optimisers) -Flux builds in many optimisation rules for use with [`train!`](@ref Flux.Optimise.train!) and +Any optimization rule from Optimisers.jl can be used with [`train!`](@ref) and other training functions. -The mechanism by which these work is gradually being replaced as part of the change -from "implicit" dictionary-based to "explicit" tree-like structures. -At present, the same struct (such as `Adam`) can be used with either form, -and will be automatically translated. - For full details of how the new interface works, see the [Optimisers.jl documentation](https://fluxml.ai/Optimisers.jl/dev/). -For full details on how the old "implicit" interface worked, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/optimisers/#Optimiser-Interface). - -## Optimiser Reference +## Optimisers Reference All optimisers return an object that, when passed to `train!`, will update the parameters passed to it. ```@docs -Descent -Momentum -Nesterov -RMSProp -Adam -RAdam -AdaMax -AdaGrad -AdaDelta -AMSGrad -NAdam -AdamW -OAdam -AdaBelief +Optimisers.Descent +Optimisers.Momentum +Optimisers.Nesterov +Optimisers.RMSProp +Optimisers.Adam +Optimisers.RAdam +Optimisers.AdaMax +Optimisers.AdaGrad +Optimisers.AdaDelta +Optimisers.AMSGrad +Optimisers.NAdam +Optimisers.AdamW +Optimisers.OAdam +Optimisers.AdaBelief ``` ## Composing Optimisers -Flux defines a special kind of optimiser simply called `Optimiser` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient -that will be fed into the next, and the resultant update will be applied to the parameter as usual. A classic use case is where adding decays is desirable. Flux defines some basic decays including `ExpDecay`, `InvDecay` etc. +Flux (through Optimisers.jl) defines a special kind of optimiser called `OptimiserChain` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient +that will be fed into the next, and the resultant update will be applied to the parameter as usual. A classic use case is where adding decays is desirable. Optimisers.jl defines the basic decay corresponding to an $L_2$ regularization in the loss as `WeighDecay`. ```julia -opt = Optimiser(ExpDecay(1, 0.1, 1000, 1e-4), Descent()) +opt = OptimiserChain(WeightDecay(1e-4), Descent()) ``` -Here we apply exponential decay to the `Descent` optimiser. The defaults of `ExpDecay` say that its learning rate will be decayed every 1000 steps. -It is then applied like any optimiser. +Here we apply the weight decay to the `Descent` optimiser. +The resulting optimiser `opt` can be used as any optimiser. ```julia -w = randn(10, 10) -w1 = randn(10,10) -ps = Params([w, w1]) +w = [randn(10, 10), randn(10, 10)] +opt_state = Flux.setup(opt, w) -loss(x) = Flux.Losses.mse(w * x, w1 * x) +loss(w, x) = Flux.mse(w[1] * x, w[2] * x) -loss(rand(10)) # around 9 +loss(w, rand(10)) # around 0.9 for t = 1:10^5 - θ = Params([w, w1]) - θ̄ = gradient(() -> loss(rand(10)), θ) - Flux.Optimise.update!(opt, θ, θ̄) + g = gradient(w -> loss(w[1], w[2], rand(10)), w) + Flux.update!(opt_state, w, g) end -loss(rand(10)) # around 0.9 +loss(w, rand(10)) # around 0.9 ``` It is possible to compose optimisers for some added flexibility. -```@docs -Flux.Optimise.Optimiser -``` - ## Scheduling Optimisers In practice, it is fairly common to schedule the learning rate of an optimiser to obtain faster convergence. There are a variety of popular scheduling policies, and you can find implementations of them in [ParameterSchedulers.jl](http://fluxml.ai/ParameterSchedulers.jl/stable). The documentation for ParameterSchedulers.jl provides a more detailed overview of the different scheduling policies, and how to use them with Flux optimisers. Below, we provide a brief snippet illustrating a [cosine annealing](https://arxiv.org/pdf/1608.03983.pdf) schedule with a momentum optimiser. @@ -109,10 +96,8 @@ ParameterSchedulers.jl allows for many more scheduling policies including arbitr Similar to optimisers, Flux also defines some simple decays that can be used in conjunction with other optimisers, or standalone. ```@docs -ExpDecay -InvDecay -WeightDecay SignDecay +WeightDecay ``` ## Gradient Clipping @@ -120,11 +105,11 @@ SignDecay Gradient clipping is useful for training recurrent neural networks, which have a tendency to suffer from the exploding gradient problem. An example usage is ```julia -opt = Optimiser(ClipValue(1e-3), Adam(1e-3)) +opt = OptimiserChain(ClipValue(1e-3), Adam(1e-3)) ``` ```@docs -ClipValue +ClipGrad ClipNorm ``` diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index 1bf0cfd1bf..67980831f9 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -10,9 +10,11 @@ Because of this: * Flux defines its own version of `setup` which checks this assumption. (Using instead `Optimisers.setup` will also work, they return the same thing.) +The available optimization rules are listed the [optimisation rules](@ref man-optimisers) page here. See the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the rules work. + ```@docs Flux.Train.setup -Flux.Train.train!(loss, model, data, state; cb) +Flux.Train.train!(loss, model, data, state) Optimisers.update! ``` @@ -32,59 +34,3 @@ Optimisers.adjust! Optimisers.freeze! Optimisers.thaw! ``` - -## Implicit style (Flux ≤ 0.14) - -Flux used to handle gradients, training, and optimisation rules quite differently. -The new style described above is called "explicit" by Zygote, and the old style "implicit". -Flux 0.13 and 0.14 are the transitional versions which support both; Flux 0.15 will remove the old. - -!!! compat "How to upgrade" - The blue-green boxes in the [training section](@ref man-training) describe - the changes needed to upgrade old code. - -The available rules are listed the [optimisation rules](@ref man-optimisers) page here. - -!!! compat "Old & new rules" - The new implementation of rules such as Adam in the Optimisers is quite different from the old one in `Flux.Optimise`. In Flux 0.14, `Flux.Adam()` still returns the old one, with supertype `Flux.Optimise.AbstractOptimiser`, but `setup` will silently translate it to its new counterpart. - -For full details on the interface for implicit-style optimisers, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). -See the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the new rules work. - -!!! compat "Flux ≤ 0.12" - Much earlier versions of Flux exported `params`, thus allowing unqualified `params(model)` - after `using Flux`. This conflicted with too many other packages, and was removed in Flux 0.13. - If you get an error `UndefVarError: params not defined`, this probably means that you are - following code for Flux 0.12 or earlier on a more recent version. - - -```@docs -Flux.params -Flux.Optimise.update!(opt::Flux.Optimise.AbstractOptimiser, xs::AbstractArray, gs) -Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb) -``` - -## Callbacks - -Implicit `train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: - -```julia -train!(objective, ps, data, opt, cb = () -> println("training")) -``` - -Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds. - -A more typical callback might look like this: - -```julia -test_x, test_y = # ... create single batch of test data ... -evalcb() = @show(loss(test_x, test_y)) -throttled_cb = throttle(evalcb, 5) -for epoch in 1:20 - @info "Epoch $epoch" - Flux.train!(objective, ps, data, opt, cb = throttled_cb) -end -``` - -See the page about [callback helpers](@ref man-callback-helpers) for more. - diff --git a/docs/src/training/training.md b/docs/src/training/training.md index f516f4ace9..0407820794 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -64,16 +64,6 @@ in order for the influence of the model's parameters to be observed by Zygote. It is also important that every `update!` step receives a newly computed gradient, as it will change whenever the model's parameters are changed, and for each new data point. -!!! compat "Implicit gradients" - Flux ≤ 0.14 used Zygote's "implicit" mode, in which `gradient` takes a zero-argument function. - It looks like this: - ``` - pars = Flux.params(model) - grad = gradient(() -> loss(model(input), label), pars) - ``` - Here `pars::Params` and `grad::Grads` are two dictionary-like structures. - Support for this will be removed from Flux 0.15, and these blue (teal?) boxes - explain what needs to change. ## Loss Functions @@ -117,13 +107,13 @@ fmap(model, grads[1]) do p, g end ``` -A slightly more refined version of this loop to update all the parameters is wrapped up as a function [`update!`](@ref Flux.Optimise.update!)`(opt_state, model, grads[1])`. -And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct. +A slightly more refined version of this loop to update all the parameters is wrapped up as a function [`update!`](@ref)`(opt_state, model, grads[1])`. +And the learning rate is the only thing stored in the [`Descent`](@ref) struct. However, there are many other optimisation rules, which adjust the step size and direction in various clever ways. Most require some memory of the gradients from earlier steps, rather than always -walking straight downhill -- [`Momentum`](@ref Flux.Optimise.Momentum) is the simplest. +walking straight downhill -- [`Momentum`](@ref) is the simplest. The function [`setup`](@ref Flux.Train.setup) creates the necessary storage for this, for a particular model. It should be called once, before training, and returns a tree-like object which is the first argument of `update!`. Like this: @@ -140,7 +130,7 @@ for data in train_set end ``` -Many commonly-used optimisation rules, such as [`Adam`](@ref Flux.Optimise.Adam), are built-in. +Many commonly-used optimisation rules, such as [`Adam`](@ref), are built-in. These are listed on the [optimisers](@ref man-optimisers) page. !!! compat "Implicit-style optimiser state" @@ -208,15 +198,6 @@ end Or explicitly writing the anonymous function which this `do` block creates, `train!((m,x,y) -> loss(m(x),y), model, train_set, opt_state)` is exactly equivalent. -!!! compat "Implicit-style `train!`" - This is a new method of `train!`, which takes the result of `setup` as its 4th argument. - The 1st argument is a function which accepts the model itself. - Flux versions ≤ 0.14 provided a method of `train!` for "implicit" parameters, - which works like this: - ``` - train!((x,y) -> loss(model(x), y), Flux.params(model), train_set, Adam()) - ``` - Real training loops often need more flexibility, and the best way to do this is just to write the loop. This is ordinary Julia code, without any need to work through some callback API. Here is an example, in which it may be helpful to note: @@ -284,12 +265,12 @@ A very simple model could be implemented as follows: grads = Flux.gradient(densemodel) do m result = m(input) penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 - my_loss(result, label) + 0.42 * penalty + my_loss(result, label) + 0.42f0 * penalty end ``` Accessing each individual parameter array by hand won't work well for large models. -Instead, we can use [`Flux.params`](@ref) to collect all of them, +Instead, we can use [`Flux.trainables`](@ref Optimisers.trainables) to collect all of them, and then apply a function to each one, and sum the result: ```julia @@ -297,8 +278,8 @@ pen_l2(x::AbstractArray) = sum(abs2, x)/2 grads = Flux.gradient(model) do m result = m(input) - penalty = sum(pen_l2, Flux.params(m)) - my_loss(result, label) + 0.42 * penalty + penalty = sum(pen_l2, Flux.trainables(m)) + my_loss(result, label) + 0.42f0 * penalty end ``` @@ -317,7 +298,7 @@ decay_opt_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model ``` Flux's optimisers are really modifications applied to the gradient before using it to update -the parameters, and `OptimiserChain` applies two such modifications. +the parameters, and [`OptimiserChain`](@ref Optimisers.OptimiserChain) applies two such modifications. The first, [`WeightDecay`](@ref Flux.WeightDecay) adds `0.42` times the original parameter to the gradient, matching the gradient of the penalty above (with the same, unrealistically large, constant). After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. @@ -325,14 +306,14 @@ After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. The same trick works for *L₁ regularisation* (also called Lasso), where the penalty is `pen_l1(x::AbstractArray) = sum(abs, x)` instead. This is implemented by `SignDecay(0.42)`. -The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref Flux.Optimise.ClipValue) or [`ClipNorm`](@ref Flux.Optimise.ClipNorm). +The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref) or [`ClipNorm`](@ref). Besides L1 / L2 / weight decay, another common and quite different kind of regularisation is provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the previous layer during training. It should switch automatically, but see [`trainmode!`](@ref Flux.trainmode!) / [`testmode!`](@ref Flux.testmode!) to manually enable or disable this layer. -## Freezing & Schedules +## Learning Rate Schedules Finer control of training, you may wish to alter the learning rate mid-way through training. This can be done with [`adjust!`](@ref Flux.adjust!), like this: @@ -348,10 +329,6 @@ for epoch in 1:1000 end ``` -!!! compat "Flux ≤ 0.14" - With the old "implicit" optimiser, `opt = Adam(0.1)`, the equivalent was to - directly mutate the `Adam` struct, `opt.eta = 0.001`. - Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt_state, beta = (0.8, 0.99))`. And such modifications can be applied to just one part of the model. For instance, this sets a different learning rate for the encoder and the decoder: @@ -367,6 +344,8 @@ opt_state = Flux.setup(Adam(0.02), bimodel) Flux.adjust!(opt_state.layers.enc, 0.03) ``` +## Freezing layer parameters + To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!). This is a temporary modification, reversed by `thaw!`: @@ -380,21 +359,7 @@ train!(loss, bimodel, data, opt_state) Flux.thaw!(opt_state) ``` -!!! compat "Flux ≤ 0.14" - The earlier "implicit" equivalent was to pass to `gradient` an object referencing only - part of the model, such as `Flux.params(bimodel.layers.enc)`. - While `adjust!` and `freeze!`/`thaw!` make temporary modifications to the optimiser state, permanently removing some fields of a new layer type from training is usually done when defining the layer, by calling for example [`@layer`](@ref Flux.@layer)` NewLayer trainable=(weight,)`. -## Implicit or Explicit? - -Flux used to handle gradients, training, and optimisation rules quite differently. -The new style described above is called "explicit" by Zygote, and the old style "implicit". -Flux 0.13 and 0.14 are the transitional versions which support both. - -The blue-green boxes above describe the changes. -For more details on training in the implicit style, see [Flux 0.13.6 documentation](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). - -For details about the two gradient modes, see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1). diff --git a/docs/src/training/zygote.md b/docs/src/training/zygote.md index 385e7dde7b..33d30d6ee8 100644 --- a/docs/src/training/zygote.md +++ b/docs/src/training/zygote.md @@ -18,22 +18,6 @@ Zygote.hessian_reverse Zygote.diaghessian ``` -## Implicit style (Flux ≤ 0.14) - -Flux used to use what Zygote calls "implicit" gradients, [described here](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) in its documentation. -However, support for this will be removed from Flux 0.15. - -!!! compat "Training" - The blue-green boxes in the [training section](@ref man-training) describe - the changes needed to upgrade old code from implicit to explicit style. - -```@docs -Zygote.gradient(loss, ::Params) -Zygote.Params -Zygote.Grads -Zygote.jacobian(loss, ::Params) -``` - ## ChainRules Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using [ChainRules](https://github.com/JuliaDiff/ChainRules.jl): diff --git a/docs/src/tutorials/2020-09-15-deep-learning-flux.md b/docs/src/tutorials/2020-09-15-deep-learning-flux.md index 7cb2a366b6..c386e5f3c4 100755 --- a/docs/src/tutorials/2020-09-15-deep-learning-flux.md +++ b/docs/src/tutorials/2020-09-15-deep-learning-flux.md @@ -167,52 +167,33 @@ gradient(myloss, W, b, x) Now we get gradients for each of the inputs `W`, `b` and `x`, which will come in handy when we want to train models. -Because ML models can contain hundreds of parameters, Flux provides a slightly different way of writing `gradient`. We instead mark arrays with `param` to indicate that we want their derivatives. `W` and `b` represent the weight and bias respectively. - -```julia -using Flux: params - -W = randn(3, 5) -b = zeros(3) -x = rand(5) - -y(x) = sum(W * x .+ b) - -grads = gradient(()->y(x), params([W, b])) - -grads[W], grads[b] -``` - - -We can now grab the gradients of `W` and `b` directly from those parameters. - -This comes in handy when working with *layers*. A layer is just a handy container for some parameters. For example, `Dense` does a linear transform for you. +ML models can contain hundreds of parameter arrays, therefore it is handy to group them into **layers**. +A layer is just a handy container for some parameters. For example, `Dense` does a linear transform for you. ```julia using Flux -m = Dense(10, 5) +m = Dense(10 => 5) x = rand(Float32, 10) ``` -We can easily get the parameters of any layer or model with params with `params`. +We can easily get the parameters of any layer or model with `trainables`. ```julia -params(m) +trainables(m) ``` -This makes it very easy to calculate the gradient for all parameters in a network, even if it has many parameters. +It very easy to calculate the gradient for all parameters in a network, even if it has many parameters. +The function `gradient` is not limited to array but can compute the gradient with respect to generic composite types. ```julia x = rand(Float32, 10) -m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax) -l(x) = sum(Flux.crossentropy(m(x), [0.5, 0.5])) -grads = gradient(params(m)) do - l(x) -end -for p in params(m) - println(grads[p]) +model = Chain(Dense(10 => 5, relu), Dense(5 => 2)) +loss(model, x) = Flux.logitcrossentropy(model(x), [0.5, 0.5]) +grad = gradient(m -> loss(m, x), model)[1] +for (k, p) in trainables(model, path=true) + println("$k => $(getkeypath(grad, k))") end ``` @@ -221,27 +202,26 @@ You don't have to use layers, but they can be convient for many simple kinds of The next step is to update our weights and perform optimisation. As you might be familiar, *Gradient Descent* is a simple algorithm that takes the weights and steps using a learning rate and the gradients. `weights = weights - learning_rate * gradient`. ```julia -using Flux.Optimise: update!, Descent η = 0.1 -for p in params(m) - update!(p, -η * grads[p]) +for (k, p) in trainables(m) + p .+= -η * getkeypath(grads, p) end ``` While this is a valid way of updating our weights, it can get more complicated as the algorithms we use get more involved. -Flux comes with a bunch of pre-defined optimisers and makes writing our own really simple. We just give it the learning rate η: +Flux comes with a bunch of pre-defined optimisers and makes writing our own really simple. We just give it the learning rate `η`: ```julia -opt = Descent(0.01) +opt_state = Flux.setup(Descent(η), model) ``` -`Training` a network reduces down to iterating on a dataset mulitple times, performing these steps in order. Just for a quick implementation, let’s train a network that learns to predict `0.5` for every input of 10 floats. `Flux` defines the `train!` function to do it for us. +Training a network reduces down to iterating on a dataset mulitple times, performing these steps in order. Just for a quick implementation, let’s train a network that learns to predict `0.5` for every input of 10 floats. `Flux` defines the `train!` function to do it for us. ```julia data, labels = rand(10, 100), fill(0.5, 2, 100) -loss(x, y) = sum(Flux.crossentropy(m(x), y)) -Flux.train!(loss, params(m), [(data,labels)], opt) +loss(m, x, y) = Flux.logitcrossentropy(m(x), y) +Flux.train!(loss, model, [(data, labels)], opt) ``` You don't have to use `train!`. In cases where arbitrary logic might be better suited, you could open up this training loop like so: @@ -249,10 +229,10 @@ You don't have to use `train!`. In cases where arbitrary logic might be better s ```julia for d in training_set # assuming d looks like (data, labels) # our super logic - gs = gradient(params(m)) do #m is our model - l = loss(d...) + g = gradient(model) do model + l = loss(model, d...) end - update!(opt, params(m), gs) + Flux.update!(opt_state, model, g) end ``` @@ -272,7 +252,7 @@ We will do the following steps in order: ```julia using Statistics -using Flux, Flux.Optimise +using Flux using MLDatasets: CIFAR10 using Images.ImageCore using Flux: onehotbatch, onecold @@ -321,18 +301,17 @@ m = Chain( Conv((5,5), 16=>8, relu), MaxPool((2,2)), x -> reshape(x, :, size(x, 4)), - Dense(200, 120), - Dense(120, 84), - Dense(84, 10), - softmax) |> gpu + Dense(200 => 120), + Dense(120 => 84), + Dense(84 => 10)) |> gpu ``` We will use a crossentropy loss and an Momentum optimiser here. Crossentropy will be a good option when it comes to working with mulitple independent classes. Momentum gradually lowers the learning rate as we proceed with the training. It helps maintain a bit of adaptivity in our optimisation, preventing us from over shooting from our desired destination. ```julia -using Flux: crossentropy, Momentum +using Flux: logitcrossentropy, Momentum -loss(x, y) = sum(crossentropy(m(x), y)) +loss(m, x, y) = logitcrossentropy(m(x), y) opt = Momentum(0.01) ``` diff --git a/docs/src/tutorials/2021-01-26-mlp.md b/docs/src/tutorials/2021-01-26-mlp.md index 2af8d3645c..763f711195 100644 --- a/docs/src/tutorials/2021-01-26-mlp.md +++ b/docs/src/tutorials/2021-01-26-mlp.md @@ -80,8 +80,8 @@ We define our model with the `build_model` function: ```julia function build_model(; imgsize=(28,28,1), nclasses=10) return Chain( - Dense(prod(imgsize), 32, relu), - Dense(32, nclasses)) + Dense(prod(imgsize) => 32, relu), + Dense(32 => nclasses)) end ``` diff --git a/docs/src/tutorials/2021-10-08-dcgan-mnist.md b/docs/src/tutorials/2021-10-08-dcgan-mnist.md index 4da32e5f2c..a746935eb6 100644 --- a/docs/src/tutorials/2021-10-08-dcgan-mnist.md +++ b/docs/src/tutorials/2021-10-08-dcgan-mnist.md @@ -109,7 +109,7 @@ dcgan_init(shape...) = randn(Float32, shape) * 0.02f0 ```julia function Generator(latent_dim) Chain( - Dense(latent_dim, 7*7*256, bias=false), + Dense(latent_dim => 7*7*256, bias=false), BatchNorm(7*7*256, relu), x -> reshape(x, 7, 7, 256, :), diff --git a/docs/src/tutorials/2021-10-14-vanilla-gan.md b/docs/src/tutorials/2021-10-14-vanilla-gan.md index f92ae54a8b..f07c7757bd 100644 --- a/docs/src/tutorials/2021-10-14-vanilla-gan.md +++ b/docs/src/tutorials/2021-10-14-vanilla-gan.md @@ -96,13 +96,13 @@ calling the model in a gradient context. As a final non-linearity, we use the `sigmoid` activation function. ```julia -discriminator = Chain(Dense(n_features, 1024, x -> leakyrelu(x, 0.2f0)), +discriminator = Chain(Dense(n_features => 1024, x -> leakyrelu(x, 0.2f0)), Dropout(0.3), - Dense(1024, 512, x -> leakyrelu(x, 0.2f0)), + Dense(1024 => 512, x -> leakyrelu(x, 0.2f0)), Dropout(0.3), - Dense(512, 256, x -> leakyrelu(x, 0.2f0)), + Dense(512 => 256, x -> leakyrelu(x, 0.2f0)), Dropout(0.3), - Dense(256, 1, sigmoid)) |> gpu + Dense(256 => 1, sigmoid)) |> gpu ``` Let's define the generator in a similar fashion. This network maps a latent @@ -113,9 +113,9 @@ the training data onto. ```julia generator = Chain(Dense(latent_dim, 256, x -> leakyrelu(x, 0.2f0)), - Dense(256, 512, x -> leakyrelu(x, 0.2f0)), - Dense(512, 1024, x -> leakyrelu(x, 0.2f0)), - Dense(1024, n_features, tanh)) |> gpu + Dense(256 => 512, x -> leakyrelu(x, 0.2f0)), + Dense(512 => 1024, x -> leakyrelu(x, 0.2f0)), + Dense(1024 => n_features, tanh)) |> gpu ``` From 2f28960b43975bae9baf58a0c4379e947a6520cc Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 Apr 2024 16:00:32 +0200 Subject: [PATCH 2/6] update docs --- docs/Project.toml | 2 +- docs/src/{ => reference}/models/activation.md | 0 docs/src/{ => reference}/models/functors.md | 0 docs/src/{ => reference}/models/layers.md | 0 docs/src/{ => reference}/models/losses.md | 0 docs/src/{ => reference}/models/nnlib.md | 0 docs/src/reference/training/callbacks.md | 75 ++++ docs/src/reference/training/reference.md | 36 ++ docs/src/reference/training/training.md | 365 ++++++++++++++++++ docs/src/reference/training/zygote.md | 38 ++ docs/src/{ => reference}/utilities.md | 0 docs/src/training/optimisers.md | 116 ------ 12 files changed, 515 insertions(+), 117 deletions(-) rename docs/src/{ => reference}/models/activation.md (100%) rename docs/src/{ => reference}/models/functors.md (100%) rename docs/src/{ => reference}/models/layers.md (100%) rename docs/src/{ => reference}/models/losses.md (100%) rename docs/src/{ => reference}/models/nnlib.md (100%) create mode 100644 docs/src/reference/training/callbacks.md create mode 100644 docs/src/reference/training/reference.md create mode 100644 docs/src/reference/training/training.md create mode 100644 docs/src/reference/training/zygote.md rename docs/src/{ => reference}/utilities.md (100%) delete mode 100644 docs/src/training/optimisers.md diff --git a/docs/Project.toml b/docs/Project.toml index a4d907e63e..1990368231 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -16,4 +16,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Documenter = "0.27" +Documenter = "1.3" diff --git a/docs/src/models/activation.md b/docs/src/reference/models/activation.md similarity index 100% rename from docs/src/models/activation.md rename to docs/src/reference/models/activation.md diff --git a/docs/src/models/functors.md b/docs/src/reference/models/functors.md similarity index 100% rename from docs/src/models/functors.md rename to docs/src/reference/models/functors.md diff --git a/docs/src/models/layers.md b/docs/src/reference/models/layers.md similarity index 100% rename from docs/src/models/layers.md rename to docs/src/reference/models/layers.md diff --git a/docs/src/models/losses.md b/docs/src/reference/models/losses.md similarity index 100% rename from docs/src/models/losses.md rename to docs/src/reference/models/losses.md diff --git a/docs/src/models/nnlib.md b/docs/src/reference/models/nnlib.md similarity index 100% rename from docs/src/models/nnlib.md rename to docs/src/reference/models/nnlib.md diff --git a/docs/src/reference/training/callbacks.md b/docs/src/reference/training/callbacks.md new file mode 100644 index 0000000000..148aa02128 --- /dev/null +++ b/docs/src/reference/training/callbacks.md @@ -0,0 +1,75 @@ +# [Callback Helpers](@id man-callback-helpers) + +```@docs +Flux.throttle +``` + +## Patience Helpers + +Flux provides utilities for controlling your training procedure according to some monitored condition and a maximum `patience`. For example, you can use `early_stopping` to stop training when the model is converging or deteriorating, or you can use `plateau` to check if the model is stagnating. + +For example, below we create a pseudo-loss function that decreases, bottoms out, and then increases. The early stopping trigger will break the loop before the loss increases too much. +```julia +# create a pseudo-loss that decreases for 4 calls, then starts increasing +# we call this like loss() +loss = let t = 0 + () -> begin + t += 1 + (t - 4) ^ 2 + end +end + +# create an early stopping trigger +# returns true when the loss increases for two consecutive steps +es = early_stopping(loss, 2; init_score = 9) + +# this will stop at the 6th (4 decreasing + 2 increasing calls) epoch +for epoch in 1:10 + es() && break +end +``` + +The keyword argument `distance` of `early_stopping` is a function of the form `distance(best_score, score)`. By default `distance` is `-`, which implies that the monitored metric `f` is expected to be decreasing and minimized. If you use some increasing metric (e.g. accuracy), you can customize the `distance` function: `(best_score, score) -> score - best_score`. +```julia +# create a pseudo-accuracy that increases by 0.01 each time from 0 to 1 +# we call this like acc() +acc = let v = 0 + () -> v = max(1, v + 0.01) +end + +# create an early stopping trigger for accuracy +es = early_stopping(acc, 3; delta = (best_score, score) -> score - best_score) + +# this will iterate until the 10th epoch +for epoch in 1:10 + es() && break +end +``` + +`early_stopping` and `plateau` are both built on top of `patience`. You can use `patience` to build your own triggers that use a patient counter. For example, if you want to trigger when the loss is below a threshold for several consecutive iterations: +```julia +threshold(f, thresh, delay) = patience(delay) do + f() < thresh +end +``` + +Both `predicate` in `patience` and `f` in `early_stopping` / `plateau` can accept extra arguments. You can pass such extra arguments to `predicate` or `f` through the returned function: +```julia +trigger = patience((a; b) -> a > b, 3) + +# this will iterate until the 10th epoch +for epoch in 1:10 + trigger(1; b = 2) && break +end + +# this will stop at the 3rd epoch +for epoch in 1:10 + trigger(3; b = 2) && break +end +``` + +```@docs +Flux.patience +Flux.early_stopping +Flux.plateau +``` diff --git a/docs/src/reference/training/reference.md b/docs/src/reference/training/reference.md new file mode 100644 index 0000000000..67980831f9 --- /dev/null +++ b/docs/src/reference/training/reference.md @@ -0,0 +1,36 @@ +# Training API Reference + +The new version of Flux's training code was written as an independent package, [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). +Only the function `train!` belongs to Flux itself. + +The Optimisers package is designed to allow for immutable objects. But at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) which can be updated in-place. +Because of this: + +* The objects returned by `Optimisers.update!` can be ignored. +* Flux defines its own version of `setup` which checks this assumption. + (Using instead `Optimisers.setup` will also work, they return the same thing.) + +The available optimization rules are listed the [optimisation rules](@ref man-optimisers) page here. See the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the rules work. + +```@docs +Flux.Train.setup +Flux.Train.train!(loss, model, data, state) +Optimisers.update! +``` + +`train!` uses [`@progress`](https://github.com/JuliaLogging/ProgressLogging.jl) which should show a progress bar in VSCode automatically. +To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl) +and follow its setup instructions. + +## Optimisation Modifiers + +The state returned by `setup` can be modified to temporarily prevent training of +some parts of the model, or to change the learning rate or other hyperparameter. +The functions for doing so may be accessed as `Flux.freeze!`, `Flux.thaw!`, and `Flux.adjust!`. +All mutate the state (or part of it) and return `nothing`. + +```@docs +Optimisers.adjust! +Optimisers.freeze! +Optimisers.thaw! +``` diff --git a/docs/src/reference/training/training.md b/docs/src/reference/training/training.md new file mode 100644 index 0000000000..0407820794 --- /dev/null +++ b/docs/src/reference/training/training.md @@ -0,0 +1,365 @@ +# [Training a Flux Model](@id man-training) + +Training refers to the process of slowly adjusting the parameters of a model to make it work better. +Besides the model itself, we will need three things: + +* An *objective function* that evaluates how well a model is doing on some input. +* An *optimisation rule* which describes how the model's parameters should be adjusted. +* Some *training data* to use as the input during this process. + +Usually the training data is some collection of examples (or batches of examples) which +are handled one-by-one. One *epoch* of training means that each example is used once, +something like this: + +```julia +# Initialise the optimiser for this model: +opt_state = Flux.setup(rule, model) + +for data in train_set + # Unpack this element (for supervised training): + input, label = data + + # Calculate the gradient of the objective + # with respect to the parameters within the model: + grads = Flux.gradient(model) do m + result = m(input) + loss(result, label) + end + + # Update the parameters so as to reduce the objective, + # according the chosen optimisation rule: + Flux.update!(opt_state, model, grads[1]) +end +``` + +This loop can also be written using the function [`train!`](@ref Flux.Train.train!), +but it's helpful to understand the pieces first: + +```julia +train!(model, train_set, opt_state) do m, x, y + loss(m(x), y) +end +``` + +## Model Gradients + +Fist recall from the section on [taking gradients](@ref man-taking-gradients) that +`Flux.gradient(f, a, b)` always calls `f(a, b)`, and returns a tuple `(∂f_∂a, ∂f_∂b)`. +In the code above, the function `f` passed to `gradient` is an anonymous function with +one argument, created by the `do` block, hence `grads` is a tuple with one element. +Instead of a `do` block, we could have written: + +```julia +grads = Flux.gradient(m -> loss(m(input), label), model) +``` + +Since the model is some nested set of layers, `grads[1]` is a similarly nested set of +`NamedTuple`s, ultimately containing gradient components. If (for example) +`θ = model.layers[1].weight[2,3]` is one scalar parameter, an entry in a matrix of weights, +then the derivative of the loss with respect to it is `∂f_∂θ = grads[1].layers[1].weight[2,3]`. + +It is important that the execution of the model takes place inside the call to `gradient`, +in order for the influence of the model's parameters to be observed by Zygote. + +It is also important that every `update!` step receives a newly computed gradient, +as it will change whenever the model's parameters are changed, and for each new data point. + + +## Loss Functions + +The objective function must return a number representing how far the model is from +the desired result. This is termed the *loss* of the model. + +This number can be produced by any ordinary Julia code, but this must be executed +within the call to `gradient`. For instance, we could define a function +```julia +loss(y_hat, y) = sum((y_hat .- y).^2) +``` +or write this directly inside the `do` block above. Many commonly used functions, +like [`mse`](@ref Flux.Losses.mse) for mean-squared error or [`crossentropy`](@ref Flux.Losses.crossentropy) for cross-entropy loss, +are available from the [`Flux.Losses`](../models/losses.md) module. + +!!! compat "Implicit-style loss functions" + Flux ≤ 0.14 needed a loss function which closed over a reference to the model, + instead of being a pure function. Thus in old code you may see something like + ``` + loss(x, y) = sum((model(x) .- y).^2) + ``` + which defines a function making reference to a particular global variable `model`. + +## Optimisation Rules + +The simplest kind of optimisation using the gradient is termed *gradient descent* +(or sometimes *stochastic gradient descent* when, as here, it is not applied to the entire dataset at once). + +Gradient descent needs a *learning rate* which is a small number describing how fast to walk downhill, +usually written as the Greek letter "eta", `η`. This is often described as a *hyperparameter*, +to distinguish it from the parameters which are being updated `θ = θ - η * ∂loss_∂θ`. +We want to update all the parameters in the model, like this: + +```julia +η = 0.01 # learning rate + +# For each parameter array, update +# according to the corresponding gradient: +fmap(model, grads[1]) do p, g + p .= p .- η .* g +end +``` + +A slightly more refined version of this loop to update all the parameters is wrapped up as a function [`update!`](@ref)`(opt_state, model, grads[1])`. +And the learning rate is the only thing stored in the [`Descent`](@ref) struct. + +However, there are many other optimisation rules, which adjust the step size and +direction in various clever ways. +Most require some memory of the gradients from earlier steps, rather than always +walking straight downhill -- [`Momentum`](@ref) is the simplest. +The function [`setup`](@ref Flux.Train.setup) creates the necessary storage for this, for a particular model. +It should be called once, before training, and returns a tree-like object which is the +first argument of `update!`. Like this: + +```julia +# Initialise momentum +opt_state = Flux.setup(Momentum(0.01, 0.9), model) + +for data in train_set + grads = [...] + + # Update both model parameters and optimiser state: + Flux.update!(opt_state, model, grads[1]) +end +``` + +Many commonly-used optimisation rules, such as [`Adam`](@ref), are built-in. +These are listed on the [optimisers](@ref man-optimisers) page. + +!!! compat "Implicit-style optimiser state" + This `setup` makes another tree-like structure. Old versions of Flux did not do this, + and instead stored a dictionary-like structure within the optimiser `Adam(0.001)`. + This was initialised on first use of the version of `update!` for "implicit" parameters. + + +## Datasets & Batches + +The loop above iterates through `train_set`, expecting at each step a tuple `(input, label)`. +The very simplest such object is a vector of tuples, such as this: + +```julia +x = randn(28, 28) +y = rand(10) +data = [(x, y)] +``` + +or `data = [(x, y), (x, y), (x, y)]` for the same values three times. + +Very often, the initial data is large arrays which you need to slice into examples. +To produce one iterator of pairs `(x, y)`, you might want `zip`: + +```julia +X = rand(28, 28, 60_000); # many images, each 28 × 28 +Y = rand(10, 60_000) +data = zip(eachslice(X; dims=3), eachcol(Y)) + +first(data) isa Tuple{AbstractMatrix, AbstractVector} # true +``` + +Here each iteration will use one matrix `x` (an image, perhaps) and one vector `y`. +It is very common to instead train on *batches* of such inputs (or *mini-batches*, +the two words mean the same thing) both for efficiency and for better results. +This can be easily done using the [`DataLoader`](@ref Flux.Data.DataLoader): + +```julia +data = Flux.DataLoader((X, Y), batchsize=32) + +x1, y1 = first(data) +size(x1) == (28, 28, 32) +length(data) == 1875 === 60_000 ÷ 32 +``` + +Flux's layers are set up to accept such a batch of input data, +and the convolutional layers such as [`Conv`](@ref Flux.Conv) require it. +The batch index is always the last dimension. + +## Training Loops + +Simple training loops like the one above can be written compactly using +the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads: + +```julia +opt_state = Flux.setup(Adam(), model) + +for epoch in 1:100 + Flux.train!(model, train_set, opt_state) do m, x, y + loss(m(x), y) + end +end +``` + +Or explicitly writing the anonymous function which this `do` block creates, +`train!((m,x,y) -> loss(m(x),y), model, train_set, opt_state)` is exactly equivalent. + +Real training loops often need more flexibility, and the best way to do this is just +to write the loop. This is ordinary Julia code, without any need to work through some +callback API. Here is an example, in which it may be helpful to note: + +* The function [`withgradient`](@ref Zygote.withgradient) is like `gradient` but also + returns the value of the function, for logging or diagnostic use. +* Logging or printing is best done outside of the `gradient` call, + as there is no need to differentiate these commands. +* To use `result` for logging purposes, you could change the `do` block to end with + `return my_loss(result, label), result`, i.e. make the function passed to `withgradient` + return a tuple. The first element is always the loss. +* Julia's `break` and `continue` keywords let you exit from parts of the loop. + +```julia +opt_state = Flux.setup(Adam(), model) + +my_log = [] +for epoch in 1:100 + losses = Float32[] + for (i, data) in enumerate(train_set) + input, label = data + + val, grads = Flux.withgradient(model) do m + # Any code inside here is differentiated. + # Evaluation of the model and loss must be inside! + result = m(input) + my_loss(result, label) + end + + # Save the loss from the forward pass. (Done outside of gradient.) + push!(losses, val) + + # Detect loss of Inf or NaN. Print a warning, and then skip update! + if !isfinite(val) + @warn "loss is $val on item $i" epoch + continue + end + + Flux.update!(opt_state, model, grads[1]) + end + + # Compute some accuracy, and save details as a NamedTuple + acc = my_accuracy(model, train_set) + push!(my_log, (; acc, losses)) + + # Stop training when some criterion is reached + if acc > 0.95 + println("stopping after $epoch epochs") + break + end +end +``` + +## Regularisation + +The term *regularisation* covers a wide variety of techniques aiming to improve the +result of training. This is often done to avoid overfitting. + +Some of these can be implemented by simply modifying the loss function. +*L₂ regularisation* (sometimes called ridge regression) adds to the loss a penalty +proportional to `θ^2` for every scalar parameter. +A very simple model could be implemented as follows: + +```julia +grads = Flux.gradient(densemodel) do m + result = m(input) + penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 + my_loss(result, label) + 0.42f0 * penalty +end +``` + +Accessing each individual parameter array by hand won't work well for large models. +Instead, we can use [`Flux.trainables`](@ref Optimisers.trainables) to collect all of them, +and then apply a function to each one, and sum the result: + +```julia +pen_l2(x::AbstractArray) = sum(abs2, x)/2 + +grads = Flux.gradient(model) do m + result = m(input) + penalty = sum(pen_l2, Flux.trainables(m)) + my_loss(result, label) + 0.42f0 * penalty +end +``` + +However, the gradient of this penalty term is very simple: It is proportional to the original weights. +So there is a simpler way to implement exactly the same thing, by modifying the optimiser +instead of the loss function. This is done by replacing this: + +```julia +opt_state = Flux.setup(Adam(0.1), model) +``` + +with this: + +```julia +decay_opt_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) +``` + +Flux's optimisers are really modifications applied to the gradient before using it to update +the parameters, and [`OptimiserChain`](@ref Optimisers.OptimiserChain) applies two such modifications. +The first, [`WeightDecay`](@ref Flux.WeightDecay) adds `0.42` times the original parameter to the gradient, +matching the gradient of the penalty above (with the same, unrealistically large, constant). +After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. + +The same trick works for *L₁ regularisation* (also called Lasso), where the penalty is +`pen_l1(x::AbstractArray) = sum(abs, x)` instead. This is implemented by `SignDecay(0.42)`. + +The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref) or [`ClipNorm`](@ref). + +Besides L1 / L2 / weight decay, another common and quite different kind of regularisation is +provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the +previous layer during training. +It should switch automatically, but see [`trainmode!`](@ref Flux.trainmode!) / [`testmode!`](@ref Flux.testmode!) to manually enable or disable this layer. + +## Learning Rate Schedules + +Finer control of training, you may wish to alter the learning rate mid-way through training. +This can be done with [`adjust!`](@ref Flux.adjust!), like this: + +```julia +opt_state = Flux.setup(Adam(0.1), model) # initialise once + +for epoch in 1:1000 + train!([...], state) # Train with η = 0.1 for first 100, + if epoch == 100 # then change to use η = 0.01 for the rest. + Flux.adjust!(opt_state, 0.01) + end +end +``` + +Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt_state, beta = (0.8, 0.99))`. +And such modifications can be applied to just one part of the model. +For instance, this sets a different learning rate for the encoder and the decoder: + +```julia +# Consider some model with two parts: +bimodel = Chain(enc = [...], dec = [...]) + +# This returns a tree whose structure matches the model: +opt_state = Flux.setup(Adam(0.02), bimodel) + +# Adjust the learning rate to be used for bimodel.layers.enc +Flux.adjust!(opt_state.layers.enc, 0.03) +``` + +## Freezing layer parameters + +To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!). +This is a temporary modification, reversed by `thaw!`: + +```julia +Flux.freeze!(opt_state.layers.enc) + +# Now training won't update parameters in bimodel.layers.enc +train!(loss, bimodel, data, opt_state) + +# Un-freeze the entire model: +Flux.thaw!(opt_state) +``` + +While `adjust!` and `freeze!`/`thaw!` make temporary modifications to the optimiser state, +permanently removing some fields of a new layer type from training is usually done +when defining the layer, by calling for example [`@layer`](@ref Flux.@layer)` NewLayer trainable=(weight,)`. + diff --git a/docs/src/reference/training/zygote.md b/docs/src/reference/training/zygote.md new file mode 100644 index 0000000000..33d30d6ee8 --- /dev/null +++ b/docs/src/reference/training/zygote.md @@ -0,0 +1,38 @@ +# [Automatic Differentiation using Zygote.jl](@id autodiff-zygote) + +Flux re-exports the `gradient` from [Zygote](https://github.com/FluxML/Zygote.jl), and uses this function within [`train!`](@ref Flux.train!) to differentiate the model. Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/). + + +## Explicit style + +The preferred way of using Zygote, and the only way of using most other AD packages, +is to explicitly provide a function and its arguments. + +```@docs +Zygote.gradient(f, args...) +Zygote.withgradient(f, args...) +Zygote.jacobian(f, args...) +Zygote.withjacobian(f, args...) +Zygote.hessian +Zygote.hessian_reverse +Zygote.diaghessian +``` + +## ChainRules + +Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using [ChainRules](https://github.com/JuliaDiff/ChainRules.jl): + +```@docs +ChainRulesCore.ignore_derivatives +ChainRulesCore.@non_differentiable +``` + +To manually supply the gradient for one function, you should define a method of `rrule`. ChainRules has [detailed documentation](https://juliadiff.org/ChainRulesCore.jl/stable/) on how this works. + +```@docs +ChainRulesCore.rrule +ChainRulesCore.frule +ChainRulesCore.@scalar_rule +ChainRulesCore.NoTangent +ChainRulesCore.ZeroTangent +``` diff --git a/docs/src/utilities.md b/docs/src/reference/utilities.md similarity index 100% rename from docs/src/utilities.md rename to docs/src/reference/utilities.md diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md deleted file mode 100644 index bc6dc0628f..0000000000 --- a/docs/src/training/optimisers.md +++ /dev/null @@ -1,116 +0,0 @@ -```@meta -CurrentModule = Flux -``` - -# [Optimisation Rules](@id man-optimisers) - -Any optimization rule from Optimisers.jl can be used with [`train!`](@ref) and -other training functions. - -For full details of how the new interface works, see the [Optimisers.jl documentation](https://fluxml.ai/Optimisers.jl/dev/). - - -## Optimisers Reference - -All optimisers return an object that, when passed to `train!`, will update the parameters passed to it. - -```@docs -Optimisers.Descent -Optimisers.Momentum -Optimisers.Nesterov -Optimisers.RMSProp -Optimisers.Adam -Optimisers.RAdam -Optimisers.AdaMax -Optimisers.AdaGrad -Optimisers.AdaDelta -Optimisers.AMSGrad -Optimisers.NAdam -Optimisers.AdamW -Optimisers.OAdam -Optimisers.AdaBelief -``` - -## Composing Optimisers - -Flux (through Optimisers.jl) defines a special kind of optimiser called `OptimiserChain` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient -that will be fed into the next, and the resultant update will be applied to the parameter as usual. A classic use case is where adding decays is desirable. Optimisers.jl defines the basic decay corresponding to an $L_2$ regularization in the loss as `WeighDecay`. - -```julia -opt = OptimiserChain(WeightDecay(1e-4), Descent()) -``` - -Here we apply the weight decay to the `Descent` optimiser. -The resulting optimiser `opt` can be used as any optimiser. - -```julia -w = [randn(10, 10), randn(10, 10)] -opt_state = Flux.setup(opt, w) - -loss(w, x) = Flux.mse(w[1] * x, w[2] * x) - -loss(w, rand(10)) # around 0.9 - -for t = 1:10^5 - g = gradient(w -> loss(w[1], w[2], rand(10)), w) - Flux.update!(opt_state, w, g) -end - -loss(w, rand(10)) # around 0.9 -``` - -It is possible to compose optimisers for some added flexibility. - -## Scheduling Optimisers - -In practice, it is fairly common to schedule the learning rate of an optimiser to obtain faster convergence. There are a variety of popular scheduling policies, and you can find implementations of them in [ParameterSchedulers.jl](http://fluxml.ai/ParameterSchedulers.jl/stable). The documentation for ParameterSchedulers.jl provides a more detailed overview of the different scheduling policies, and how to use them with Flux optimisers. Below, we provide a brief snippet illustrating a [cosine annealing](https://arxiv.org/pdf/1608.03983.pdf) schedule with a momentum optimiser. - -First, we import ParameterSchedulers.jl and initialize a cosine annealing schedule to vary the learning rate between `1e-4` and `1e-2` every 10 steps. We also create a new [`Momentum`](@ref) optimiser. -```julia -using ParameterSchedulers - -opt = Momentum() -schedule = Cos(λ0 = 1e-4, λ1 = 1e-2, period = 10) -for (eta, epoch) in zip(schedule, 1:100) - opt.eta = eta - # your training code here -end -``` -`schedule` can also be indexed (e.g. `schedule(100)`) or iterated like any iterator in Julia. - -ParameterSchedulers.jl schedules are stateless (they don't store their iteration state). If you want a _stateful_ schedule, you can use `ParameterSchedulers.Stateful`: -```julia -using ParameterSchedulers: Stateful, next! - -schedule = Stateful(Cos(λ0 = 1e-4, λ1 = 1e-2, period = 10)) -for epoch in 1:100 - opt.eta = next!(schedule) - # your training code here -end -``` - -ParameterSchedulers.jl allows for many more scheduling policies including arbitrary functions, looping any function with a given period, or sequences of many schedules. See the ParameterSchedulers.jl documentation for more info. - -## Decays - -Similar to optimisers, Flux also defines some simple decays that can be used in conjunction with other optimisers, or standalone. - -```@docs -SignDecay -WeightDecay -``` - -## Gradient Clipping - -Gradient clipping is useful for training recurrent neural networks, which have a tendency to suffer from the exploding gradient problem. An example usage is - -```julia -opt = OptimiserChain(ClipValue(1e-3), Adam(1e-3)) -``` - -```@docs -ClipGrad -ClipNorm -``` - - From 5c0785214a83d8908de5554870c9f851545e37e0 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 Apr 2024 20:01:07 +0200 Subject: [PATCH 3/6] update docs --- docs/make.jl | 65 ++-- .../2020-09-15-deep-learning-flux.md | 0 .../2021-02-07-convnet.md | 0 .../2021-10-08-dcgan-mnist.md | 0 .../2021-10-14-vanilla-gan.md | 0 docs/src/{ => guide}/gpu.md | 4 +- docs/src/{ => guide}/models/basics.md | 4 - .../models/custom_layers.md} | 6 +- docs/src/{ => guide}/models/overview.md | 2 +- docs/src/{ => guide}/models/quickstart.md | 4 +- docs/src/{ => guide}/models/recurrence.md | 4 +- docs/src/{ => guide}/performance.md | 0 docs/src/{ => guide}/saving.md | 0 .../{reference => guide}/training/training.md | 23 +- docs/src/{ => reference}/data/mlutils.md | 0 docs/src/{ => reference}/data/onehot.md | 0 docs/src/{ => reference}/destructure.md | 8 + docs/src/reference/models/functors.md | 11 + docs/src/reference/models/layers.md | 4 - docs/src/reference/models/nnlib.md | 2 +- docs/src/{ => reference}/outputsize.md | 0 docs/src/reference/training/optimisers.md | 122 ++++++ docs/src/reference/training/reference.md | 2 + docs/src/reference/training/zygote.md | 8 + docs/src/training/callbacks.md | 75 ---- docs/src/training/reference.md | 36 -- docs/src/training/training.md | 365 ------------------ docs/src/training/zygote.md | 38 -- .../tutorials/{2021-01-26-mlp.md => mlp.md} | 69 ++-- src/Flux.jl | 7 +- src/functor.jl | 10 +- src/layers/basic.jl | 48 +-- src/optimise/train.jl | 2 +- src/train.jl | 12 +- src/utils.jl | 14 +- 35 files changed, 286 insertions(+), 659 deletions(-) rename docs/{src/tutorials => old_tutorials}/2020-09-15-deep-learning-flux.md (100%) rename docs/{src/tutorials => old_tutorials}/2021-02-07-convnet.md (100%) rename docs/{src/tutorials => old_tutorials}/2021-10-08-dcgan-mnist.md (100%) rename docs/{src/tutorials => old_tutorials}/2021-10-14-vanilla-gan.md (100%) rename docs/src/{ => guide}/gpu.md (97%) rename docs/src/{ => guide}/models/basics.md (99%) rename docs/src/{models/advanced.md => guide/models/custom_layers.md} (98%) rename docs/src/{ => guide}/models/overview.md (98%) rename docs/src/{ => guide}/models/quickstart.md (92%) rename docs/src/{ => guide}/models/recurrence.md (98%) rename docs/src/{ => guide}/performance.md (100%) rename docs/src/{ => guide}/saving.md (100%) rename docs/src/{reference => guide}/training/training.md (93%) rename docs/src/{ => reference}/data/mlutils.md (100%) rename docs/src/{ => reference}/data/onehot.md (100%) rename docs/src/{ => reference}/destructure.md (97%) rename docs/src/{ => reference}/outputsize.md (100%) create mode 100644 docs/src/reference/training/optimisers.md delete mode 100644 docs/src/training/callbacks.md delete mode 100644 docs/src/training/reference.md delete mode 100644 docs/src/training/training.md delete mode 100644 docs/src/training/zygote.md rename docs/src/tutorials/{2021-01-26-mlp.md => mlp.md} (79%) diff --git a/docs/make.jl b/docs/make.jl index a1b588d618..331e38fdac 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -13,32 +13,33 @@ makedocs( # You could read this end-to-end, or skip to what you need. # Aim is to cover each new concept exactly once (but not list all variants). # Hard to invent further divisions which aren't more confusing than helpful? - "Quick Start" => "models/quickstart.md", - "Fitting a Line" => "models/overview.md", - "Gradients and Layers" => "models/basics.md", - "Training" => "training/training.md", - "Recurrence" => "models/recurrence.md", - "GPU Support" => "gpu.md", - "Saving & Loading" => "saving.md", - "Performance Tips" => "performance.md", + "Quick Start" => "guide/models/quickstart.md", + "Fitting a Line" => "guide/models/overview.md", + "Gradients and Layers" => "guide/models/basics.md", + "Custom Layers" => "guide/models/custom_layers.md", + "Training" => "guide/training/training.md", + "Recurrence" => "guide/models/recurrence.md", + "GPU Support" => "guide/gpu.md", + "Saving & Loading" => "guide/saving.md", + "Performance Tips" => "guide/performance.md", ], "Ecosystem" => "ecosystem.md", "Reference" => [ # This essentially collects docstrings, with a bit of introduction. - "Built-in Layers" => "models/layers.md", - "Activation Functions" => "models/activation.md", - "Weight Initialisation" => "utilities.md", - "Loss Functions" => "models/losses.md", - "Training API" => "training/reference.md", - "Optimisation Rules" => "training/optimisers.md", - "Shape Inference" => "outputsize.md", - "Flat vs. Nested" => "destructure.md", - "Callback Helpers" => "training/callbacks.md", - "Gradients -- Zygote.jl" => "training/zygote.md", - "Batching Data -- MLUtils.jl" => "data/mlutils.md", - "OneHotArrays.jl" => "data/onehot.md", - "Low-level Operations -- NNlib.jl" => "models/nnlib.md", - "Nested Structures -- Functors.jl" => "models/functors.md", + "Built-in Layers" => "reference/models/layers.md", + "Activation Functions" => "reference/models/activation.md", + "Weight Initialisation" => "reference/utilities.md", + "Loss Functions" => "reference/models/losses.md", + "Training API" => "reference/training/reference.md", + "Optimisation Rules" => "reference/training/optimisers.md", + "Shape Inference" => "reference/outputsize.md", + "Flat vs. Nested" => "reference/destructure.md", + "Callback Helpers" => "reference/training/callbacks.md", + "Gradients -- Zygote.jl" => "reference/training/zygote.md", + "Batching Data -- MLUtils.jl" => "reference/data/mlutils.md", + "OneHotArrays.jl" => "reference/data/onehot.md", + "Low-level Operations -- NNlib.jl" => "reference/models/nnlib.md", + "Nested Structures -- Functors.jl" => "reference/models/functors.md", ], "Tutorials" => [ # These walk you through various tasks. It's fine if they overlap quite a lot. @@ -46,15 +47,14 @@ makedocs( # Or perhaps those should just be trashed, model zoo versions are newer & more useful. "Linear Regression" => "tutorials/linear_regression.md", "Logistic Regression" => "tutorials/logistic_regression.md", + "Multi-layer Perceptron" => "tutorials/mlp.md", #= "Julia & Flux: 60 Minute Blitz" => "tutorials/2020-09-15-deep-learning-flux.md", - "Multi-layer Perceptron" => "tutorials/2021-01-26-mlp.md", "Simple ConvNet" => "tutorials/2021-02-07-convnet.md", "Generative Adversarial Net" => "tutorials/2021-10-14-vanilla-gan.md", "Deep Convolutional GAN" => "tutorials/2021-10-08-dcgan-mnist.md", =# # Not really sure where this belongs... some in Fluxperimental, aim to delete? - "Custom Layers" => "models/advanced.md", # TODO move freezing to Training ], ], format = Documenter.HTML( @@ -63,19 +63,10 @@ makedocs( assets = ["assets/flux.css"], prettyurls = get(ENV, "CI", nothing) == "true" ), - doctest = false, - # linkcheck = true, - checkdocs = :exports, - # strict = true, - # strict = [ - # :cross_references, - # :missing_docs, - # :doctest, - # :linkcheck, - # :parse_error, - # :example_block, - # :autodocs_block, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :setup_block - # ], + doctest = false, # done later + checkdocs = :none, # :exports # Do not check if all functions appear in the docs + # since it considers all packages + warnonly = [:cross_references] ) doctest(Flux) # only test Flux modules diff --git a/docs/src/tutorials/2020-09-15-deep-learning-flux.md b/docs/old_tutorials/2020-09-15-deep-learning-flux.md similarity index 100% rename from docs/src/tutorials/2020-09-15-deep-learning-flux.md rename to docs/old_tutorials/2020-09-15-deep-learning-flux.md diff --git a/docs/src/tutorials/2021-02-07-convnet.md b/docs/old_tutorials/2021-02-07-convnet.md similarity index 100% rename from docs/src/tutorials/2021-02-07-convnet.md rename to docs/old_tutorials/2021-02-07-convnet.md diff --git a/docs/src/tutorials/2021-10-08-dcgan-mnist.md b/docs/old_tutorials/2021-10-08-dcgan-mnist.md similarity index 100% rename from docs/src/tutorials/2021-10-08-dcgan-mnist.md rename to docs/old_tutorials/2021-10-08-dcgan-mnist.md diff --git a/docs/src/tutorials/2021-10-14-vanilla-gan.md b/docs/old_tutorials/2021-10-14-vanilla-gan.md similarity index 100% rename from docs/src/tutorials/2021-10-14-vanilla-gan.md rename to docs/old_tutorials/2021-10-14-vanilla-gan.md diff --git a/docs/src/gpu.md b/docs/src/guide/gpu.md similarity index 97% rename from docs/src/gpu.md rename to docs/src/guide/gpu.md index 182661a0cb..ffce90b055 100644 --- a/docs/src/gpu.md +++ b/docs/src/guide/gpu.md @@ -1,8 +1,7 @@ # GPU Support Starting with v0.14, Flux doesn't force a specific GPU backend and the corresponding package dependencies on the users. -Thanks to the [package extension mechanism]( -https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) introduced in julia v1.9, Flux conditionally loads GPU specific code once a GPU package is made available (e.g. through `using CUDA`). +Thanks to the [package extension mechanism](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) introduced in julia v1.9, Flux conditionally loads GPU specific code once a GPU package is made available (e.g. through `using CUDA`). NVIDIA GPU support requires the packages `CUDA.jl` and `cuDNN.jl` to be installed in the environment. In the julia REPL, type `] add CUDA, cuDNN` to install them. For more details see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) readme. @@ -384,4 +383,5 @@ Flux.FluxAMDGPUDevice Flux.FluxMetalDevice Flux.supported_devices Flux.get_device +Flux.gpu_backend! ``` diff --git a/docs/src/models/basics.md b/docs/src/guide/models/basics.md similarity index 99% rename from docs/src/models/basics.md rename to docs/src/guide/models/basics.md index 4334f47d33..5aded62128 100644 --- a/docs/src/models/basics.md +++ b/docs/src/guide/models/basics.md @@ -245,7 +245,3 @@ end Affine(3 => 1, bias=false) |> gpu ``` -```@docs -Flux.@layer -Flux.create_bias -``` diff --git a/docs/src/models/advanced.md b/docs/src/guide/models/custom_layers.md similarity index 98% rename from docs/src/models/advanced.md rename to docs/src/guide/models/custom_layers.md index 9569944b2e..01942c5ea2 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/guide/models/custom_layers.md @@ -26,9 +26,9 @@ Notice that we parameterized the type of the `chain` field. This is necessary fo You can then use the model like: ```julia -chain = Chain(Dense(10 => 10)) +chain = Chain(Dense(10 => 10, relu), Dense(10 => 10)) model = CustomModel(chain) -model(rand(10)) +model(rand(Float32, 10)) ``` For an intro to Flux and automatic differentiation, see this [tutorial](https://fluxml.ai/tutorials/2020/09/15/deep-learning-flux.html). @@ -168,7 +168,7 @@ model(xs) Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs. -We start by following the same steps as the `Join` layer: define a struct, use [`@layer`](@ref), and define the forward pass. +We start by following the same steps as the `Join` layer: define a struct, use [`Flux.@layer`](@ref), and define the forward pass. ```julia using Flux using CUDA diff --git a/docs/src/models/overview.md b/docs/src/guide/models/overview.md similarity index 98% rename from docs/src/models/overview.md rename to docs/src/guide/models/overview.md index adc7484fe4..71eff0d33f 100644 --- a/docs/src/models/overview.md +++ b/docs/src/guide/models/overview.md @@ -89,7 +89,7 @@ More accurate predictions will yield a lower loss. You can write your own loss f ## 3. Improve the Prediction -Under the hood, the Flux [`Flux.train!`](@ref) function uses *a loss function* and *training data* to improve the *parameters* of your model based on a pluggable [`optimiser`](../training/optimisers.md): +Under the hood, the Flux [`Flux.train!`](@ref) function uses *a loss function* and *training data* to improve the *parameters* of your model based on a pluggable [`optimiser`](../../reference/training/optimisers.md): ```jldoctest overview julia> using Flux: train! diff --git a/docs/src/models/quickstart.md b/docs/src/guide/models/quickstart.md similarity index 92% rename from docs/src/models/quickstart.md rename to docs/src/guide/models/quickstart.md index e7e379f17f..664f56ff04 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/guide/models/quickstart.md @@ -49,7 +49,7 @@ probs2 = softmax(out2) # normalise to get probabilities mean((probs2[1,:] .> 0.5) .== truth) # accuracy 94% so far! ``` -![](../assets/quickstart/oneminute.png) +![](../../assets/quickstart/oneminute.png) ```julia using Plots # to draw the above figure @@ -87,7 +87,7 @@ Some things to notice in this example are: * The `model` can be called like a function, `y = model(x)`. Each layer like [`Dense`](@ref Flux.Dense) is an ordinary `struct`, which encapsulates some arrays of parameters (and possibly other state, as for [`BatchNorm`](@ref Flux.BatchNorm)). -* But the model does not contain the loss function, nor the optimisation rule. The momenta needed by [`Adam`](@ref Flux.Adam) are stored in the object returned by [setup](@ref Flux.Train.setup). And [`Flux.logitcrossentropy`](@ref Flux.Losses.logitcrossentropy) is an ordinary function that combines the [`softmax`](@ref Flux.softmax) and [`crossentropy`](@ref Flux.crossentropy) functions. +* But the model does not contain the loss function, nor the optimisation rule. The momenta needed by [`Adam`](@ref Optimisers.Adam) are stored in the object returned by [setup](@ref Flux.Train.setup). And [`Flux.logitcrossentropy`](@ref Flux.Losses.logitcrossentropy) is an ordinary function that combines the [`softmax`](@ref Flux.softmax) and [`crossentropy`](@ref Flux.crossentropy) functions. * The `do` block creates an anonymous function, as the first argument of `gradient`. Anything executed within this is differentiated. diff --git a/docs/src/models/recurrence.md b/docs/src/guide/models/recurrence.md similarity index 98% rename from docs/src/models/recurrence.md rename to docs/src/guide/models/recurrence.md index 87cd944f4f..a93b0dc258 100644 --- a/docs/src/models/recurrence.md +++ b/docs/src/guide/models/recurrence.md @@ -4,7 +4,7 @@ To introduce Flux's recurrence functionalities, we will consider the following vanilla recurrent neural network structure: -![](../assets/rnn-basic.png) +![](../../assets/rnn-basic.png) In the above, we have a sequence of length 3, where `x1` to `x3` represent the input at each step (could be a timestamp or a word in a sentence), and `y1` to `y3` are their respective outputs. @@ -34,7 +34,7 @@ Notice how the above is essentially a `Dense` layer that acts on two inputs, `h` If you run the last line a few times, you'll notice the output `y` changing slightly even though the input `x` is the same. -There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCell` and `GRUCell`, which are documented in the [layer reference](layers.md). The hand-written example above can be replaced with: +There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCell` and `GRUCell`, which are documented in the [layer reference](../../reference/models/layers.md). The hand-written example above can be replaced with: ```julia using Flux diff --git a/docs/src/performance.md b/docs/src/guide/performance.md similarity index 100% rename from docs/src/performance.md rename to docs/src/guide/performance.md diff --git a/docs/src/saving.md b/docs/src/guide/saving.md similarity index 100% rename from docs/src/saving.md rename to docs/src/guide/saving.md diff --git a/docs/src/reference/training/training.md b/docs/src/guide/training/training.md similarity index 93% rename from docs/src/reference/training/training.md rename to docs/src/guide/training/training.md index 0407820794..02e17eee41 100644 --- a/docs/src/reference/training/training.md +++ b/docs/src/guide/training/training.md @@ -77,15 +77,8 @@ loss(y_hat, y) = sum((y_hat .- y).^2) ``` or write this directly inside the `do` block above. Many commonly used functions, like [`mse`](@ref Flux.Losses.mse) for mean-squared error or [`crossentropy`](@ref Flux.Losses.crossentropy) for cross-entropy loss, -are available from the [`Flux.Losses`](../models/losses.md) module. +are available from the [`Flux.Losses`](../../reference/models/losses.md) module. -!!! compat "Implicit-style loss functions" - Flux ≤ 0.14 needed a loss function which closed over a reference to the model, - instead of being a pure function. Thus in old code you may see something like - ``` - loss(x, y) = sum((model(x) .- y).^2) - ``` - which defines a function making reference to a particular global variable `model`. ## Optimisation Rules @@ -107,13 +100,13 @@ fmap(model, grads[1]) do p, g end ``` -A slightly more refined version of this loop to update all the parameters is wrapped up as a function [`update!`](@ref)`(opt_state, model, grads[1])`. -And the learning rate is the only thing stored in the [`Descent`](@ref) struct. +A slightly more refined version of this loop to update all the parameters is wrapped up as a function [`update!`](@ref Optimisers.update!)`(opt_state, model, grads[1])`. +And the learning rate is the only thing stored in the [`Descent`](@ref Optimisers.Descent) struct. However, there are many other optimisation rules, which adjust the step size and direction in various clever ways. Most require some memory of the gradients from earlier steps, rather than always -walking straight downhill -- [`Momentum`](@ref) is the simplest. +walking straight downhill -- [`Momentum`](@ref Optimisers.Momentum) is the simplest. The function [`setup`](@ref Flux.Train.setup) creates the necessary storage for this, for a particular model. It should be called once, before training, and returns a tree-like object which is the first argument of `update!`. Like this: @@ -130,7 +123,7 @@ for data in train_set end ``` -Many commonly-used optimisation rules, such as [`Adam`](@ref), are built-in. +Many commonly-used optimisation rules, such as [`Adam`](@ref Optimisers.Adam), are built-in. These are listed on the [optimisers](@ref man-optimisers) page. !!! compat "Implicit-style optimiser state" @@ -299,14 +292,14 @@ decay_opt_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model Flux's optimisers are really modifications applied to the gradient before using it to update the parameters, and [`OptimiserChain`](@ref Optimisers.OptimiserChain) applies two such modifications. -The first, [`WeightDecay`](@ref Flux.WeightDecay) adds `0.42` times the original parameter to the gradient, +The first, [`WeightDecay`](@ref Optimisers.WeightDecay) adds `0.42` times the original parameter to the gradient, matching the gradient of the penalty above (with the same, unrealistically large, constant). -After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. +After that, in either case, [`Adam`](@ref Optimisers.Adam) computes the final update. The same trick works for *L₁ regularisation* (also called Lasso), where the penalty is `pen_l1(x::AbstractArray) = sum(abs, x)` instead. This is implemented by `SignDecay(0.42)`. -The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref) or [`ClipNorm`](@ref). +The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref Optimisers.ClipGrad) or [`ClipNorm`](@ref Optimisers.ClipNorm). Besides L1 / L2 / weight decay, another common and quite different kind of regularisation is provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the diff --git a/docs/src/data/mlutils.md b/docs/src/reference/data/mlutils.md similarity index 100% rename from docs/src/data/mlutils.md rename to docs/src/reference/data/mlutils.md diff --git a/docs/src/data/onehot.md b/docs/src/reference/data/onehot.md similarity index 100% rename from docs/src/data/onehot.md rename to docs/src/reference/data/onehot.md diff --git a/docs/src/destructure.md b/docs/src/reference/destructure.md similarity index 97% rename from docs/src/destructure.md rename to docs/src/reference/destructure.md index 16089380c4..2071b5466b 100644 --- a/docs/src/destructure.md +++ b/docs/src/reference/destructure.md @@ -86,4 +86,12 @@ Flux.modules ```@docs Flux.state Flux.loadmodel! +``` + +### KeyPath + +```@docs +Functors.KeyPath +Functors.getkeypath +Functors.haskeypath ``` \ No newline at end of file diff --git a/docs/src/reference/models/functors.md b/docs/src/reference/models/functors.md index 861528cda9..1637a7b8a6 100644 --- a/docs/src/reference/models/functors.md +++ b/docs/src/reference/models/functors.md @@ -1,3 +1,7 @@ +```@meta +CollapsedDocStrings = true +``` + # Recursive transformations from Functors.jl Flux models are deeply nested structures, and [Functors.jl](https://github.com/FluxML/Functors.jl) provides tools needed to explore such objects, apply functions to the parameters they contain, and re-build them. @@ -11,13 +15,20 @@ Flux models are deeply nested structures, and [Functors.jl](https://github.com/F `Functors.jl` has its own [notes on basic usage](https://fluxml.ai/Functors.jl/stable/#Basic-Usage-and-Implementation) for more details. Additionally, the [Advanced Model Building and Customisation](@ref man-advanced) page covers the use cases of `Functors` in greater details. ```@docs +Flux.@layer Functors.@functor Functors.fmap +Functors.fmap_with_path Functors.isleaf Functors.children Functors.fcollect Functors.functor Functors.fmapstructure +Functors.fmapstructure_with_path +Functors.execute +Functors.AbstractWalk +Functors.ExcludeWalk +Functors.CachedWalk ``` ## Moving models, or data, to the GPU diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md index 177a3eca94..509702e30c 100644 --- a/docs/src/reference/models/layers.md +++ b/docs/src/reference/models/layers.md @@ -27,10 +27,6 @@ Flux.Scale Perhaps `Scale` isn't quite fully connected, but it may be thought of as `Dense(Diagonal(s.weights), s.bias)`, and LinearAlgebra's `Diagonal` is a matrix which just happens to contain many zeros. -!!! compat "Flux ≤ 0.12" - Old versions of Flux accepted only `Dense(in, out, act)` and not `Dense(in => out, act)`. - This notation makes a `Pair` object. If you get an error like `MethodError: no method matching Dense(::Pair{Int64,Int64})`, this means that you should upgrade to newer Flux versions. - ## Convolution Models diff --git a/docs/src/reference/models/nnlib.md b/docs/src/reference/models/nnlib.md index d50af652ab..e7739f0ebf 100644 --- a/docs/src/reference/models/nnlib.md +++ b/docs/src/reference/models/nnlib.md @@ -5,7 +5,7 @@ Flux re-exports all of the functions exported by the [NNlib](https://github.com/ ## Attention -Primitives for the [`MultiHeadAttention`](ref) layer. +Primitives for the [`MultiHeadAttention`](@ref) layer. ```@docs NNlib.dot_product_attention diff --git a/docs/src/outputsize.md b/docs/src/reference/outputsize.md similarity index 100% rename from docs/src/outputsize.md rename to docs/src/reference/outputsize.md diff --git a/docs/src/reference/training/optimisers.md b/docs/src/reference/training/optimisers.md new file mode 100644 index 0000000000..dab36eeb3a --- /dev/null +++ b/docs/src/reference/training/optimisers.md @@ -0,0 +1,122 @@ +```@meta +CurrentModule = Flux +CollapsedDocStrings = true +``` + +# [Optimisation Rules](@id man-optimisers) + +Any optimization rule from Optimisers.jl can be used with [`train!`](@ref Flux.Train.train!) and +other training functions. + +For full details of how the new interface works, see the [Optimisers.jl documentation](https://fluxml.ai/Optimisers.jl/dev/). + + +## Optimisers Reference + +All optimisers return an object that, when passed to `train!`, will update the parameters passed to it. + +```@docs +Optimisers.Descent +Optimisers.Momentum +Optimisers.Nesterov +Optimisers.RMSProp +Optimisers.Adam +Optimisers.RAdam +Optimisers.AdaMax +Optimisers.AdaGrad +Optimisers.AdaDelta +Optimisers.AMSGrad +Optimisers.NAdam +Optimisers.AdamW +Optimisers.OAdam +Optimisers.AdaBelief +Optimisers.Lion +``` + +## Composing Optimisers + +Flux (through Optimisers.jl) defines a special kind of optimiser called `OptimiserChain` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient +that will be fed into the next, and the resultant update will be applied to the parameter as usual. A classic use case is where adding decays is desirable. Optimisers.jl defines the basic decay corresponding to an $L_2$ regularization in the loss as `WeightDecay`. + +```julia +opt = OptimiserChain(WeightDecay(1e-4), Descent()) +``` + +Here we apply the weight decay to the `Descent` optimiser. +The resulting optimiser `opt` can be used as any optimiser. + +```julia +w = [randn(10, 10), randn(10, 10)] +opt_state = Flux.setup(opt, w) + +loss(w, x) = Flux.mse(w[1] * x, w[2] * x) + +loss(w, rand(10)) # around 0.9 + +for t = 1:10^5 + g = gradient(w -> loss(w[1], w[2], rand(10)), w) + Flux.update!(opt_state, w, g) +end + +loss(w, rand(10)) # around 0.9 +``` + +It is possible to compose optimisers for some added flexibility. + +```@docs +Optimisers.OptimiserChain +``` + +## Scheduling Optimisers + +In practice, it is fairly common to schedule the learning rate of an optimiser to obtain faster convergence. There are a variety of popular scheduling policies, and you can find implementations of them in [ParameterSchedulers.jl](http://fluxml.ai/ParameterSchedulers.jl/stable). The documentation for ParameterSchedulers.jl provides a more detailed overview of the different scheduling policies, and how to use them with Flux optimisers. Below, we provide a brief snippet illustrating a [cosine annealing](https://arxiv.org/pdf/1608.03983.pdf) schedule with a momentum optimiser. + +First, we import ParameterSchedulers.jl and initialize a cosine annealing schedule to vary the learning rate between `1e-4` and `1e-2` every 10 steps. We also create a new [`Momentum`](@ref Optimisers.Momentum) optimiser. +```julia +using ParameterSchedulers + +opt = Momentum() +schedule = Cos(λ0 = 1e-4, λ1 = 1e-2, period = 10) +for (eta, epoch) in zip(schedule, 1:100) + opt.eta = eta + # your training code here +end +``` +`schedule` can also be indexed (e.g. `schedule(100)`) or iterated like any iterator in Julia. + +ParameterSchedulers.jl schedules are stateless (they don't store their iteration state). If you want a _stateful_ schedule, you can use `ParameterSchedulers.Stateful`: +```julia +using ParameterSchedulers: Stateful, next! + +schedule = Stateful(Cos(λ0 = 1e-4, λ1 = 1e-2, period = 10)) +for epoch in 1:100 + opt.eta = next!(schedule) + # your training code here +end +``` + +ParameterSchedulers.jl allows for many more scheduling policies including arbitrary functions, looping any function with a given period, or sequences of many schedules. See the ParameterSchedulers.jl documentation for more info. + +## Decays + +Similar to optimisers, Flux also defines some simple decays that can be used in conjunction with other optimisers, or standalone. + +```@docs +Optimisers.SignDecay +Optimisers.WeightDecay +``` + +## Gradient Clipping + +Gradient clipping is useful for training recurrent neural networks, which have a tendency to suffer from the exploding gradient problem. An example usage is + +```julia +opt = OptimiserChain(ClipValue(1e-3), Adam(1e-3)) +``` + +```@docs +Optimisers.ClipGrad +Optimisers.ClipNorm +``` + + diff --git a/docs/src/reference/training/reference.md b/docs/src/reference/training/reference.md index 67980831f9..aa55ec0927 100644 --- a/docs/src/reference/training/reference.md +++ b/docs/src/reference/training/reference.md @@ -15,7 +15,9 @@ The available optimization rules are listed the [optimisation rules](@ref man-op ```@docs Flux.Train.setup Flux.Train.train!(loss, model, data, state) +Optimisers.update Optimisers.update! +Optimisers.setup ``` `train!` uses [`@progress`](https://github.com/JuliaLogging/ProgressLogging.jl) which should show a progress bar in VSCode automatically. diff --git a/docs/src/reference/training/zygote.md b/docs/src/reference/training/zygote.md index 33d30d6ee8..5641f4db23 100644 --- a/docs/src/reference/training/zygote.md +++ b/docs/src/reference/training/zygote.md @@ -1,3 +1,7 @@ +```@meta +CollapsedDocStrings = true +``` + # [Automatic Differentiation using Zygote.jl](@id autodiff-zygote) Flux re-exports the `gradient` from [Zygote](https://github.com/FluxML/Zygote.jl), and uses this function within [`train!`](@ref Flux.train!) to differentiate the model. Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/). @@ -16,6 +20,7 @@ Zygote.withjacobian(f, args...) Zygote.hessian Zygote.hessian_reverse Zygote.diaghessian +Zygote.pullback ``` ## ChainRules @@ -35,4 +40,7 @@ ChainRulesCore.frule ChainRulesCore.@scalar_rule ChainRulesCore.NoTangent ChainRulesCore.ZeroTangent +ChainRulesCore.RuleConfig +ChainRulesCore.Tangent +ChainRulesCore.canonicalize ``` diff --git a/docs/src/training/callbacks.md b/docs/src/training/callbacks.md deleted file mode 100644 index 148aa02128..0000000000 --- a/docs/src/training/callbacks.md +++ /dev/null @@ -1,75 +0,0 @@ -# [Callback Helpers](@id man-callback-helpers) - -```@docs -Flux.throttle -``` - -## Patience Helpers - -Flux provides utilities for controlling your training procedure according to some monitored condition and a maximum `patience`. For example, you can use `early_stopping` to stop training when the model is converging or deteriorating, or you can use `plateau` to check if the model is stagnating. - -For example, below we create a pseudo-loss function that decreases, bottoms out, and then increases. The early stopping trigger will break the loop before the loss increases too much. -```julia -# create a pseudo-loss that decreases for 4 calls, then starts increasing -# we call this like loss() -loss = let t = 0 - () -> begin - t += 1 - (t - 4) ^ 2 - end -end - -# create an early stopping trigger -# returns true when the loss increases for two consecutive steps -es = early_stopping(loss, 2; init_score = 9) - -# this will stop at the 6th (4 decreasing + 2 increasing calls) epoch -for epoch in 1:10 - es() && break -end -``` - -The keyword argument `distance` of `early_stopping` is a function of the form `distance(best_score, score)`. By default `distance` is `-`, which implies that the monitored metric `f` is expected to be decreasing and minimized. If you use some increasing metric (e.g. accuracy), you can customize the `distance` function: `(best_score, score) -> score - best_score`. -```julia -# create a pseudo-accuracy that increases by 0.01 each time from 0 to 1 -# we call this like acc() -acc = let v = 0 - () -> v = max(1, v + 0.01) -end - -# create an early stopping trigger for accuracy -es = early_stopping(acc, 3; delta = (best_score, score) -> score - best_score) - -# this will iterate until the 10th epoch -for epoch in 1:10 - es() && break -end -``` - -`early_stopping` and `plateau` are both built on top of `patience`. You can use `patience` to build your own triggers that use a patient counter. For example, if you want to trigger when the loss is below a threshold for several consecutive iterations: -```julia -threshold(f, thresh, delay) = patience(delay) do - f() < thresh -end -``` - -Both `predicate` in `patience` and `f` in `early_stopping` / `plateau` can accept extra arguments. You can pass such extra arguments to `predicate` or `f` through the returned function: -```julia -trigger = patience((a; b) -> a > b, 3) - -# this will iterate until the 10th epoch -for epoch in 1:10 - trigger(1; b = 2) && break -end - -# this will stop at the 3rd epoch -for epoch in 1:10 - trigger(3; b = 2) && break -end -``` - -```@docs -Flux.patience -Flux.early_stopping -Flux.plateau -``` diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md deleted file mode 100644 index 67980831f9..0000000000 --- a/docs/src/training/reference.md +++ /dev/null @@ -1,36 +0,0 @@ -# Training API Reference - -The new version of Flux's training code was written as an independent package, [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). -Only the function `train!` belongs to Flux itself. - -The Optimisers package is designed to allow for immutable objects. But at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s) which can be updated in-place. -Because of this: - -* The objects returned by `Optimisers.update!` can be ignored. -* Flux defines its own version of `setup` which checks this assumption. - (Using instead `Optimisers.setup` will also work, they return the same thing.) - -The available optimization rules are listed the [optimisation rules](@ref man-optimisers) page here. See the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the rules work. - -```@docs -Flux.Train.setup -Flux.Train.train!(loss, model, data, state) -Optimisers.update! -``` - -`train!` uses [`@progress`](https://github.com/JuliaLogging/ProgressLogging.jl) which should show a progress bar in VSCode automatically. -To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl) -and follow its setup instructions. - -## Optimisation Modifiers - -The state returned by `setup` can be modified to temporarily prevent training of -some parts of the model, or to change the learning rate or other hyperparameter. -The functions for doing so may be accessed as `Flux.freeze!`, `Flux.thaw!`, and `Flux.adjust!`. -All mutate the state (or part of it) and return `nothing`. - -```@docs -Optimisers.adjust! -Optimisers.freeze! -Optimisers.thaw! -``` diff --git a/docs/src/training/training.md b/docs/src/training/training.md deleted file mode 100644 index 0407820794..0000000000 --- a/docs/src/training/training.md +++ /dev/null @@ -1,365 +0,0 @@ -# [Training a Flux Model](@id man-training) - -Training refers to the process of slowly adjusting the parameters of a model to make it work better. -Besides the model itself, we will need three things: - -* An *objective function* that evaluates how well a model is doing on some input. -* An *optimisation rule* which describes how the model's parameters should be adjusted. -* Some *training data* to use as the input during this process. - -Usually the training data is some collection of examples (or batches of examples) which -are handled one-by-one. One *epoch* of training means that each example is used once, -something like this: - -```julia -# Initialise the optimiser for this model: -opt_state = Flux.setup(rule, model) - -for data in train_set - # Unpack this element (for supervised training): - input, label = data - - # Calculate the gradient of the objective - # with respect to the parameters within the model: - grads = Flux.gradient(model) do m - result = m(input) - loss(result, label) - end - - # Update the parameters so as to reduce the objective, - # according the chosen optimisation rule: - Flux.update!(opt_state, model, grads[1]) -end -``` - -This loop can also be written using the function [`train!`](@ref Flux.Train.train!), -but it's helpful to understand the pieces first: - -```julia -train!(model, train_set, opt_state) do m, x, y - loss(m(x), y) -end -``` - -## Model Gradients - -Fist recall from the section on [taking gradients](@ref man-taking-gradients) that -`Flux.gradient(f, a, b)` always calls `f(a, b)`, and returns a tuple `(∂f_∂a, ∂f_∂b)`. -In the code above, the function `f` passed to `gradient` is an anonymous function with -one argument, created by the `do` block, hence `grads` is a tuple with one element. -Instead of a `do` block, we could have written: - -```julia -grads = Flux.gradient(m -> loss(m(input), label), model) -``` - -Since the model is some nested set of layers, `grads[1]` is a similarly nested set of -`NamedTuple`s, ultimately containing gradient components. If (for example) -`θ = model.layers[1].weight[2,3]` is one scalar parameter, an entry in a matrix of weights, -then the derivative of the loss with respect to it is `∂f_∂θ = grads[1].layers[1].weight[2,3]`. - -It is important that the execution of the model takes place inside the call to `gradient`, -in order for the influence of the model's parameters to be observed by Zygote. - -It is also important that every `update!` step receives a newly computed gradient, -as it will change whenever the model's parameters are changed, and for each new data point. - - -## Loss Functions - -The objective function must return a number representing how far the model is from -the desired result. This is termed the *loss* of the model. - -This number can be produced by any ordinary Julia code, but this must be executed -within the call to `gradient`. For instance, we could define a function -```julia -loss(y_hat, y) = sum((y_hat .- y).^2) -``` -or write this directly inside the `do` block above. Many commonly used functions, -like [`mse`](@ref Flux.Losses.mse) for mean-squared error or [`crossentropy`](@ref Flux.Losses.crossentropy) for cross-entropy loss, -are available from the [`Flux.Losses`](../models/losses.md) module. - -!!! compat "Implicit-style loss functions" - Flux ≤ 0.14 needed a loss function which closed over a reference to the model, - instead of being a pure function. Thus in old code you may see something like - ``` - loss(x, y) = sum((model(x) .- y).^2) - ``` - which defines a function making reference to a particular global variable `model`. - -## Optimisation Rules - -The simplest kind of optimisation using the gradient is termed *gradient descent* -(or sometimes *stochastic gradient descent* when, as here, it is not applied to the entire dataset at once). - -Gradient descent needs a *learning rate* which is a small number describing how fast to walk downhill, -usually written as the Greek letter "eta", `η`. This is often described as a *hyperparameter*, -to distinguish it from the parameters which are being updated `θ = θ - η * ∂loss_∂θ`. -We want to update all the parameters in the model, like this: - -```julia -η = 0.01 # learning rate - -# For each parameter array, update -# according to the corresponding gradient: -fmap(model, grads[1]) do p, g - p .= p .- η .* g -end -``` - -A slightly more refined version of this loop to update all the parameters is wrapped up as a function [`update!`](@ref)`(opt_state, model, grads[1])`. -And the learning rate is the only thing stored in the [`Descent`](@ref) struct. - -However, there are many other optimisation rules, which adjust the step size and -direction in various clever ways. -Most require some memory of the gradients from earlier steps, rather than always -walking straight downhill -- [`Momentum`](@ref) is the simplest. -The function [`setup`](@ref Flux.Train.setup) creates the necessary storage for this, for a particular model. -It should be called once, before training, and returns a tree-like object which is the -first argument of `update!`. Like this: - -```julia -# Initialise momentum -opt_state = Flux.setup(Momentum(0.01, 0.9), model) - -for data in train_set - grads = [...] - - # Update both model parameters and optimiser state: - Flux.update!(opt_state, model, grads[1]) -end -``` - -Many commonly-used optimisation rules, such as [`Adam`](@ref), are built-in. -These are listed on the [optimisers](@ref man-optimisers) page. - -!!! compat "Implicit-style optimiser state" - This `setup` makes another tree-like structure. Old versions of Flux did not do this, - and instead stored a dictionary-like structure within the optimiser `Adam(0.001)`. - This was initialised on first use of the version of `update!` for "implicit" parameters. - - -## Datasets & Batches - -The loop above iterates through `train_set`, expecting at each step a tuple `(input, label)`. -The very simplest such object is a vector of tuples, such as this: - -```julia -x = randn(28, 28) -y = rand(10) -data = [(x, y)] -``` - -or `data = [(x, y), (x, y), (x, y)]` for the same values three times. - -Very often, the initial data is large arrays which you need to slice into examples. -To produce one iterator of pairs `(x, y)`, you might want `zip`: - -```julia -X = rand(28, 28, 60_000); # many images, each 28 × 28 -Y = rand(10, 60_000) -data = zip(eachslice(X; dims=3), eachcol(Y)) - -first(data) isa Tuple{AbstractMatrix, AbstractVector} # true -``` - -Here each iteration will use one matrix `x` (an image, perhaps) and one vector `y`. -It is very common to instead train on *batches* of such inputs (or *mini-batches*, -the two words mean the same thing) both for efficiency and for better results. -This can be easily done using the [`DataLoader`](@ref Flux.Data.DataLoader): - -```julia -data = Flux.DataLoader((X, Y), batchsize=32) - -x1, y1 = first(data) -size(x1) == (28, 28, 32) -length(data) == 1875 === 60_000 ÷ 32 -``` - -Flux's layers are set up to accept such a batch of input data, -and the convolutional layers such as [`Conv`](@ref Flux.Conv) require it. -The batch index is always the last dimension. - -## Training Loops - -Simple training loops like the one above can be written compactly using -the [`train!`](@ref Flux.Train.train!) function. Including `setup`, this reads: - -```julia -opt_state = Flux.setup(Adam(), model) - -for epoch in 1:100 - Flux.train!(model, train_set, opt_state) do m, x, y - loss(m(x), y) - end -end -``` - -Or explicitly writing the anonymous function which this `do` block creates, -`train!((m,x,y) -> loss(m(x),y), model, train_set, opt_state)` is exactly equivalent. - -Real training loops often need more flexibility, and the best way to do this is just -to write the loop. This is ordinary Julia code, without any need to work through some -callback API. Here is an example, in which it may be helpful to note: - -* The function [`withgradient`](@ref Zygote.withgradient) is like `gradient` but also - returns the value of the function, for logging or diagnostic use. -* Logging or printing is best done outside of the `gradient` call, - as there is no need to differentiate these commands. -* To use `result` for logging purposes, you could change the `do` block to end with - `return my_loss(result, label), result`, i.e. make the function passed to `withgradient` - return a tuple. The first element is always the loss. -* Julia's `break` and `continue` keywords let you exit from parts of the loop. - -```julia -opt_state = Flux.setup(Adam(), model) - -my_log = [] -for epoch in 1:100 - losses = Float32[] - for (i, data) in enumerate(train_set) - input, label = data - - val, grads = Flux.withgradient(model) do m - # Any code inside here is differentiated. - # Evaluation of the model and loss must be inside! - result = m(input) - my_loss(result, label) - end - - # Save the loss from the forward pass. (Done outside of gradient.) - push!(losses, val) - - # Detect loss of Inf or NaN. Print a warning, and then skip update! - if !isfinite(val) - @warn "loss is $val on item $i" epoch - continue - end - - Flux.update!(opt_state, model, grads[1]) - end - - # Compute some accuracy, and save details as a NamedTuple - acc = my_accuracy(model, train_set) - push!(my_log, (; acc, losses)) - - # Stop training when some criterion is reached - if acc > 0.95 - println("stopping after $epoch epochs") - break - end -end -``` - -## Regularisation - -The term *regularisation* covers a wide variety of techniques aiming to improve the -result of training. This is often done to avoid overfitting. - -Some of these can be implemented by simply modifying the loss function. -*L₂ regularisation* (sometimes called ridge regression) adds to the loss a penalty -proportional to `θ^2` for every scalar parameter. -A very simple model could be implemented as follows: - -```julia -grads = Flux.gradient(densemodel) do m - result = m(input) - penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 - my_loss(result, label) + 0.42f0 * penalty -end -``` - -Accessing each individual parameter array by hand won't work well for large models. -Instead, we can use [`Flux.trainables`](@ref Optimisers.trainables) to collect all of them, -and then apply a function to each one, and sum the result: - -```julia -pen_l2(x::AbstractArray) = sum(abs2, x)/2 - -grads = Flux.gradient(model) do m - result = m(input) - penalty = sum(pen_l2, Flux.trainables(m)) - my_loss(result, label) + 0.42f0 * penalty -end -``` - -However, the gradient of this penalty term is very simple: It is proportional to the original weights. -So there is a simpler way to implement exactly the same thing, by modifying the optimiser -instead of the loss function. This is done by replacing this: - -```julia -opt_state = Flux.setup(Adam(0.1), model) -``` - -with this: - -```julia -decay_opt_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model) -``` - -Flux's optimisers are really modifications applied to the gradient before using it to update -the parameters, and [`OptimiserChain`](@ref Optimisers.OptimiserChain) applies two such modifications. -The first, [`WeightDecay`](@ref Flux.WeightDecay) adds `0.42` times the original parameter to the gradient, -matching the gradient of the penalty above (with the same, unrealistically large, constant). -After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. - -The same trick works for *L₁ regularisation* (also called Lasso), where the penalty is -`pen_l1(x::AbstractArray) = sum(abs, x)` instead. This is implemented by `SignDecay(0.42)`. - -The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref) or [`ClipNorm`](@ref). - -Besides L1 / L2 / weight decay, another common and quite different kind of regularisation is -provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the -previous layer during training. -It should switch automatically, but see [`trainmode!`](@ref Flux.trainmode!) / [`testmode!`](@ref Flux.testmode!) to manually enable or disable this layer. - -## Learning Rate Schedules - -Finer control of training, you may wish to alter the learning rate mid-way through training. -This can be done with [`adjust!`](@ref Flux.adjust!), like this: - -```julia -opt_state = Flux.setup(Adam(0.1), model) # initialise once - -for epoch in 1:1000 - train!([...], state) # Train with η = 0.1 for first 100, - if epoch == 100 # then change to use η = 0.01 for the rest. - Flux.adjust!(opt_state, 0.01) - end -end -``` - -Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt_state, beta = (0.8, 0.99))`. -And such modifications can be applied to just one part of the model. -For instance, this sets a different learning rate for the encoder and the decoder: - -```julia -# Consider some model with two parts: -bimodel = Chain(enc = [...], dec = [...]) - -# This returns a tree whose structure matches the model: -opt_state = Flux.setup(Adam(0.02), bimodel) - -# Adjust the learning rate to be used for bimodel.layers.enc -Flux.adjust!(opt_state.layers.enc, 0.03) -``` - -## Freezing layer parameters - -To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!). -This is a temporary modification, reversed by `thaw!`: - -```julia -Flux.freeze!(opt_state.layers.enc) - -# Now training won't update parameters in bimodel.layers.enc -train!(loss, bimodel, data, opt_state) - -# Un-freeze the entire model: -Flux.thaw!(opt_state) -``` - -While `adjust!` and `freeze!`/`thaw!` make temporary modifications to the optimiser state, -permanently removing some fields of a new layer type from training is usually done -when defining the layer, by calling for example [`@layer`](@ref Flux.@layer)` NewLayer trainable=(weight,)`. - diff --git a/docs/src/training/zygote.md b/docs/src/training/zygote.md deleted file mode 100644 index 33d30d6ee8..0000000000 --- a/docs/src/training/zygote.md +++ /dev/null @@ -1,38 +0,0 @@ -# [Automatic Differentiation using Zygote.jl](@id autodiff-zygote) - -Flux re-exports the `gradient` from [Zygote](https://github.com/FluxML/Zygote.jl), and uses this function within [`train!`](@ref Flux.train!) to differentiate the model. Zygote has its own [documentation](https://fluxml.ai/Zygote.jl/dev/), in particular listing some [important limitations](https://fluxml.ai/Zygote.jl/dev/limitations/). - - -## Explicit style - -The preferred way of using Zygote, and the only way of using most other AD packages, -is to explicitly provide a function and its arguments. - -```@docs -Zygote.gradient(f, args...) -Zygote.withgradient(f, args...) -Zygote.jacobian(f, args...) -Zygote.withjacobian(f, args...) -Zygote.hessian -Zygote.hessian_reverse -Zygote.diaghessian -``` - -## ChainRules - -Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using [ChainRules](https://github.com/JuliaDiff/ChainRules.jl): - -```@docs -ChainRulesCore.ignore_derivatives -ChainRulesCore.@non_differentiable -``` - -To manually supply the gradient for one function, you should define a method of `rrule`. ChainRules has [detailed documentation](https://juliadiff.org/ChainRulesCore.jl/stable/) on how this works. - -```@docs -ChainRulesCore.rrule -ChainRulesCore.frule -ChainRulesCore.@scalar_rule -ChainRulesCore.NoTangent -ChainRulesCore.ZeroTangent -``` diff --git a/docs/src/tutorials/2021-01-26-mlp.md b/docs/src/tutorials/mlp.md similarity index 79% rename from docs/src/tutorials/2021-01-26-mlp.md rename to docs/src/tutorials/mlp.md index 763f711195..fed886b49e 100644 --- a/docs/src/tutorials/2021-01-26-mlp.md +++ b/docs/src/tutorials/mlp.md @@ -6,11 +6,11 @@ To run this example, we need the following packages: ```julia using Flux, Statistics -using Flux.Data: DataLoader -using Flux: onehotbatch, onecold, logitcrossentropy, throttle, params -using Base.Iterators: repeated +using Flux: DataLoader +using Flux: onehotbatch, onecold, logitcrossentropy using CUDA -using MLDatasets +using MLDatasets: MNIST + if has_cuda() # Check if CUDA is available @info "CUDA is on" CUDA.allowscalar(false) @@ -24,7 +24,7 @@ Base.@kwdef mutable struct Args rate::Float64 = 3e-4 # learning rate batchsize::Int = 1024 # batch size epochs::Int = 10 # number of epochs - device::Function = gpu # set as gpu, if gpu available + usegpu::Bool = true end ``` @@ -40,8 +40,8 @@ function getdata(args) ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # Loading Dataset - xtrain, ytrain = MLDatasets.MNIST.traindata(Float32) - xtest, ytest = MLDatasets.MNIST.testdata(Float32) + xtrain, ytrain = MNIST(:train)[:] + xtest, ytest = MNIST(:test)[:] # Reshape Data in order to flatten each image into a linear array xtrain = Flux.flatten(xtrain) @@ -51,10 +51,10 @@ function getdata(args) ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9) # Batching - train_data = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true) - test_data = DataLoader((xtest, ytest), batchsize=args.batchsize) + train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true) + test_loader = DataLoader((xtest, ytest), batchsize=args.batchsize) - return train_data, test_data + return train_loader, test_loader end ``` @@ -90,15 +90,17 @@ Note that we use the functions [Dense](https://fluxml.ai/Flux.jl/stable/models/l ## Loss functions -Now, we define the loss function `loss_all`. It expects a DataLoader object and the `model` function we defined above as arguments. Notice that this function iterates through the `dataloader` object in mini-batches and uses the function [logitcrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitcrossentropy) to compute the difference between the predicted and actual values. +Now, we define the loss function `loss_all`. It expects a DataLoader object and the `model` function we defined above as arguments. Notice that this function iterates through the `DataLoader` object in mini-batches and uses the function [logitcrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitcrossentropy) to compute the difference between the predicted and actual values. ```julia function loss_all(dataloader, model) l = 0f0 - for (x,y) in dataloader - l += logitcrossentropy(model(x), y) + n = 0 + for (x, y) in dataloader + l += logitcrossentropy(model(x), y, agg=sum) + n += MLUtils.numobs(x) end - l/length(dataloader) + return l / n end ``` @@ -106,12 +108,14 @@ end In addition, we define the function (`accuracy`) to report the accuracy of our model during the training process. To compute the accuray, we need to decode the output of our model using the [onecold](https://fluxml.ai/Flux.jl/stable/data/onehot/#Flux.onecold) function. ```julia -function accuracy(data_loader, model) +function accuracy(dataloader, model) acc = 0 - for (x,y) in data_loader - acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2) + n = 0 + for (x, y) in dataloader + acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y))) + n += MLUtils.numobs(x) end - acc/length(data_loader) + return acc / n end ``` @@ -125,28 +129,29 @@ function train(; kws...) # Initializing Model parameters args = Args(; kws...) + device = args.usegpu ? Flux.get_device() : Flux.get_device("CPU") + # Load Data - train_data,test_data = getdata(args) + train_loader, test_loader = getdata(args) # Construct model - m = build_model() - train_data = args.device.(train_data) - test_data = args.device.(test_data) - m = args.device(m) - loss(x,y) = logitcrossentropy(m(x), y) + model = build_model() |> device + + loss(model, x, y) = logitcrossentropy(model(x), y) ## Training - evalcb = () -> @show(loss_all(train_data, m)) - opt = Adam(args.rate) + opt_state = Flux.setup(Adam(args.rate), model) for epoch in 1:args.epochs @info "Epoch $epoch" - Flux.train!(loss, params(m), train_data, opt, cb = evalcb) + for d in train_loader + x, y = d |> device + g = gradient(m -> loss(m, x, y), model)[1] + Flux.update!(opt_state, model, g) + end + @show accuracy(train_loader, model) + @show accuracy(test_loader, model) end - - @show accuracy(train_data, m) - - @show accuracy(test_data, m) end ``` @@ -156,7 +161,7 @@ end * **Initializes the model parameters:** Creates the `args` object that contains the defult values for training our model. * **Loads the train and test data:** Calls the function `getdata` we defined above. * **Constructs the model:** Builds the model and loads the train and test data sets, and our model onto the GPU (if available). -* **Trains the model:** Defines the *callback* function `evalcb` to show the value of the `loss_all` function during the training process. Then, it sets [Adam](@ref Flux.Optimise.Adam) as the optimiser for training out model. Finally, it runs the training process for `10` epochs (as defined in the `args` object) and shows the `accuracy` value for the train and test data. +* **Trains the model:** Sets [Adam](@ref Optimisers.Adam) as the optimiser for training out model, runs the training process for `10` epochs (as defined in the `args` object) and shows the `accuracy` value for the train and test data. To see the full version of this example, see [Simple multi-layer perceptron - model-zoo](https://github.com/FluxML/model-zoo/blob/master/vision/mlp_mnist/mlp_mnist.jl). diff --git a/src/Flux.jl b/src/Flux.jl index a8720b7905..e9cc78b5dd 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,7 +11,7 @@ using MacroTools: @forward using MLUtils const stack = MLUtils.stack # now exported by Base import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions -using Optimisers: freeze!, thaw!, adjust! +using Optimisers: freeze!, thaw!, adjust!, trainables using Random: default_rng using Zygote, ChainRulesCore using Zygote: Params, @adjoint, gradient, pullback @@ -21,7 +21,8 @@ export gradient # Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.) Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`") -export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion, +export Chain, Dense, Embedding, EmbeddingBag, + Maxout, SkipConnection, Parallel, PairwiseFusion, RNN, LSTM, GRU, GRUv3, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, @@ -64,7 +65,7 @@ include("functor.jl") # from Functors.jl functor, @functor, # from Optimise/Train/Optimisers.jl - setup, update!, destructure, freeze!, adjust!, params, trainable + setup, update!, destructure, freeze!, adjust!, params, trainable, trainables )) # Pirate error to catch a common mistake. diff --git a/src/functor.jl b/src/functor.jl index e0168edf6b..689012ea4a 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -189,10 +189,18 @@ _isleaf(::AbstractRNG) = true _isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) # the order below is important -const GPU_BACKENDS = ["CUDA", "AMDGPU", "Metal", "CPU"] +const GPU_BACKENDS = ("CUDA", "AMDGPU", "Metal", "CPU") const GPU_BACKEND_ORDER = Dict(collect(zip(GPU_BACKENDS, 1:length(GPU_BACKENDS)))) const GPU_BACKEND = @load_preference("gpu_backend", "CUDA") +""" + gpu_backend!(backend::String) + +Set the GPU backend to `backend` in the `LocalPreferences.toml` file in you project directory. +After restarting Julia, the new backend will affect all subsequent calls to [`gpu`](@ref) and [`get_device`](@ref). + +The supported backends are `"CUDA"`, `"AMDGPU"`, `"Metal"` and `"CPU"`. +""" function gpu_backend!(backend::String) if backend == GPU_BACKEND @info """ diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ef81c30872..e28f11a0b9 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -129,25 +129,26 @@ The weight matrix and/or the bias vector (of length `out`) may also be provided # Examples ```jldoctest -julia> d = Dense(5 => 2) +julia> model = Dense(5 => 2) Dense(5 => 2) # 12 parameters -julia> d(rand32(5, 64)) |> size +julia> model(rand32(5, 64)) |> size (2, 64) -julia> d(rand32(5, 6, 4, 64)) |> size # treated as three batch dimensions +julia> model(rand32(5, 6, 4, 64)) |> size # treated as three batch dimensions (2, 6, 4, 64) -julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix +julia> model2 = Dense(ones(2, 5), false, tanh) # using provided weight matrix Dense(5 => 2, tanh; bias=false) # 10 parameters -julia> d1(ones(5)) +julia> model2(ones(5)) 2-element Vector{Float64}: 0.9999092042625951 0.9999092042625951 -julia> Flux.params(d1) # no trainable bias -Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]]) +julia> Flux.trainables(model2) # no trainable bias +1-element Vector{AbstractArray}: + [1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0] ``` """ struct Dense{F, M<:AbstractMatrix, B} @@ -218,24 +219,27 @@ Used by [`LayerNorm`](@ref) with `affine=true`. julia> a = Flux.Scale(2) Scale(2) # 4 parameters -julia> Flux.params(a) -Params([Float32[1.0, 1.0], Float32[0.0, 0.0]]) +julia> Flux.trainables(a) +2-element Vector{AbstractArray}: + Float32[1.0, 1.0] + Float32[0.0, 0.0] julia> a([1 2 3]) 2×3 Matrix{Float32}: 1.0 2.0 3.0 1.0 2.0 3.0 -julia> b = Flux.Scale([1 2 3 4], false, abs2) +julia> b = Flux.Scale(Float32[1 2 3 4], false, abs2) Scale(1, 4, abs2; bias=false) # 4 parameters julia> b([1, 10]) -2×4 Matrix{Int64}: - 1 4 9 16 - 100 400 900 1600 +2×4 Matrix{Float32}: + 1.0 4.0 9.0 16.0 + 100.0 400.0 900.0 1600.0 -julia> Flux.params(b) -Params([[1 2 3 4]]) +julia> Flux.trainables(b) +1-element Vector{AbstractArray}: + Float32[1.0 2.0 3.0 4.0] ``` """ struct Scale{F, A<:AbstractArray, B} @@ -490,7 +494,7 @@ julia> model = Chain(Dense(3 => 5), julia> model(rand32(3)) |> size (17,) -julia> model2 = Parallel(+; α = Dense(10, 2, tanh), β = Dense(5, 2)) +julia> model2 = Parallel(+; α = Dense(10 => 2, tanh), β = Dense(5 => 2)) Parallel( +, α = Dense(10 => 2, tanh), # 22 parameters @@ -770,7 +774,7 @@ The "bag" may equivalently be represented as a `OneHotMatrix`. A collection of t or one higher-rank `OneHotArray`, again produce a stack of embeddings. See details below. # Examples -```jldoctest +```jldoctest ebag julia> vocab_size = 26; # embed into 3 dimensions, with non-random vectors: julia> eb = EmbeddingBag(vocab_size => 3, init=Flux.identity_init(gain=100)) @@ -809,11 +813,11 @@ and a vector `at` stating where to split that up into "bags". The first bag starts with `data[at[1]]`, the second at `data[at[2]]`, and so on, with no overlaps and nothing left out (thus it requires `at[1]==1`). -```jldoctest +```jldoctest ebag julia> data = [11, 1, 12, 2, 13, 3, 14]; -julia> Flux._splitat(data, [1, 4]) |> println # internal function, makes data[1:3], data[4:end] -[[11, 1, 12], [2, 13, 3, 14]] +julia> data[1:3], data[4:end] +([11, 1, 12], [2, 13, 3, 14]) julia> eb(data, [1, 4]) # two bags, of 3 and 4 items 3×2 Matrix{Float32}: @@ -824,7 +828,7 @@ julia> eb(data, [1, 4]) # two bags, of 3 and 4 items Finally, each bag may also be also be represented as a [`OneHotMatrix`](@ref OneHotArrays.onehotbatch). -```jldoctest +```jldoctest ebag julia> eb(Flux.onehotbatch("bba", 'a':'z')) # same as [2,2,1], one bag of 3 items 3-element Vector{Float32}: 33.333332 @@ -843,7 +847,7 @@ struct EmbeddingBag{F, W<:AbstractMatrix} reduction::F end -@functor EmbeddingBag +@layer EmbeddingBag EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction) EmbeddingBag(weight::AbstractMatrix) = EmbeddingBag(weight, mean) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 111207d479..7bd3f9b277 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -56,7 +56,7 @@ For each `d in data`, first the gradient of the `loss` is computed like this: gradient(() -> loss(d...), pars) # if d isa Tuple gradient(() -> loss(d), pars) # otherwise ``` -Here `pars` is produced by calling [`Flux.params`](@ref) on your model. +Here `pars` is produced by calling `Flux.params` on your model. (Or just on the layers you want to train, like `train!(loss, params(model[1:end-2]), data, opt)`.) This is the "implicit" style of parameter handling. diff --git a/src/train.jl b/src/train.jl index fd21e53f17..e72eedebf3 100644 --- a/src/train.jl +++ b/src/train.jl @@ -21,16 +21,12 @@ It differs from `Optimisers.setup` in that it: * has methods which accept Flux's old optimisers, and convert them. (The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.) -!!! compat "New" - This function was added in Flux 0.13.9. It was not used by the old "implicit" - interface, using `Flux.Optimise` module and [`Flux.params`](@ref). - # Example ```jldoctest -julia> model = Dense(2=>1, leakyrelu; init=ones); +julia> model = Dense(2 => 1, leakyrelu; init=ones); julia> opt_state = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state -(weight = Leaf(Momentum{Float64}(0.1, 0.9), [0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [0.0]), σ = ()) +(weight = Leaf(Momentum(0.1, 0.9), [0.0 0.0]), bias = Leaf(Momentum(0.1, 0.9), [0.0]), σ = ()) julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps: @@ -43,7 +39,7 @@ julia> model.bias # was zero, mutated by Flux.train! 10.19 julia> opt_state # mutated by Flux.train! -(weight = Leaf(Momentum{Float64}(0.1, 0.9), [-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), [-10.09]), σ = ()) +(weight = Leaf(Momentum(0.1, 0.9), [-2.018 3.027]), bias = Leaf(Momentum(0.1, 0.9), [-10.09]), σ = ()) ``` """ function setup(rule::Optimisers.AbstractRule, model) @@ -90,7 +86,7 @@ It adds only a few features to the loop above: !!! compat "New" This method was added in Flux 0.13.9. It has significant changes from the one used by Flux ≤ 0.13: - * It now takes the `model` itself, not the result of [`Flux.params`](@ref). + * It now takes the `model` itself, not the result of `Flux.params`. (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) * Instead of `loss` being a function which accepts only the data, now it must also accept the `model` itself, as the first argument. diff --git a/src/utils.jl b/src/utils.jl index 1f8230c522..8fa3889a11 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -61,10 +61,10 @@ julia> Flux.glorot_uniform(3, 4) |> summary "3×4 Matrix{Float32}" julia> round.(extrema(Flux.glorot_uniform(10, 100)), digits=3) -(-0.232f0, 0.234f0) +(-0.233f0, 0.233f0) julia> round.(extrema(Flux.glorot_uniform(100, 10)), digits=3) -(-0.233f0, 0.233f0) +(-0.234f0, 0.233f0) julia> round.(extrema(Flux.glorot_uniform(100, 100)), digits=3) (-0.173f0, 0.173f0) @@ -109,7 +109,7 @@ julia> round(std(Flux.glorot_normal(10, 1000)), digits=3) 0.044f0 julia> round(std(Flux.glorot_normal(1000, 10)), digits=3) -0.044f0 +0.045f0 julia> round(std(Flux.glorot_normal(1000, 1000)), digits=3) 0.032f0 @@ -146,10 +146,10 @@ This method is described in [1] and also known as He initialization. # Examples ```jldoctest; setup = :(using Random; Random.seed!(0)) julia> round.(extrema(Flux.kaiming_uniform(100, 10)), digits=3) -(-0.774f0, 0.774f0) +(-0.774f0, 0.773f0) julia> round.(extrema(Flux.kaiming_uniform(10, 100)), digits=3) -(-0.245f0, 0.244f0) +(-0.243f0, 0.245f0) julia> round.(extrema(Flux.kaiming_uniform(100, 100)), digits=3) (-0.245f0, 0.245f0) @@ -183,10 +183,10 @@ This method is described in [1] and also known as He initialization. julia> using Statistics julia> round(std(Flux.kaiming_normal(10, 1000)), digits=3) -0.045f0 +0.044f0 julia> round(std(Flux.kaiming_normal(1000, 10)), digits=3) -0.447f0 +0.449f0 julia> round(std(Flux.kaiming_normal(1000, 1000)), digits=3) 0.045f0 From 4966b586ac47f30705c2210d6a823803a484149f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 7 Apr 2024 20:14:39 +0200 Subject: [PATCH 4/6] fixes for mlp tutorial --- Project.toml | 6 ++---- docs/src/tutorials/mlp.md | 10 ++++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 7cc923c531..e256b39315 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.14.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -77,7 +78,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", - "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU", - "Enzyme", "FiniteDifferences", "Tracker"] - +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU", "Enzyme", "FiniteDifferences", "Tracker"] diff --git a/docs/src/tutorials/mlp.md b/docs/src/tutorials/mlp.md index fed886b49e..d939685122 100644 --- a/docs/src/tutorials/mlp.md +++ b/docs/src/tutorials/mlp.md @@ -8,13 +8,9 @@ To run this example, we need the following packages: using Flux, Statistics using Flux: DataLoader using Flux: onehotbatch, onecold, logitcrossentropy -using CUDA +# using CUDA # Uncomment this line if you have a nvidia GPU. Also AMDGPU and Metal are supported. using MLDatasets: MNIST - -if has_cuda() # Check if CUDA is available - @info "CUDA is on" - CUDA.allowscalar(false) -end +using MLUtils ``` We set default values for learning rate, batch size, epochs, and the usage of a GPU (if available) for our model: @@ -153,6 +149,8 @@ function train(; kws...) @show accuracy(test_loader, model) end end + +train() ``` From fcf62367c61e4276539dc69b467fbd24550464e9 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 9 Apr 2024 06:05:30 +0200 Subject: [PATCH 5/6] compat --- Project.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index e256b39315..eaa6f22073 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ version = "0.14.15" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -49,7 +48,7 @@ MacroTools = "0.5" Metal = "0.5, 1" NNlib = "0.9.1" OneHotArrays = "0.2.4" -Optimisers = "0.3.2" +Optimisers = "0.3.3" Preferences = "1" ProgressLogging = "0.1" Reexport = "1.0" From 15727cb771c1e07f7dc5944f006665607be5f062 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 10 Apr 2024 06:13:32 +0200 Subject: [PATCH 6/6] blitz update and model zoo page --- Project.toml | 2 + docs/make.jl | 6 +- ...p-learning-flux.md => 2024-04-10-blitz.md} | 140 ++++++++++-------- .../2024-04-10-mlp.md} | 0 docs/old_tutorials/README.md | 7 + docs/src/tutorials/model_zoo.md | 10 ++ src/Flux.jl | 2 +- 7 files changed, 101 insertions(+), 66 deletions(-) rename docs/old_tutorials/{2020-09-15-deep-learning-flux.md => 2024-04-10-blitz.md} (74%) rename docs/{src/tutorials/mlp.md => old_tutorials/2024-04-10-mlp.md} (100%) create mode 100644 docs/old_tutorials/README.md create mode 100644 docs/src/tutorials/model_zoo.md diff --git a/Project.toml b/Project.toml index eaa6f22073..b6f7417b5f 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" +ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" diff --git a/docs/make.jl b/docs/make.jl index 331e38fdac..6c7b483caa 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -47,14 +47,14 @@ makedocs( # Or perhaps those should just be trashed, model zoo versions are newer & more useful. "Linear Regression" => "tutorials/linear_regression.md", "Logistic Regression" => "tutorials/logistic_regression.md", - "Multi-layer Perceptron" => "tutorials/mlp.md", + "Model Zoo" => "tutorials/model_zoo.md", #= - "Julia & Flux: 60 Minute Blitz" => "tutorials/2020-09-15-deep-learning-flux.md", + # "Multi-layer Perceptron" => "tutorials/mlp.md", + # "Julia & Flux: 60 Minute Blitz" => "tutorials/blitz.md", "Simple ConvNet" => "tutorials/2021-02-07-convnet.md", "Generative Adversarial Net" => "tutorials/2021-10-14-vanilla-gan.md", "Deep Convolutional GAN" => "tutorials/2021-10-08-dcgan-mnist.md", =# - # Not really sure where this belongs... some in Fluxperimental, aim to delete? ], ], format = Documenter.HTML( diff --git a/docs/old_tutorials/2020-09-15-deep-learning-flux.md b/docs/old_tutorials/2024-04-10-blitz.md similarity index 74% rename from docs/old_tutorials/2020-09-15-deep-learning-flux.md rename to docs/old_tutorials/2024-04-10-blitz.md index c386e5f3c4..5e461f4ffe 100755 --- a/docs/old_tutorials/2020-09-15-deep-learning-flux.md +++ b/docs/old_tutorials/2024-04-10-blitz.md @@ -78,7 +78,7 @@ We can see Julia tile the column vector `1:5` across all rows of the larger arra zeros(5,5) .+ (1:5) ``` -The x' syntax is used to transpose a column `1:5` into an equivalent row, and Julia will tile that across columns. +The `x'` syntax is used to transpose a column `1:5` into an equivalent row, and Julia will tile that across columns. ```julia zeros(5,5) .+ (1:5)' @@ -181,16 +181,19 @@ x = rand(Float32, 10) We can easily get the parameters of any layer or model with `trainables`. ```julia -trainables(m) +Flux.trainables(m) ``` It very easy to calculate the gradient for all parameters in a network, even if it has many parameters. The function `gradient` is not limited to array but can compute the gradient with respect to generic composite types. ```julia +using Flux +using Flux: logitcrossentropy, trainables, getkeypath + x = rand(Float32, 10) model = Chain(Dense(10 => 5, relu), Dense(5 => 2)) -loss(model, x) = Flux.logitcrossentropy(model(x), [0.5, 0.5]) +loss(model, x) = logitcrossentropy(model(x), [0.5, 0.5]) grad = gradient(m -> loss(m, x), model)[1] for (k, p) in trainables(model, path=true) println("$k => $(getkeypath(grad, k))") @@ -203,8 +206,8 @@ The next step is to update our weights and perform optimisation. As you might be ```julia η = 0.1 -for (k, p) in trainables(m) - p .+= -η * getkeypath(grads, p) +for (k, p) in trainables(model, path=true) + p .+= -η * getkeypath(grad, p) end ``` @@ -220,22 +223,24 @@ Training a network reduces down to iterating on a dataset mulitple times, perfor ```julia data, labels = rand(10, 100), fill(0.5, 2, 100) -loss(m, x, y) = Flux.logitcrossentropy(m(x), y) -Flux.train!(loss, model, [(data, labels)], opt) +loss(m, x, y) = logitcrossentropy(m(x), y) +Flux.train!(loss, model, [(data, labels)], opt_state) ``` You don't have to use `train!`. In cases where arbitrary logic might be better suited, you could open up this training loop like so: ```julia - for d in training_set # assuming d looks like (data, labels) +for d in training_set # assuming d looks like (data, labels) # our super logic g = gradient(model) do model - l = loss(model, d...) - end + l = loss(model, d...) + end[1] Flux.update!(opt_state, model, g) - end +end ``` +The `do` block is a closure, which is a way of defining a function inline. It's a very powerful feature of Julia, and you can learn more about it [here](https://docs.julialang.org/en/v1/manual/functions/#Do-Block-Syntax-for-Function-Arguments). + ## Training a Classifier Getting a real classifier to work might help cement the workflow a bit more. [CIFAR10](https://https://www.cs.toronto.edu/~kriz/cifar.html) is a dataset of 50k tiny training images split into 10 classes. @@ -254,10 +259,14 @@ We will do the following steps in order: using Statistics using Flux using MLDatasets: CIFAR10 -using Images.ImageCore -using Flux: onehotbatch, onecold -using Base.Iterators: partition -using CUDA +using ImageCore: colorview, RGB +using Flux: onehotbatch, onecold, DataLoader +using Plots: plot +using MLUtils: splitobs, numobs + +# using CUDA # Uncomment if you have CUDA installed. Can also use AMDGPU or Metal instead +# using AMDGPU +# using Metal ``` This image will give us an idea of what we are dealing with. @@ -265,27 +274,26 @@ This image will give us an idea of what we are dealing with. ![title](https://pytorch.org/tutorials/_images/cifar10.png) ```julia -train_x, train_y = CIFAR10.traindata(Float32) +train_x, train_y = CIFAR10(:train)[:] labels = onehotbatch(train_y, 0:9) ``` The `train_x` contains 50000 images converted to 32 X 32 X 3 arrays with the third dimension being the 3 channels (R,G,B). Let's take a look at a random image from the train_x. For this, we need to permute the dimensions to 3 X 32 X 32 and use `colorview` to convert it back to an image. ```julia -using Plots image(x) = colorview(RGB, permutedims(x, (3, 2, 1))) plot(image(train_x[:,:,:,rand(1:end)])) ``` -We can now arrange the training data in batches of say, 1000 and keep a validation set to track our progress. This process is called minibatch learning, which is a popular method of training large neural networks. Rather that sending the entire dataset at once, we break it down into smaller chunks (called minibatches) that are typically chosen at random, and train only on them. It is shown to help with escaping [saddle points](https://en.wikipedia.org/wiki/Saddle_point). +We can now arrange the training data in batches of say, 256 and keep a validation set to track our progress. This process is called minibatch learning, which is a popular method of training large neural networks. Rather that sending the entire dataset at once, we break it down into smaller chunks (called minibatches) that are typically chosen at random, and train only on them. It is shown to help with escaping [saddle points](https://en.wikipedia.org/wiki/Saddle_point). -The first 49k images (in batches of 1000) will be our training set, and the rest is for validation. `partition` handily breaks down the set we give it in consecutive parts (1000 in this case). +The first 45k images (in batches of 256) will be our training set, and the rest is for validation. +The `DataLoader` function will help us load the data in batches. ```julia -train = ([(train_x[:,:,:,i], labels[:,i]) for i in partition(1:49000, 1000)]) |> gpu -valset = 49001:50000 -valX = train_x[:,:,:,valset] |> gpu -valY = labels[:, valset] |> gpu +trainset, valset = splitobs((train_x, labels), at = 45000) +trainloader = DataLoader(trainset, batchsize = 1000, shuffle = true) +valloader = DataLoader(trainset, batchsize = 1000) ``` ### Defining the Classifier @@ -295,30 +303,40 @@ Now we can define our Convolutional Neural Network (CNN). A convolutional neural network is one which defines a kernel and slides it across a matrix to create an intermediate representation to extract features from. It creates higher order features as it goes into deeper layers, making it suitable for images, where the strucure of the subject is what will help us determine which class it belongs to. ```julia -m = Chain( - Conv((5,5), 3=>16, relu), - MaxPool((2,2)), - Conv((5,5), 16=>8, relu), - MaxPool((2,2)), - x -> reshape(x, :, size(x, 4)), - Dense(200 => 120), - Dense(120 => 84), - Dense(84 => 10)) |> gpu +model = Chain( + Conv((5,5), 3 => 16, relu), + MaxPool((2, 2)), + Conv((5, 5), 16 => 8, relu), + MaxPool((2,2)), + x -> reshape(x, :, size(x, 4)), + Dense(200 => 120), + Dense(120 => 84), + Dense(84 => 10)) |> gpu ``` -We will use a crossentropy loss and an Momentum optimiser here. Crossentropy will be a good option when it comes to working with mulitple independent classes. Momentum gradually lowers the learning rate as we proceed with the training. It helps maintain a bit of adaptivity in our optimisation, preventing us from over shooting from our desired destination. +We will use a crossentropy loss and an `Momentum` optimiser here. Crossentropy will be a good option when it comes to working with mulitple independent classes. Momentum gradually lowers the learning rate as we proceed with the training. It helps maintain a bit of adaptivity in our optimisation, preventing us from over shooting from our desired destination. ```julia using Flux: logitcrossentropy, Momentum loss(m, x, y) = logitcrossentropy(m(x), y) -opt = Momentum(0.01) +opt_state = Flux.setup(Momentum(0.01), model) ``` -We can start writing our train loop where we will keep track of some basic accuracy numbers about our model. We can define an `accuracy` function for it like so. +We can start writing our train loop where we will keep track of some basic accuracy numbers about our model. We can define an `accuracy` function for it like so: ```julia -accuracy(x, y) = mean(onecold(m(x), 0:9) .== onecold(y, 0:9)) +function accuracy(model, loader) + n = 0 + acc = 0 + for batch in loader + x, y = batch |> gpu + ŷ = model(x) + acc += sum(onecold(ŷ) .== onecold(y)) + n += numobs(x) + end + return acc / n +end ``` ### Training the Classifier @@ -329,14 +347,15 @@ Training is where we do a bunch of the interesting operations we defined earlier ```julia epochs = 10 -for epoch = 1:epochs - for d in train - gs = gradient(params(m)) do - l = loss(d...) +for epoch in 1:epochs + for batch in trainloader + x, y = batch |> gpu + g = gradient(model) do m + loss(m, x, y) + end[1] + Flux.update!(opt_state, model, g) end - update!(opt, params(m), gs) - end - @show accuracy(valX, valY) + @show accuracy(model, valloader) end ``` @@ -355,10 +374,9 @@ We will check this by predicting the class label that the neural network outputs Okay, first step. Let us perform the exact same preprocessing on this set, as we did on our training set. ```julia -test_x, test_y = CIFAR10.testdata(Float32) +test_x, test_y = CIFAR10(:test)[:] test_labels = onehotbatch(test_y, 0:9) - -test = gpu.([(test_x[:,:,:,i], test_labels[:,i]) for i in partition(1:10000, 1000)]) +testloader = DataLoader((test_x, test_labels), batchsize = 1000, shuffle = true) ``` Next, display an image from the test set. @@ -367,7 +385,7 @@ Next, display an image from the test set. plot(image(test_x[:,:,:,rand(1:end)])) ``` -The outputs are energies for the 10 classes. Higher the energy for a class, the more the network thinks that the image is of the particular class. Every column corresponds to the output of one image, with the 10 floats in the column being the energies. +The outputs of the networks are (log)likelihoods for the 10 classes. Higher the energy for a class, the more the network thinks that the image is of the particular class. Every column corresponds to the output of one image, with the 10 floats in the column being the energies. Let's see how the model fared. @@ -375,13 +393,13 @@ Let's see how the model fared. ids = rand(1:10000, 5) rand_test = test_x[:,:,:,ids] |> gpu rand_truth = test_y[ids] -m(rand_test) +model(rand_test) ``` This looks similar to how we would expect the results to be. At this point, it's a good idea to see how our net actually performs on new data, that we have prepared. ```julia -accuracy(test[1]...) +accuracy(model, testloader) ``` This is much better than random chance set at 10% (since we only have 10 classes), and not bad at all for a small hand written network like ours. @@ -389,22 +407,20 @@ This is much better than random chance set at 10% (since we only have 10 classes Let's take a look at how the net performed on all the classes performed individually. ```julia -class_correct = zeros(10) -class_total = zeros(10) -for i in 1:10 - preds = m(test[i][1]) - lab = test[i][2] - for j = 1:1000 - pred_class = findmax(preds[:, j])[2] - actual_class = findmax(lab[:, j])[2] - if pred_class == actual_class - class_correct[pred_class] += 1 +confusion_matrix = zeros(Int, 10, 10) +m = model |> cpu +for batch in testloader + @show numobs(batch) + x, y = batch + preds = m(x) + ŷ = onecold(preds) + y = onecold(y) + for (yi, ŷi) in zip(y, ŷ) + confusion_matrix[yi, ŷi] += 1 end - class_total[actual_class] += 1 - end end -class_correct ./ class_total +confusion_matrix ``` The spread seems pretty good, with certain classes performing significantly better than the others. Why should that be? diff --git a/docs/src/tutorials/mlp.md b/docs/old_tutorials/2024-04-10-mlp.md similarity index 100% rename from docs/src/tutorials/mlp.md rename to docs/old_tutorials/2024-04-10-mlp.md diff --git a/docs/old_tutorials/README.md b/docs/old_tutorials/README.md new file mode 100644 index 0000000000..46a1ceb979 --- /dev/null +++ b/docs/old_tutorials/README.md @@ -0,0 +1,7 @@ +These tutorials are hard to mantain +and overlapping with model-zoo examples. + +Some of the tutorials are outdated. + +Mantainance would be simplified by moving them +to Literate.jl and CI testing them. diff --git a/docs/src/tutorials/model_zoo.md b/docs/src/tutorials/model_zoo.md new file mode 100644 index 0000000000..c4e87ab8ab --- /dev/null +++ b/docs/src/tutorials/model_zoo.md @@ -0,0 +1,10 @@ +# Model Zoo + +The [model zoo](https://github.com/FluxML/model-zoo) is a collection of examples that demonstrate how to build and train models using Flux. The examples are organised by domain and include vision, text, and audio. Each example includes a description of the model, the data used, and the training process. + +Some of the examples are pedagogical, see for instance +- [Multilayer Perceptron](https://github.com/FluxML/model-zoo/tree/master/vision/mlp_mnist) +- [Simple Convolutional Neural Network](https://github.com/FluxML/model-zoo/tree/master/vision/conv_mnist) + +Others are more advanced, see for instance +- [Variational Autoencoder](https://github.com/FluxML/model-zoo/blob/master/vision/vae_mnist) \ No newline at end of file diff --git a/src/Flux.jl b/src/Flux.jl index e9cc78b5dd..6b7302f8cb 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -63,7 +63,7 @@ include("functor.jl") # from OneHotArrays.jl onehot, onehotbatch, onecold, # from Functors.jl - functor, @functor, + functor, @functor, KeyPath, haskeypath, getkeypath, # from Optimise/Train/Optimisers.jl setup, update!, destructure, freeze!, adjust!, params, trainable, trainables ))