diff --git a/docs/src/destructure.md b/docs/src/destructure.md index 1cdcad5ce7..16089380c4 100644 --- a/docs/src/destructure.md +++ b/docs/src/destructure.md @@ -49,20 +49,27 @@ julia> Flux.destructure(grad) # acts on non-models, too (Float32[10.339018, 11.379145, 22.845667, -29.565302, -37.644184], Restructure(Tuple, ..., 5)) ``` -!!! compat "Flux ≤ 0.12" - Old versions of Flux had an entirely different implementation of `destructure`, which - had many bugs (and almost no tests). Many comments online still refer to that now-deleted - function, or to memories of it. +In order to collect all parameters of a model into a list instead, you can use the `trainables` function: +```julia +julia> Flux.trainables(model) +5-element Vector{AbstractArray}: + [0.863101 1.2454957] + [0.0] + [1.290355429422727;;] + [0.0] +``` +Any mutation of the elements of the resulting list will affect the model's parameters. ### All Parameters -The function `destructure` now lives in [`Optimisers.jl`](https://github.com/FluxML/Optimisers.jl). -(Be warned this package is unrelated to the `Flux.Optimisers` sub-module! The confusion is temporary.) +The functions `destructure` and `trainables` live in [`Optimisers.jl`](https://github.com/FluxML/Optimisers.jl). + ```@docs Optimisers.destructure Optimisers.trainable +Optimisers.trainables Optimisers.isnumeric ``` diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index cf2d1fedb3..9569944b2e 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -26,7 +26,7 @@ Notice that we parameterized the type of the `chain` field. This is necessary fo You can then use the model like: ```julia -chain = Chain(Dense(10, 10)) +chain = Chain(Dense(10 => 10)) model = CustomModel(chain) model(rand(10)) ``` @@ -40,33 +40,37 @@ Taking reference from our example `Affine` layer from the [basics](@ref man-basi By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, the way to mark some fields of our layer as trainable is through overloading the `trainable` function: ```julia-repl -julia> @layer Affine +julia> struct Affine + W + b + end + +julia> Affine(in::Int, out::Int) = Affine(randn(out, in), randn(out)); + +julia> (m::Affine)(x) = m.W * x .+ m.b; + +julia> Flux.@layer Affine julia> a = Affine(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9]) Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]) -julia> Flux.params(a) # default behavior -Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]]) +julia> Flux.trainable(a) # default behavior +(W = Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], b = Float32[7.0, 8.0, 9.0]) julia> Flux.trainable(a::Affine) = (; W = a.W) # returns a NamedTuple using the field's name -julia> Flux.params(a) -Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0]]) +julia> Flux.trainable(a) +(W = Float32[1.0 2.0; 3.0 4.0; 5.0 6.0],) ``` -Only the fields returned by `trainable` will be collected as trainable parameters of the layer when calling `Flux.params`, and only these fields will be seen by `Flux.setup` and `Flux.update!` for training. But all fields wil be seen by `gpu` and similar functions, for example: +Only the fields returned by `trainable` will be seen by `Flux.setup` and `Flux.update!` for training. But all fields wil be seen by `gpu` and similar functions, for example: ```julia-repl julia> a |> f16 Affine(Float16[1.0 2.0; 3.0 4.0; 5.0 6.0], Float16[7.0, 8.0, 9.0]) ``` -Note that there is no need to overload `trainable` to hide fields which do not contain trainable parameters. (For example, activation functions, or Boolean flags.) These are always ignored by `params` and by training: - -```julia-repl -julia> Flux.params(Affine(true, [10, 11, 12.0])) -Params([]) -``` +Note that there is no need to overload `trainable` to hide fields which do not contain numerical array (for example, activation functions, or Boolean flags). These are always ignored by training. The exact same method of `trainable` can also be defined using the macro, for convenience: @@ -76,52 +80,14 @@ Flux.@layer Affine trainable=(W,) There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument. - -## Freezing Layer Parameters - -When it is desired to not include all the model parameters (for e.g. transfer learning), we can simply not pass in those layers into our call to `params`. - -!!! compat "Flux ≤ 0.14" - The mechanism described here is for Flux's old "implicit" training style. - When upgrading for Flux 0.15, it should be replaced by [`freeze!`](@ref Flux.freeze!) and `thaw!`. - -Consider a simple multi-layer perceptron model where we want to avoid optimising the first two `Dense` layers. We can obtain -this using the slicing features `Chain` provides: - -```julia -m = Chain( - Dense(784 => 64, relu), - Dense(64 => 64, relu), - Dense(32 => 10) - ); - -ps = Flux.params(m[3:end]) -``` - -The `Zygote.Params` object `ps` now holds a reference to only the parameters of the layers passed to it. - -During training, the gradients will only be computed for (and applied to) the last `Dense` layer, therefore only that would have its parameters changed. - -`Flux.params` also takes multiple inputs to make it easy to collect parameters from heterogenous models with a single call. A simple demonstration would be if we wanted to omit optimising the second `Dense` layer in the previous example. It would look something like this: - -```julia -Flux.params(m[1], m[3:end]) -``` - -Sometimes, a more fine-tuned control is needed. -We can freeze a specific parameter of a specific layer which already entered a `Params` object `ps`, -by simply deleting it from `ps`: - -```julia -ps = Flux.params(m) -delete!(ps, m[2].bias) -``` - ## Custom multiple input or output layer Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf). -Naively, we could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. But that would mean a new struct any time the operations along each path changes. Instead, this guide will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path. +We could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. That would mean a new struct for each different block, +e.g. one would have a `TransformerBlock` struct for a transformer block, and a `ResNetBlock` struct for a ResNet block, each block being composed by smaller sub-blocks. This is often the simplest and cleanest way to implement complex models. + +This guide instead will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path. ### Multiple inputs: a custom `Join` layer diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index cf83764349..4334f47d33 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -74,50 +74,24 @@ julia> Flux.withgradient(g, nt) (val = 1, grad = ((a = [0.0, 2.0], b = [-0.0, -2.0], c = nothing),)) ``` -!!! note "Implicit gradients" - Flux used to handle many parameters in a different way, using the [`params`](@ref Flux.params) function. - This uses a method of `gradient` which takes a zero-argument function, and returns a dictionary - through which the resulting gradients can be looked up: - - ```jldoctest basics - julia> x = [2, 1]; - - julia> y = [2, 0]; - - julia> gs = gradient(Flux.params(x, y)) do - f(x, y) - end - Grads(...) - - julia> gs[x] - 2-element Vector{Float64}: - 0.0 - 2.0 - - julia> gs[y] - 2-element Vector{Float64}: - -0.0 - -2.0 - ``` - - ## Building Simple Models Consider a simple linear regression, which tries to predict an output array `y` from an input `x`. ```julia -W = rand(2, 5) -b = rand(2) -predict(x) = W*x .+ b +predict(W, b, x) = W*x .+ b -function loss(x, y) - ŷ = predict(x) +function loss(W, b, x, y) + ŷ = predict(W, b, x) sum((y .- ŷ).^2) end x, y = rand(5), rand(2) # Dummy data -loss(x, y) # ~ 3 +W = rand(2, 5) +b = rand(2) + +loss(W, b, x, y) # ~ 3 ``` To improve the prediction we can take the gradients of the loss with respect to `W` and `b` and perform gradient descent. @@ -125,17 +99,15 @@ To improve the prediction we can take the gradients of the loss with respect to ```julia using Flux -gs = gradient(() -> loss(x, y), Flux.params(W, b)) +dW, db = gradient((W, b) -> loss(W, b, x, y), W, b) ``` Now that we have gradients, we can pull them out and update `W` to train the model. ```julia -W̄ = gs[W] +W .-= 0.1 .* dW -W .-= 0.1 .* W̄ - -loss(x, y) # ~ 2.5 +loss(W, b, x, y) # ~ 2.5 ``` The loss has decreased a little, meaning that our prediction `x` is closer to the target `y`. If we have some data we can already try [training the model](../training/training.md). @@ -144,7 +116,7 @@ All deep learning in Flux, however complex, is a simple generalisation of this e ## Building Layers -It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) (`σ`) in between them. In the above style we could write this as: +It's common to create more complex models than the linear regression above. For example, we might want to have two linear layers with a nonlinearity like [sigmoid](https://en.wikipedia.org/wiki/Sigmoid_function) in between them. We could write this as: ```julia using Flux @@ -157,7 +129,7 @@ W2 = rand(2, 3) b2 = rand(2) layer2(x) = W2 * x .+ b2 -model(x) = layer2(σ.(layer1(x))) +model(x) = layer2(sigmoid.(layer1(x))) model(rand(5)) # => 2-element vector ``` @@ -174,7 +146,7 @@ end linear1 = linear(5, 3) # we can access linear1.W etc linear2 = linear(3, 2) -model(x) = linear2(σ.(linear1(x))) +model(x) = linear2(sigmoid.(linear1(x))) model(rand(5)) # => 2-element vector ``` @@ -188,7 +160,7 @@ struct Affine end Affine(in::Integer, out::Integer) = - Affine(randn(out, in), randn(out)) + Affine(randn(out, in), zeros(out)) # Overload call, so the object can be used as a function (m::Affine)(x) = m.W * x .+ m.b @@ -198,16 +170,16 @@ a = Affine(10, 5) a(rand(10)) # => 5-element vector ``` -Congratulations! You just built the `Dense` layer that comes with Flux. Flux has many interesting layers available, but they're all things you could have built yourself very easily. +Congratulations! You just built the [`Dense`](@ref) layer that comes with Flux. Flux has many interesting layers available, but they're all things you could have built yourself very easily. -(There is one small difference with `Dense` – for convenience it also takes an activation function, like `Dense(10 => 5, σ)`.) +(There is one small difference with `Dense` – for convenience it also takes an activation function, like `Dense(10 => 5, sigmoid)`.) ## Stacking It Up It's pretty common to write models that look something like: ```julia -layer1 = Dense(10 => 5, σ) +layer1 = Dense(10 => 5, relu) # ... model(x) = layer3(layer2(layer1(x))) ``` @@ -217,7 +189,7 @@ For long chains, it might be a bit more intuitive to have a list of layers, like ```julia using Flux -layers = [Dense(10 => 5, σ), Dense(5 => 2), softmax] +layers = [Dense(10 => 5, relu), Dense(5 => 2), softmax] model(x) = foldl((x, m) -> m(x), layers, init = x) @@ -228,7 +200,7 @@ Handily, this is also provided for in Flux: ```julia model2 = Chain( - Dense(10 => 5, σ), + Dense(10 => 5, relu), Dense(5 => 2), softmax) @@ -255,7 +227,7 @@ m(5) # => 26 ## Layer Helpers -There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@layer`](@ref Flux.@layer) macro: +There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref Flux.train!) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@layer`](@ref Flux.@layer) macro: ```julia Flux.@layer Affine @@ -263,14 +235,14 @@ Flux.@layer Affine Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the `Affine` layer as follows, using the helper function [`create_bias`](@ref Flux.create_bias): -``` -function Affine((in, out)::Pair; bias=true, init=Flux.randn32) +```julia +function Affine((in, out)::Pair; bias=true, init=glorot_uniform) W = init(out, in) b = Flux.create_bias(W, bias, out) - Affine(W, b) + return Affine(W, b) end -Affine(3 => 1, bias=false, init=ones) |> gpu +Affine(3 => 1, bias=false) |> gpu ``` ```@docs diff --git a/docs/src/models/quickstart.md b/docs/src/models/quickstart.md index dfef1f0c04..e7e379f17f 100644 --- a/docs/src/models/quickstart.md +++ b/docs/src/models/quickstart.md @@ -16,11 +16,11 @@ truth = [xor(col[1]>0.5, col[2]>0.5) for col in eachcol(noisy)] # 1000-element model = Chain( Dense(2 => 3, tanh), # activation function inside layer BatchNorm(3), - Dense(3 => 2), - softmax) |> gpu # move model to GPU, if available + Dense(3 => 2)) |> gpu # move model to GPU, if available # The model encapsulates parameters, randomly initialised. Its initial output is: out1 = model(noisy |> gpu) |> cpu # 2×1000 Matrix{Float32} +probs1 = softmax(out1) # normalise to get probabilities # To train the model, we use batches of 64 samples, and one-hot encoding: target = Flux.onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix @@ -36,7 +36,7 @@ losses = [] loss, grads = Flux.withgradient(model) do m # Evaluate model and loss inside gradient context: y_hat = m(x) - Flux.crossentropy(y_hat, y) + Flux.logitcrossentropy(y_hat, y) end Flux.update!(optim, model, grads[1]) push!(losses, loss) # logging, outside gradient context @@ -45,8 +45,8 @@ end optim # parameters, momenta and output have all changed out2 = model(noisy |> gpu) |> cpu # first row is prob. of true, second row p(false) - -mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far! +probs2 = softmax(out2) # normalise to get probabilities +mean((probs2[1,:] .> 0.5) .== truth) # accuracy 94% so far! ``` ![](../assets/quickstart/oneminute.png) @@ -55,8 +55,8 @@ mean((out2[1,:] .> 0.5) .== truth) # accuracy 94% so far! using Plots # to draw the above figure p_true = scatter(noisy[1,:], noisy[2,:], zcolor=truth, title="True classification", legend=false) -p_raw = scatter(noisy[1,:], noisy[2,:], zcolor=out1[1,:], title="Untrained network", label="", clims=(0,1)) -p_done = scatter(noisy[1,:], noisy[2,:], zcolor=out2[1,:], title="Trained network", legend=false) +p_raw = scatter(noisy[1,:], noisy[2,:], zcolor=probs1[1,:], title="Untrained network", label="", clims=(0,1)) +p_done = scatter(noisy[1,:], noisy[2,:], zcolor=probs2[1,:], title="Trained network", legend=false) plot(p_true, p_raw, p_done, layout=(1,3), size=(1000,330)) ``` @@ -87,7 +87,7 @@ Some things to notice in this example are: * The `model` can be called like a function, `y = model(x)`. Each layer like [`Dense`](@ref Flux.Dense) is an ordinary `struct`, which encapsulates some arrays of parameters (and possibly other state, as for [`BatchNorm`](@ref Flux.BatchNorm)). -* But the model does not contain the loss function, nor the optimisation rule. The momenta needed by [`Adam`](@ref Flux.Adam) are stored in the object returned by [setup](@ref Flux.Train.setup). And [`Flux.crossentropy`](@ref Flux.Losses.crossentropy) is an ordinary function. +* But the model does not contain the loss function, nor the optimisation rule. The momenta needed by [`Adam`](@ref Flux.Adam) are stored in the object returned by [setup](@ref Flux.Train.setup). And [`Flux.logitcrossentropy`](@ref Flux.Losses.logitcrossentropy) is an ordinary function that combines the [`softmax`](@ref Flux.softmax) and [`crossentropy`](@ref Flux.crossentropy) functions. * The `do` block creates an anonymous function, as the first argument of `gradient`. Anything executed within this is differentiated. @@ -97,21 +97,7 @@ Instead of calling [`gradient`](@ref Zygote.gradient) and [`update!`](@ref Flux. for epoch in 1:1_000 Flux.train!(model, loader, optim) do m, x, y y_hat = m(x) - Flux.crossentropy(y_hat, y) + Flux.logitcrossentropy(y_hat, y) end end ``` - -!!! compat "Implicit-style training, Flux ≤ 0.14" - Until recently Flux's training worked a bit differently. - Any code which looks like - ``` - gradient(() -> loss(model, x, y), Flux.params(model)) - ``` - (gradient of a zero-argument function) or - ``` - train!((x,y) -> loss(model, x, y), Flux.params(model), loader, opt) - ``` - (with `Flux.params`) is in the old "implicit" style. - This still works on Flux 0.14, but will be removed from Flux 0.15. - See the [training section](@ref man-training) for more details. diff --git a/docs/src/models/recurrence.md b/docs/src/models/recurrence.md index dab24edff6..87cd944f4f 100644 --- a/docs/src/models/recurrence.md +++ b/docs/src/models/recurrence.md @@ -154,7 +154,7 @@ In such a model, only the last two outputs are used to compute the loss, hence t Alternatively, if one wants to perform some warmup of the sequence, it could be performed once, followed with a regular training where all the steps of the sequence would be considered for the gradient update: ```julia -function loss(x, y) +function loss(m, x, y) sum(mse(m(xi), yi) for (xi, yi) in zip(x, y)) end @@ -172,9 +172,8 @@ data = zip(X,Y) Flux.reset!(m) [m(x) for x in seq_init] -ps = Flux.params(m) -opt= Adam(1e-3) -Flux.train!(loss, ps, data, opt) +opt = Flux.setup(Adam(1e-3), m) +Flux.train!(loss, m, data, opt) ``` In this previous example, model's state is first reset with `Flux.reset!`. Then, there's a warmup that is performed over a sequence of length 1 by feeding it with `seq_init`, resulting in a warmup state. The model can then be trained for 1 epoch, where 2 batches are provided (`seq_1` and `seq_2`) and all the timesteps outputs are considered for the loss. diff --git a/docs/src/saving.md b/docs/src/saving.md index 066795bfc5..0b1e4fc91b 100644 --- a/docs/src/saving.md +++ b/docs/src/saving.md @@ -18,7 +18,7 @@ julia> struct MyModel julia> Flux.@layer MyModel -julia> MyModel() = MyModel(Chain(Dense(10, 5, relu), Dense(5, 2))); +julia> MyModel() = MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))); julia> model = MyModel() MyModel(Chain(Dense(10 => 5, relu), Dense(5 => 2))) # 67 parameters @@ -113,7 +113,7 @@ Save a model: ```jldoctest saving julia> using Flux -julia> model = Chain(Dense(10, 5, NNlib.relu), Dense(5, 2)); +julia> model = Chain(Dense(10 => 5, NNlib.relu), Dense(5 => 2)); julia> using BSON: @save @@ -138,10 +138,3 @@ Chain( and across Flux versions if some of the Flux layers' internals are changed. It is therefore not recommended for long term storage, use [`Flux.state`](@ref) instead. -!!! warning - - Previous versions of Flux suggested saving only the model weights using - `@save "mymodel.bson" params(model)`. - This is no longer recommended and even strongly discouraged. - Saving models this way will only store the trainable parameters which - will result in incorrect behavior for layers like `BatchNorm`. diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index fc4e38eebe..bc6dc0628f 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -4,76 +4,63 @@ CurrentModule = Flux # [Optimisation Rules](@id man-optimisers) -Flux builds in many optimisation rules for use with [`train!`](@ref Flux.Optimise.train!) and +Any optimization rule from Optimisers.jl can be used with [`train!`](@ref) and other training functions. -The mechanism by which these work is gradually being replaced as part of the change -from "implicit" dictionary-based to "explicit" tree-like structures. -At present, the same struct (such as `Adam`) can be used with either form, -and will be automatically translated. - For full details of how the new interface works, see the [Optimisers.jl documentation](https://fluxml.ai/Optimisers.jl/dev/). -For full details on how the old "implicit" interface worked, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/optimisers/#Optimiser-Interface). - -## Optimiser Reference +## Optimisers Reference All optimisers return an object that, when passed to `train!`, will update the parameters passed to it. ```@docs -Descent -Momentum -Nesterov -RMSProp -Adam -RAdam -AdaMax -AdaGrad -AdaDelta -AMSGrad -NAdam -AdamW -OAdam -AdaBelief +Optimisers.Descent +Optimisers.Momentum +Optimisers.Nesterov +Optimisers.RMSProp +Optimisers.Adam +Optimisers.RAdam +Optimisers.AdaMax +Optimisers.AdaGrad +Optimisers.AdaDelta +Optimisers.AMSGrad +Optimisers.NAdam +Optimisers.AdamW +Optimisers.OAdam +Optimisers.AdaBelief ``` ## Composing Optimisers -Flux defines a special kind of optimiser simply called `Optimiser` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient -that will be fed into the next, and the resultant update will be applied to the parameter as usual. A classic use case is where adding decays is desirable. Flux defines some basic decays including `ExpDecay`, `InvDecay` etc. +Flux (through Optimisers.jl) defines a special kind of optimiser called `OptimiserChain` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient +that will be fed into the next, and the resultant update will be applied to the parameter as usual. A classic use case is where adding decays is desirable. Optimisers.jl defines the basic decay corresponding to an $L_2$ regularization in the loss as `WeighDecay`. ```julia -opt = Optimiser(ExpDecay(1, 0.1, 1000, 1e-4), Descent()) +opt = OptimiserChain(WeightDecay(1e-4), Descent()) ``` -Here we apply exponential decay to the `Descent` optimiser. The defaults of `ExpDecay` say that its learning rate will be decayed every 1000 steps. -It is then applied like any optimiser. +Here we apply the weight decay to the `Descent` optimiser. +The resulting optimiser `opt` can be used as any optimiser. ```julia -w = randn(10, 10) -w1 = randn(10,10) -ps = Params([w, w1]) +w = [randn(10, 10), randn(10, 10)] +opt_state = Flux.setup(opt, w) -loss(x) = Flux.Losses.mse(w * x, w1 * x) +loss(w, x) = Flux.mse(w[1] * x, w[2] * x) -loss(rand(10)) # around 9 +loss(w, rand(10)) # around 0.9 for t = 1:10^5 - θ = Params([w, w1]) - θ̄ = gradient(() -> loss(rand(10)), θ) - Flux.Optimise.update!(opt, θ, θ̄) + g = gradient(w -> loss(w[1], w[2], rand(10)), w) + Flux.update!(opt_state, w, g) end -loss(rand(10)) # around 0.9 +loss(w, rand(10)) # around 0.9 ``` It is possible to compose optimisers for some added flexibility. -```@docs -Flux.Optimise.Optimiser -``` - ## Scheduling Optimisers In practice, it is fairly common to schedule the learning rate of an optimiser to obtain faster convergence. There are a variety of popular scheduling policies, and you can find implementations of them in [ParameterSchedulers.jl](http://fluxml.ai/ParameterSchedulers.jl/stable). The documentation for ParameterSchedulers.jl provides a more detailed overview of the different scheduling policies, and how to use them with Flux optimisers. Below, we provide a brief snippet illustrating a [cosine annealing](https://arxiv.org/pdf/1608.03983.pdf) schedule with a momentum optimiser. @@ -109,10 +96,8 @@ ParameterSchedulers.jl allows for many more scheduling policies including arbitr Similar to optimisers, Flux also defines some simple decays that can be used in conjunction with other optimisers, or standalone. ```@docs -ExpDecay -InvDecay -WeightDecay SignDecay +WeightDecay ``` ## Gradient Clipping @@ -120,11 +105,11 @@ SignDecay Gradient clipping is useful for training recurrent neural networks, which have a tendency to suffer from the exploding gradient problem. An example usage is ```julia -opt = Optimiser(ClipValue(1e-3), Adam(1e-3)) +opt = OptimiserChain(ClipValue(1e-3), Adam(1e-3)) ``` ```@docs -ClipValue +ClipGrad ClipNorm ``` diff --git a/docs/src/training/reference.md b/docs/src/training/reference.md index 1bf0cfd1bf..67980831f9 100644 --- a/docs/src/training/reference.md +++ b/docs/src/training/reference.md @@ -10,9 +10,11 @@ Because of this: * Flux defines its own version of `setup` which checks this assumption. (Using instead `Optimisers.setup` will also work, they return the same thing.) +The available optimization rules are listed the [optimisation rules](@ref man-optimisers) page here. See the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the rules work. + ```@docs Flux.Train.setup -Flux.Train.train!(loss, model, data, state; cb) +Flux.Train.train!(loss, model, data, state) Optimisers.update! ``` @@ -32,59 +34,3 @@ Optimisers.adjust! Optimisers.freeze! Optimisers.thaw! ``` - -## Implicit style (Flux ≤ 0.14) - -Flux used to handle gradients, training, and optimisation rules quite differently. -The new style described above is called "explicit" by Zygote, and the old style "implicit". -Flux 0.13 and 0.14 are the transitional versions which support both; Flux 0.15 will remove the old. - -!!! compat "How to upgrade" - The blue-green boxes in the [training section](@ref man-training) describe - the changes needed to upgrade old code. - -The available rules are listed the [optimisation rules](@ref man-optimisers) page here. - -!!! compat "Old & new rules" - The new implementation of rules such as Adam in the Optimisers is quite different from the old one in `Flux.Optimise`. In Flux 0.14, `Flux.Adam()` still returns the old one, with supertype `Flux.Optimise.AbstractOptimiser`, but `setup` will silently translate it to its new counterpart. - -For full details on the interface for implicit-style optimisers, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). -See the [Optimisers documentation](https://fluxml.ai/Optimisers.jl/dev/) for details on how the new rules work. - -!!! compat "Flux ≤ 0.12" - Much earlier versions of Flux exported `params`, thus allowing unqualified `params(model)` - after `using Flux`. This conflicted with too many other packages, and was removed in Flux 0.13. - If you get an error `UndefVarError: params not defined`, this probably means that you are - following code for Flux 0.12 or earlier on a more recent version. - - -```@docs -Flux.params -Flux.Optimise.update!(opt::Flux.Optimise.AbstractOptimiser, xs::AbstractArray, gs) -Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb) -``` - -## Callbacks - -Implicit `train!` takes an additional argument, `cb`, that's used for callbacks so that you can observe the training process. For example: - -```julia -train!(objective, ps, data, opt, cb = () -> println("training")) -``` - -Callbacks are called for every batch of training data. You can slow this down using `Flux.throttle(f, timeout)` which prevents `f` from being called more than once every `timeout` seconds. - -A more typical callback might look like this: - -```julia -test_x, test_y = # ... create single batch of test data ... -evalcb() = @show(loss(test_x, test_y)) -throttled_cb = throttle(evalcb, 5) -for epoch in 1:20 - @info "Epoch $epoch" - Flux.train!(objective, ps, data, opt, cb = throttled_cb) -end -``` - -See the page about [callback helpers](@ref man-callback-helpers) for more. - diff --git a/docs/src/training/training.md b/docs/src/training/training.md index f516f4ace9..0407820794 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -64,16 +64,6 @@ in order for the influence of the model's parameters to be observed by Zygote. It is also important that every `update!` step receives a newly computed gradient, as it will change whenever the model's parameters are changed, and for each new data point. -!!! compat "Implicit gradients" - Flux ≤ 0.14 used Zygote's "implicit" mode, in which `gradient` takes a zero-argument function. - It looks like this: - ``` - pars = Flux.params(model) - grad = gradient(() -> loss(model(input), label), pars) - ``` - Here `pars::Params` and `grad::Grads` are two dictionary-like structures. - Support for this will be removed from Flux 0.15, and these blue (teal?) boxes - explain what needs to change. ## Loss Functions @@ -117,13 +107,13 @@ fmap(model, grads[1]) do p, g end ``` -A slightly more refined version of this loop to update all the parameters is wrapped up as a function [`update!`](@ref Flux.Optimise.update!)`(opt_state, model, grads[1])`. -And the learning rate is the only thing stored in the [`Descent`](@ref Flux.Optimise.Descent) struct. +A slightly more refined version of this loop to update all the parameters is wrapped up as a function [`update!`](@ref)`(opt_state, model, grads[1])`. +And the learning rate is the only thing stored in the [`Descent`](@ref) struct. However, there are many other optimisation rules, which adjust the step size and direction in various clever ways. Most require some memory of the gradients from earlier steps, rather than always -walking straight downhill -- [`Momentum`](@ref Flux.Optimise.Momentum) is the simplest. +walking straight downhill -- [`Momentum`](@ref) is the simplest. The function [`setup`](@ref Flux.Train.setup) creates the necessary storage for this, for a particular model. It should be called once, before training, and returns a tree-like object which is the first argument of `update!`. Like this: @@ -140,7 +130,7 @@ for data in train_set end ``` -Many commonly-used optimisation rules, such as [`Adam`](@ref Flux.Optimise.Adam), are built-in. +Many commonly-used optimisation rules, such as [`Adam`](@ref), are built-in. These are listed on the [optimisers](@ref man-optimisers) page. !!! compat "Implicit-style optimiser state" @@ -208,15 +198,6 @@ end Or explicitly writing the anonymous function which this `do` block creates, `train!((m,x,y) -> loss(m(x),y), model, train_set, opt_state)` is exactly equivalent. -!!! compat "Implicit-style `train!`" - This is a new method of `train!`, which takes the result of `setup` as its 4th argument. - The 1st argument is a function which accepts the model itself. - Flux versions ≤ 0.14 provided a method of `train!` for "implicit" parameters, - which works like this: - ``` - train!((x,y) -> loss(model(x), y), Flux.params(model), train_set, Adam()) - ``` - Real training loops often need more flexibility, and the best way to do this is just to write the loop. This is ordinary Julia code, without any need to work through some callback API. Here is an example, in which it may be helpful to note: @@ -284,12 +265,12 @@ A very simple model could be implemented as follows: grads = Flux.gradient(densemodel) do m result = m(input) penalty = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 - my_loss(result, label) + 0.42 * penalty + my_loss(result, label) + 0.42f0 * penalty end ``` Accessing each individual parameter array by hand won't work well for large models. -Instead, we can use [`Flux.params`](@ref) to collect all of them, +Instead, we can use [`Flux.trainables`](@ref Optimisers.trainables) to collect all of them, and then apply a function to each one, and sum the result: ```julia @@ -297,8 +278,8 @@ pen_l2(x::AbstractArray) = sum(abs2, x)/2 grads = Flux.gradient(model) do m result = m(input) - penalty = sum(pen_l2, Flux.params(m)) - my_loss(result, label) + 0.42 * penalty + penalty = sum(pen_l2, Flux.trainables(m)) + my_loss(result, label) + 0.42f0 * penalty end ``` @@ -317,7 +298,7 @@ decay_opt_state = Flux.setup(OptimiserChain(WeightDecay(0.42), Adam(0.1)), model ``` Flux's optimisers are really modifications applied to the gradient before using it to update -the parameters, and `OptimiserChain` applies two such modifications. +the parameters, and [`OptimiserChain`](@ref Optimisers.OptimiserChain) applies two such modifications. The first, [`WeightDecay`](@ref Flux.WeightDecay) adds `0.42` times the original parameter to the gradient, matching the gradient of the penalty above (with the same, unrealistically large, constant). After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. @@ -325,14 +306,14 @@ After that, in either case, [`Adam`](@ref Flux.Adam) computes the final update. The same trick works for *L₁ regularisation* (also called Lasso), where the penalty is `pen_l1(x::AbstractArray) = sum(abs, x)` instead. This is implemented by `SignDecay(0.42)`. -The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref Flux.Optimise.ClipValue) or [`ClipNorm`](@ref Flux.Optimise.ClipNorm). +The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref) or [`ClipNorm`](@ref). Besides L1 / L2 / weight decay, another common and quite different kind of regularisation is provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the previous layer during training. It should switch automatically, but see [`trainmode!`](@ref Flux.trainmode!) / [`testmode!`](@ref Flux.testmode!) to manually enable or disable this layer. -## Freezing & Schedules +## Learning Rate Schedules Finer control of training, you may wish to alter the learning rate mid-way through training. This can be done with [`adjust!`](@ref Flux.adjust!), like this: @@ -348,10 +329,6 @@ for epoch in 1:1000 end ``` -!!! compat "Flux ≤ 0.14" - With the old "implicit" optimiser, `opt = Adam(0.1)`, the equivalent was to - directly mutate the `Adam` struct, `opt.eta = 0.001`. - Other hyper-parameters can also be adjusted, such as `Flux.adjust!(opt_state, beta = (0.8, 0.99))`. And such modifications can be applied to just one part of the model. For instance, this sets a different learning rate for the encoder and the decoder: @@ -367,6 +344,8 @@ opt_state = Flux.setup(Adam(0.02), bimodel) Flux.adjust!(opt_state.layers.enc, 0.03) ``` +## Freezing layer parameters + To completely disable training of some part of the model, use [`freeze!`](@ref Flux.freeze!). This is a temporary modification, reversed by `thaw!`: @@ -380,21 +359,7 @@ train!(loss, bimodel, data, opt_state) Flux.thaw!(opt_state) ``` -!!! compat "Flux ≤ 0.14" - The earlier "implicit" equivalent was to pass to `gradient` an object referencing only - part of the model, such as `Flux.params(bimodel.layers.enc)`. - While `adjust!` and `freeze!`/`thaw!` make temporary modifications to the optimiser state, permanently removing some fields of a new layer type from training is usually done when defining the layer, by calling for example [`@layer`](@ref Flux.@layer)` NewLayer trainable=(weight,)`. -## Implicit or Explicit? - -Flux used to handle gradients, training, and optimisation rules quite differently. -The new style described above is called "explicit" by Zygote, and the old style "implicit". -Flux 0.13 and 0.14 are the transitional versions which support both. - -The blue-green boxes above describe the changes. -For more details on training in the implicit style, see [Flux 0.13.6 documentation](https://fluxml.ai/Flux.jl/v0.13.6/training/training/). - -For details about the two gradient modes, see [Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1). diff --git a/docs/src/training/zygote.md b/docs/src/training/zygote.md index 385e7dde7b..33d30d6ee8 100644 --- a/docs/src/training/zygote.md +++ b/docs/src/training/zygote.md @@ -18,22 +18,6 @@ Zygote.hessian_reverse Zygote.diaghessian ``` -## Implicit style (Flux ≤ 0.14) - -Flux used to use what Zygote calls "implicit" gradients, [described here](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) in its documentation. -However, support for this will be removed from Flux 0.15. - -!!! compat "Training" - The blue-green boxes in the [training section](@ref man-training) describe - the changes needed to upgrade old code from implicit to explicit style. - -```@docs -Zygote.gradient(loss, ::Params) -Zygote.Params -Zygote.Grads -Zygote.jacobian(loss, ::Params) -``` - ## ChainRules Sometimes it is necessary to exclude some code, or a whole function, from automatic differentiation. This can be done using [ChainRules](https://github.com/JuliaDiff/ChainRules.jl): diff --git a/docs/src/tutorials/2020-09-15-deep-learning-flux.md b/docs/src/tutorials/2020-09-15-deep-learning-flux.md index 7cb2a366b6..c386e5f3c4 100755 --- a/docs/src/tutorials/2020-09-15-deep-learning-flux.md +++ b/docs/src/tutorials/2020-09-15-deep-learning-flux.md @@ -167,52 +167,33 @@ gradient(myloss, W, b, x) Now we get gradients for each of the inputs `W`, `b` and `x`, which will come in handy when we want to train models. -Because ML models can contain hundreds of parameters, Flux provides a slightly different way of writing `gradient`. We instead mark arrays with `param` to indicate that we want their derivatives. `W` and `b` represent the weight and bias respectively. - -```julia -using Flux: params - -W = randn(3, 5) -b = zeros(3) -x = rand(5) - -y(x) = sum(W * x .+ b) - -grads = gradient(()->y(x), params([W, b])) - -grads[W], grads[b] -``` - - -We can now grab the gradients of `W` and `b` directly from those parameters. - -This comes in handy when working with *layers*. A layer is just a handy container for some parameters. For example, `Dense` does a linear transform for you. +ML models can contain hundreds of parameter arrays, therefore it is handy to group them into **layers**. +A layer is just a handy container for some parameters. For example, `Dense` does a linear transform for you. ```julia using Flux -m = Dense(10, 5) +m = Dense(10 => 5) x = rand(Float32, 10) ``` -We can easily get the parameters of any layer or model with params with `params`. +We can easily get the parameters of any layer or model with `trainables`. ```julia -params(m) +trainables(m) ``` -This makes it very easy to calculate the gradient for all parameters in a network, even if it has many parameters. +It very easy to calculate the gradient for all parameters in a network, even if it has many parameters. +The function `gradient` is not limited to array but can compute the gradient with respect to generic composite types. ```julia x = rand(Float32, 10) -m = Chain(Dense(10, 5, relu), Dense(5, 2), softmax) -l(x) = sum(Flux.crossentropy(m(x), [0.5, 0.5])) -grads = gradient(params(m)) do - l(x) -end -for p in params(m) - println(grads[p]) +model = Chain(Dense(10 => 5, relu), Dense(5 => 2)) +loss(model, x) = Flux.logitcrossentropy(model(x), [0.5, 0.5]) +grad = gradient(m -> loss(m, x), model)[1] +for (k, p) in trainables(model, path=true) + println("$k => $(getkeypath(grad, k))") end ``` @@ -221,27 +202,26 @@ You don't have to use layers, but they can be convient for many simple kinds of The next step is to update our weights and perform optimisation. As you might be familiar, *Gradient Descent* is a simple algorithm that takes the weights and steps using a learning rate and the gradients. `weights = weights - learning_rate * gradient`. ```julia -using Flux.Optimise: update!, Descent η = 0.1 -for p in params(m) - update!(p, -η * grads[p]) +for (k, p) in trainables(m) + p .+= -η * getkeypath(grads, p) end ``` While this is a valid way of updating our weights, it can get more complicated as the algorithms we use get more involved. -Flux comes with a bunch of pre-defined optimisers and makes writing our own really simple. We just give it the learning rate η: +Flux comes with a bunch of pre-defined optimisers and makes writing our own really simple. We just give it the learning rate `η`: ```julia -opt = Descent(0.01) +opt_state = Flux.setup(Descent(η), model) ``` -`Training` a network reduces down to iterating on a dataset mulitple times, performing these steps in order. Just for a quick implementation, let’s train a network that learns to predict `0.5` for every input of 10 floats. `Flux` defines the `train!` function to do it for us. +Training a network reduces down to iterating on a dataset mulitple times, performing these steps in order. Just for a quick implementation, let’s train a network that learns to predict `0.5` for every input of 10 floats. `Flux` defines the `train!` function to do it for us. ```julia data, labels = rand(10, 100), fill(0.5, 2, 100) -loss(x, y) = sum(Flux.crossentropy(m(x), y)) -Flux.train!(loss, params(m), [(data,labels)], opt) +loss(m, x, y) = Flux.logitcrossentropy(m(x), y) +Flux.train!(loss, model, [(data, labels)], opt) ``` You don't have to use `train!`. In cases where arbitrary logic might be better suited, you could open up this training loop like so: @@ -249,10 +229,10 @@ You don't have to use `train!`. In cases where arbitrary logic might be better s ```julia for d in training_set # assuming d looks like (data, labels) # our super logic - gs = gradient(params(m)) do #m is our model - l = loss(d...) + g = gradient(model) do model + l = loss(model, d...) end - update!(opt, params(m), gs) + Flux.update!(opt_state, model, g) end ``` @@ -272,7 +252,7 @@ We will do the following steps in order: ```julia using Statistics -using Flux, Flux.Optimise +using Flux using MLDatasets: CIFAR10 using Images.ImageCore using Flux: onehotbatch, onecold @@ -321,18 +301,17 @@ m = Chain( Conv((5,5), 16=>8, relu), MaxPool((2,2)), x -> reshape(x, :, size(x, 4)), - Dense(200, 120), - Dense(120, 84), - Dense(84, 10), - softmax) |> gpu + Dense(200 => 120), + Dense(120 => 84), + Dense(84 => 10)) |> gpu ``` We will use a crossentropy loss and an Momentum optimiser here. Crossentropy will be a good option when it comes to working with mulitple independent classes. Momentum gradually lowers the learning rate as we proceed with the training. It helps maintain a bit of adaptivity in our optimisation, preventing us from over shooting from our desired destination. ```julia -using Flux: crossentropy, Momentum +using Flux: logitcrossentropy, Momentum -loss(x, y) = sum(crossentropy(m(x), y)) +loss(m, x, y) = logitcrossentropy(m(x), y) opt = Momentum(0.01) ``` diff --git a/docs/src/tutorials/2021-01-26-mlp.md b/docs/src/tutorials/2021-01-26-mlp.md index 2af8d3645c..763f711195 100644 --- a/docs/src/tutorials/2021-01-26-mlp.md +++ b/docs/src/tutorials/2021-01-26-mlp.md @@ -80,8 +80,8 @@ We define our model with the `build_model` function: ```julia function build_model(; imgsize=(28,28,1), nclasses=10) return Chain( - Dense(prod(imgsize), 32, relu), - Dense(32, nclasses)) + Dense(prod(imgsize) => 32, relu), + Dense(32 => nclasses)) end ``` diff --git a/docs/src/tutorials/2021-10-08-dcgan-mnist.md b/docs/src/tutorials/2021-10-08-dcgan-mnist.md index 4da32e5f2c..a746935eb6 100644 --- a/docs/src/tutorials/2021-10-08-dcgan-mnist.md +++ b/docs/src/tutorials/2021-10-08-dcgan-mnist.md @@ -109,7 +109,7 @@ dcgan_init(shape...) = randn(Float32, shape) * 0.02f0 ```julia function Generator(latent_dim) Chain( - Dense(latent_dim, 7*7*256, bias=false), + Dense(latent_dim => 7*7*256, bias=false), BatchNorm(7*7*256, relu), x -> reshape(x, 7, 7, 256, :), diff --git a/docs/src/tutorials/2021-10-14-vanilla-gan.md b/docs/src/tutorials/2021-10-14-vanilla-gan.md index f92ae54a8b..f07c7757bd 100644 --- a/docs/src/tutorials/2021-10-14-vanilla-gan.md +++ b/docs/src/tutorials/2021-10-14-vanilla-gan.md @@ -96,13 +96,13 @@ calling the model in a gradient context. As a final non-linearity, we use the `sigmoid` activation function. ```julia -discriminator = Chain(Dense(n_features, 1024, x -> leakyrelu(x, 0.2f0)), +discriminator = Chain(Dense(n_features => 1024, x -> leakyrelu(x, 0.2f0)), Dropout(0.3), - Dense(1024, 512, x -> leakyrelu(x, 0.2f0)), + Dense(1024 => 512, x -> leakyrelu(x, 0.2f0)), Dropout(0.3), - Dense(512, 256, x -> leakyrelu(x, 0.2f0)), + Dense(512 => 256, x -> leakyrelu(x, 0.2f0)), Dropout(0.3), - Dense(256, 1, sigmoid)) |> gpu + Dense(256 => 1, sigmoid)) |> gpu ``` Let's define the generator in a similar fashion. This network maps a latent @@ -113,9 +113,9 @@ the training data onto. ```julia generator = Chain(Dense(latent_dim, 256, x -> leakyrelu(x, 0.2f0)), - Dense(256, 512, x -> leakyrelu(x, 0.2f0)), - Dense(512, 1024, x -> leakyrelu(x, 0.2f0)), - Dense(1024, n_features, tanh)) |> gpu + Dense(256 => 512, x -> leakyrelu(x, 0.2f0)), + Dense(512 => 1024, x -> leakyrelu(x, 0.2f0)), + Dense(1024 => n_features, tanh)) |> gpu ``` diff --git a/perf/dense.jl b/perf/dense.jl index 005d9360ba..58bbbf01d4 100644 --- a/perf/dense.jl +++ b/perf/dense.jl @@ -1,6 +1,6 @@ for n in [2, 20, 200, 2000] x = randn(Float32, n, n) - model = Dense(n, n) + model = Dense(n => n) println("CPU n=$n") run_benchmark(model, x, cuda=false) println("CUDA n=$n") diff --git a/perf/vgg.jl b/perf/vgg.jl index 33b7bfd61d..5a9b5e1f5f 100644 --- a/perf/vgg.jl +++ b/perf/vgg.jl @@ -39,11 +39,11 @@ function vgg16() BatchNorm(512), MaxPool((2,2)), Flux.flatten, - Dense(512, 4096, relu), + Dense(512 => 4096, relu), Dropout(0.5), - Dense(4096, 4096, relu), + Dense(4096 => 4096, relu), Dropout(0.5), - Dense(4096, 10) + Dense(4096 => 10) ) end diff --git a/src/Flux.jl b/src/Flux.jl index a8720b7905..7c51aa1d10 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -10,17 +10,15 @@ using MacroTools: @forward @reexport using NNlib using MLUtils const stack = MLUtils.stack # now exported by Base -import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions -using Optimisers: freeze!, thaw!, adjust! +@reexport using Optimisers +import Optimisers: trainable +using Optimisers: update!, trainables using Random: default_rng using Zygote, ChainRulesCore -using Zygote: Params, @adjoint, gradient, pullback +using Zygote: @adjoint, gradient, pullback using Zygote.ForwardDiff: value export gradient -# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.) -Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`") - export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion, RNN, LSTM, GRU, GRUv3, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, @@ -41,18 +39,9 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion outputsize, state, create_bias, @layer, )) -include("optimise/Optimise.jl") -using .Optimise -export Descent, Adam, Momentum, Nesterov, RMSProp, - AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, OAdam, - AdamW, RAdam, AdaBelief, InvDecay, ExpDecay, - WeightDecay, SignDecay, ClipValue, ClipNorm - -export ClipGrad, OptimiserChain # these are const defined in deprecations, for ClipValue, Optimiser - include("train.jl") using .Train -using .Train: setup +using .Train: setup, train! using Adapt, Functors, OneHotArrays include("utils.jl") @@ -63,8 +52,8 @@ include("functor.jl") onehot, onehotbatch, onecold, # from Functors.jl functor, @functor, - # from Optimise/Train/Optimisers.jl - setup, update!, destructure, freeze!, adjust!, params, trainable + # from Train/Optimisers.jl + setup, update!, destructure, freeze!, thaw!, adjust!, trainable, trainables )) # Pirate error to catch a common mistake. diff --git a/src/deprecations.jl b/src/deprecations.jl index 5acdec5455..6b0c064cca 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,225 +1,57 @@ - -# v0.13 deprecations - -function Broadcast.broadcasted(f::Recur, args...) - # This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12 - Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order. - Re-writing this as a comprehension would be better.""", :broadcasted) - map(f, args...) # map isn't really safe either, but -end - -@deprecate frequencies(xs) group_counts(xs) - -struct Zeros - function Zeros() - Base.depwarn("Flux.Zeros is no more, has ceased to be, is bereft of life, is an ex-boondoggle... please use bias=false instead", :Zeros) - false - end -end -Zeros(args...) = Zeros() # was used both Dense(10, 2, initb = Zeros) and Dense(rand(2,10), Zeros()) - -function Optimise.update!(x::AbstractArray, x̄) - Base.depwarn("`Flux.Optimise.update!(x, x̄)` was not used internally and has been removed. Please write `x .-= x̄` instead.", :update!) - x .-= x̄ -end - -function Diagonal(size::Integer...; kw...) - Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal) - Scale(size...; kw...) -end -function Diagonal(size::Tuple; kw...) - Base.depwarn("Flux.Diagonal is now Flux.Scale, and also allows an activation function.", :Diagonal) - Scale(size...; kw...) -end - -# Deprecate this eventually once saving models w/o structure is no more -function loadparams!(m, xs) - Base.depwarn("loadparams! will be deprecated eventually. Use loadmodel! instead.", :loadparams!) - for (p, x) in zip(params(m), xs) - size(p) == size(x) || - error("Expected param size $(size(p)), got $(size(x))") - copyto!(p, x) - end -end - # Channel notation: Changed to match Conv, but very softly deprecated! -# Perhaps change to @deprecate for v0.15, but there is no plan to remove these. Dense(in::Integer, out::Integer, σ = identity; kw...) = - Dense(in => out, σ; kw...) +Dense(in => out, σ; kw...) + Bilinear(in1::Integer, in2::Integer, out::Integer, σ = identity; kw...) = Bilinear((in1, in2) => out, σ; kw...) + Embedding(in::Integer, out::Integer; kw...) = Embedding(in => out; kw...) RNNCell(in::Integer, out::Integer, σ = tanh; kw...) = RNNCell(in => out, σ; kw...) + LSTMCell(in::Integer, out::Integer; kw...) = LSTMCell(in => out; kw...) GRUCell(in::Integer, out::Integer; kw...) = GRUCell(in => out; kw...) -GRUv3Cell(in::Integer, out::Integer; kw...) = GRUv3Cell(in => out; kw...) - -# Optimisers with old naming convention -Base.@deprecate_binding ADAM Adam -Base.@deprecate_binding NADAM NAdam -Base.@deprecate_binding ADAMW AdamW -Base.@deprecate_binding RADAM RAdam -Base.@deprecate_binding OADAM OAdam -Base.@deprecate_binding ADAGrad AdaGrad -Base.@deprecate_binding ADADelta AdaDelta - -# Remove sub-module Data, while making sure Flux.Data.DataLoader keeps working -Base.@deprecate_binding Data Flux false "Sub-module Flux.Data has been removed. The only thing it contained may be accessed as Flux.DataLoader" - -@deprecate paramtype(T,m) _paramtype(T,m) false # internal method, renamed to make this clear - -@deprecate rng_from_array() Random.default_rng() - -function istraining() - Base.depwarn("Flux.istraining() is deprecated, use NNlib.within_gradient(x) instead", :istraining) - false -end -ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),) - -function _isactive(m) - Base.depwarn("_isactive(m) is deprecated, use _isactive(m,x)", :_isactive, force=true) - _isactive(m, 1:0) -end - -#= - # Valid method in Optimise, old implicit style, is: - train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) - # Valid methods in Train, new explict style, are: - train!(loss, model, data, opt) # preferred - train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup - - # Provide friendly errors for what happens if you mix these up: -=# -import .Optimise: train! - -train!(loss, ps::Params, data, opt; cb=nothing) = error( - """can't mix implict Params with explict state! - To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module. - But better to use the new explicit style, in which `m` itself is the 2nd argument. - """) - -train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error( - """can't mix implict Params with explict rule from Optimisers.jl - To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module. - But better to use the new explicit style, in which `m` itself is the 2nd argument. - """) - -train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) = train!(loss, model, data, _old_to_new(opt); cb) +GRUv3Cell(in::Integer, out::Integer; kw...) = GRUv3Cell(in => out; kw...) -# Next, to use the new `setup` with the still-exported old-style `Adam` etc: -import .Train: setup -setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model) -# ... and allow accidental use of `Optimisers.setup` to do the same: -Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(rule), model) +# v0.15 deprecations -for T in [:Descent, :Adam, :Momentum, :Nesterov, - :AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief, - # :InvDecay, :ExpDecay, - :SignDecay, - ] - @eval function _old_to_new(rule::$T) - args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T)) - Optimisers.$T(args...) +Train.train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( + """On Flux 0.15, `train!` no longer accepts implicit `Zygote.Params`. + Instead of `train!(loss_xy, Flux.params(model), data, Adam())` + it now needs `opt_state = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt_state)` + where `loss_mxy` accepts the model as its first argument. + """ +)) + + +function params!(p::Params, x, seen = IdSet()) + # @depwarn "Implicit use of `params` is deprecated. TODO." + + if x isa AbstractArray{<:Number} && Functors.isleaf(x) + return push!(p, x) + elseif x in seen + nothing + else + _check_new_macro(x) # complains if you used @functor not @layer + push!(seen, x) + for child in trainable(x) + params!(p, child, seen) + end end end -_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...) -const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too. -_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called lambda now -_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields -_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs -const ClipGrad = Optimise.ClipValue -_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred - -_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule") - -# This allows you to mix and match, like Flux.setup(OptimiserChain(Optimisers.SignDecay(), Flux.Descent()), [1,2,3.]) -Optimisers.OptimiserChain(rules::Union{Optimisers.AbstractRule, Optimise.AbstractOptimiser}...) = - Optimisers.OptimiserChain(map(_old_to_new, rules)) -_old_to_new(rule::Optimisers.AbstractRule) = rule -# Since `update!` should be called in a loop, it makes less sense to call `setup` for you if you forgot. -# But let's make sure that such uses give a helpful error: -import .Optimise: update! - -function update!(opt::Optimise.AbstractOptimiser, model, grad) - # This error method requires narrowing the main worker method of Flux.Optimise - # to accept only arrays. Remove if this causes problems! - # update!(opt::Flux.Optimise.AbstractOptimiser, x::AbstractArray, x̄) - error("""Invalid input to `update!`. - * For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)` - * For the explicit style, `update(state, model, grad)` needs `state = Flux.setup(opt, model)`. - """) -end - -# An easy error to make is to pass result of explicit gradient(...), not gradient(...)[1] -# Can't catch every case, but can catch many simple Flux models: - -function update!(opt, model::Chain, grads::Tuple) - # Zygote will make a NamedTuple{(:layers,)} for the gradient of Chain, Diffractor a Tangent - @warn """explicit `update!(opt, model, grad)` wants the gradient for the model alone, - not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`.""" - update!(opt, model, grads[1]) +function params(m...) + # @depwarn "Implicit use of `params` is deprecated. TODO." + ps = Params() + params!(ps, m) + return ps end -function update!(opt::Optimise.AbstractOptimiser, model::Chain, grads::Tuple) # ambiguity - update!(opt, model, grads[1]) # calls error case "Invalid input" just above -end - -# One more easy error to catch is using explicit gradient with `params(m)`: - -function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple, NamedTuple}) - error("""can't mix implicit Params with explicit gradients! - * For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)` with implicit gradient. - * For the explicit style, `update(state, model, grad)` needs the model itself, and `state = Flux.setup(opt, model)`. - """) -end - -""" - trainmode!(m, active) - -!!! warning - This two-argument method is deprecated. - -Possible values of `active` are: -- `true` for training, or -- `false` for testing, same as [`testmode!`](@ref)`(m)` -- `:auto` or `nothing` for Flux to detect training automatically. -""" -function trainmode!(m, active::Bool) - Base.depwarn("trainmode!(m, active::Bool) is deprecated", :trainmode) - testmode!(m, !active) -end - -# Greek-letter keywords deprecated in Flux 0.13 -# Arguments (old => new, :function, "β" => "beta") -function _greek_ascii_depwarn(βbeta::Pair, func = :loss, names = "" => "") - Base.depwarn(LazyString("function ", func, " no longer accepts greek-letter keyword ", names.first, """ - please use ascii """, names.second, " instead"), func) - βbeta.first -end -_greek_ascii_depwarn(βbeta::Pair{Nothing}, _...) = βbeta.second - -ChainRulesCore.@non_differentiable _greek_ascii_depwarn(::Any...) - - -# v0.14 deprecations -@deprecate default_rng_value() Random.default_rng() - -Base.@deprecate_binding FluxAMDAdaptor FluxAMDGPUAdaptor - -# v0.15 deprecations - -# Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc: -# Base.@deprecate_binding Optimiser OptimiserChain -# Base.@deprecate_binding ClipValue ClipGrad +# Allows caching of the parameters when params is called within gradient() to fix #2040. +# @non_differentiable params(m...) # https://github.com/FluxML/Flux.jl/pull/2054 +# That speeds up implicit use, and silently breaks explicit use. +# From @macroexpand Zygote.@non_differentiable params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248 +Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing -# train!(loss::Function, ps::Zygote.Params, data, opt) = throw(ArgumentError( -# """On Flux 0.15, `train!` no longer accepts implicit `Zygote.Params`. -# Instead of `train!(loss_xy, Flux.params(model), data, Adam())` -# it now needs `opt = Flux.setup(Adam(), model); train!(loss_mxy, model, data, opt)` -# where `loss_mxy` accepts the model as its first argument. -# """ -# )) diff --git a/src/functor.jl b/src/functor.jl index e0168edf6b..6238df7350 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -75,65 +75,6 @@ function testmode!(m, mode) m end -function params!(p::Params, x, seen = IdSet()) - if x isa AbstractArray{<:Number} && Functors.isleaf(x) - return push!(p, x) - elseif x in seen - nothing - else - _check_new_macro(x) # complains if you used @functor not @layer - push!(seen, x) - for child in trainable(x) - params!(p, child, seen) - end - end -end - -""" - params(model) - params(layers...) - -Given a model or specific layers from a model, create a `Params` object pointing to its trainable parameters. - -This can be used with the `gradient` function, see the [training section of the manual](@ref man-training), or as input to the [`Flux.train!`](@ref Flux.train!) function. - -The behaviour of `params` on custom types can be customized using [`Functors.@functor`](@ref) or [`Flux.trainable`](@ref). - -# Examples -```jldoctest -julia> using Flux: params - -julia> params(Chain(Dense(ones(2,3)), softmax)) # unpacks Flux models -Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]]) - -julia> bn = BatchNorm(2, relu) -BatchNorm(2, relu) # 4 parameters, plus 4 non-trainable - -julia> params(bn) # only the trainable parameters -Params([Float32[0.0, 0.0], Float32[1.0, 1.0]]) - -julia> params([1, 2, 3], [4]) # one or more arrays of numbers -Params([[1, 2, 3], [4]]) - -julia> params([[1, 2, 3], [4]]) # unpacks array of arrays -Params([[1, 2, 3], [4]]) - -julia> params(1, [2 2], (alpha=[3,3,3], beta=Ref(4), gamma=sin)) # ignores scalars, unpacks NamedTuples -Params([[2 2], [3, 3, 3]]) -``` -""" -function params(m...) - ps = Params() - params!(ps, m) - return ps -end - -# Allows caching of the parameters when params is called within gradient() to fix #2040. -# @non_differentiable params(m...) # https://github.com/FluxML/Flux.jl/pull/2054 -# That speeds up implicit use, and silently breaks explicit use. -# From @macroexpand Zygote.@non_differentiable params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248 -Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing - struct FluxCPUAdaptor end # define rules for handling structured arrays @@ -315,7 +256,7 @@ See also [`f32`](@ref) and [`f64`](@ref). # Example ```jldoctest -julia> m = Chain(Dense(784, 2048, relu), Dense(2048, 10)) # all Float32 +julia> m = Chain(Dense(784 => 2048, relu), Dense(2048 => 10)) # all Float32 Chain( Dense(784 => 2048, relu), # 1_607_680 parameters Dense(2048 => 10), # 20_490 parameters diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ef81c30872..868dd474ec 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -129,25 +129,26 @@ The weight matrix and/or the bias vector (of length `out`) may also be provided # Examples ```jldoctest -julia> d = Dense(5 => 2) +julia> model = Dense(5 => 2) Dense(5 => 2) # 12 parameters -julia> d(rand32(5, 64)) |> size +julia> model(rand32(5, 64)) |> size (2, 64) -julia> d(rand32(5, 6, 4, 64)) |> size # treated as three batch dimensions +julia> model(rand32(5, 6, 4, 64)) |> size # treated as three batch dimensions (2, 6, 4, 64) -julia> d1 = Dense(ones(2, 5), false, tanh) # using provided weight matrix +julia> model2 = Dense(ones(2, 5), false, tanh) # using provided weight matrix Dense(5 => 2, tanh; bias=false) # 10 parameters -julia> d1(ones(5)) +julia> model2(ones(5)) 2-element Vector{Float64}: 0.9999092042625951 0.9999092042625951 -julia> Flux.params(d1) # no trainable bias -Params([[1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0]]) + julia> trainables(model2) # no trainable bias + 1-element Vector{AbstractArray}: + [1.0 1.0 … 1.0 1.0; 1.0 1.0 … 1.0 1.0] ``` """ struct Dense{F, M<:AbstractMatrix, B} @@ -218,24 +219,27 @@ Used by [`LayerNorm`](@ref) with `affine=true`. julia> a = Flux.Scale(2) Scale(2) # 4 parameters -julia> Flux.params(a) -Params([Float32[1.0, 1.0], Float32[0.0, 0.0]]) +julia> Flux.trainables(a) +2-element Vector{AbstractArray}: + Float32[1.0, 1.0] + Float32[0.0, 0.0] julia> a([1 2 3]) 2×3 Matrix{Float32}: 1.0 2.0 3.0 1.0 2.0 3.0 -julia> b = Flux.Scale([1 2 3 4], false, abs2) +julia> b = Flux.Scale(Float32[1 2 3 4], false, abs2) Scale(1, 4, abs2; bias=false) # 4 parameters julia> b([1, 10]) -2×4 Matrix{Int64}: - 1 4 9 16 - 100 400 900 1600 +2×4 Matrix{Float32}: + 1.0 4.0 9.0 16.0 + 100.0 400.0 900.0 1600.0 -julia> Flux.params(b) -Params([[1 2 3 4]]) +julia> Flux.trainables(b) +1-element Vector{AbstractArray}: + Float32[1.0 2.0 3.0 4.0] ``` """ struct Scale{F, A<:AbstractArray, B} @@ -490,7 +494,7 @@ julia> model = Chain(Dense(3 => 5), julia> model(rand32(3)) |> size (17,) -julia> model2 = Parallel(+; α = Dense(10, 2, tanh), β = Dense(5, 2)) +julia> model2 = Parallel(+; α = Dense(10 => 2, tanh), β = Dense(5 => 2)) Parallel( +, α = Dense(10 => 2, tanh), # 22 parameters diff --git a/src/layers/conv.jl b/src/layers/conv.jl index fdf3c756e9..aa0635989f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -145,7 +145,7 @@ Conv((3,), 4 => 5, σ) # 65 parameters julia> layer(randn(100, 4, 64)) |> size (98, 5, 64) -julia> Flux.params(layer) |> length +julia> Flux.trainables(layer) |> length 2 ``` """ @@ -286,7 +286,7 @@ ConvTranspose((3,), 5 => 4, σ) # 64 parameters julia> layer(randn(100, 5, 64)) |> size # transposed convolution will increase the dimension size (upsampling) (102, 4, 64) -julia> Flux.params(layer) |> length +julia> Flux.trainables(layer) |> length 2 ``` """ diff --git a/src/layers/macro.jl b/src/layers/macro.jl index dcebe551e3..4e9528ec86 100644 --- a/src/layers/macro.jl +++ b/src/layers/macro.jl @@ -35,7 +35,7 @@ julia> Flux.destructure(tri) # parameters are not yet visible to Flux julia> Flux.@layer :expand Trio -julia> Flux.destructure(tri) # now gpu, params, train!, etc will see inside too +julia> Flux.destructure(tri) # now gpu, train!, etc will see inside too ([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4)) julia> tri # and layer is printed like Chain diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index c6663cca88..cf85c995ef 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -191,10 +191,9 @@ struct LayerNorm{F,D,T,N} affine::Bool end -function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5, ϵ=nothing) - ε = _greek_ascii_depwarn(ϵ => eps, :LayerNorm, "ϵ" => "eps") +function LayerNorm(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5) diag = affine ? Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity - return LayerNorm(λ, diag, ε, size, affine) + return LayerNorm(λ, diag, eps, size, affine) end LayerNorm(size::Integer...; kw...) = LayerNorm(Int.(size); kw...) LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]; kw...) @@ -330,15 +329,13 @@ function BatchNorm(chs::Int, λ=identity; affine::Bool=true, track_stats::Bool=true, active::Union{Bool,Nothing}=nothing, eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) - ε = _greek_ascii_depwarn(ϵ => eps, :BatchNorm, "ϵ" => "eps") - β = affine ? initβ(chs) : nothing γ = affine ? initγ(chs) : nothing μ = track_stats ? zeros32(chs) : nothing σ² = track_stats ? ones32(chs) : nothing return BatchNorm(λ, β, γ, - μ, σ², ε, momentum, + μ, σ², eps, momentum, affine, track_stats, active, chs) end @@ -421,9 +418,7 @@ end function InstanceNorm(chs::Int, λ=identity; initβ=zeros32, initγ=ones32, affine::Bool=false, track_stats::Bool=false, active::Union{Bool,Nothing}=nothing, - eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) - - ε = _greek_ascii_depwarn(ϵ => eps, :InstanceNorm, "ϵ" => "eps") + eps::Real=1f-5, momentum::Real=0.1f0) β = affine ? initβ(chs) : nothing γ = affine ? initγ(chs) : nothing @@ -431,7 +426,7 @@ function InstanceNorm(chs::Int, λ=identity; σ² = track_stats ? ones32(chs) : nothing return InstanceNorm(λ, β, γ, - μ, σ², ε, momentum, + μ, σ², eps, momentum, affine, track_stats, active, chs) end @@ -454,7 +449,7 @@ function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(l.chs)") l.λ == identity || print(io, ", $(l.λ)") hasaffine(l) || print(io, ", affine=false") - l.active == nothing || print(io, ", active=", l.active) + l.active === nothing || print(io, ", active=", l.active) print(io, ")") end @@ -520,9 +515,7 @@ end function GroupNorm(chs::Int, G::Int, λ=identity; initβ=zeros32, initγ=ones32, affine::Bool=true, active::Union{Bool,Nothing}=nothing, - eps::Real=1f-5, momentum::Real=0.1f0, ϵ=nothing) - - ε = _greek_ascii_depwarn(ϵ => eps, :GroupNorm, "ϵ" => "eps") + eps::Real=1f-5, momentum::Real=0.1f0) chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)") @@ -535,7 +528,7 @@ function GroupNorm(chs::Int, G::Int, λ=identity; return GroupNorm(G, λ, β, γ, μ, σ², - ε, momentum, + eps, momentum, affine, track_stats, active, chs) end @@ -561,7 +554,7 @@ function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(l.chs), $(l.G)") l.λ == identity || print(io, ", ", l.λ) hasaffine(l) || print(io, ", affine=false") - l.active == nothing || print(io, ", active=", l.active) + l.active === nothing || print(io, ", active=", l.active) print(io, ")") end diff --git a/src/layers/show.jl b/src/layers/show.jl index a03ddf3754..95e7d8746b 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -90,21 +90,21 @@ function _layer_show(io::IO, layer, indent::Int=0, name=nothing) _str = isnothing(name) ? "" : "$name = " str = _str * sprint(show, layer, context=io) print(io, " "^indent, str, indent==0 ? "" : ",") - if !isempty(params(layer)) + if !isempty(trainables(layer)) print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) - printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters"; + printstyled(io, "# ", underscorise(sum(length, trainables(layer); init=0)), " parameters"; color=:light_black) - nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0) + nonparam = _childarray_sum(length, layer) - sum(length, trainables(layer), init=0) if nonparam > 0 printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black) end - _nan_show(io, params(layer)) + _nan_show(io, trainables(layer)) end indent==0 || println(io) end function _big_finale(io::IO, m) - ps = params(m) + ps = trainables(m) if length(ps) > 2 pars = underscorise(sum(length, ps; init=0)) bytes = Base.format_bytes(Base.summarysize(m)) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 2565ea2e84..2a8a16b0e6 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -35,10 +35,9 @@ true ``` """ @inline function normalise(x::AbstractArray; dims=ndims(x), eps=ofeltype(x, 1e-5), ϵ=nothing) - ε = _greek_ascii_depwarn(ϵ => eps, :InstanceNorm, "ϵ" => "eps") μ = mean(x, dims=dims) σ = std(x, dims=dims, mean=μ, corrected=false) - return @. (x - μ) / (σ + ε) + return @. (x - μ) / (σ + eps) end """ diff --git a/src/loading.jl b/src/loading.jl index 8238a19cde..0f21cb94e1 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -129,12 +129,12 @@ The state can be passed to [`loadmodel!`](@ref) to restore the model. ## Copy the state into another model ```jldoctest -julia> m1 = Chain(Dense(1, 2, tanh; init=ones), Dense(2, 1; init=ones)); +julia> m1 = Chain(Dense(1 => 2, tanh; init=ones), Dense(2 => 1; init=ones)); julia> s = Flux.state(m1) (layers = ((weight = [1.0; 1.0;;], bias = [0.0, 0.0], σ = ()), (weight = [1.0 1.0], bias = [0.0], σ = ())),) -julia> m2 = Chain(Dense(1, 2, tanh), Dense(2, 1; bias=false)); # weights are random numbers +julia> m2 = Chain(Dense(1 => 2, tanh), Dense(2 => 1; bias=false)); # weights are random numbers julia> Flux.loadmodel!(m2, s); diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 34315baadd..5b4a1d697b 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -4,7 +4,7 @@ using Statistics using Zygote using Zygote: @adjoint using ChainRulesCore -using ..Flux: ofeltype, epseltype, _greek_ascii_depwarn +using ..Flux: ofeltype, epseltype using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss import Base.Broadcast: broadcasted diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 7897cc5754..f84ca22186 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -66,10 +66,9 @@ julia> Flux.msle(Float32[0.9, 1.8, 2.7], 1:3) 0.011100831f0 ``` """ -function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :msle, "ϵ" => "eps") +function msle(ŷ, y; agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) - agg((log.((ŷ .+ ϵ) ./ (y .+ ϵ))) .^2 ) + agg((log.((ŷ .+ eps) ./ (y .+ eps))) .^2 ) end function _huber_metric(abs_error, δ) @@ -101,9 +100,8 @@ julia> Flux.huber_loss(ŷ, 1:3, delta=0.05) # changes behaviour as |ŷ - y| > 0.003750000000000005 ``` """ -function huber_loss(ŷ, y; agg = mean, delta::Real = 1, δ = nothing) - delta_tmp = _greek_ascii_depwarn(δ => delta, :huber_loss, "δ" => "delta") - δ = ofeltype(ŷ, delta_tmp) +function huber_loss(ŷ, y; agg = mean, delta::Real = 1) + δ = ofeltype(ŷ, delta) _check_sizes(ŷ, y) abs_error = abs.(ŷ .- y) @@ -230,10 +228,9 @@ julia> Flux.crossentropy(y_model, y_smooth) 1.5776052f0 ``` """ -function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :crossentropy, "ϵ" => "eps") +function crossentropy(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) - agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims)) + agg(.-sum(xlogy.(y, ŷ .+ eps); dims = dims)) end """ @@ -319,10 +316,9 @@ julia> Flux.crossentropy(y_prob, y_hot) 0.43989f0 ``` """ -function binarycrossentropy(ŷ, y; agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :binarycrossentropy, "ϵ" => "eps") +function binarycrossentropy(ŷ, y; agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) - agg(@.(-xlogy(y, ŷ + ϵ) - xlogy(1 - y, 1 - ŷ + ϵ))) + agg(@.(-xlogy(y, ŷ + eps) - xlogy(1 - y, 1 - ŷ + eps))) end """ @@ -390,11 +386,10 @@ julia> Flux.kldivergence(p1, p2; eps = 0) # about 17.3 with the regulator Inf ``` """ -function kldivergence(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ), ϵ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :kldivergence, "ϵ" => "eps") +function kldivergence(ŷ, y; dims = 1, agg = mean, eps::Real = epseltype(ŷ)) _check_sizes(ŷ, y) entropy = agg(sum(xlogx.(y); dims = dims)) - cross_entropy = crossentropy(ŷ, y; dims, agg, eps=ϵ) + cross_entropy = crossentropy(ŷ, y; dims, agg, eps) return entropy + cross_entropy end @@ -530,14 +525,13 @@ Calculated as: 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β)*(1 .- y) .* ŷ + β*y .* (1 .- ŷ)) + 1) """ -function tversky_loss(ŷ, y; beta::Real = 0.7, β = nothing) - beta_temp = _greek_ascii_depwarn(β => beta, :tversky_loss, "β" => "beta") - β = ofeltype(ŷ, beta_temp) +function tversky_loss(ŷ, y; beta::Real = 0.7) + β = ofeltype(ŷ, beta) _check_sizes(ŷ, y) #TODO add agg num = sum(y .* ŷ) + 1 den = sum(y .* ŷ + β * (1 .- y) .* ŷ + (1 - β) * y .* (1 .- ŷ)) + 1 - 1 - num / den + return 1 - num / den end """ @@ -568,17 +562,15 @@ julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 true ``` """ -function binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ = nothing, γ = nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :binary_focal_loss, "ϵ" => "eps") - gamma_temp = _greek_ascii_depwarn(γ => gamma, :binary_focal_loss, "γ" => "gamma") - γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp) +function binary_focal_loss(ŷ, y; agg=mean, gamma=2, eps::Real=epseltype(ŷ)) + γ = gamma isa Integer ? gamma : ofeltype(ŷ, gamma) _check_sizes(ŷ, y) - ŷϵ = ŷ .+ ϵ + ŷϵ = ŷ .+ eps p_t = y .* ŷϵ + (1 .- y) .* (1 .- ŷϵ) ce = .-log.(p_t) weight = (1 .- p_t) .^ γ loss = weight .* ce - agg(loss) + return agg(loss) end """ @@ -615,12 +607,10 @@ true See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels """ -function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ), ϵ=nothing, γ=nothing) - ϵ = _greek_ascii_depwarn(ϵ => eps, :focal_loss, "ϵ" => "eps") - gamma_temp = _greek_ascii_depwarn(γ => gamma, :focal_loss, "γ" => "gamma") - γ = gamma_temp isa Integer ? gamma_temp : ofeltype(ŷ, gamma_temp) +function focal_loss(ŷ, y; dims=1, agg=mean, gamma=2, eps::Real=epseltype(ŷ)) + γ = gamma isa Integer ? gamma : ofeltype(ŷ, gamma) _check_sizes(ŷ, y) - ŷϵ = ŷ .+ ϵ + ŷϵ = ŷ .+ eps agg(sum(@. -y * (1 - ŷϵ)^γ * log(ŷϵ); dims)) end diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl deleted file mode 100644 index f637d83242..0000000000 --- a/src/optimise/Optimise.jl +++ /dev/null @@ -1,14 +0,0 @@ -module Optimise - -using LinearAlgebra - -export train!, update!, - Descent, Adam, Momentum, Nesterov, RMSProp, - AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW,RAdam, OAdam, AdaBelief, - InvDecay, ExpDecay, WeightDecay, Optimiser, - ClipValue, ClipNorm, SignDecay - -include("optimisers.jl") -include("train.jl") - -end diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl deleted file mode 100644 index 18f9d3ddae..0000000000 --- a/src/optimise/optimisers.jl +++ /dev/null @@ -1,753 +0,0 @@ -using Flux -using MacroTools: @forward - -abstract type AbstractOptimiser end - -const EPS = 1e-8 - -# TODO: should use weak refs - -""" - Descent(η = 0.1) - -Classic gradient descent optimiser with learning rate `η`. -For each parameter `p` and its gradient `δp`, this runs `p -= η*δp` - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. - -# Examples -```julia -opt = Descent() - -opt = Descent(0.3) - -ps = Flux.params(model) - -gs = gradient(ps) do - loss(x, y) -end - -Flux.Optimise.update!(opt, ps, gs) -``` -""" -mutable struct Descent <: AbstractOptimiser - eta::Float64 -end - -Descent() = Descent(0.1) - -function apply!(o::Descent, x, Δ) - Δ .*= o.eta -end - -""" - Momentum(η = 0.01, ρ = 0.9) - -Gradient descent optimiser with learning rate `η` and momentum `ρ`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect damping oscillations. - -# Examples -```julia -opt = Momentum() - -opt = Momentum(0.01, 0.99) -``` -""" -mutable struct Momentum <: AbstractOptimiser - eta::Float64 - rho::Float64 - velocity::IdDict -end - -Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict()) - -function apply!(o::Momentum, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - @. v = ρ * v - η * Δ - @. Δ = -v -end - -""" - Nesterov(η = 0.001, ρ = 0.9) - -Gradient descent optimiser with learning rate `η` and Nesterov momentum `ρ`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Nesterov momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect damping oscillations. - -# Examples -```julia -opt = Nesterov() - -opt = Nesterov(0.003, 0.95) -``` -""" -mutable struct Nesterov <: AbstractOptimiser - eta::Float64 - rho::Float64 - velocity::IdDict -end - -Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict()) - -function apply!(o::Nesterov, x, Δ) - η, ρ = o.eta, o.rho - v = get!(() -> zero(x), o.velocity, x)::typeof(x) - d = @. ρ^2 * v - (1+ρ) * η * Δ - @. v = ρ*v - η*Δ - @. Δ = -d -end - -""" - RMSProp(η = 0.001, ρ = 0.9, ϵ = $EPS) - -Optimizer using the -[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -algorithm. Often a good choice for recurrent networks. Parameters other than learning rate -generally don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the - prominent direction, in effect damping oscillations. - -# Examples -```julia -opt = RMSProp() - -opt = RMSProp(0.002, 0.95) -``` -""" -mutable struct RMSProp <: AbstractOptimiser - eta::Float64 - rho::Float64 - epsilon::Float64 - acc::IdDict -end -RMSProp(η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = EPS) = RMSProp(η, ρ, ϵ, IdDict()) -RMSProp(η::Real, ρ::Real, acc::IdDict) = RMSProp(η, ρ, EPS, acc) - -function apply!(o::RMSProp, x, Δ) - η, ρ = o.eta, o.rho - acc = get!(() -> zero(x), o.acc, x)::typeof(x) - @. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ) - @. Δ *= η / (√acc + o.epsilon) -end - -""" - Adam(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -[Adam](https://arxiv.org/abs/1412.6980) optimiser. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = Adam() - -opt = Adam(0.001, (0.9, 0.8)) -``` -""" -mutable struct Adam <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -Adam(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = Adam(η, β, ϵ, IdDict()) -Adam(η::Real, β::Tuple, state::IdDict) = Adam(η, β, EPS, state) - -function apply!(o::Adam, x, Δ) - η, β = o.eta, o.beta - - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ) - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + o.epsilon) * η - βp .= βp .* β - - return Δ -end - -""" - RAdam(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -[Rectified Adam](https://arxiv.org/abs/1908.03265) optimiser. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = RAdam() - -opt = RAdam(0.001, (0.9, 0.8)) -``` -""" -mutable struct RAdam <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -RAdam(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = RAdam(η, β, ϵ, IdDict()) -RAdam(η::Real, β::Tuple, state::IdDict) = RAdam(η, β, EPS, state) - -function apply!(o::RAdam, x, Δ) - η, β = o.eta, o.beta - ρ∞ = 2/(1-β[2])-1 - - mt, vt, βp, t = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]], Ref(1)) - end :: Tuple{typeof(x),typeof(x),Vector{Float64},Base.RefValue{Int}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ) - ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2]) - if ρ > 4 - r = sqrt((ρ-4)*(ρ-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ)) - @. Δ = mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + o.epsilon) * η * r - else - @. Δ = mt / (1 - βp[1]) * η - end - βp .= βp .* β - t[] += 1 - - return Δ -end - -""" - AdaMax(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -[AdaMax](https://arxiv.org/abs/1412.6980) is a variant of Adam based on the ∞-norm. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AdaMax() - -opt = AdaMax(0.001, (0.9, 0.995)) -``` -""" -mutable struct AdaMax <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -AdaMax(η::Real = 0.001, β::Tuple = (0.9, 0.999), ϵ::Real = EPS) = AdaMax(η, β, ϵ, IdDict()) -AdaMax(η::Real, β::Tuple, state::IdDict) = AdaMax(η, β, EPS, state) - -function apply!(o::AdaMax, x, Δ) - η, β = o.eta, o.beta - - mt, ut, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. ut = max(β[2] * ut, abs(Δ)) - @. Δ = (η/(1 - βp[1])) * mt/(ut + o.epsilon) - βp .= βp .* β - - return Δ -end - -""" - OAdam(η = 0.0001, β::Tuple = (0.5, 0.9), ϵ = $EPS) - -[OAdam](https://arxiv.org/abs/1711.00141) (Optimistic Adam) -is a variant of Adam adding an "optimistic" term suitable for adversarial training. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = OAdam() - -opt = OAdam(0.001, (0.9, 0.995)) -``` -""" -mutable struct OAdam <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -OAdam(η::Real = 0.001, β::Tuple = (0.5, 0.9), ϵ::Real = EPS) = OAdam(η, β, ϵ, IdDict()) -OAdam(η::Real, β::Tuple, state::IdDict) = RMSProp(η, β, EPS, state) - -function apply!(o::OAdam, x, Δ) - η, β = o.eta, o.beta - - mt, vt, Δ_, βp = get!(o.state, x) do - (zero(x), zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x),typeof(x),typeof(x),Vector{Float64}} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ) - @. Δ = -Δ_ - @. Δ_ = η * mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + o.epsilon) - @. Δ += 2Δ_ - βp .= βp .* β - - return Δ -end - -""" - AdaGrad(η = 0.1, ϵ = $EPS) - -[AdaGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. It has -parameter specific learning rates based on how frequently it is updated. -Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. - -# Examples -```julia -opt = AdaGrad() - -opt = AdaGrad(0.001) -``` -""" -mutable struct AdaGrad <: AbstractOptimiser - eta::Float64 - epsilon::Float64 - acc::IdDict -end -AdaGrad(η::Real = 0.1, ϵ::Real = EPS) = AdaGrad(η, ϵ, IdDict()) -AdaGrad(η::Real, state::IdDict) = AdaGrad(η, EPS, state) - -function apply!(o::AdaGrad, x, Δ) - η = o.eta - acc = get!(() -> fill!(similar(x), o.epsilon), o.acc, x)::typeof(x) - @. acc += Δ * conj(Δ) - @. Δ *= η / (√acc + o.epsilon) -end - -""" - AdaDelta(ρ = 0.9, ϵ = $EPS) - -[AdaDelta](https://arxiv.org/abs/1212.5701) is a version of AdaGrad adapting its learning -rate based on a window of past gradient updates. -Parameters don't need tuning. - -# Parameters -- Rho (`ρ`): Factor by which the gradient is decayed at each time step. - -# Examples -```julia -opt = AdaDelta() - -opt = AdaDelta(0.89) -``` -""" -mutable struct AdaDelta <: AbstractOptimiser - rho::Float64 - epsilon::Float64 - state::IdDict{Any, Any} -end -AdaDelta(ρ::Real = 0.9, ϵ::Real = EPS) = AdaDelta(ρ, ϵ, IdDict()) -AdaDelta(ρ::Real, state::IdDict) = AdaDelta(ρ, EPS, state) - -function apply!(o::AdaDelta, x, Δ) - ρ = o.rho - acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)} - @. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ) - # DON'T remove epsilon from numerator - # or even out of the square roots - @. Δ *= √(Δacc + o.epsilon) / √(acc + o.epsilon) - @. Δacc = ρ * Δacc + (1 - ρ) * Δ * conj(Δ) - return Δ -end - -""" - AMSGrad(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -The [AMSGrad](https://openreview.net/forum?id=ryQu7f-RZ) version of the Adam -optimiser. Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AMSGrad() - -opt = AMSGrad(0.001, (0.89, 0.995)) -``` -""" -mutable struct AMSGrad <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64, Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -AMSGrad(η::Real = 0.001, β = (0.9, 0.999), ϵ::Real = EPS) = AMSGrad(η, β, ϵ, IdDict()) -AMSGrad(η::Real, β::Tuple, state::IdDict) = AMSGrad(η, β, EPS, state) - -function apply!(o::AMSGrad, x, Δ) - η, β = o.eta, o.beta - - mt, vt, v̂t = get!(o.state, x) do - (fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon)) - end :: NTuple{3,typeof(x)} - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 - @. v̂t = max(v̂t, vt) - @. Δ = η * mt / (√v̂t + o.epsilon) -end - -""" - NAdam(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -[NAdam](https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ) is a Nesterov variant of Adam. -Parameters don't need tuning. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = NAdam() - -opt = NAdam(0.002, (0.89, 0.995)) -``` -""" -mutable struct NAdam <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64, Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -NAdam(η::Real = 0.001, β = (0.9, 0.999), ϵ::Real = EPS) = NAdam(η, β, ϵ, IdDict()) -NAdam(η::Real, β::Tuple, state::IdDict) = NAdam(η, β, EPS, state) - -function apply!(o::NAdam, x, Δ) - η, β = o.eta, o.beta - - mt, vt, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[o.beta[1], o.beta[2]]) - end :: Tuple{typeof(x),typeof(x),Vector{Float64}} - β1p, β2p = βp - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ) - @. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / (√(vt * β[2] / (1 - β2p)) + o.epsilon) * η - βp .= βp .* β - - return Δ -end - -""" - AdamW(η = 0.001, β::Tuple = (0.9, 0.999), decay = 0) - -[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its -weight decay regularization. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. -- `decay`: Decay applied to weights during optimisation. - -# Examples -```julia -opt = AdamW() - -opt = AdamW(0.001, (0.89, 0.995), 0.1) -``` -""" -AdamW(η = 0.001, β = (0.9, 0.999), decay = 0) = - Optimiser(Adam(η, β), WeightDecay(decay)) - -""" - AdaBelief(η = 0.001, β::Tuple = (0.9, 0.999), ϵ = $EPS) - -The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known -Adam optimiser. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the - second (β2) momentum estimate. - -# Examples -```julia -opt = AdaBelief() - -opt = AdaBelief(0.001, (0.9, 0.8)) -``` -""" -mutable struct AdaBelief <: AbstractOptimiser - eta::Float64 - beta::Tuple{Float64,Float64} - epsilon::Float64 - state::IdDict{Any, Any} -end -AdaBelief(η::Real = 0.001, β = (0.9, 0.999), ϵ::Real = EPS) = AdaBelief(η, β, ϵ, IdDict()) -AdaBelief(η::Real, β::Tuple, state::IdDict) = AdaBelief(η, β, EPS, state) - -function apply!(o::AdaBelief, x, Δ) - η, β = o.eta, o.beta - - mt, st, βp = get!(o.state, x) do - (zero(x), zero(x), Float64[β[1], β[2]]) - end :: Tuple{typeof(x), typeof(x), Vector{Float64}} - - #= st is a variance and can go to zero. This is in contrast to Adam, which uses the - second moment which is usually far enough from zero. This is problematic, since st - can be slightly negative due to numerical error, and the square root below will fail. - Also, if we want to differentiate through the optimiser, √0 is not differentiable. - To protect against this, we add a small number, st -> st + eps2. - The original implementation (https://github.com/juntang-zhuang/Adabelief-Optimizer) - uses the square of Adam's epsilon, which we do here. - See also: https://github.com/juntang-zhuang/Adabelief-Optimizer/issues/61 =# - eps2 = o.epsilon^2 # TODO: make epsilon^2 the default in next breaking release - - @. mt = β[1] * mt + (1 - β[1]) * Δ - @. st = β[2] * st + (1 - β[2]) * (Δ - mt) * conj(Δ - mt) + eps2 - @. Δ = η * mt / (1 - βp[1]) / (√(st / (1 - βp[2])) + eps2) - βp .= βp .* β - - return Δ -end - - -# Compose optimisers - -""" - Optimiser(a, b, c...) - -Combine several optimisers into one; each optimiser produces a modified gradient -that will be fed into the next, and this is finally applied to the parameter as -usual. - -!!! note - This will be replaced by `Optimisers.OptimiserChain` in Flux 0.15. -""" -mutable struct Optimiser <: AbstractOptimiser - os::Vector{Any} -end - -Optimiser(opts::AbstractOptimiser...) = Optimiser(Any[opts...]) - -@forward Optimiser.os Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!, Base.setindex! -@forward Optimiser.os Base.iterate - -Base.getindex(c::Optimiser, i::AbstractArray) = Optimiser(c.os[i]...) - -function apply!(o::Optimiser, x, Δ) - for opt in o.os - Δ = apply!(opt, x, Δ) - end - return Δ -end - -""" - InvDecay(γ = 0.001) - -Apply inverse time decay to an optimiser, so that the effective step size at -iteration `n` is `eta / (1 + γ * n)` where `eta` is the initial step size. -The wrapped optimiser's step size is not modified. - -See also the [Scheduling Optimisers](@ref) section of the docs -for more general scheduling techniques. - -# Examples - -`InvDecay` is typically composed with other optimisers -as the last transformation of the gradient: - -```julia -# Inverse decay of the learning rate -# with starting value 0.001 and decay coefficient 0.01. -opt = Optimiser(Adam(1f-3), InvDecay(1f-2)) -``` -""" -mutable struct InvDecay <: AbstractOptimiser - gamma::Float64 - state::IdDict{Any, Int} -end - -InvDecay(γ = 0.001) = InvDecay(γ, IdDict{Any, Int}()) - -function apply!(o::InvDecay, x, Δ) - γ = o.gamma - n = get!(o.state, x, 1) - Δ .*= 1 / (1 + γ * n) - o.state[x] = n + 1 - return Δ -end - -""" - ExpDecay(η = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4, start = 1) - -Discount the learning rate `η` by the factor `decay` every `decay_step` steps till -a minimum of `clip`. - -# Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating - the weights. -- `decay`: Factor by which the learning rate is discounted. -- `decay_step`: Schedule decay operations by setting the number of steps between - two decay operations. -- `clip`: Minimum value of learning rate. -- 'start': Step at which the decay starts. - - -See also the [Scheduling Optimisers](@ref) section of the docs -for more general scheduling techniques. - -# Examples - -`ExpDecay` is typically composed with other optimisers -as the last transformation of the gradient: -```julia -opt = Optimiser(Adam(), ExpDecay(1.0)) -``` -Note: you may want to start with `η=1` in `ExpDecay` when combined with other -optimisers (`Adam` in this case) that have their own learning rate. -""" -mutable struct ExpDecay <: AbstractOptimiser - eta::Float64 - decay::Float64 - step::Int64 - clip::Float64 - start::Int64 - current::IdDict -end - -ExpDecay(opt = 0.001, decay = 0.1, decay_step = 1000, clip = 1e-4, start = 0) = - ExpDecay(opt, decay, decay_step, clip, start, IdDict()) - -function apply!(o::ExpDecay, x, Δ) - η, s, decay, start = o.eta, o.step, o.decay, o.start - n = o.current[x] = get(o.current, x, 0) + 1 - if n > start && n % s == 0 && count(x -> x > start && x % s == 0, values(o.current)) == 1 - η = max(η * decay, o.clip) - o.eta = η - end - @. Δ *= η -end - -""" - WeightDecay(λ = 0) - -Decay weights by ``λ``. -Typically composed with other optimisers as the first transformation to the gradient, -making it equivalent to adding ``L_2`` regularization -with coefficient ``λ`` to the loss. - -# Examples - -```julia -opt = Optimiser(WeightDecay(1f-4), Adam()) -``` -""" -mutable struct WeightDecay <: AbstractOptimiser - wd::Real -end - -WeightDecay() = WeightDecay(0) - -function apply!(o::WeightDecay, x, Δ) - wd = o.wd - @. Δ += wd * x -end - -""" - SignDecay(λ = 1e-3) - -Version of `WeightDecay` which implements ``L_1`` regularisation, -when composed with other optimisers as the first transformation to the gradient. - -# Examples - -```julia -opt = Optimiser(SignDecay(1e-4), Adam()) -``` -""" -mutable struct SignDecay <: AbstractOptimiser - lambda::Float32 -end - -SignDecay() = SignDecay(1f-3) - -function apply!(o::SignDecay, x, Δ) - λ = o.lambda - @. Δ += λ * sign(x) -end - -""" - ClipValue(thresh) - -Clip gradients when their absolute value exceeds `thresh`. - -!!! note - This will be replaced by `Optimisers.ClipGrad` in Flux 0.15. -""" -mutable struct ClipValue{T} <: AbstractOptimiser - thresh::T -end - -apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh) - -""" - ClipNorm(thresh) - -Clip gradients when their L2 norm exceeds `thresh`. -""" -mutable struct ClipNorm{T} <: AbstractOptimiser - thresh::T -end - -function apply!(o::ClipNorm, x, Δ) - Δnrm = norm(Δ) - if Δnrm > o.thresh - rmul!(Δ, o.thresh / Δnrm) - end - return Δ -end diff --git a/src/optimise/train.jl b/src/optimise/train.jl deleted file mode 100644 index 111207d479..0000000000 --- a/src/optimise/train.jl +++ /dev/null @@ -1,102 +0,0 @@ -using ProgressLogging: @progress, @withprogress, @logprogress -import Zygote: Params, gradient, withgradient - -# Add methods to Optimisers.jl's function, so that there is just one Flux.update! -# for both explicit and implicit parameters. -import Optimisers.update! - -""" - update!(opt, p, g) - update!(opt, ps::Params, gs) - -Perform an update step of the parameters `ps` (or the single parameter `p`) -according to optimiser `opt::AbstractOptimiser` and the gradients `gs` (the gradient `g`). - -As a result, the parameters are mutated and the optimiser's internal state may change. -The gradient could be mutated as well. - -!!! compat "Deprecated" - This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.15. - The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain. -""" -function update!(opt::AbstractOptimiser, x::AbstractArray, x̄) - x̄r = copyto!(similar(x̄), x̄) # Flux.Optimise assumes it can mutate the gradient. This is not - # safe due to aliasing, nor guaranteed to be possible, e.g. Fill. - x .-= apply!(opt, x, x̄r) -end - -function update!(opt::AbstractOptimiser, xs::Params, gs) - for x in xs - isnothing(gs[x]) && continue - update!(opt, x, gs[x]) - end -end - -# Callback niceties -call(f, xs...) = f(xs...) -runall(f) = f -runall(fs::AbstractVector) = () -> foreach(call, fs) - - -batchmemaybe(x) = tuple(x) -batchmemaybe(x::Tuple) = x - -""" - train!(loss, pars::Params, data, opt::AbstractOptimiser; [cb]) - -Uses a `loss` function and training `data` to improve the -model's parameters according to a particular optimisation rule `opt`. - -!!! compat "Deprecated" - This method with implicit `Params` will be removed from Flux 0.15. - It should be replaced with the explicit method `train!(loss, model, data, opt)`. - -For each `d in data`, first the gradient of the `loss` is computed like this: -``` - gradient(() -> loss(d...), pars) # if d isa Tuple - gradient(() -> loss(d), pars) # otherwise -``` -Here `pars` is produced by calling [`Flux.params`](@ref) on your model. -(Or just on the layers you want to train, like `train!(loss, params(model[1:end-2]), data, opt)`.) -This is the "implicit" style of parameter handling. - -This gradient is then used by optimiser `opt` to update the parameters: -``` - update!(opt, pars, grads) -``` -The optimiser should be from the `Flux.Optimise` module (see [Optimisers](@ref)). -Different optimisers can be combined using [`Flux.Optimise.Optimiser`](@ref Flux.Optimiser). - -This training loop iterates through `data` once. -It will stop with a `DomainError` if the loss is `NaN` or infinite. - -You can use use `train!` inside a for loop to do this several times, or -use for instance `Itertools.ncycle` to make a longer `data` iterator. - -## Callbacks - -[Callbacks](@ref) are given with the keyword argument `cb`. -For example, this will print "training" every 10 seconds (using [`Flux.throttle`](@ref)): -``` - train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10)) -``` - -Multiple callbacks can be passed to `cb` as array. -""" -function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ()) - cb = runall(cb) - itrsz = Base.IteratorSize(typeof(data)) - n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0 - @withprogress for (i, d) in enumerate(data) - l, gs = withgradient(ps) do - loss(batchmemaybe(d)...) - end - if !isfinite(l) - throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) - end - update!(opt, ps, gs) - cb() - - @logprogress iszero(n) ? nothing : i / n - end -end diff --git a/src/outputsize.jl b/src/outputsize.jl index 5d6132d059..c413405048 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -302,8 +302,6 @@ function ChainRulesCore.rrule(::typeof(striplazy), m) striplazy(m), _ -> error("striplazy should never be used within a gradient") end -params!(p::Params, x::LazyLayer, seen = IdSet()) = error("LazyLayer should never be used within params(m). Call striplazy(m) first.") - Functors.functor(::Type{<:LazyLayer}, x) = error("LazyLayer should not be walked with Functors.jl, as the arrays which Flux.gpu wants to move may not exist yet.") function Base.show(io::IO, l::LazyLayer) diff --git a/src/train.jl b/src/train.jl index fd21e53f17..ffde2c640b 100644 --- a/src/train.jl +++ b/src/train.jl @@ -4,26 +4,18 @@ using LinearAlgebra using Optimisers: Optimisers using Functors: fmap, fmapstructure using ..Flux: Flux # used only in docstring -import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions export setup, train! using ProgressLogging: @progress, @withprogress, @logprogress -using Zygote: Zygote, Params +using Zygote: Zygote """ opt_state = setup(rule, model) This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!). -It differs from `Optimisers.setup` in that it: -* has one extra check for mutability (since Flux expects to mutate the model in-place, - while Optimisers.jl is designed to return an updated model) -* has methods which accept Flux's old optimisers, and convert them. - (The old `Flux.Optimise.Adam` and new `Optimisers.Adam` are distinct types.) - -!!! compat "New" - This function was added in Flux 0.13.9. It was not used by the old "implicit" - interface, using `Flux.Optimise` module and [`Flux.params`](@ref). +It differs from `Optimisers.setup` in that it has one extra check for mutability (since Flux expects to mutate the model in-place, + while Optimisers.jl is designed to return an updated model). # Example ```jldoctest @@ -53,7 +45,7 @@ function setup(rule::Optimisers.AbstractRule, model) Optimisers.maywrite(x) || error("""model must be fully mutable for `train!` to work, got `x::$(typeof(x))`. If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::$(typeof(x))) = true`""") end - state + return state end """ @@ -86,22 +78,9 @@ It adds only a few features to the loop above: * Stop with a `DomainError` if the loss is infinite or `NaN` at any point. * Show a progress bar using [`@withprogress`](https://github.com/JuliaLogging/ProgressLogging.jl). - -!!! compat "New" - This method was added in Flux 0.13.9. - It has significant changes from the one used by Flux ≤ 0.13: - * It now takes the `model` itself, not the result of [`Flux.params`](@ref). - (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) - * Instead of `loss` being a function which accepts only the data, - now it must also accept the `model` itself, as the first argument. - * `opt_state` should be the result of [`Flux.setup`](@ref). Using an optimiser - such as `Adam()` without this step should give you a warning. - * Callback functions are not supported. - (But any code can be included in the above `for` loop.) """ -function train!(loss, model, data, opt; cb = nothing) - isnothing(cb) || error("""train! does not support callback functions. - For more control use a loop with `gradient` and `update!`.""") +function train!(loss, model, data, opt) + @withprogress for (i,d) in enumerate(data) d_splat = d isa Tuple ? d : (d,) l, gs = Zygote.withgradient(m -> loss(m, d_splat...), model) @@ -114,8 +93,8 @@ function train!(loss, model, data, opt; cb = nothing) end # This method let you use Optimisers.Descent() without setup, when there is no state -function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) - train!(loss, model, data, _rule_to_state(model, rule); cb) +function train!(loss, model, data, rule::Optimisers.AbstractRule) + return train!(loss, model, data, _rule_to_state(model, rule)) end function _rule_to_state(model, rule::Optimisers.AbstractRule) @@ -127,7 +106,7 @@ function _rule_to_state(model, rule::Optimisers.AbstractRule) Please run `opt = Flux.setup($name(), model)` and pass this `opt` to `train!`.""" leaf maxlog=1 _id=warn_id leaf end - state + return state end end # module Train diff --git a/src/utils.jl b/src/utils.jl index 1f8230c522..f49b900f2f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,7 +13,7 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r # Examples ```jldoctest -julia> layer = Dense(10, 20); +julia> layer = Dense(10 => 20); julia> Flux.nfan(size(layer.weight)) (10, 20) @@ -580,9 +580,9 @@ over specific modules or subsets of the parameters # Examples ```jldoctest -julia> m1 = Chain(Dense(28^2, 64), BatchNorm(64, relu)); +julia> m1 = Chain(Dense(28^2 => 64), BatchNorm(64, relu)); -julia> m2 = Chain(m1, Dense(64, 10)) +julia> m2 = Chain(m1, Dense(64 => 10)) Chain( Chain( Dense(784 => 64), # 50_240 parameters diff --git a/test/data.jl b/test/data.jl index b97c4dae80..1274bdcf1d 100644 --- a/test/data.jl +++ b/test/data.jl @@ -80,18 +80,20 @@ using Random # test interaction with `train!` θ = ones(2) X = zeros(2, 10) - loss(x) = sum((x .- θ).^2) + loss(θ, x) = sum((x .- θ).^2) d = DataLoader(X) - Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1)) + opt = Flux.setup(Descent(0.1), θ) + Flux.train!(loss, θ, ncycle(d, 10), opt) @test norm(θ) < 1e-4 # test interaction with `train!` θ = zeros(2) X = ones(2, 10) Y = fill(2, 10) - loss(x, y) = sum((y - x'*θ).^2) + loss(θ, x, y) = sum((y - x'*θ).^2) d = DataLoader((X, Y)) - Flux.train!(loss, Params([θ]), ncycle(d, 10), Descent(0.1)) + opt = Flux.setup(Descent(0.1), θ) + Flux.train!(loss, θ, ncycle(d, 10), opt) @test norm(θ .- 1) < 1e-10 # specify the rng diff --git a/test/ext_amdgpu/basic.jl b/test/ext_amdgpu/basic.jl index b7bbb286e5..0758a9d88a 100644 --- a/test/ext_amdgpu/basic.jl +++ b/test/ext_amdgpu/basic.jl @@ -19,7 +19,7 @@ end end @testset "Chain of Dense layers" begin - m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32 + m = Chain(Dense(10 => 5, tanh), Dense(5 => 2), softmax) |> f32 x = rand(Float32, 10, 10) gpu_autodiff_test(m, x) end @@ -69,7 +69,7 @@ end end @testset "Restructure" begin - m = Dense(1, 1) |> Flux.gpu + m = Dense(1 => 1) |> Flux.gpu θ, m̂ = Flux.destructure(m) foo(x) = sum(re(p)(x)) diff --git a/test/ext_cuda/cuda.jl b/test/ext_cuda/cuda.jl index bbfd2854ba..6d34af9c65 100644 --- a/test/ext_cuda/cuda.jl +++ b/test/ext_cuda/cuda.jl @@ -16,10 +16,10 @@ using SparseArrays: sparse, SparseMatrixCSC, AbstractSparseArray @test cx isa Flux.OneHotMatrix && cx.indices isa CuArray @test (cx .+ 1) isa CuArray - m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) + m = Chain(Dense(10 => 5, tanh), Dense(5 => 2), softmax) cm = gpu(m) - @test all(p isa CuArray for p in Flux.params(cm)) + @test all(p isa CuArray for p in Flux.trainables(cm)) @test cm(gpu(rand(10, 10))) isa CuArray{Float32,2} xs = rand(5, 5) @@ -70,7 +70,7 @@ end end @testset "restructure gpu" begin - dudt = Dense(1,1) |> gpu + dudt = Dense(1 => 1) |> gpu p,re = Flux.destructure(dudt) foo(x) = sum(re(p)(x)) @test gradient(foo, cu(rand(1)))[1] isa CuArray diff --git a/test/ext_cuda/layers.jl b/test/ext_cuda/layers.jl index e59ff35aa4..cdfefc972f 100644 --- a/test/ext_cuda/layers.jl +++ b/test/ext_cuda/layers.jl @@ -247,7 +247,7 @@ end @testset "Two-streams Bilinear" begin x = zeros(Float32,10,9) |> gpu y = zeros(Float32,2,9) |> gpu - b = Flux.Bilinear(10, 2, 3) |> gpu + b = Flux.Bilinear((10, 2) => 3) |> gpu @test size(b(x,y)) == (3,9) @test sum(abs2, b(x,y)) ≈ 0f0 gs_gpu = gradient(() -> sum(abs2.(b(x, y))), params(b)) @@ -268,7 +268,7 @@ end @testset "vararg input" begin inputs = (randn(10), randn(5), randn(4)) .|> gpu - layer = Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2)) |> gpu + layer = Parallel(+, Dense(10 => 2), Dense(5 => 2), Dense(4 => 2)) |> gpu @test size(layer(inputs)) == (2,) end diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 36212bb10f..f444889a6c 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -106,9 +106,9 @@ end end models_xs = [ - (Dense(2, 4), randn(Float32, 2), "Dense"), - (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(Float32, 2), "Chain(Dense, Dense)"), - (f64(Chain(Dense(2, 4), Dense(4, 2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"), + (Dense(2 => 4), randn(Float32, 2), "Dense"), + (Chain(Dense(2 => 4, relu), Dense(4 => 3)), randn(Float32, 2), "Chain(Dense, Dense)"), + (f64(Chain(Dense(2 => 4), Dense(4 => 2))), randn(Float64, 2, 1), "f64(Chain(Dense, Dense))"), (Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"), (Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"), (Chain(Conv((3, 3), 2 => 3, relu), Conv((3, 3), 3 => 1, relu)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"), @@ -124,7 +124,7 @@ end for (model, x, name) in models_xs @testset "check grad $name" begin - println("testing $name") + # println("testing $name") test_enzyme_grad(loss, model, x) end end @@ -170,7 +170,7 @@ end for (model, x, name) in models_xs @testset "check grad $name" begin - println("testing $name") + # println("testing $name") broken = false try test_enzyme_grad(loss, model, x) diff --git a/test/ext_metal/basic.jl b/test/ext_metal/basic.jl index 9e4a9ef9cb..323145c7ad 100644 --- a/test/ext_metal/basic.jl +++ b/test/ext_metal/basic.jl @@ -20,7 +20,7 @@ end end @testset "Chain of Dense layers" begin - m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) + m = Chain(Dense(10 => 5, tanh), Dense(5 => 2), softmax) x = rand(Float32, 10, 10) @test (m|>gpu)(x|>gpu) isa MtlArray{Float32, 2} gpu_autodiff_test(m, x) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 95da13f0c9..964aa929a8 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -16,41 +16,41 @@ using Flux: activations end @testset "Chain" begin - @test_nowarn Chain(Dense(10, 5, σ), Dense(5, 2))(randn32(10)) - @test_throws DimensionMismatch Chain(Dense(10, 5, σ),Dense(2, 1))(randn32(10)) + @test_nowarn Chain(Dense(10 => 5, σ), Dense(5 => 2))(randn32(10)) + @test_throws DimensionMismatch Chain(Dense(10 => 5, σ),Dense(2 => 1))(randn32(10)) # numeric test should be put into testset of corresponding layer - @test_nowarn Chain(first = Dense(10, 5, σ), second = Dense(5, 2))(randn32(10)) - m = Chain(first = Dense(10, 5, σ), second = Dense(5, 2)) + @test_nowarn Chain(first = Dense(10 => 5, σ), second = Dense(5 => 2))(randn32(10)) + m = Chain(first = Dense(10 => 5, σ), second = Dense(5 => 2)) @test m[:first] == m[1] @test m[1:2] == m @test m == m @test m == fmap(identity, m) # does not forget names - @test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name + @test_throws ArgumentError Chain(layers = Dense(10 => 10), two = identity) # reserved name - @test_nowarn Chain([Dense(10, 5, σ), Dense(5, 2)])(randn(Float32, 10)) # vector of layers + @test_nowarn Chain([Dense(10 => 5, σ), Dense(5 => 2)])(randn(Float32, 10)) # vector of layers - c = Chain(Dense(10, 5, σ), Dense(5, 2), Dense(2, 1, relu)) + c = Chain(Dense(10 => 5, σ), Dense(5 => 2), Dense(2 => 1, relu)) @test c[1] == c[begin] @test c[3] == c[end] end @testset "Activations" begin - c = Chain(Dense(3,5,relu), Dense(5,1,relu)) + c = Chain(Dense(3 => 5,relu), Dense(5 => 1,relu)) X = Float32.([1.0; 1.0; 1.0]) - @test_nowarn gradient(()->Flux.activations(c, X)[2][1], Flux.params(c)) + @test_nowarn gradient(c -> Flux.activations(c, X)[2][1], c) c2 = Chain(enc = c[1], dec = c[2]) @test Flux.activations(c, X) == Flux.activations(c2, X) - @test_nowarn gradient(()->Flux.activations(c2, X)[2][1], Flux.params(c2)) + @test_nowarn gradient(c -> Flux.activations(c, X)[2][1], c2) end @testset "Dense" begin @testset "constructors" begin - @test size(Dense(10, 100).weight) == (100, 10) - @test size(Dense(10, 100).bias) == (100,) + @test size(Dense(10 => 100).weight) == (100, 10) + @test size(Dense(10 => 100).bias) == (100,) @test Dense(rand(100,10), rand(100)).σ == identity @test Dense(rand(100,10)).σ == identity @@ -60,12 +60,12 @@ using Flux: activations @test Dense(rand(Float16, 100,10), true).bias isa Vector{Float16} # creates matching type @test Dense(rand(Float16, 100,10), rand(100)).bias isa Vector{Float16} # converts to match - @test Dense(3,4; init=Base.randn, bias=true).bias isa Vector{Float64} - @test Dense(3,4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64} + @test Dense(3 => 4; init=Base.randn, bias=true).bias isa Vector{Float64} + @test Dense(3 => 4; init=Base.randn, bias=[1,2,3,4]).bias isa Vector{Float64} - @test_throws MethodError Dense(10, 10.5) - @test_throws MethodError Dense(10, 10.5, tanh) - @test_throws DimensionMismatch Dense(3,4; bias=rand(5)) + @test_throws MethodError Dense(10 => 10.5) + @test_throws MethodError Dense(10 => 10.5, tanh) + @test_throws DimensionMismatch Dense(3 => 4; bias=rand(5)) @test_throws DimensionMismatch Dense(rand(4,3), rand(5)) @test_throws MethodError Dense(rand(5)) @test_throws MethodError Dense(rand(5), rand(5)) @@ -76,18 +76,18 @@ using Flux: activations @test_throws DimensionMismatch Dense(10 => 5)(randn32(1)) @test_throws MethodError Dense(10 => 5)(1) # avoid broadcasting @test_throws MethodError Dense(10 => 5).(randn32(10)) # avoid broadcasting - @test size(Dense(10, 5)(randn(10))) == (5,) - @test size(Dense(10, 5)(randn(10,2))) == (5,2) - @test size(Dense(10, 5)(randn(10,2,3))) == (5,2,3) - @test size(Dense(10, 5)(randn(10,2,3,4))) == (5,2,3,4) - @test_throws DimensionMismatch Dense(10, 5)(randn(11,2,3)) + @test size(Dense(10 => 5)(randn(10))) == (5,) + @test size(Dense(10 => 5)(randn(10,2))) == (5,2) + @test size(Dense(10 => 5)(randn(10,2,3))) == (5,2,3) + @test size(Dense(10 => 5)(randn(10,2,3,4))) == (5,2,3,4) + @test_throws DimensionMismatch Dense(10 => 5)(randn(11,2,3)) end @testset "zeros" begin - @test Dense(10, 1, identity, init = ones)(ones(10,1)) == 10*ones(1, 1) - @test Dense(10, 1, identity, init = ones)(ones(10,2)) == 10*ones(1, 2) - @test Dense(10, 2, identity, init = ones)(ones(10,1)) == 10*ones(2, 1) - @test Dense(10, 2, identity, init = ones)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] - @test Dense(10, 2, identity, init = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] + @test Dense(10 => 1, identity, init = ones)(ones(10,1)) == 10*ones(1, 1) + @test Dense(10 => 1, identity, init = ones)(ones(10,2)) == 10*ones(1, 2) + @test Dense(10 => 2, identity, init = ones)(ones(10,1)) == 10*ones(2, 1) + @test Dense(10 => 2, identity, init = ones)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] + @test Dense(10 => 2, identity, init = ones, bias = false)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20] end @testset "type matching" begin d1 = Dense(2 => 3) @@ -156,9 +156,9 @@ using Flux: activations @test mo(input) == target end - @testset "params" begin - mo = Maxout(()->Dense(32, 64), 4) - ps = Flux.params(mo) + @testset "trainables" begin + mo = Maxout(()->Dense(32 => 64), 4) + ps = Flux.trainables(mo) @test length(ps) == 8 #4 alts, each with weight and bias end end @@ -171,14 +171,14 @@ using Flux: activations @testset "concat size" begin input = randn(10, 2) - @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) + @test size(SkipConnection(Dense(10 => 10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) end end @testset "Bilinear" begin @testset "SkipConnection recombinator" begin - d = Dense(10, 10) - b = Flux.Bilinear(10, 10, 5) + d = Dense(10 => 10) + b = Flux.Bilinear((10, 10) => 5) x = randn(Float32,10,9) sc = SkipConnection(d, b) @test size(sc(x)) == (5,9) @@ -187,16 +187,16 @@ using Flux: activations @testset "Two-streams zero sum" begin x = zeros(Float32,10,9) y = zeros(Float32,2,9) - b = Flux.Bilinear(10, 2, 3) + b = Flux.Bilinear((10, 2) => 3) @test size(b(x,y)) == (3,9) @test sum(abs2, b(x,y)) == 0f0 end @testset "Inner interactions" begin x = randn(Float32,11,7) - b = Flux.Bilinear(11, 11, 3) + b = Flux.Bilinear((11, 11) => 3) @test size(b(x)) == (3,7) - @test_nowarn gs = gradient(() -> sum(abs2.(b(x))), params(b)) + @test_nowarn gs = gradient(b -> sum(abs2.(b(x))), b) end @testset "constructors" begin @@ -212,7 +212,7 @@ using Flux: activations @test b3.bias isa Vector{Float16} @test size(b3(rand(4), rand(5))) == (3,) - b4 = Flux.Bilinear(3,3,7; bias=1:7, init=Flux.zeros32) + b4 = Flux.Bilinear((3,3) => 7; bias=1:7, init=Flux.zeros32) @test_skip b4.bias isa Vector{Float32} @test_throws ArgumentError Flux.Bilinear(rand(3)) # expects a 3-array @@ -229,25 +229,25 @@ using Flux: activations @testset "concat size" begin input = randn(10, 2) - @test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4) - @test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4) + @test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10 => 10), identity)(input)) == (10, 4) + @test size(Parallel(hcat, one = Dense(10 => 10), two = identity)(input)) == (10, 4) end @testset "vararg input" begin inputs = randn(10), randn(5), randn(4) - @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) - @test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,) + @test size(Parallel(+, Dense(10 => 2), Dense(5 => 2), Dense(4 => 2))(inputs)) == (2,) + @test size(Parallel(+; a = Dense(10 => 2), b = Dense(5 => 2), c = Dense(4 => 2))(inputs)) == (2,) @test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs @test Parallel(+, sin, cos)(pi/2) ≈ 1 end @testset "named access" begin - m = Parallel(hcat, one = Dense(10, 10), two = identity) + m = Parallel(hcat, one = Dense(10 => 10), two = identity) @test m[1] == m[:one] @test m[1:2] == m - @test_throws ArgumentError Parallel(hcat, layers = Dense(10, 10), two = identity) # reserved names - @test_throws ArgumentError Parallel(hcat, connection = Dense(10, 10), two = identity) + @test_throws ArgumentError Parallel(hcat, layers = Dense(10 => 10), two = identity) # reserved names + @test_throws ArgumentError Parallel(hcat, connection = Dense(10 => 10), two = identity) @test m == fmap(identity, m) # does not forget names @@ -306,7 +306,7 @@ using Flux: activations @testset "Embedding" begin vocab_size, embed_size = 10, 4 - m = Embedding(vocab_size, embed_size) + m = Embedding(vocab_size => embed_size) @test size(m.weight) == (embed_size, vocab_size) # one index @@ -416,7 +416,7 @@ using Flux: activations end @testset "second derivatives" begin - m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) + m1 = Chain(Dense(3 => 4,tanh; bias=false), Dense(4 => 2)) @test Zygote.hessian_dual(sum∘m1, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1, [1,2,3]) m1v = Chain([m1[1], m1[2]]) # vector of layers @@ -424,28 +424,28 @@ end @test Zygote.hessian_dual(sum∘m1v, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1v, [1,2,3]) # NNlib's softmax gradient writes in-place - m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax) + m2 = Chain(Dense(3 => 4,tanh), Dense(4 => 2), softmax) @test_broken Zygote.hessian_dual(sum∘m2, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m2, [1,2,3]) # https://github.com/FluxML/NNlib.jl/issues/362 - m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2)) + m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2 => 2)) x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3) @test Zygote.hessian_dual(sum∘m3, x3) ≈ Zygote.hessian_reverse(sum∘m3, x3) end @testset "gradients of Chain{Vector}" begin - m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2)) + m1 = Chain(Dense(3 => 4,tanh; bias=false), Dense(4 => 2)) m1v = Chain([m1[1], m1[2]]) - @test sum(length, params(m1)) == sum(length, params(m1v)) + @test sum(length, trainables(m1)) == sum(length, trainables(m1v)) x1 = randn(Float32,3,5) @test m1(x1) ≈ m1v(x1) y1 = rand(Bool,2,5) - g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1)) - g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v)) - @test g1[m1[1].weight] ≈ g1v[m1v[1].weight] - @test g1[m1[2].bias] ≈ g1v[m1v[2].bias] + g1 = gradient(m1 -> Flux.Losses.logitcrossentropy(m1(x1), y1), m1)[1] + g1v = gradient(m1v -> Flux.Losses.logitcrossentropy(m1v(x1), y1), m1v)[1] + @test g1.layers[1].weight ≈ g1v.layers[1].weight + @test g1.layers[1].bias === g1v.layers[1].bias === nothing @test Flux.destructure(m1)[1] ≈ Flux.destructure(m1v)[1] z1 = rand(22); @@ -455,14 +455,14 @@ end @testset "PairwiseFusion" begin x = (rand(1, 10), rand(30, 10)) - layer = PairwiseFusion(+, Dense(1, 30), Dense(30, 10)) + layer = PairwiseFusion(+, Dense(1 => 30), Dense(30 => 10)) y = layer(x) @test length(y) == 2 @test size(y[1]) == (30, 10) @test size(y[2]) == (10, 10) x = rand(1, 10) - layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1)) + layer = PairwiseFusion(.+, Dense(1 => 10), Dense(10 => 1)) y = layer(x) @test length(y) == 2 @test size(y[1]) == (10, 10) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 5b4b80d918..a72e8ffb28 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -31,40 +31,40 @@ end Conv((2, 2), 16 => 8, relu), MaxPool((2,2)), x -> reshape(x, :, size(x, 4)), - Dense(288, 10), softmax) + Dense(288 => 10), softmax) @test size(m(r)) == (10, 5) # Test bias switch - bias = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3)) + m2 = Conv(ones(Float32, 2, 2, 1, 3), ones(Float32, 3)) ip = zeros(Float32, 28,28,1,1) - op = bias(ip) + op = m2(ip) @test sum(op) == prod(size(op)) @testset "No bias mapped through $lmap" for lmap in (identity, cpu, f32) - bias = Conv((2,2), 1=>3, bias = false) |> lmap - op = bias(ip) + m3 = Conv((2,2), 1=>3, bias = false) |> lmap + op = m3(ip) @test sum(op) ≈ 0.f0 - gs = gradient(() -> sum(bias(ip)), Flux.params(bias)) - @test bias.bias ∉ gs.params + gs = gradient(m -> sum(m(ip)), m3)[1] + @test gs.bias === nothing end # Train w/o bias and make sure no convergence happens # when only bias can be converged - bias = Conv((2, 2), 1=>3, bias = false); + m4 = Conv((2, 2), 1=>3, bias = false); ip = zeros(Float32, 28,28,1,1) op = zeros(Float32, 27,27,3,1) .+ 2.f0 - opt = Descent() + opt_state = Flux.setup(Descent(), m4) for _ = 1:10^3 - gs = gradient(Flux.params(bias)) do - Flux.Losses.mse(bias(ip), op) - end - Flux.Optimise.update!(opt, params(bias), gs) + gs = gradient(m4) do m + Flux.mse(m(ip), op) + end[1] + Flux.update!(opt_state, m4, gs) end - @test Flux.Losses.mse(bias(ip), op) ≈ 4.f0 + @test Flux.Losses.mse(m4(ip), op) ≈ 4.f0 @testset "Grouped Conv" begin ip = rand(Float32, 28, 100, 2) @@ -164,11 +164,11 @@ end m = ConvTranspose((3,3), 1=>1) # Test that the gradient call does not throw: #900 - @test gradient(()->sum(m(x)), Flux.params(m)) isa Flux.Zygote.Grads + gradient(m -> sum(m(x)), m) x = zeros(Float32, 5, 5, 2, 4) - m = ConvTranspose((3,3), 2=>3) - @test gradient(()->sum(m(x)), params(m)) isa Flux.Zygote.Grads + m = ConvTranspose((3, 3), 2 => 3) + gradient(m -> sum(m(x)), m) # test ConvTranspose supports groups argument x = randn(Float32, 10, 10, 2, 3) @@ -178,7 +178,7 @@ end m2 = ConvTranspose((3,3), 2=>4, groups=2, pad=SamePad()) @test size(m2.weight) == (3,3,2,2) @test size(m1(x)) == size(m2(x)) - @test gradient(()->sum(m2(x)), params(m2)) isa Flux.Zygote.Grads + gradient(m2 -> sum(m2(x)), m2) x = randn(Float32, 10, 2,1) m = ConvTranspose((3,), 2=>4, pad=SamePad(), groups=2) @@ -213,7 +213,7 @@ end CrossCor((2, 2), 16=>8, relu; bias=false), MaxPool((2,2)), x -> reshape(x, :, size(x, 4)), - Dense(288, 10), softmax) + Dense(288 => 10), softmax) @test size(m(r)) == (10, 5) @test y(x) != Conv(w, [0.0])(x) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 35f11a4adc..8cccaa7d05 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -32,7 +32,7 @@ evalwgrad(f, x...) = pullback(f, x...)[1] @test count(iszero, y2) == 0 x = rand(Float32, 100) - m = Chain(Dense(100,100), + m = Chain(Dense(100 => 100), Dropout(0.9; rng_kwargs...)) y = evalwgrad(m, x) @test count(a->a == 0, y) > 50 @@ -129,7 +129,7 @@ end 2.0 4.0 6.0] @test Flux.hasaffine(m) == true - @test length(Flux.params(m)) == 2 + @test length(Flux.trainables(m)) == 2 @test m.β == [0, 0] # initβ(2) @test m.γ == [1, 1] # initγ(2) @@ -211,9 +211,9 @@ end @inferred m(x) end - @test length(Flux.params(BatchNorm(10))) == 2 - @test length(Flux.params(BatchNorm(10, affine=true))) == 2 - @test length(Flux.params(BatchNorm(10, affine=false))) == 0 + @test length(Flux.trainables(BatchNorm(10))) == 2 + @test length(Flux.trainables(BatchNorm(10, affine=true))) == 2 + @test length(Flux.trainables(BatchNorm(10, affine=false))) == 0 @test BatchNorm(5; active=true).active === true @test_throws Exception BatchNorm(5; active=:something_else) @@ -224,7 +224,7 @@ end let m = InstanceNorm(2; affine=true, track_stats=true), sizes = (3, 2, 2), x = reshape(collect(1:prod(sizes)), sizes) - @test length(Flux.params(m)) == 2 + @test length(Flux.trainables(m)) == 2 x = Float32.(x) @test m.β == [0, 0] # initβ(2) @test m.γ == [1, 1] # initγ(2) @@ -287,7 +287,7 @@ end x = reshape(collect(1:prod(sizes)), sizes) @test Flux.hasaffine(m) == true - @test length(Flux.params(m)) == 2 + @test length(Flux.trainables(m)) == 2 x = Float64.(x) y = m(x) μ = mean(x, dims=1) @@ -300,7 +300,7 @@ end let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), x = reshape(collect(1:prod(sizes)), sizes) @test Flux.hasaffine(m) == false - @test length(Flux.params(m)) == 0 + @test length(Flux.trainables(m)) == 0 x = Float64.(x) y = m(x) @@ -345,9 +345,9 @@ end @inferred m(x) end - @test length(Flux.params(InstanceNorm(10))) == 0 - @test length(Flux.params(InstanceNorm(10, affine=true))) == 2 - @test length(Flux.params(InstanceNorm(10, affine=false))) == 0 + @test length(Flux.trainables(InstanceNorm(10))) == 0 + @test length(Flux.trainables(InstanceNorm(10, affine=true))) == 2 + @test length(Flux.trainables(InstanceNorm(10, affine=false))) == 0 @test InstanceNorm(5; active=true).active === true @test_throws Exception InstanceNorm(5; active=:something_else) @@ -370,10 +370,10 @@ end m = LayerNorm((2,3,4)) @test Flux.hasaffine(m) == true - @test length(Flux.params(m)) == 2 + @test length(Flux.trainables(m)) == 2 m = LayerNorm((2,3,4), affine=false) @test Flux.hasaffine(m) == false - @test length(Flux.params(m)) == 0 + @test length(Flux.trainables(m)) == 0 end @testset "GroupNorm" begin @@ -383,7 +383,7 @@ end let m = GroupNorm(4,2), sizes = (3,4,2), x = reshape(collect(1:prod(sizes)), sizes) - @test length(Flux.params(m)) == 2 + @test length(Flux.trainables(m)) == 2 x = Float32.(x) @test m.β == [0, 0, 0, 0] # initβ(32) @test m.γ == [1, 1, 1, 1] # initγ(32) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 7df8b0d4c2..1fd18266ef 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,39 +1,8 @@ using LinearAlgebra -@testset "RNN gradients-implicit" begin - layer = Flux.Recur(Flux.RNNCell(1, 1, identity)) - layer.cell.Wi .= 5.0 - layer.cell.Wh .= 4.0 - layer.cell.b .= 0.0f0 - layer.cell.state0 .= 7.0 - x = [[2.0f0], [3.0f0]] - - # theoretical primal gradients - primal = - layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+ - x[2] .* layer.cell.Wi - ∇Wi = x[1] .* layer.cell.Wh .+ x[2] - ∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi - ∇b = layer.cell.Wh .+ 1 - ∇state0 = layer.cell.Wh .^ 2 - Flux.reset!(layer) - ps = Flux.params(layer) - e, g = Flux.withgradient(ps) do - out = [layer(xi) for xi in x] - sum(out[2]) - end - - @test primal[1] ≈ e - @test ∇Wi ≈ g[ps[1]] - @test ∇Wh ≈ g[ps[2]] - @test ∇b ≈ g[ps[3]] - @test ∇state0 ≈ g[ps[4]] - -end - -@testset "RNN gradients-explicit" begin - layer = Flux.Recur(Flux.RNNCell(1, 1, identity)) +@testset "RNN gradients" begin + layer = Flux.Recur(Flux.RNNCell(1 => 1, identity)) layer.cell.Wi .= 5.0f0 layer.cell.Wh .= 4.0f0 layer.cell.b .= 0.0f0 @@ -70,9 +39,10 @@ end for r ∈ [RNN,] rnn = r(2 => 3) Flux.reset!(rnn) - grads_seq = gradient(Flux.params(rnn)) do + grads_seq = gradient(rnn) do rnn sum([rnn(s) for s in seq][3]) - end + end[1] + Flux.reset!(rnn); bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh * tanh.(rnn.cell.Wi * seq[2] + Wh * @@ -82,7 +52,7 @@ end + rnn.cell.b) + rnn.cell.b)), rnn.cell.Wh) - @test grads_seq[rnn.cell.Wh] ≈ bptt[1] + @test grads_seq.cell.Wh ≈ bptt[1] end end @@ -92,9 +62,9 @@ end for r ∈ [RNN,] rnn = r(2 => 3) Flux.reset!(rnn) - grads_seq = gradient(Flux.params(rnn)) do + grads_seq = gradient(rnn) do rnn sum([rnn(s) for s in seq][3]) - end + end[1] Flux.reset!(rnn); bptt = gradient(Wh -> sum(tanh.(rnn.cell.Wi * seq[3] + Wh * tanh.(rnn.cell.Wi * seq[2] + Wh * @@ -104,7 +74,7 @@ end + rnn.cell.b) + rnn.cell.b)), rnn.cell.Wh) - @test grads_seq[rnn.cell.Wh] ≈ bptt[1] + @test grads_seq.cell.Wh ≈ bptt[1] end end @@ -112,9 +82,9 @@ end seq = rand(Float32, (2, 1, 3)) rnn = RNN(2 => 3) Flux.reset!(rnn) - grads_seq = gradient(Flux.params(rnn)) do + grads_seq = gradient(rnn) do rnn sum(rnn(seq)[:, :, 3]) - end + end[1] Flux.reset!(rnn); bptt = gradient(rnn.cell.Wh) do Wh # calculate state 1 @@ -131,26 +101,22 @@ end rnn.cell.b) sum(s3) # loss is sum of state 3 end - @test grads_seq[rnn.cell.Wh] ≈ bptt[1] + @test grads_seq.cell.Wh ≈ bptt[1] end @testset "RNN-shapes" begin @testset for R in [RNN, GRU, LSTM, GRUv3] m1 = R(3 => 5) m2 = R(3 => 5) - m3 = R(3, 5) # leave one to test the silently deprecated "," not "=>" notation x1 = rand(Float32, 3) x2 = rand(Float32, 3, 1) x3 = rand(Float32, 3, 1, 2) Flux.reset!(m1) Flux.reset!(m2) - Flux.reset!(m3) @test size(m1(x1)) == (5,) @test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape @test size(m2(x2)) == (5, 1) @test size(m2(x2)) == (5, 1) - @test size(m3(x3)) == (5, 1, 2) - @test size(m3(x3)) == (5, 1, 2) end end diff --git a/test/layers/show.jl b/test/layers/show.jl index 6910e5fa08..18e36201d1 100644 --- a/test/layers/show.jl +++ b/test/layers/show.jl @@ -1,46 +1,46 @@ @testset "layer printing" begin # 2-arg show, defined with layes - @test repr(Dense(2,3)) == "Dense(2 => 3)" - @test repr(Chain(Dense(2,3))) == "Chain(Dense(2 => 3))" - @test repr(Chain(lay=Dense(2,3))) == "Chain(lay = Dense(2 => 3))" + @test repr(Dense(2 => 3)) == "Dense(2 => 3)" + @test repr(Chain(Dense(2 => 3))) == "Chain(Dense(2 => 3))" + @test repr(Chain(lay=Dense(2 => 3))) == "Chain(lay = Dense(2 => 3))" end @testset "nested model printing" begin # 3-arg show, defined in show.jl # Dense -- has parameter count, but not when inside a matrix: - toplevel_dense = repr("text/plain", Dense(2,3)) + toplevel_dense = repr("text/plain", Dense(2 => 3)) @test occursin("Dense(2 => 3)", toplevel_dense) @test occursin("# 9 parameters", toplevel_dense) @test Meta.isexpr(Meta.parse(toplevel_dense), :call) # comment is ignored - vector_dense = repr("text/plain", [Dense(2,3), Dense(2,3)]) + vector_dense = repr("text/plain", [Dense(2 => 3), Dense(2 => 3)]) @test occursin("Dense(2 => 3)", vector_dense) @test occursin("# 9 parameters", vector_dense) - matrix_dense = repr("text/plain", fill(Dense(2,3), 3, 3)) + matrix_dense = repr("text/plain", fill(Dense(2 => 3), 3, 3)) @test occursin("Dense(2 => 3)", matrix_dense) @test !occursin("# 9 parameters", matrix_dense) - tuple_dense = repr("text/plain", tuple(Dense(2,3))) + tuple_dense = repr("text/plain", tuple(Dense(2 => 3))) @test occursin("Dense(2 => 3)", tuple_dense) @test !occursin("# 9 parameters", tuple_dense) # Chain -- gets split over lines at top level only - toplevel_chain = repr("text/plain", Chain(Dense(2,3))) + toplevel_chain = repr("text/plain", Chain(Dense(2 => 3))) @test occursin("Chain(\n Dense(2 => 3)", toplevel_chain) @test occursin("# 9 parameters", toplevel_chain) @test !occursin("# Total:", toplevel_chain) - vector_chain = repr("text/plain", [Chain(Dense(2,3)), Chain(Dense(2,3))]) + vector_chain = repr("text/plain", [Chain(Dense(2 => 3)), Chain(Dense(2 => 3))]) @test occursin("Chain(Dense(2 => 3))", vector_chain) @test occursin("# 9 parameters", vector_chain) @test !occursin("# Total:", vector_chain) - matrix_chain = repr("text/plain", fill(Chain(Dense(2,3)), 3,3)) + matrix_chain = repr("text/plain", fill(Chain(Dense(2 => 3)), 3,3)) @test occursin("Chain(Dense(2 => 3))", matrix_chain) @test !occursin("# 9 parameters", matrix_chain) @test !occursin("# Total:", matrix_chain) diff --git a/test/loading.jl b/test/loading.jl index 06bc412d31..c4de6055f5 100644 --- a/test/loading.jl +++ b/test/loading.jl @@ -16,12 +16,12 @@ end @testset "loadmodel!(dst, src)" begin - m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) - m2 = Chain(Dense(10, 5), Dense(5, 2)) - m3 = Chain(Conv((3, 3), 3 => 16), Dense(5, 2)) - m4 = Chain(Dense(10, 6), Dense(6, 2)) - m5 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5, 2))) - m6 = Chain(Dense(10, 5), Parallel(+, Dense(5, 2), Dense(5, 2))) + m1 = Chain(Dense(10 => 5), Dense(5 => 2, relu)) + m2 = Chain(Dense(10 => 5), Dense(5 => 2)) + m3 = Chain(Conv((3, 3), 3 => 16), Dense(5 => 2)) + m4 = Chain(Dense(10 => 6), Dense(6 => 2)) + m5 = Chain(Dense(10 => 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5 => 2))) + m6 = Chain(Dense(10 => 5), Parallel(+, Dense(5 => 2), Dense(5 => 2))) loadmodel!(m1, m2) # trainable parameters copy over @@ -73,7 +73,7 @@ end Dropout(0.2), x -> reshape(x, :, size(x, 4)), Dropout(0.2), - Dense(90, 10), + Dense(90 => 10), softmax) chain2 = Chain([Dropout(0.1), Conv((3, 3), 1 => 32, relu), @@ -88,7 +88,7 @@ end Dropout(0.1), x -> reshape(x, :, size(x, 4)), Dropout(0.1), - Dense(90, 10), + Dense(90 => 10), softmax]) chain2[3].μ .= 5f0 chain2[3].σ² .= 2f0 @@ -143,9 +143,9 @@ end @test_throws ErrorException loadmodel!(m1, m2) @testset "loadmodel! & filter" begin - m1 = Chain(Dense(10, 5), Dense(5, 2, relu)) - m2 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2)) - m3 = Chain(Dense(10, 5), Dense(5, 2, relu)) + m1 = Chain(Dense(10 => 5), Dense(5 => 2, relu)) + m2 = Chain(Dense(10 => 5), Dropout(0.2), Dense(5 => 2)) + m3 = Chain(Dense(10 => 5), Dense(5 => 2, relu)) # this will not error cause Dropout is skipped loadmodel!(m1, m2; filter = x -> !(x isa Dropout)) @@ -191,8 +191,8 @@ end end @testset "state" begin - m1 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5 => 2))) - m2 = Chain(Dense(10, 5), Parallel(+, Dense(Flux.zeros32(2, 5), Flux.ones32(2)), Dense(5 => 2))) + m1 = Chain(Dense(10 => 5), Parallel(+, Dense(Flux.ones32(2, 5), false), Dense(5 => 2))) + m2 = Chain(Dense(10 => 5), Parallel(+, Dense(Flux.zeros32(2, 5), Flux.ones32(2)), Dense(5 => 2))) s = Flux.state(m1) @test s isa NamedTuple @test fieldnames(typeof(s)) == (:layers,) @@ -217,7 +217,7 @@ end end @testset "track active state and batch norm params" begin - m3 = Chain(Dense(10, 5), Dropout(0.2), Dense(5, 2), BatchNorm(2)) + m3 = Chain(Dense(10 => 5), Dropout(0.2), Dense(5 => 2), BatchNorm(2)) trainmode!(m3) s = Flux.state(m3) @test s.layers[2].active == true diff --git a/test/losses.jl b/test/losses.jl index a5ce1139df..1d745c9d20 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -76,7 +76,7 @@ y_dis[1,:], y_dis[2,:] = y_dis[2,:], y_dis[1,:] @test crossentropy(ŷ, y_smoothed) ≈ lossvalue_smoothed @test crossentropy(ylp, label_smoothing(yl, 2sf)) ≈ -sum(yls.*log.(ylp)) @test crossentropy(ylp, yl) ≈ -sum(yl.*log.(ylp)) - @test iszero(crossentropy(y_same, ya, ϵ=0)) # ε is deprecated + @test iszero(crossentropy(y_same, ya, eps=0)) # ε is deprecated @test iszero(crossentropy(ya, ya, eps=0)) @test crossentropy(y_sim, ya) < crossentropy(y_sim, ya_smoothed) @test crossentropy(y_dis, ya) > crossentropy(y_dis, ya_smoothed) @@ -92,15 +92,15 @@ logŷ, y = randn(3), rand(3) yls = y.*(1-2sf).+sf @testset "binarycrossentropy" begin - @test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); ϵ=0) ≈ -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ)) + @test binarycrossentropy.(σ.(logŷ), label_smoothing(y, 2sf; dims=0); eps=0) ≈ -yls.*log.(σ.(logŷ)) - (1 .- yls).*log.(1 .- σ.(logŷ)) @test binarycrossentropy(σ.(logŷ), y; eps=0) ≈ mean(-y.*log.(σ.(logŷ)) - (1 .- y).*log.(1 .- σ.(logŷ))) @test binarycrossentropy(σ.(logŷ), y) ≈ mean(-y.*log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - (1 .- y).*log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))) @test binarycrossentropy([0.1,0.2,0.9], 1) ≈ -mean(log, [0.1,0.2,0.9]) # constant label end @testset "logitbinarycrossentropy" begin - @test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) ≈ binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); ϵ=0) - @test logitbinarycrossentropy(logŷ, y) ≈ binarycrossentropy(σ.(logŷ), y; ϵ=0) + @test logitbinarycrossentropy.(logŷ, label_smoothing(y, 0.2)) ≈ binarycrossentropy.(σ.(logŷ), label_smoothing(y, 0.2); eps=0) + @test logitbinarycrossentropy(logŷ, y) ≈ binarycrossentropy(σ.(logŷ), y; eps=0) end y = onehotbatch([1], 0:1) @@ -152,7 +152,7 @@ end @testset "tversky_loss" begin @test Flux.tversky_loss(ŷ, y) ≈ -0.06772009029345383 - @test Flux.tversky_loss(ŷ, y, β=0.8) ≈ -0.09490740740740744 + @test Flux.tversky_loss(ŷ, y, beta=0.8) ≈ -0.09490740740740744 @test Flux.tversky_loss(y, y) ≈ -0.5576923076923075 end @@ -180,7 +180,7 @@ end 0.4 0.7] @test Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 @test Flux.binary_focal_loss(ŷ1, y1) ≈ 0.05691642237852222 - @test Flux.binary_focal_loss(ŷ, y; γ=0.0) ≈ Flux.binarycrossentropy(ŷ, y) + @test Flux.binary_focal_loss(ŷ, y; gamma=0.0) ≈ Flux.binarycrossentropy(ŷ, y) end @testset "focal_loss" begin diff --git a/test/optimise.jl b/test/optimise.jl deleted file mode 100644 index c63ba85727..0000000000 --- a/test/optimise.jl +++ /dev/null @@ -1,222 +0,0 @@ -using Flux.Optimise -using Flux.Optimise: runall -using Flux: Params, gradient -import FillArrays, ComponentArrays -import Optimisers -using Test -using Random - -@testset "Optimise" begin - # Ensure rng has different state inside and outside the inner @testset - # so that w and w' are different - Random.seed!(84) - w = randn(10, 10) - @testset for opt in [AdamW(), AdaGrad(0.1), AdaMax(), AdaDelta(0.9), AMSGrad(), - NAdam(), RAdam(), Descent(0.1), Adam(), OAdam(), AdaBelief(), - Nesterov(), RMSProp(), Momentum()] - Random.seed!(42) - w′ = randn(10, 10) - b = false - loss(x) = Flux.Losses.mse(w*x, w′*x .+ b) - for t = 1: 10^5 - θ = params([w′, b]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) - end - @test loss(rand(10, 10)) < 0.01 - end -end - -@testset "Optimiser" begin - Random.seed!(84) - w = randn(10, 10) - @testset for Opt in [InvDecay, WeightDecay, ExpDecay, SignDecay] - Random.seed!(42) - w′ = randn(10, 10) - loss(x) = Flux.Losses.mse(w*x, w′*x) - opt = Optimiser(Opt(), Adam(0.001)) - for t = 1:10^5 - θ = Params([w′]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - Optimise.update!(opt, θ, θ̄) - end - @test loss(rand(10, 10)) < 0.01 - end -end - -@testset "Training Loop" begin - - # Test multiple callbacks - x = 0 - fs = [() -> (), () -> x = 1] - cbs = runall(fs) - cbs() - @test x == 1 - - r = rand(3, 3) - loss(x) = sum(x .* x) - Flux.train!(loss, Flux.params(r), (r,), Descent()) -end - -@testset "Stop on NaN" begin - m = Dense(1 => 1) - m.weight .= 0 - CNT = 0 - @test_throws DomainError Flux.train!(Flux.params(m), 1:100, Descent(0.1)) do i - CNT += 1 - (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) - end - @test CNT == 51 # stopped early - @test m.weight[1] ≈ -5 # did not corrupt weights -end - -@testset "ExpDecay" begin - - @testset "Sanity Check" begin - o = ExpDecay(0.2, 0.5, 1, 1e-3) - p = [0.0] - steps = 1:8 - eta_expected = @. max(o.eta * 0.5 ^ steps, o.clip) - eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] - @test eta_actual == eta_expected - end - - @testset "starting step" begin - start = 4 - o = ExpDecay(0.2, 0.5, 1, 1e-3, start) - p = [0.0] - steps = 1:8 - eta_expected = @. max(o.eta * 0.5 ^ max(steps - start, 0), o.clip) - eta_actual = [Optimise.apply!(o, p, [1.0])[1] for _ in steps] - @test eta_actual == eta_expected - end - - w = randn(10, 10) - o = ExpDecay(0.1, 0.1, 1000, 1e-4) - w1 = randn(10,10) - loss(x) = Flux.Losses.mse(w*x, w1*x) - flag = 1 - decay_steps = [] - for t = 1:10^5 - prev_eta = o.eta - θ = Params([w1]) - x = rand(10) - θ̄ = gradient(() -> loss(x), θ) - prev_grad = collect(θ̄[w1]) - delta = Optimise.apply!(o, w1, θ̄[w1]) - w1 .-= delta - new_eta = o.eta - if new_eta != prev_eta - push!(decay_steps, t) - end - array = fill(o.eta, size(prev_grad)) - if array .* prev_grad != delta - flag = 0 - end - end - @test flag == 1 - # Test to check if decay happens at decay steps. Eta reaches clip value (1e-4) after 4000 steps (decay by 0.1 every 1000 steps starting at 0.1). - ground_truth = [] - for i in 1:4 - push!(ground_truth, 1000*i) # Expected decay steps for this example. - end - @test decay_steps == ground_truth - @test o.eta == o.clip -end - -@testset "Clipping" begin - w = randn(10, 10) - loss(x) = sum(w * x) - θ = Params([w]) - x = 1000 * randn(10) - w̄ = gradient(() -> loss(x), θ)[w] - w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄)) - @test all(w̄_value .<= 1) - w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) - @test norm(w̄_norm) <= 1 -end - -@testset "update!: handle Fills from Zygote" begin - w = randn(10,10) - wold = copy(w) - g = FillArrays.Ones(size(w)) - opt = Descent(0.1) - Flux.update!(opt, w, g) - @test w ≈ wold .- 0.1 - - w = randn(3) - wold = copy(w) - θ = Flux.params([w]) - gs = gradient(() -> w[1], θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w[1] ≈ wold[1] .- 0.1 - @test w[2:3] ≈ wold[2:3] - - ## Issue #1510 - w = randn(10,10) - wold = copy(w) - θ = Flux.params([w]) - gs = gradient(() -> sum(w), θ) - opt = Descent(0.1) - Flux.update!(opt, θ, gs) - @test w ≈ wold .- 0.1 -end - -@testset "update!: handle ComponentArrays" begin - w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) - wold = deepcopy(w) - opt_state = Optimisers.setup(Optimisers.Descent(0.1), w) - gs = gradient(w -> w.a + sum(w.c.b), w)[1] - Flux.update!(opt_state, w, gs) - @test w.a ≈ wold.a - 0.1 - @test w.b ≈ wold.b - @test w.c.b ≈ wold.c.b .- 0.1 - @test w.c.a ≈ wold.c.a - - w = ComponentArrays.ComponentArray(a=1.0, b=[2, 1, 4], c=(a=2, b=[1, 2])) - wold = deepcopy(w) - opt_state = Optimisers.setup(Optimisers.Descent(0.1), w) - gs = gradient(w -> sum(w), w)[1] - Flux.update!(opt_state, w, gs) - @test w ≈ wold .- 0.1 -end - -# Flux PR #1776 -# We need to test that optimisers like Adam that maintain an internal momentum -# estimate properly calculate the second-order statistics on the gradients as -# the flow backward through the model. Previously, we would calculate second- -# order statistics via `Δ^2` rather than the complex-aware `Δ * conj(Δ)`, which -# wreaks all sorts of havoc on our training loops. This test ensures that -# a simple optimization is montonically decreasing (up to learning step effects) -@testset "Momentum Optimisers and complex values" begin - # Test every optimiser that has momentum internally - for opt_ctor in [Adam, RMSProp, RAdam, OAdam, AdaGrad, AdaDelta, NAdam, AdaBelief] - # Our "model" is just a complex number - w = zeros(ComplexF32, 1) - - # Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x` - function loss() - # Deterministic training data is the best training data - x = ones(1, 1) + 1im*ones(1, 1) - - # Manually implement `mse()` to allow demonstration of brokenness - # on older Flux builds that don't have a fixed `mse()` - return sum(abs2.(w * x .- conj(x))) - end - - params = Flux.Params([w]) - opt = opt_ctor(1e-2) - - # Train for 10 iterations, enforcing that loss is monotonically decreasing - last_loss = Inf - for idx in 1:10 - grads = Flux.gradient(loss, params) - @test loss() < last_loss - last_loss = loss() - Flux.update!(opt, params, grads) - end - end -end diff --git a/test/outputsize.jl b/test/outputsize.jl index 0eab572eb7..fdf19127a5 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -2,16 +2,16 @@ m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) @test outputsize(m, (10, 10, 3, 1)) == (6, 6, 32, 1) - m = Dense(10, 5) + m = Dense(10 => 5) @test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1) @test outputsize(m, (10,); padbatch=true) == (5, 1) - m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) + m = Chain(Dense(10 => 8, σ), Dense(8 => 5), Dense(5 => 2)) @test outputsize(m, (10,); padbatch=true) == (2, 1) @test outputsize(m, (10, 30)) == (2, 30) @info "Don't mind the following error, it's for testing purpose." - m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) + m = Chain(Dense(10 => 8, σ), Dense(8 => 4), Dense(5 => 2)) @test_throws DimensionMismatch outputsize(m, (10,)) m = Flux.Scale(10) @@ -26,11 +26,11 @@ m = Flux.unsqueeze(dims=3) @test outputsize(m, (5, 7, 13)) == (5, 7, 1, 13) - m = Flux.Bilinear(10, 10, 7) + m = Flux.Bilinear((10, 10) => 7) @test outputsize(m, (10,)) == (7,) @test outputsize(m, (10, 32)) == (7, 32) - m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) + m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024 => 10)) @test outputsize(m, (10, 10, 3, 50)) == (10, 50) @test outputsize(m, (10, 10, 3, 2)) == (10, 2) @@ -42,13 +42,13 @@ end @testset "multiple inputs" begin - m = Parallel(vcat, Dense(2, 4, relu), Dense(3, 6, relu)) + m = Parallel(vcat, Dense(2 => 4, relu), Dense(3 => 6, relu)) @test outputsize(m, (2,), (3,)) == (10,) @test outputsize(m, ((2,), (3,))) == (10,) @test outputsize(m, (2,), (3,); padbatch=true) == (10, 1) @test outputsize(m, (2,7), (3,7)) == (10, 7) - m = Chain(m, Dense(10, 13, tanh), softmax) + m = Chain(m, Dense(10 => 13, tanh), softmax) @test outputsize(m, (2,), (3,)) == (13,) @test outputsize(m, ((2,), (3,))) == (13,) @test outputsize(m, (2,), (3,); padbatch=true) == (13, 1) @@ -60,7 +60,7 @@ end leakyrelu, lisht, logcosh, logσ, mish, relu, relu6, rrelu, selu, σ, softplus, softshrink, softsign, swish, tanhshrink, trelu] - @test outputsize(Dense(10, 5, f), (10, 1)) == (5, 1) + @test outputsize(Dense(10 => 5, f), (10, 1)) == (5, 1) end end @@ -168,7 +168,7 @@ end m = @autosize (3,) Dense(_ => 4) @test randn(3) |> m |> size == (4,) - m = @autosize (3, 1) Chain(Dense(_, 4), Dense(4 => 10), softmax) + m = @autosize (3, 1) Chain(Dense(_ => 4), Dense(4 => 10), softmax) @test randn(3, 5) |> m |> size == (10, 5) m = @autosize (2, 3, 4, 5) Dense(_ => 10) # goes by first dim, not 2nd-last @@ -249,7 +249,6 @@ end @test string(ld) == "LazyLayer(Dense(2 => 3, relu))" @test Flux.striplazy(ld) isa Dense - @test_throws Exception Flux.params(lm) @test_throws Exception gradient(x -> sum(abs2, lm(x)), [1,2]) @test_throws Exception gradient(m -> sum(abs2, Flux.striplazy(m)([1,2])), ld) diff --git a/test/runtests.jl b/test/runtests.jl index 39013b84e6..10eb0c8481 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,5 @@ using Flux using Flux: OneHotArray, OneHotMatrix, OneHotVector -using Flux: params using Test using Random, Statistics, LinearAlgebra using IterTools: ncycle @@ -8,7 +7,7 @@ using Zygote # ENV["FLUX_TEST_AMDGPU"] = "true" # ENV["FLUX_TEST_CUDA"] = "true" -# ENV["FLUX_TEST_METAL"] = "true" +ENV["FLUX_TEST_METAL"] = "true" # ENV["FLUX_TEST_CPU"] = "false" include("test_utils.jl") @@ -17,25 +16,30 @@ Random.seed!(0) @testset verbose=true "Flux.jl" begin if get(ENV, "FLUX_TEST_CPU", "true") == "true" + @testset "Utils" begin + @info "testing Utils" include("utils.jl") end @testset "Loading" begin + @info "testing Loading" include("loading.jl") end - @testset "Optimise / Train" begin - include("optimise.jl") + @testset "Train" begin + @info "testing Train" include("train.jl") include("tracker.jl") end @testset "Data" begin + @info "testing Data" include("data.jl") end @testset "Losses" begin + @info "testing Losses" include("losses.jl") include("ctc.jl") end @@ -53,11 +57,13 @@ Random.seed!(0) end @testset "outputsize" begin + @info "testing outputsize" using Flux: outputsize include("outputsize.jl") end @testset "functors" begin + @info "testing functors" include("functors.jl") end @@ -118,6 +124,7 @@ Random.seed!(0) end @testset "Enzyme" begin + @info "testing Enzyme" import Enzyme include("ext_enzyme/enzyme.jl") end diff --git a/test/train.jl b/test/train.jl index 1d938649d0..a921c79c62 100644 --- a/test/train.jl +++ b/test/train.jl @@ -17,8 +17,8 @@ using Random model = (weight=copy(w2), bias=zeros(10), ignore=nothing) @test loss(model, rand(10, 10)) > 1 - opt = Flux.setup(rule, model) - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + opt_state = Flux.setup(rule, model) + Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt_state) @test loss(model, rand(10, 10)) < 0.01 end @@ -39,7 +39,7 @@ end CNT = 0 @test_throws DomainError Flux.train!(m1, tuple.(1:100), Descent(0.1)) do m, i CNT += 1 - (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) + (i == 51 ? NaN32 : 1f0) * sum(m([1f0])) end @test CNT == 51 # stopped early @test m1.weight[1] ≈ -5 # did not corrupt weights @@ -54,49 +54,8 @@ end Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end - - @testset "callbacks give helpful error" begin - m1 = Dense(1 => 1) - cb = () -> println("this should not be printed") - @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) - end end -@testset "Explicit Flux.update! features" begin - m = Chain(Dense(2=>3, tanh), Dense(3=>1), only) - x = rand(2) - y1 = m(x) # before - - # Implicit gradient - gold = gradient(() -> m(x), Flux.params(m)) - @test gold isa Flux.Zygote.Grads - @test_throws ErrorException Flux.update!(Flux.Adam(), m, gold) # friendly - Flux.update!(Flux.Adam(), Flux.params(m), gold) - y2 = m(x) - @test y2 < y1 - - # Explicit gradient - gs = gradient(marg -> marg(x), m) - @test gs isa Tuple - @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly - @test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly - @test_throws ErrorException Flux.update!(Flux.Adam(), m, gs) # friendly - @test_throws ErrorException Flux.update!(Flux.Adam(), m, gs[1]) # friendly - s = Flux.setup(Adam(), m) - @info "ignore this warning, just testing an upgrade path:" - Flux.update!(s, m, gs) # Chain + Tuple can be unambiguously sorted out - y3 = m(x) - @test y3 < y2 - Flux.update!(s, m, gs[1]) # finally, this is the correct thing - y4 = m(x) - @test y4 < y3 - - # Also check that if you import the new Adam, then Flux.setup does still work! - s2 = Flux.setup(Optimisers.Adam(), m) - Flux.update!(s2, m, gs[1]) - y5 = m(x) - @test y5 < y4 -end @testset "L2 regularisation" begin # New docs claim an exact equivalent. It's a bit long to put the example in there, @@ -115,14 +74,14 @@ end end diff1 = model.weight .- init_weight - # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! + # Take 2: the same, but with Flux.trainables. model.weight .= init_weight model.bias .= 0 pen2(x::AbstractArray) = sum(abs2, x)/2 opt = Flux.setup(Adam(0.1), model) Flux.train!(model, data, opt) do m, x, y err = Flux.mse(m(x), y) - l2 = sum(pen2, Flux.params(m)) + l2 = sum(pen2, Flux.trainables(m)) err + 0.33 * l2 end diff2 = model.weight .- init_weight @@ -143,6 +102,6 @@ end # https://github.com/FluxML/Flux.jl/issues/2144 @test Flux.setup(Flux.Adam(), Embedding(3 => 1)).weight isa Optimisers.Leaf # Typo in 0.13.9's deprecation - @test Flux.setup(Flux.ClipValue(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipGrad + @test Flux.setup(Flux.ClipGrad(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipGrad @test Flux.setup(Flux.ClipNorm(1), Dense(2 => 3)).weight.rule isa Optimisers.ClipNorm end diff --git a/test/utils.jl b/test/utils.jl index e175eb1f5b..17d266e602 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -220,7 +220,7 @@ end @test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3) end @testset "Dense ID mapping" begin - l = Dense(3,3, init = identity_init) + l = Dense(3 => 3, init = identity_init) indata = reshape(collect(Float32, 1:9), 3, 3) @test l(indata) == indata @@ -250,42 +250,37 @@ end end end -@testset "Params" begin - m = Dense(10, 5) - @test size.(params(m)) == [(5, 10), (5,)] - m = RNN(10, 5) - @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)] +@testset "Trainables" begin + m = Dense(10 => 5) + @test size.(Flux.trainables(m)) == [(5, 10), (5,)] + m = RNN(10 => 5) + @test size.(Flux.trainables(m)) == [(5, 10), (5, 5), (5,), (5, 1)] # Layer duplicated in same chain, params just once pls. c = Chain(m, m) - @test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)] + @test size.(Flux.trainables(c)) == [(5, 10), (5, 5), (5,), (5, 1)] # Self-referential array. Just want params, no stack overflow pls. - r = Any[nothing,m] + r = Any[nothing, m] r[1] = r - @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)] + @test_broken size.(Flux.trainables(r)) == [(5, 10), (5, 5), (5,), (5, 1)] # Ensure functor explores inside Transpose but not SubArray m = (x = view([1,2,3]pi, 1:2), y = transpose([4 5]pi)) - @test size.(Flux.params(m)) == [(2,), (1, 2)] + @test size.(Flux.trainables(m)) == [(2,), (1, 2)] end -@testset "params gradient" begin +@testset "trainables gradient" begin m = (x=[1,2.0], y=[3.0]); # Explicit -- was broken by #2054 - gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] + gnew = gradient(m -> (sum(norm, Flux.trainables(m))), m)[1] @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] @test gnew.y ≈ [1.0] - - # Implicit - gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m)) - @test gold[m.x] ≈ [0.4472135954999579, 0.8944271909999159] - @test gold[m.y] ≈ [1.0] end @testset "Precision" begin - m = Chain(Dense(10, 5, relu; bias=false), Dense(5, 2)) + m = Chain(Dense(10 => 5, relu; bias=false), Dense(5 => 2)) x64 = rand(Float64, 10) x32 = rand(Float32, 10) i64 = rand(Int64, 10) @@ -340,28 +335,12 @@ end o = ones(s) z = zeros(s) - @testset "Explicit" begin - gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...) - g = gfun(o, z) - @test gfun(o, false) == (g[1], nothing) - - g = gfun(z, o) - @test gfun(false, o) == (nothing, g[2]) - end - - @testset "Implicit" begin - gfun(args...) = gradient(() -> sum(op.(args...)), params(collect(args))) - g = gfun(o, z) - - gres = gfun(o, false) - @test gres[o] == g[o] - @test false ∉ gres.params + gfun(args...) = gradient((x, y) -> sum(op.(x,y)), args...) + g = gfun(o, z) + @test gfun(o, false) == (g[1], nothing) - g = gfun(z, o) - gres = gfun(false, o) - @test gres[o] == g[o] - @test false ∉ gres.params - end + g = gfun(z, o) + @test gfun(false, o) == (nothing, g[2]) end end @@ -466,10 +445,10 @@ end @test modules[5] === m2 @test modules[6] === m3 - mod_par = Flux.modules(Parallel(Flux.Bilinear(2,2,2,cbrt), Dense(2,2,abs), Dense(2,2,abs2))) + mod_par = Flux.modules(Parallel(Flux.Bilinear((2,2) => 2,cbrt), Dense(2 => 2,abs), Dense(2 => 2,abs2))) @test length(mod_par) == 5 - mod_rnn = Flux.modules(Chain(Dense(2,3), BatchNorm(3), LSTM(3,4))) + mod_rnn = Flux.modules(Chain(Dense(2 => 3), BatchNorm(3), LSTM(3 => 4))) @test length(mod_rnn) == 6 @test mod_rnn[end] isa Flux.LSTMCell @@ -561,10 +540,10 @@ end @testset "Shared parameters" begin mat = [1 2; 3 4.0] simple = ((nothing, mat, (3, mat, 4))) - @test length(Flux.params(simple)) == 1 + @test length(Flux.trainables(simple)) == 1 oneadj = (nt = (m = mat, a = mat')) - @test length(Flux.params(oneadj)) == 1 # needs Functors@0.3 + @test length(Flux.trainables(oneadj)) == 1 # needs Functors@0.3 @test Flux.destructure(simple)[1] == Flux.destructure(oneadj)[1] == [1, 3, 2, 4] end @@ -583,8 +562,8 @@ end end model = TwoDenses( - Dense(3,1), - Dense(3,2) + Dense(3 => 1), + Dense(3 => 2) ) p, re = Flux.destructure(model) @@ -619,20 +598,18 @@ end Flux.@layer Model (m::Model)(x) = m.a(x) .+ m.b(x) - d = Dense(1, 1) + d = Dense(1 => 1) x = rand(Float32, 1, 1) # Sharing the parameters model = Model(d, d) - # Works - g1 = Flux.gradient(() -> sum(model(x)), Flux.params(model)) + g1 = Flux.gradient(model -> sum(model(x)), model)[1] p, re = Flux.destructure(model) - # Fails - g2 = Flux.gradient(p -> sum(re(p)(x)), p) + g2 = Flux.gradient(p -> sum(re(p)(x)), p)[1] - @test g2[1] ≈ vcat(g1[d.weight], g1[d.bias]) + @test g2 ≈ vcat(vec(g1.a.weight) + vec(g1.b.weight), g1.a.bias + g1.b.bias) end @testset "issue 1826" begin @@ -648,8 +625,8 @@ end data = rand(Float32, n_input, n_batch) model = Chain( - Dense(n_input, n_shared), - Split(Dense(n_shared, n_outputs[1]), Dense(n_shared, n_outputs[2])) + Dense(n_input => n_shared), + Split(Dense(n_shared => n_outputs[1]), Dense(n_shared => n_outputs[2])) ) pvec, re = Flux.destructure(model) @@ -662,6 +639,6 @@ end # make sure rng_from_array is non_differentiable @testset "rng_from_array" begin - m(x) = (rand(rng_from_array(x)) * x)[1] + m(x) = (rand(Flux.rng_from_array(x)) * x)[1] gradient(m, ones(2)) end