diff --git a/Project.toml b/Project.toml index 401403c6..c93bdb4b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -22,6 +23,7 @@ ComputationalResources = "0.3.2" Flux = "0.14" MLJModelInterface = "1.1.1" Metalhead = "0.9.3" +Optimisers = "0.3.2" ProgressMeter = "1.7.1" StatisticalMeasures = "0.1" Statistics = "<0.0.1, 1" @@ -30,7 +32,6 @@ julia = "1.9" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -38,6 +39,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [targets] test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"] diff --git a/README.md b/README.md index ee51cada..1879b283 100644 --- a/README.md +++ b/README.md @@ -250,15 +250,6 @@ to builders for the purposes of weight initialization. This can be any `AbstractRNG` or the seed (integer) for a `MersenneTwister` that will be reset on every cold restart of model (machine) training. -Until there is a [mechanism for -doing so](https://github.com/FluxML/Flux.jl/issues/1617) `rng` is *not* -passed to dropout layers and one must manually seed the `GLOBAL_RNG` -for reproducibility purposes, when using a builder that includes -`Dropout` (such as `MLJFlux.Short`). If training models on a -GPU (i.e., `acceleration isa CUDALibs`) one must additionally call -`CUDA.seed!(...)`. - - ### Built-in builders The following builders are provided out-of-the-box. Query their diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 5091d798..1445ee78 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -1,4 +1,4 @@ -module MLJFlux +module MLJFlux export CUDALibs, CPU1 @@ -14,11 +14,11 @@ using ColorTypes using ComputationalResources using Random import Metalhead +import Optimisers include("utilities.jl") const MMI=MLJModelInterface -include("penalizers.jl") include("builders.jl") include("metalhead.jl") include("types.jl") @@ -32,6 +32,7 @@ export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor export NeuralNetworkClassifier, ImageClassifier export CUDALibs, CPU1 +include("deprecated.jl") end #module diff --git a/src/builders.jl b/src/builders.jl index b106058a..742f3686 100644 --- a/src/builders.jl +++ b/src/builders.jl @@ -14,13 +14,12 @@ abstract type Builder <: MLJModelInterface.MLJType end """ - Linear(; σ=Flux.relu, rng=Random.GLOBAL_RNG) + Linear(; σ=Flux.relu) -MLJFlux builder that constructs a fully connected two layer network -with activation function `σ`. The number of input and output nodes is -determined from the data. The bias and coefficients are initialized -using `Flux.glorot_uniform(rng)`. If `rng` is an integer, it is -instead used as the seed for a `MersenneTwister`. +MLJFlux builder that constructs a fully connected two layer network with activation +function `σ`. The number of input and output nodes is determined from the data. Weights +are initialized using `Flux.glorot_uniform(rng)`, where `rng` is inferred from the `rng` +field of the MLJFlux model. """ mutable struct Linear <: Builder @@ -31,7 +30,7 @@ build(builder::Linear, rng, n::Integer, m::Integer) = Flux.Chain(Flux.Dense(n, m, builder.σ, init=Flux.glorot_uniform(rng))) """ - Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid, rng=GLOBAL_RNG) + Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid) MLJFlux builder that constructs a full-connected three-layer network using `n_hidden` nodes in the hidden layer and the specified `dropout` @@ -40,9 +39,8 @@ hidden and final layers. If `n_hidden=0` (the default) then `n_hidden` is the geometric mean of the number of input and output nodes. The number of input and output nodes is determined from the data. -The each layer is initialized using `Flux.glorot_uniform(rng)`. If -`rng` is an integer, it is instead used as the seed for a -`MersenneTwister`. +Each layer is initialized using `Flux.glorot_uniform(rng)`, where `rng` is inferred from +the `rng` field of the MLJFlux model. """ mutable struct Short <: Builder @@ -57,22 +55,19 @@ function build(builder::Short, rng, n, m) init=Flux.glorot_uniform(rng) Flux.Chain( Flux.Dense(n, n_hidden, builder.σ, init=init), - # TODO: fix next after https://github.com/FluxML/Flux.jl/issues/1617 - Flux.Dropout(builder.dropout), + Flux.Dropout(builder.dropout; rng), Flux.Dense(n_hidden, m, init=init)) end """ - MLP(; hidden=(100,), σ=Flux.relu, rng=GLOBAL_RNG) + MLP(; hidden=(100,), σ=Flux.relu) -MLJFlux builder that constructs a Multi-layer perceptron network. The -ith element of `hidden` represents the number of neurons in the ith -hidden layer. An activation function `σ` is applied between each -layer. +MLJFlux builder that constructs a Multi-layer perceptron network. The ith element of +`hidden` represents the number of neurons in the ith hidden layer. An activation function +`σ` is applied between each layer. -The each layer is initialized using `Flux.glorot_uniform(rng)`. If -`rng` is an integer, it is instead used as the seed for a -`MersenneTwister`. +Each layer is initialized using `Flux.glorot_uniform(rng)`, where `rng` is inferred from +the `rng` field of the MLJFlux model. """ mutable struct MLP{N} <: MLJFlux.Builder @@ -110,6 +105,7 @@ Creates a builder for `neural_net`. The variables `rng`, `n_in`, `n_out` and input and output sizes `n_in` and `n_out` and number of input channels `n_channels`. # Examples + ```jldoctest julia> import MLJFlux: @builder; @@ -132,4 +128,5 @@ macro builder(ex) end) end -build(b::GenericBuilder, rng, n_in, n_out, n_channels = 1) = b.apply(rng, n_in, n_out, n_channels) +build(b::GenericBuilder, rng, n_in, n_out, n_channels = 1) = + b.apply(rng, n_in, n_out, n_channels) diff --git a/src/core.jl b/src/core.jl index cca5a145..bd49933b 100644 --- a/src/core.jl +++ b/src/core.jl @@ -2,7 +2,7 @@ # make the optimiser structs "transparent" so that their field values # are exposed by calls to MLJ.params: -MLJModelInterface.istransparent(m::Flux.Optimise.AbstractOptimiser) = true +MLJModelInterface.istransparent(m::Optimisers.AbstractRule) = true ## GENERAL METHOD TO OPTIMIZE A CHAIN @@ -15,47 +15,71 @@ end (::Mover{<:CUDALibs})(data) = Flux.gpu(data) """ - train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) - -A private method that can be overloaded for custom models. + train_epoch( + model, + chain, + optimiser, + optimiser_state, + X, + y, + ) -> updated_chain, updated_optimiser_state, training_loss Update the parameters of a Flux `chain`, where: +- `model` is typically an `MLJFluxModel` instance, but could be any object such that + `model.loss` is a Flux.jl loss function. + - the loss function `(yhat, y) -> loss(yhat, y)` is inferred from the `model` -- `params -> penalty(params)` is a regularization penalty function - - `X` and `y` are vectors of batches of the training data, as detailed - in the [`MLJFlux.fit!`](@ref) document string. + in the [`MLJFlux.train`](@ref) document string. """ -function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) +function train_epoch( + model, + chain, + optimiser, + optimiser_state, + X, + y, + ) + loss = model.loss n_batches = length(y) training_loss = zero(Float32) + for i in 1:n_batches - parameters = Flux.params(chain) - gs = Flux.gradient(parameters) do - yhat = chain(X[i]) - batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches - training_loss += batch_loss - return batch_loss + batch_loss, gs = Flux.withgradient(chain) do m + yhat = m(X[i]) + loss(yhat, y[i]) end - Flux.update!(optimiser, parameters, gs) + training_loss += batch_loss + # The `do` syntax above means `gs` is a tuple of length one we need to unwrap to + # get the actual gradient: + ∇ = first(gs) + optimiser_state, chain = Optimisers.update(optimiser_state, chain, ∇) end - return training_loss / n_batches + + return chain, optimiser_state, training_loss / n_batches end """ - fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y) - -A private method that can be overloaded for custom models. - -Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is -the loss function inferred from the `model`, and `parameters -> penalty(parameters)` is the -regularization penalty function. + train( + model, + chain, + optimiser, + optimiser_state, + epochs, + verbosity, + X, + y, + ) -> (updated_chain, updated_optimiser_state, history) + +Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is the loss function +inferred from the `model`. Typically, `model` will be an `MLJFluxModel` instance, but it +could be any object such that `model.loss` is a Flux.jl loss function. Here `chain` is a `Flux.Chain` object, or other Flux model such that `Flux.params(chain)` returns the parameters to be optimized. @@ -76,17 +100,26 @@ batches. Specifically, it is expected that: total number of training batches. Both the `chain` and the data `(X, y)` must both live on a CPU or both -live on a GPU. This `fit!` method takes no responsibility for data +live on a GPU. This `train` method takes no responsibility for data movement. -### Return value +# Return value -`(chain_trained, history)`, where `chain_trained` is a trained version -of `chain` and `history` is a vector of penalized losses - one initial -loss, and one loss per epoch. +Returns `(updated_chain, updated_optimiser_state, history)`, where `updated_chain` is a +trained version of `chain` and `history` is a vector of losses, including the +initial (no-train) loss. """ -function fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y) +function train( + model, + chain, + optimiser, + optimiser_state, + epochs, + verbosity, + X, + y, + ) loss = model.loss @@ -98,20 +131,25 @@ function fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, ve # initiate history: n_batches = length(y) - parameters = Flux.params(chain) - losses = (loss(chain(X[i]), y[i]) + - penalty(parameters) / n_batches for i in 1:n_batches) + losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches) history = [mean(losses),] for i in 1:epochs - current_loss = train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) + chain, optimiser_state, current_loss = train_epoch( + model, + chain, + optimiser, + optimiser_state, + X, + y, + ) verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" verbosity != 1 || next!(meter) push!(history, current_loss) end - return chain, history + return chain, optimiser_state, history end @@ -221,7 +259,9 @@ _get(X::AbstractArray{<:Any,4}, b) = X[:, :, :, b] """ collate(model, X, y) -Return the Flux-friendly data object required by `MLJFlux.fit!`, given +**Private method** + +Return the Flux-friendly data object required by `MLJFlux.train`, given input `X` and target `y` in the form required by `MLJModelInterface.input_scitype(X)` and `MLJModelInterface.target_scitype(y)`. (The batch size used is given diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 00000000..9b09cabd --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1,19 @@ +Base.@deprecate fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y) train( + model::MLJFlux.MLJFluxModel, + chain, + optimiser, + optimiser_state, + epochs, + verbosity, + X, + y, +) false + +Base.@deprecate train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) train_epoch( + model::MLJFlux.MLJFluxModel, + chain, + optimiser, + optimiser_state, + X, + y, +) false \ No newline at end of file diff --git a/src/metalhead.jl b/src/metalhead.jl index 29b2b6fb..05d1766d 100644 --- a/src/metalhead.jl +++ b/src/metalhead.jl @@ -130,7 +130,7 @@ function VGGHack( depth in keys(Metalhead.VGG_CONFIGS), "depth must be from one in $(sort(collect(keys(Metalhead.VGG_CONFIGS))))" ) - model = Metalhead.VGG(imsize; + model = Metalhead.vgg(imsize; config = Metalhead.VGG_CONFIGS[depth], inchannels, batchnorm, diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index 5ffe903f..f25c137e 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -30,10 +30,16 @@ function MLJModelInterface.clean!(model::MLJFluxModel) warning *= "`acceleration isa CUDALibs` "* "but no CUDA device (GPU) currently live. " end - if ! (model.acceleration isa CUDALibs || model.acceleration isa CPU1) + if !(model.acceleration isa CUDALibs || model.acceleration isa CPU1) warning *= "`Undefined acceleration, falling back to CPU`" model.acceleration = CPU1() end + if model.acceleration isa CUDALibs && model.rng isa Integer + warning *= "Specifying an RNG seed when "* + "`acceleration isa CUDALibs()` may fail for layers depending "* + "on an RNG during training, such as `Dropout`. Consider using "* + " `Random.default_rng()` instead. `" + end return warning end @@ -43,7 +49,38 @@ end const ERR_BUILDER = "Builder does not appear to build an architecture compatible with supplied data. " -true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng +true_rng(model) = model.rng isa Integer ? Random.Xoshiro(model.rng) : model.rng + +# Models implement L1/L2 regularization by chaining the chosen optimiser with weight/sign +# decay. Note that the weight/sign decay must be scaled down by the number of batches to +# ensure penalization over an epoch does not scale with the choice of batch size; see +# https://github.com/FluxML/MLJFlux.jl/issues/213. + +function regularized_optimiser(model, nbatches) + model.lambda == 0 && return model.optimiser + λ_L1 = model.alpha*model.lambda + λ_L2 = (1 - model.alpha)*model.lambda + λ_sign = λ_L1/nbatches + λ_weight = 2*λ_L2/nbatches + + # recall components in an optimiser chain are executed from left to right: + if model.alpha == 0 + return Optimisers.OptimiserChain( + Optimisers.WeightDecay(λ_weight), + model.optimiser, + ) + elseif model.alpha == 1 + return Optimisers.OptimiserChain( + Optimisers.SignDecay(λ_sign), + model.optimiser, + ) + else return Optimisers.OptimiserChain( + Optimisers.SignDecay(λ_sign), + Optimisers.WeightDecay(λ_weight), + model.optimiser, + ) + end +end function MLJModelInterface.fit(model::MLJFluxModel, verbosity, @@ -62,9 +99,7 @@ function MLJModelInterface.fit(model::MLJFluxModel, rethrow() end - penalty = Penalty(model) data = move.(collate(model, X, y)) - x = data[1][1] try @@ -74,26 +109,31 @@ function MLJModelInterface.fit(model::MLJFluxModel, throw(ex) end - optimiser = deepcopy(model.optimiser) - - chain, history = fit!(model, - penalty, - chain, - optimiser, - model.epochs, - verbosity, - data[1], - data[2]) - - # `optimiser` is now mutated - - cache = (deepcopy(model), - data, - history, - shape, - optimiser, - deepcopy(rng), - move) + nbatches = length(data[2]) + regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) + optimiser_state = Optimisers.setup(regularized_optimiser, chain) + + chain, optimiser_state, history = train( + model, + chain, + regularized_optimiser, + optimiser_state, + model.epochs, + verbosity, + data[1], + data[2], + ) + + cache = ( + deepcopy(model), + data, + history, + shape, + regularized_optimiser, + optimiser_state, + deepcopy(rng), + move, + ) fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) report = (training_losses=history, ) @@ -108,7 +148,8 @@ function MLJModelInterface.update(model::MLJFluxModel, X, y) - old_model, data, old_history, shape, optimiser, rng, move = old_cache + old_model, data, old_history, shape, regularized_optimiser, + optimiser_state, rng, move = old_cache old_chain = old_fitresult[1] optimiser_flag = model.optimiser_changes_trigger_retraining && @@ -120,46 +161,45 @@ function MLJModelInterface.update(model::MLJFluxModel, if keep_chain chain = move(old_chain) epochs = model.epochs - old_model.epochs + # (`optimiser_state` is not reset) else move = Mover(model.acceleration) rng = true_rng(model) chain = build(model, rng, shape) |> move + # reset `optimiser_state`: data = move.(collate(model, X, y)) + nbatches = length(data[2]) + regularized_optimiser = MLJFlux.regularized_optimiser(model, nbatches) + optimiser_state = Optimisers.setup(regularized_optimiser, chain) epochs = model.epochs end - penalty = Penalty(model) - - # we only get to keep the optimiser "state" carried over from - # previous training if we're doing a warm restart and the user has not - # changed the optimiser hyper-parameter: - if !keep_chain || - !MLJModelInterface._equal_to_depth_one(model.optimiser, - old_model.optimiser) - optimiser = deepcopy(model.optimiser) - end - - chain, history = fit!(model, - penalty, - chain, - optimiser, - epochs, - verbosity, - data[1], - data[2]) + chain, optimiser_state, history = train( + model, + chain, + regularized_optimiser, + optimiser_state, + epochs, + verbosity, + data[1], + data[2], + ) if keep_chain # note: history[1] = old_history[end] history = vcat(old_history[1:end-1], history) end fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) - cache = (deepcopy(model), - data, - history, - shape, - optimiser, - deepcopy(rng), - move) + cache = ( + deepcopy(model), + data, + history, + shape, + regularized_optimiser, + optimiser_state, + deepcopy(rng), + move, + ) report = (training_losses=history, ) return fitresult, cache, report diff --git a/src/penalizers.jl b/src/penalizers.jl deleted file mode 100644 index ccea2d2d..00000000 --- a/src/penalizers.jl +++ /dev/null @@ -1,62 +0,0 @@ -# Note (1). See -# https://discourse.julialang.org/t/weight-regularisation-which-iterates-params-m-in-flux-mutating-arrays-is-not-supported/64314 - - -""" Penalizer(λ, α) - -Returns a callable object `penalizer` for evaluating regularization -penalties associated with some numerical array. Specifically, -`penalizer(A)` returns - - λ*(α*L1 + (1 - α)*L2), - -where `L1` is the sum of absolute values of the elments of `A` and -`L2` is the sum of squares of those elements. -""" -struct Penalizer{T} - lambda::T - alpha::T - function Penalizer(lambda, alpha) - lambda == 0 && return new{Nothing}(nothing, nothing) - T = promote_type(typeof.((lambda, alpha))...) - return new{T}(lambda, alpha) - end -end - -(::Penalizer{Nothing})(::Any) = 0 -function (p::Penalizer)(A) - λ = p.lambda - α = p.alpha - # avoiding broadcasting; see Note (1) above - L2 = sum(abs2, A) - L1 = sum(abs, A) - return λ*(α*L1 + (1 - α)*L2) -end - - - -""" - Penalty(model) - -Returns a callable object `p`, for returning the regularization -penalty `p(w)` associated with some collection of parameters `w`. For -example, `w = params(chain)` where `chain` is some Flux -model. Here `model` is an MLJFlux model ("model" in the MLJ -sense, not the Flux sense). Specifically, `p(w)` returns - - sum(Penalizer(λ, α).w) - -where `α = model.alpha`, `λ = model.lambda`. - -See also [`Penalizer`](@ref) - -""" -struct Penalty{P} - penalizer::P - function Penalty(model) - penalizer = Penalizer(model.lambda, model.alpha) - return new{typeof(penalizer)}(penalizer) - end -end -(p::Penalty{Penalizer{Nothing}})(w) = zero(Float32) -(p::Penalty)(w) = sum(p.penalizer(wt) for wt in w) diff --git a/src/regressor.jl b/src/regressor.jl index 9960c90b..222560b7 100644 --- a/src/regressor.jl +++ b/src/regressor.jl @@ -33,13 +33,16 @@ MLJModelInterface.metadata_model(NeuralNetworkRegressor, # # MULTITARGET NEURAL NETWORK REGRESSOR + ncols(X::AbstractMatrix) = size(X, 2) ncols(X) = Tables.columns(X) |> Tables.columnnames |> length """ shape(model::MultitargetNeuralNetworkRegressor, X, y) -A private method that returns the shape of the input and output of the model for given data `X` and `y`. +A private method that returns the shape of the input and output of the model for given +data `X` and `y`. + """ shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y)) @@ -49,7 +52,7 @@ build(model::MultitargetNeuralNetworkRegressor, rng, shape) = function fitresult(model::MultitargetNeuralNetworkRegressor, chain, y) if y isa Matrix target_column_names = nothing - else + else target_column_names = Tables.schema(y).names end return (chain, target_column_names) diff --git a/src/types.jl b/src/types.jl index c608abbf..45886171 100644 --- a/src/types.jl +++ b/src/types.jl @@ -23,10 +23,32 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` end - function $Model(; builder::B=$default_builder_ex, finaliser::F=Flux.softmax, optimiser::O=Flux.Optimise.Adam(), loss::L=Flux.crossentropy, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1() - ) where {B,F,O,L} - - model = $Model{B,F,O,L}(builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration + function $Model( + ;builder::B=$default_builder_ex, + finaliser::F=Flux.softmax, + optimiser::O=Optimisers.Adam(), + loss::L=Flux.crossentropy, + epochs=10, + batch_size=1, + lambda=0, + alpha=0, + rng=Random.default_rng(), + optimiser_changes_trigger_retraining=false, + acceleration=CPU1(), + ) where {B,F,O,L} + + model = $Model{B,F,O,L}( + builder, + finaliser, + optimiser, + loss, + epochs, + batch_size, + lambda, + alpha, + rng, + optimiser_changes_trigger_retraining, + acceleration, ) message = clean!(model) @@ -57,10 +79,31 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` end - function $Model(; builder::B=Linear(), optimiser::O=Flux.Optimise.Adam(), loss::L=Flux.mse, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1() + function $Model( + ; builder::B=Linear(), + optimiser::O=Optimisers.Adam(), + loss::L=Flux.mse, + epochs=10, + batch_size=1, + lambda=0, + alpha=0, + rng=Random.default_rng(), + optimiser_changes_trigger_retraining=false, + acceleration=CPU1(), ) where {B,O,L} - model = $Model{B,O,L}(builder, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration) + model = $Model{B,O,L}( + builder, + optimiser, + loss, + epochs, + batch_size, + lambda, + alpha, + rng, + optimiser_changes_trigger_retraining, + acceleration, + ) message = clean!(model) isempty(message) || @warn message @@ -126,11 +169,10 @@ Train the machine with `fit!(mach, rows=...)`. MLJFlux.jl documentation for examples of user-defined builders. See also `finaliser` below. -- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the - updating of the weights of the network. For further reference, see [the Flux optimiser - documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a - learning rate (the update rate of the optimizer), a good rule of thumb is to start out - at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `optimiser::Optimisers.Adam()`: An Optimisers.jl optimiser. The optimiser performs the + updating of the weights of the network. To choose a learning rate (the update rate of + the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers + of 10 between `1` and `1e-7`. - `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are @@ -163,13 +205,13 @@ Train the machine with `fit!(mach, rows=...)`. GPU is available. - `lambda::Float64=0`: The strength of the weight regularization penalty. Can be any value - in the range `[0, ∞)`. + in the range `[0, ∞)`. Note the history reports unpenalized losses. - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - `rng::Union{AbstractRNG, Int64}`: The random number generator or seed used during - training. + training. The default is `Random.default_rng()`. - `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when re-fitting a machine if the associated optimiser has changed. If `true`, the associated machine @@ -316,11 +358,10 @@ Train the machine with `fit!(mach, rows=...)`. below for a user-specified builder. A convenience macro `@builder` is also available. See also `finaliser` below. -- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the - updating of the weights of the network. For further reference, see [the Flux optimiser - documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a - learning rate (the update rate of the optimizer), a good rule of thumb is to start out - at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `optimiser::Optimisers.Adam()`: An Optimisers.jl optimiser. The optimiser performs the + updating of the weights of the network. To choose a learning rate (the update rate of + the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers + of 10 between `1` and `1e-7`. - `loss=Flux.crossentropy`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are @@ -353,13 +394,13 @@ Train the machine with `fit!(mach, rows=...)`. GPU is available. - `lambda::Float64=0`: The strength of the weight regularization penalty. Can be any value - in the range `[0, ∞)`. + in the range `[0, ∞)`. Note the history reports unpenalized losses. - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - `rng::Union{AbstractRNG, Int64}`: The random number generator or seed used during - training. + training. The default is `Random.default_rng()`. - `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when re-fitting a machine if the associated optimiser has changed. If `true`, the associated machine @@ -560,11 +601,10 @@ Train the machine with `fit!(mach, rows=...)`. `MLJFlux.MLP`. See MLJFlux documentation for more on builders, and the example below for using the `@builder` convenience macro. -- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the - updating of the weights of the network. For further reference, see [the Flux optimiser - documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a - learning rate (the update rate of the optimizer), a good rule of thumb is to start out - at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `optimiser::Optimisers.Adam()`: An Optimisers.jl optimiser. The optimiser performs the + updating of the weights of the network. To choose a learning rate (the update rate of + the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers + of 10 between `1` and `1e-7`. - `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in @@ -590,13 +630,13 @@ Train the machine with `fit!(mach, rows=...)`. GPU is available. - `lambda::Float64=0`: The strength of the weight regularization penalty. Can be any value - in the range `[0, ∞)`. + in the range `[0, ∞)`. Note the history reports unpenalized losses. - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - `rng::Union{AbstractRNG, Int64}`: The random number generator or seed used during - training. + training. The default is `Random.default_rng()`. - `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when re-fitting a machine if the associated optimiser has changed. If `true`, the associated machine @@ -772,11 +812,15 @@ In MLJ or MLJBase, bind an instance `model` to data with Here: -- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose columns are of scitype - `Continuous`; check column scitypes with `schema(X)`. If `X` is a `Matrix`, it is assumed to have columns corresponding to features and rows corresponding to observations. +- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose + columns are of scitype `Continuous`; check column scitypes with `schema(X)`. If `X` is a + `Matrix`, it is assumed to have columns corresponding to features and rows corresponding + to observations. -- `y` is the target, which can be any table or matrix of output targets whose element scitype is - `Continuous`; check column scitypes with `schema(y)`. If `y` is a `Matrix`, it is assumed to have columns corresponding to variables and rows corresponding to observations. +- `y` is the target, which can be any table or matrix of output targets whose element + scitype is `Continuous`; check column scitypes with `schema(y)`. If `y` is a `Matrix`, + it is assumed to have columns corresponding to variables and rows corresponding to + observations. # Hyper-parameters @@ -786,11 +830,10 @@ Here: documentation for more on builders, and the example below for using the `@builder` convenience macro. -- `optimiser::Flux.Adam()`: A `Flux.Optimise` optimiser. The optimiser performs the - updating of the weights of the network. For further reference, see [the Flux optimiser - documentation](https://fluxml.ai/Flux.jl/stable/training/optimisers/). To choose a - learning rate (the update rate of the optimizer), a good rule of thumb is to start out - at `10e-3`, and tune using powers of 10 between `1` and `1e-7`. +- `optimiser::Optimisers.Adam()`: An Optimisers.jl optimiser. The optimiser performs the + updating of the weights of the network. To choose a learning rate (the update rate of + the optimizer), a good rule of thumb is to start out at `10e-3`, and tune using powers + of 10 between `1` and `1e-7`. - `loss=Flux.mse`: The loss function which the network will optimize. Should be a function which can be called in the form `loss(yhat, y)`. Possible loss functions are listed in @@ -816,13 +859,13 @@ Here: GPU is available. - `lambda::Float64=0`: The strength of the weight regularization penalty. Can be any value - in the range `[0, ∞)`. + in the range `[0, ∞)`. Note the history reports unpenalized losses. - `alpha::Float64=0`: The L2/L1 mix of regularization, in the range `[0, 1]`. A value of 0 represents L2 regularization, and a value of 1 represents L1 regularization. - `rng::Union{AbstractRNG, Int64}`: The random number generator or seed used during - training. + training. The default is `Random.default_rng()`. - `optimizer_changes_trigger_retraining::Bool=false`: Defines what happens when re-fitting a machine if the associated optimiser has changed. If `true`, the associated machine diff --git a/test/builders.jl b/test/builders.jl index ba7df095..e717a2ed 100644 --- a/test/builders.jl +++ b/test/builders.jl @@ -24,7 +24,7 @@ MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) = mach = machine(model, X, y) fit!(mach, verbosity=0) - # extract the pre-training loss computed in the `fit!(chain, ...)` method: + # extract the pre-training loss computed in the `MLJFlux.train(...)` method: pretraining_loss = report(mach).training_losses[1] # compute by hand: @@ -40,12 +40,12 @@ MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) = end @testset_accelerated "Short" accel begin - builder = MLJFlux.Short(n_hidden=4, σ=Flux.relu, dropout=0) + builder = MLJFlux.Short(n_hidden=4, σ=Flux.relu, dropout=0.5) chain = MLJFlux.build(builder, StableRNGs.StableRNG(123), 5, 3) ps = Flux.params(chain) @test size.(ps) == [(4, 5), (4,), (3, 4), (3,)] - # reproducibility (without dropout): + # reproducibility: chain2 = MLJFlux.build(builder, StableRNGs.StableRNG(123), 5, 3) x = rand(Float32, 5) @test chain(x) ≈ chain2(x) diff --git a/test/classifier.jl b/test/classifier.jl index 82de9500..81ca2023 100644 --- a/test/classifier.jl +++ b/test/classifier.jl @@ -16,15 +16,17 @@ y = map(ycont) do η end end |> categorical; -# TODO: replace Short2 -> Short when -# https://github.com/FluxML/Flux.jl/issues/1372 is resolved: -builder = Short2() -optimiser = Flux.Optimise.Adam(0.03) +# In the tests below we want to check GPU and CPU give similar results. We use the `MLP` +# builer instead of the default `Short()` because `Dropout()` in `Short()` does not appear +# to behave the same on GPU as on a CPU, even when we use `default_rng()` for both. + +builder = MLJFlux.MLP(hidden=(8,)) +optimiser = Optimisers.Adam(0.03) losses = [] @testset_accelerated "NeuralNetworkClassifier" accel begin - Random.seed!(123) + # Table input: @testset "Table input" begin basictest(MLJFlux.NeuralNetworkClassifier, @@ -35,6 +37,7 @@ losses = [] 0.85, accel) end + # Matrix input: @testset "Matrix input" begin basictest(MLJFlux.NeuralNetworkClassifier, @@ -59,14 +62,16 @@ losses = [] StatisticalMeasures.cross_entropy(fill(dist, length(test)), y[test]) |> mean # check flux model is an improvement on predicting constant - # distribution: - stable_rng = StableRNGs.StableRNG(123) + # distribution + # (GPUs only support `default_rng`): + rng = Random.default_rng() + seed!(rng, 123) model = MLJFlux.NeuralNetworkClassifier(epochs=50, builder=builder, optimiser=optimiser, acceleration=accel, batch_size=10, - rng=stable_rng) + rng=rng) @time mach = fit!(machine(model, X, y), rows=train, verbosity=0) first_last_training_loss = MLJBase.report(mach)[1][[1, end]] push!(losses, first_last_training_loss[2]) diff --git a/test/core.jl b/test/core.jl index 6d6f7007..4bfc400b 100644 --- a/test/core.jl +++ b/test/core.jl @@ -4,7 +4,7 @@ stable_rng = StableRNGs.StableRNG(123) rowvec(y) = y rowvec(y::Vector) = reshape(y, 1, length(y)) -@test MLJFlux.MLJModelInterface.istransparent(Flux.Adam(0.1)) +@test MLJBase.MLJModelInterface.istransparent(Optimisers.Adam(0.05)) @testset "nrows" begin Xmatrix = rand(stable_rng, Float32, 10, 3) @@ -103,32 +103,38 @@ test_input = rand(stable_rng, Float32, 5, 1) epochs = 10 -@testset_accelerated "fit! and dropout" accel begin +@testset_accelerated "train and dropout" accel begin move = MLJFlux.Mover(accel) Random.seed!(123) - penalty = MLJFlux.Penalty(model) - _chain_yes_drop, history = MLJFlux.fit!(model, - penalty, - chain_yes_drop, - Flux.Optimise.Adam(0.001), - epochs, - 0, - data[1], - data[2]) + opt = Optimisers.Adam(0.001) + opt_state = Optimisers.setup(opt, chain_yes_drop) + _chain_yes_drop, _, history = MLJFlux.train( + model, + chain_yes_drop, + opt, + opt_state, + epochs, + 0, + data[1], + data[2], + ) println() Random.seed!(123) - penalty = MLJFlux.Penalty(model) - _chain_no_drop, history = MLJFlux.fit!(model, - penalty, - chain_no_drop, - Flux.Optimise.Adam(0.001), - epochs, - 0, - data[1], - data[2]) + opt = Optimisers.Adam(0.001) + opt_state = Optimisers.setup(opt, chain_no_drop) + _chain_no_drop, _, history = MLJFlux.train( + model, + chain_no_drop, + opt, + opt_state, + epochs, + 0, + data[1], + data[2], + ) # check chains have different behaviour after training: @test !(_chain_yes_drop(test_input) ≈ diff --git a/test/image.jl b/test/image.jl index 1260c75f..203cee71 100644 --- a/test/image.jl +++ b/test/image.jl @@ -1,8 +1,5 @@ # # BASIC IMAGE TESTS GREY -Random.seed!(123) -stable_rng = StableRNGs.StableRNG(123) - mutable struct MyNeuralNetwork <: MLJFlux.Builder kernel1 kernel2 @@ -30,18 +27,16 @@ function MLJFlux.build(builder::MyNeuralNetwork, rng, ip, op, n_channels) end builder = MyNeuralNetwork((2,2), (2,2)) -images, labels = MLJFlux.make_images(stable_rng); +images, labels = MLJFlux.make_images(StableRNG(123)); losses = [] @testset_accelerated "ImageClassifier basic tests" accel begin - Random.seed!(123) - stable_rng = StableRNGs.StableRNG(123) - + rng = StableRNG(123) model = MLJFlux.ImageClassifier(builder=builder, epochs=10, acceleration=accel, - rng=stable_rng) + rng=rng) fitresult, cache, _report = MLJBase.fit(model, 0, images, labels) @@ -57,8 +52,8 @@ losses = [] epochs=10, batch_size=2, acceleration=accel, - rng=stable_rng) - model.optimiser.eta = 0.005 + rng=rng) + model.optimiser = clonewith(model.optimiser, 0.005) # changes the learning rate @time fitresult, cache, _report = MLJBase.fit(model, 0, images, labels); first_last_training_loss = _report[1][[1, end]] push!(losses, first_last_training_loss[2]) @@ -67,10 +62,6 @@ losses = [] # tests update logic, etc (see test_utililites.jl): @test basictest(MLJFlux.ImageClassifier, images, labels, model.builder, model.optimiser, 0.95, accel) - - @test optimisertest(MLJFlux.ImageClassifier, images, labels, - model.builder, model.optimiser, accel) - end # check different resources (CPU1, CUDALibs) give about the same loss: @@ -82,18 +73,17 @@ reference = losses[1] # # BASIC IMAGE TESTS COLOR builder = MyNeuralNetwork((2,2), (2,2)) -images, labels = MLJFlux.make_images(stable_rng, color=true) +images, labels = MLJFlux.make_images(StableRNG(123), color=true) losses = [] @testset_accelerated "ColorImages" accel begin - Random.seed!(123) - stable_rng = StableRNGs.StableRNG(123) + rng = StableRNG(123) model = MLJFlux.ImageClassifier(builder=builder, epochs=10, acceleration=accel, - rng=stable_rng) + rng=rng) # tests update logic, etc (see test_utililites.jl): @test basictest(MLJFlux.ImageClassifier, images, labels, model.builder, model.optimiser, 0.95, accel) @@ -108,12 +98,8 @@ losses = [] builder=builder, batch_size=2, acceleration=accel, - rng=stable_rng) + rng=rng) fitresult, cache, _report = MLJBase.fit(model, 0, images, labels); - - @test optimisertest(MLJFlux.ImageClassifier, images, labels, - model.builder, model.optimiser, accel) - end # check different resources (CPU1, CUDALibs, etc)) give about the same loss: @@ -124,14 +110,19 @@ reference = losses[1] # # SMOKE TEST FOR DEFAULT BUILDER -images, labels = MLJFlux.make_images(stable_rng, image_size=(32, 32), n_images=12, +images, labels = MLJFlux.make_images(StableRNG(123), image_size=(32, 32), n_images=12, noise=0.2, color=true); @testset_accelerated "ImageClassifier basic tests" accel begin + + # GPUs only support `default_rng`: + rng = Random.default_rng() + seed!(rng, 123) + model = MLJFlux.ImageClassifier(epochs=5, batch_size=4, acceleration=accel, - rng=stable_rng) + rng=rng) fitresult, _, _ = MLJBase.fit(model, 0, images, labels); predict(model, fitresult, images) end diff --git a/test/integration.jl b/test/integration.jl index 39f4f79c..8e0bdfdc 100644 --- a/test/integration.jl +++ b/test/integration.jl @@ -2,6 +2,7 @@ rng = StableRNGs.StableRNG(123) table = load_iris() y, X = unpack(table, ==(:target), _->true, rng=rng) +X = Tables.table(Float32.(Tables.matrix(X))) @testset_accelerated "regularization has an effect" accel begin @@ -10,7 +11,9 @@ y, X = unpack(table, ==(:target), _->true, rng=rng) rng=rng) model2 = deepcopy(model) model3 = deepcopy(model) + model4 = deepcopy(model) model3.lambda = 0.1 + model4.alpha = 0.1 # still no regularization here because `lambda=0`. e = evaluate(model, X, y, resampling=Holdout(), measure=StatisticalMeasures.LogLoss()) loss1 = e.measurement[1] @@ -21,6 +24,10 @@ y, X = unpack(table, ==(:target), _->true, rng=rng) e = evaluate(model3, X, y, resampling=Holdout(), measure=StatisticalMeasures.LogLoss()) loss3 = e.measurement[1] + e = evaluate(model4, X, y, resampling=Holdout(), measure=StatisticalMeasures.LogLoss()) + loss4 = e.measurement[1] + @test loss1 ≈ loss2 - @test !(loss2 ≈ loss3) + @test !(loss1 ≈ loss3) + @test loss1 ≈ loss4 end diff --git a/test/mlj_model_interface.jl b/test/mlj_model_interface.jl index 52452ea5..7b929050 100644 --- a/test/mlj_model_interface.jl +++ b/test/mlj_model_interface.jl @@ -4,7 +4,7 @@ ModelType = MLJFlux.NeuralNetworkRegressor model = MLJFlux.ImageClassifier() clone = deepcopy(model) @test model == clone - clone.optimiser.eta *= 10 + clone.optimiser = Optimisers.Adam(model.optimiser.eta*10) @test model != clone end @@ -37,6 +37,78 @@ end end end +@testset "regularization: logic" begin + optimiser = Optimisers.Momentum() + + # lambda = 0: + model = MLJFlux.NeuralNetworkRegressor(; alpha=0.3, lambda=0, optimiser) + chain = MLJFlux.regularized_optimiser(model, 1) + @test chain == optimiser + + # alpha = 0: + model = MLJFlux.NeuralNetworkRegressor(; alpha=0, lambda=0.3, optimiser) + chain = MLJFlux.regularized_optimiser(model, 1) + @test chain isa Optimisers.OptimiserChain{ + Tuple{Optimisers.WeightDecay, Optimisers.Momentum} + } + + # alpha = 1: + model = MLJFlux.NeuralNetworkRegressor(; alpha=1, lambda=0.3, optimiser) + chain = MLJFlux.regularized_optimiser(model, 1) + @test chain isa Optimisers.OptimiserChain{ + Tuple{Optimisers.SignDecay, Optimisers.Momentum} + } + + # general case: + model = MLJFlux.NeuralNetworkRegressor(; alpha=0.4, lambda=0.3, optimiser) + chain = MLJFlux.regularized_optimiser(model, 1) + @test chain isa Optimisers.OptimiserChain{ + Tuple{Optimisers.SignDecay, Optimisers.WeightDecay, Optimisers.Momentum} + } +end + +@testset "regularization: integration" begin + rng = StableRNG(123) + nobservations = 12 + Xuser = rand(Float32, nobservations, 3) + yuser = rand(Float32, nobservations) + alpha = rand(rng) + lambda = rand(rng) + optimiser = Optimisers.Momentum() + builder = MLJFlux.Linear() + epochs = 1 # don't change this + opts = (; alpha, lambda, optimiser, builder, epochs) + + for batch_size in [1, 2, 3] + + # (1) train using weight/sign decay, as implemented in MLJFlux: + model = MLJFlux.NeuralNetworkRegressor(; batch_size, rng=StableRNG(123), opts...); + mach = machine(model, Xuser, yuser); + fit!(mach, verbosity=0); + w1 = Optimisers.trainables(fitted_params(mach).chain) + + # (2) manually train for one epoch explicitly adding a loss penalty: + chain = MLJFlux.build(builder, StableRNG(123), 3, 1); + penalty = Penalizer(lambda, alpha); # defined in test_utils.jl + X, y = MLJFlux.collate(model, Xuser, yuser); + loss = model.loss; + n_batches = div(nobservations, batch_size) + optimiser_state = Optimisers.setup(optimiser, chain); + for i in 1:n_batches + batch_loss, gs = Flux.withgradient(chain) do m + yhat = m(X[i]) + loss(yhat, y[i]) + sum(penalty, Optimisers.trainables(m))/n_batches + end + ∇ = first(gs) + optimiser_state, chain = Optimisers.update(optimiser_state, chain, ∇) + end + w2 = Optimisers.trainables(chain) + + # (3) compare the trained weights + @test w1 ≈ w2 + end +end + @testset "iteration api" begin model = MLJFlux.NeuralNetworkRegressor(epochs=10) @test MLJBase.supports_training_losses(model) diff --git a/test/penalizers.jl b/test/penalizers.jl deleted file mode 100644 index 68ee6b99..00000000 --- a/test/penalizers.jl +++ /dev/null @@ -1,35 +0,0 @@ -using Statistics -import MLJFlux -import Flux - -@testset "Penalizer" begin - A = [-1 2; -3 4] - lambda = 1 - - # 100% L2: - alpha = 0 - penalty = MLJFlux.Penalizer(lambda, alpha) - @test penalty(A) ≈ 1 + 4 + 9 + 16 - - # 100% L1: - alpha = 1 - penalty = MLJFlux.Penalizer(lambda, alpha) - @test penalty(A) ≈ 1 + 2 + 3 + 4 - - # no strength: - lambda = 0 - alpha = 42.324 - penalty = MLJFlux.Penalizer(lambda, alpha) - @test penalty(A) == 0 -end - -@testset "Penalty" begin - model = MLJFlux.NeuralNetworkRegressor(lambda=1, alpha=1, loss=Flux.mae) - chain = Flux.Dense(3, 1, identity) - w = Flux.params(chain) - p = MLJFlux.Penalty(model) - - # compare loss by hand and with penalized loss function: - penalty = (sum(abs.(chain.weight)) + abs(chain.bias[1])) - @test p(w) ≈ penalty -end diff --git a/test/regressor.jl b/test/regressor.jl index 1345125f..d17a3760 100644 --- a/test/regressor.jl +++ b/test/regressor.jl @@ -3,12 +3,8 @@ Random.seed!(123) N = 200 X = MLJBase.table(randn(Float32, N, 5)); -# TODO: replace Short2 -> Short when -# https://github.com/FluxML/Flux.jl/pull/1618 is resolved: -builder = Short2(σ=identity) -optimiser = Flux.Optimise.Adam() - -losses = [] +builder = MLJFlux.Short(σ=identity) +optimiser = Optimisers.Adam() Random.seed!(123) y = 1 .+ X.x1 - X.x2 .- 2X.x4 + X.x5 @@ -16,113 +12,96 @@ train, test = MLJBase.partition(1:N, 0.7) @testset_accelerated "NeuralNetworkRegressor" accel begin - Random.seed!(123) - # Table input: @testset "Table input" begin - basictest(MLJFlux.NeuralNetworkRegressor, - X, - y, - builder, - optimiser, - 0.7, - accel) + basictest( + MLJFlux.NeuralNetworkRegressor, + X, + y, + builder, + optimiser, + 0.7, + accel, + ) end - + # Matrix input: @testset "Matrix input" begin - basictest(MLJFlux.NeuralNetworkRegressor, - matrix(X), - y, - builder, - optimiser, - 0.7, - accel) + @test basictest( + MLJFlux.NeuralNetworkRegressor, + matrix(X), + y, + builder, + optimiser, + 0.7, + accel, + ) end # test model is a bit better than constant predictor: - stable_rng = StableRNGs.StableRNG(123) + # (GPUs only support `default_rng` when there's `Dropout`): + rng = Random.default_rng() + seed!(rng, 123) model = MLJFlux.NeuralNetworkRegressor(builder=builder, acceleration=accel, - rng=stable_rng) + rng=rng) @time fitresult, _, rpt = fit(model, 0, MLJBase.selectrows(X, train), y[train]) first_last_training_loss = rpt[1][[1, end]] - push!(losses, first_last_training_loss[2]) # @show first_last_training_loss yhat = predict(model, fitresult, selectrows(X, test)) truth = y[test] goal = 0.9*model.loss(truth .- mean(truth), 0) @test model.loss(yhat, truth) < goal - - optimisertest(MLJFlux.NeuralNetworkRegressor, - X, - y, - builder, - optimiser, - accel) - end -# check different resources (CPU1, CUDALibs, etc)) give about the same loss: -reference = losses[1] -@test all(x->abs(x - reference)/reference < 1e-6, losses[2:end]) - Random.seed!(123) ymatrix = hcat(1 .+ X.x1 - X.x2, 1 .- 2X.x4 + X.x5); y = MLJBase.table(ymatrix); -losses = [] - @testset_accelerated "MultitargetNeuralNetworkRegressor" accel begin - Random.seed!(123) - # Table input: @testset "Table input" begin - basictest(MLJFlux.MultitargetNeuralNetworkRegressor, - X, - y, - builder, - optimiser, - 1.0, - accel) + @test basictest( + MLJFlux.MultitargetNeuralNetworkRegressor, + X, + y, + builder, + optimiser, + 1.0, + accel, + ) end # Matrix input: @testset "Matrix input" begin - basictest(MLJFlux.MultitargetNeuralNetworkRegressor, - matrix(X), - ymatrix, - builder, - optimiser, - 1.0, - accel) + @test basictest( + MLJFlux.MultitargetNeuralNetworkRegressor, + matrix(X), + ymatrix, + builder, + optimiser, + 1.0, + accel, + ) end # test model is a bit better than constant predictor - model = MLJFlux.MultitargetNeuralNetworkRegressor(acceleration=accel, - builder=builder) + # (GPUs only support `default_rng` when there's `Dropout`): + rng = Random.default_rng() + seed!(rng, 123) + model = MLJFlux.MultitargetNeuralNetworkRegressor( + acceleration=accel, + builder=builder, + rng=rng, + ) @time fitresult, _, rpt = fit(model, 0, MLJBase.selectrows(X, train), selectrows(y, train)) first_last_training_loss = rpt[1][[1, end]] - push!(losses, first_last_training_loss[2]) -# @show first_last_training_loss yhat = predict(model, fitresult, selectrows(X, test)) truth = ymatrix[test,:] - goal = 0.8*model.loss(truth .- mean(truth), 0) + goal = 0.85*model.loss(truth .- mean(truth), 0) @test model.loss(Tables.matrix(yhat), truth) < goal - - optimisertest(MLJFlux.MultitargetNeuralNetworkRegressor, - X, - y, - builder, - optimiser, - accel) - end -# check different resources (CPU1, CUDALibs, etc)) give about the same loss: -reference = losses[1] -@test all(x->abs(x - reference)/reference < 1e-6, losses[2:end]) - true diff --git a/test/runtests.jl b/test/runtests.jl index d8df35be..b7b11d66 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,7 @@ import StatsBase using StableRNGs using CUDA, cuDNN import StatisticalMeasures +import Optimisers using ComputationalResources using ComputationalResources: CPU1, CUDALibs @@ -26,23 +27,6 @@ MLJFlux.gpu_isdead() && push!(EXCLUDED_RESOURCE_TYPES, CUDALibs) "these types, as unavailable:\n$EXCLUDED_RESOURCE_TYPES\n"* "Excluded tests marked as \"broken\"." -# alternative version of Short builder with no dropout; see -# https://github.com/FluxML/Flux.jl/issues/1372 and -# https://github.com/FluxML/Flux.jl/issues/1372 -mutable struct Short2 <: MLJFlux.Builder - n_hidden::Int # if zero use geometric mean of input/output - σ -end -Short2(; n_hidden=0, σ=Flux.sigmoid) = Short2(n_hidden, σ) -function MLJFlux.build(builder::Short2, rng, n, m) - n_hidden = - builder.n_hidden == 0 ? round(Int, sqrt(n*m)) : builder.n_hidden - init = Flux.glorot_uniform(rng) - return Flux.Chain( - Flux.Dense(n, n_hidden, builder.σ, init=init), - Flux.Dense(n_hidden, m, init=init)) -end - seed!(123) include("test_utils.jl") @@ -58,9 +42,6 @@ macro conditional_testset(name, expr) end end) end -@conditional_testset "penalizers" begin - include("penalizers.jl") -end @conditional_testset "core" begin include("core.jl") diff --git a/test/test_utils.jl b/test/test_utils.jl index 7a2728b4..3dbeda8a 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -8,6 +8,11 @@ macro testset_accelerated(name::String, var, opts::Expr, ex) testset_accelerated(name, var, ex; eval(opts)...) end +clonewith(optimiser, args...) = + error("`basictest` and `optimisertest` only support `Adam` optimiser. ") +clonewith(optimiser::Optimisers.Adam, args...) = + Optimisers.Adam(args...) + # To exclude a resource, say, CPU1, do like # `@test_accelerated "cool test" accel (exclude=[CPU1,],) begin ... end` function testset_accelerated(name::String, var, ex; exclude=[]) @@ -57,12 +62,14 @@ function basictest(ModelType, X, y, builder, optimiser, threshold, accel) eval(quote - stable_rng = StableRNGs.StableRNG(123) + # GPUs only support `default_rng`: + rng = accel == CPU1() ? StableRNGs.StableRNG(123) : Random.default_rng() + seed!(rng, 123) model = $ModelType_ex(builder=$builder, optimiser=$optimiser, acceleration=$accel_ex, - rng=stable_rng) + rng=rng) fitresult, cache, _report = MLJBase.fit(model, 0, $X, $y); @@ -93,7 +100,7 @@ function basictest(ModelType, X, y, builder, optimiser, threshold, accel) optimiser=$optimiser, epochs=2, acceleration=$accel_ex, - rng=stable_rng) + rng=rng) fitresult, cache, _report = MLJBase.fit(model, 0, $X, $y); @@ -105,14 +112,14 @@ function basictest(ModelType, X, y, builder, optimiser, threshold, accel) MLJBase.update(model, 2, fitresult, cache, $X, $y )); # change learning rate and check it does *not* restart: - model.optimiser.eta /= 2 + model.optimiser = clonewith(model.optimiser, model.optimiser.eta/2) fitresult, cache, _report = @test_logs(MLJBase.update(model, 2, fitresult, cache, $X, $y)); # set `optimiser_changes_trigger_retraining = true` and change # learning rate and check it does restart: model.optimiser_changes_trigger_retraining = true - model.optimiser.eta /= 2 + model.optimiser = clonewith(model.optimiser, model.optimiser.eta/2) @test_logs((:info, r""), # one line of :info per extra epoch (:info, r""), MLJBase.update(model, 2, fitresult, cache, $X, $y)); @@ -132,14 +139,14 @@ function optimisertest(ModelType, X, y, builder, optimiser, accel) eval(quote - model = $ModelType_ex(builder=$builder, - optimiser=$optimiser, - acceleration=$accel_ex, - epochs=1) + model = $ModelType_ex(builder=$builder, + optimiser=$optimiser, + acceleration=$accel_ex, + epochs=1) mach = machine(model, $X, $y); - # USING GLOBAL RNG + # USING DEFAULT RNG # two epochs in stages: Random.seed!(123) # chains are always initialized on CPU @@ -156,31 +163,65 @@ function optimisertest(ModelType, X, y, builder, optimiser, accel) if accel isa CPU1 @test isapprox(l1, l2) else - @test_broken isapprox(l1, l2, rtol=1e-8) + @test isapprox(l1, l2, rtol=1e-8) end - # USING USER SPECIFIED RNG SEED + # USING USER SPECIFIED RNG SEED (unsupported on GPU) - # two epochs in stages: - model.rng = 1234 - mach = machine(model, $X, $y); + if !(accel isa CUDALibs) + # two epochs in stages: + model.rng = 1234 + mach = machine(model, $X, $y); - fit!(mach, verbosity=0, force=true); - model.epochs = model.epochs + 1 - fit!(mach, verbosity=0); # update - l1 = MLJBase.report(mach).training_losses[end] + fit!(mach, verbosity=0, force=true); + model.epochs = model.epochs + 1 + fit!(mach, verbosity=0); # update + l1 = MLJBase.report(mach).training_losses[end] - # two epochs in one go: - fit!(mach, verbosity=1, force=true) - l2 = MLJBase.report(mach).training_losses[end] + # two epochs in one go: + fit!(mach, verbosity=1, force=true) + l2 = MLJBase.report(mach).training_losses[end] - if accel isa CPU1 @test isapprox(l1, l2) - else - @test_broken isapprox(l1, l2, rtol=1e-8) end end) return true end + + +# # LOSS PENALIZERS + +""" + Penalizer(λ, α) + +Returns a callable object `penalizer` for evaluating regularization +penalties associated with some numerical array. Specifically, +`penalizer(A)` returns + + λ*(α*L1 + (1 - α)*L2), + +where `L1` is the sum of absolute values of the elments of `A` and +`L2` is the sum of squares of those elements. + +""" +struct Penalizer{T} + lambda::T + alpha::T + function Penalizer(lambda, alpha) + lambda == 0 && return new{Nothing}(nothing, nothing) + T = promote_type(typeof.((lambda, alpha))...) + return new{T}(lambda, alpha) + end +end + +(::Penalizer{Nothing})(::Any) = 0 +function (p::Penalizer)(A) + λ = p.lambda + α = p.alpha + # avoiding broadcasting; see Note (1) above + L2 = sum(abs2, A) + L1 = sum(abs, A) + return λ*(α*L1 + (1 - α)*L2) +end