diff --git a/Project.toml b/Project.toml index 0ad3d0d2..40892f2c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJFlux" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" authors = ["Anthony D. Blaom ", "Ayush Shridhar "] -version = "0.2.2" +version = "0.2.3" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index 1aea02ed..5a909ba4 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -14,6 +14,7 @@ using ColorTypes using ComputationalResources using Random +include("penalized_losses.jl") include("core.jl") include("builders.jl") include("types.jl") diff --git a/src/common.jl b/src/common.jl index 47b77186..a585ae9e 100644 --- a/src/common.jl +++ b/src/common.jl @@ -43,34 +43,39 @@ end true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng function MLJModelInterface.fit(model::MLJFluxModel, - verbosity::Int, + verbosity, X, y) - data = collate(model, X, y) + move = Mover(model.acceleration) rng = true_rng(model) - shape = MLJFlux.shape(model, X, y) - chain = build(model, rng, shape) + chain = build(model, rng, shape) |> move + penalized_loss = PenalizedLoss(model, chain) + + data = move.(collate(model, X, y)) optimiser = deepcopy(model.optimiser) - chain, history = fit!(chain, + chain, history = fit!(penalized_loss, + chain, optimiser, - model.loss, model.epochs, - model.lambda, - model.alpha, verbosity, - model.acceleration, - data[1], + data[1], data[2]) # `optimiser` is now mutated - cache = (deepcopy(model), data, history, shape, optimiser, deepcopy(rng)) - fitresult = MLJFlux.fitresult(model, chain, y) + cache = (deepcopy(model), + data, + history, + shape, + optimiser, + deepcopy(rng), + move) + fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) report = (training_losses=history, ) @@ -78,13 +83,13 @@ function MLJModelInterface.fit(model::MLJFluxModel, end function MLJModelInterface.update(model::MLJFluxModel, - verbosity::Int, + verbosity, old_fitresult, old_cache, X, y) - old_model, data, old_history, shape, optimiser, rng = old_cache + old_model, data, old_history, shape, optimiser, rng, move = old_cache old_chain = old_fitresult[1] optimiser_flag = model.optimiser_changes_trigger_retraining && @@ -94,15 +99,18 @@ function MLJModelInterface.update(model::MLJFluxModel, MLJModelInterface.is_same_except(model, old_model, :optimiser, :epochs) if keep_chain - chain = old_chain + chain = move(old_chain) epochs = model.epochs - old_model.epochs else + move = Mover(model.acceleration) rng = true_rng(model) - chain = build(model, rng, shape) - data = collate(model, X, y) + chain = build(model, rng, shape) |> move + data = move.(collate(model, X, y)) epochs = model.epochs end + penalized_loss = PenalizedLoss(model, chain) + # we only get to keep the optimiser "state" carried over from # previous training if we're doing a warm restart and the user has not # changed the optimiser hyper-parameter: @@ -112,14 +120,11 @@ function MLJModelInterface.update(model::MLJFluxModel, optimiser = deepcopy(model.optimiser) end - chain, history = fit!(chain, + chain, history = fit!(penalized_loss, + chain, optimiser, - model.loss, epochs, - model.lambda, - model.alpha, verbosity, - model.acceleration, data[1], data[2]) if keep_chain @@ -127,8 +132,14 @@ function MLJModelInterface.update(model::MLJFluxModel, history = vcat(old_history[1:end-1], history) end - fitresult = MLJFlux.fitresult(model, chain, y) - cache = (deepcopy(model), data, history, shape, optimiser, deepcopy(rng)) + fitresult = MLJFlux.fitresult(model, Flux.cpu(chain), y) + cache = (deepcopy(model), + data, + history, + shape, + optimiser, + deepcopy(rng), + move) report = (training_losses=history, ) return fitresult, cache, report diff --git a/src/core.jl b/src/core.jl index dbb450fa..2b5c375f 100644 --- a/src/core.jl +++ b/src/core.jl @@ -37,11 +37,15 @@ end (::Mover{<:CUDALibs})(data) = Flux.gpu(data) """ -Custom training loop. Here, `loss_func` is the objective -function to optimise, `parameters` are the model parameters, -`optimiser` is the optimizer to be used, `X` (input features)is a -vector of arrays where the last dimension is the batch size. `y` -is the target observation vector. + train!(loss_func, parameters, optimiser, X, y) + +A private method. + +Update the parameters of a Flux chain with parameters `parameters`, +given a Flux-style "loss function" `loss(x, y)`. Here `X` and `y` are +vectors of batches of the training data, as detailed in the +[`MLJFlux.fit!`](@ref) document string. + """ function train!(loss_func, parameters, optimiser, X, y) n_batches = length(y) @@ -59,102 +63,74 @@ end """ - fit!(chain, - optimiser, - loss, - epochs, - lambda, - alpha, - verbosity, - acceleration, - X, - y) - -Optimize a Flux model `chain` using the regularization parameters -`lambda` (strength) and `alpha` (l2/l1 mix), where `loss(yhat, y) ` is -the supervised loss for instances (or vectors of instances) of the -target predictions `yhat` and target observations `y`. + fit!(penalized_loss, chain, optimiser, epochs, verbosity, X, y) + +A private method. + +Optimize a Flux model `chain`, where `penalized_loss(xb, yb)` is the +penalized loss associated with a batch of training input features `xb` +and target observations `yb` (and generally depends on `chain`). Here `chain` is a `Flux.Chain` object, or other "Flux model" such that -`Flux.params(chain)` returns the parameters to be optimised. +`Flux.params(chain)` returns the parameters to be optimized. -The `X` argument is the training features and `y` argument is the -target: +`X`, the vector of input batches and `y` the vector of target +batches. Specifically, it is expected that: - `X` and `y` have type `Vector{<:Array{<:AbstractFloat}}` -- the shape of each elment of `X` is `(n1, n2, ..., nk, batch_size)` +- The shape of each element of `X` is `(n1, n2, ..., nk, batch_size)` where `(n1, n2, ..., nk)` is the shape of the inputs of `chain` -- the shape of each element of `y` is `(m1, m2, ..., mk, batch_size)` +- The shape of each element of `y` is `(m1, m2, ..., mk, batch_size)` where `(m1, m2, ..., mk)` is the shape of the `chain` outputs (even if `batch_size == 1`). -- the vectors `X` and `y` have the same length, coinciding with the +- The vectors `X` and `y` have the same length, coinciding with the total number of training batches. -The contribution to the objective function of a single training batch -`(X[i], y[i])` is - - loss(chain(X[i]), y[i]) + lambda*(model.alpha*l1) + (1 - model.alpha)*l2 - -where `l1 = sum(norm, params(chain)` and `l2 = sum(norm, params(chain))`. +The [`MLJFlux.PenalizedLoss`](@ref) constructor is available for +defining an appropriate `penalized_loss` from an MLJFlux model and +chain (the model specifies the unpenalized Flux loss, such as `mse`, +and the regularization parameters). -One must have `acceleration isa CPU1` or `acceleration isa CUDALibs` -(for running on a GPU) where `CPU1` and `CUDALibs` are types defined -in `ComputationalResources.jl`. +Both the `chain` and the data `(X, y)` must both live on a CPU or both +live on a GPU. This `fit!` method takes no responsibility for data +movement. ### Return value `(chain_trained, history)`, where `chain_trained` is a trained version -of `chain` (possibly moved to a gpu) and `history` is a vector of -losses - one intial loss, and one loss per epoch. The method may -mutate the argument `chain`, depending on cpu <-> gpu movements. +of `chain` and `history` is a vector of penalized losses - one initial +loss, and one loss per epoch. """ -function fit!(chain, optimiser, loss, epochs, - lambda, alpha, verbosity, acceleration, X, y) +function fit!(penalized_loss, chain, optimiser, epochs, verbosity, X, y) # intitialize and start progress meter: meter = Progress(epochs+1, dt=0, desc="Optimising neural net:", barglyphs=BarGlyphs("[=> ]"), barlen=25, color=:yellow) verbosity != 1 || next!(meter) - move = Mover(acceleration) - X = move(X) - y = move(y) - chain = move(chain) - - loss_func(x, y) = loss(chain(x), y) - # initiate history: n_batches = length(y) - training_loss = mean(loss_func(X[i], y[i]) for i in 1:n_batches) + training_loss = mean(penalized_loss(X[i], y[i]) for i in 1:n_batches) history = [training_loss,] for i in 1:epochs - # We're taking data in a Flux-fashion. -# @show i rand() - current_loss = train!(loss_func, Flux.params(chain), optimiser, X, y) + current_loss = train!(penalized_loss, + Flux.params(chain), + optimiser, + X, + y) verbosity < 2 || @info "Loss is $(round(current_loss; sigdigits=4))" - push!(history, current_loss) - - # Early stopping is to be externally controlled. - # So @ablaom has commented next 5 lines : - # if current_loss == prev_loss - # @info "Model has reached maximum possible accuracy."* - # "More training won't increase accuracy" - # break - # end - - prev_loss = current_loss verbosity != 1 || next!(meter) - + push!(history, current_loss) end - return Flux.cpu(chain), history + return chain, history end diff --git a/src/penalized_losses.jl b/src/penalized_losses.jl new file mode 100644 index 00000000..477bc913 --- /dev/null +++ b/src/penalized_losses.jl @@ -0,0 +1,63 @@ +# Note (1). See +# https://discourse.julialang.org/t/weight-regularisation-which-iterates-params-m-in-flux-mutating-arrays-is-not-supported/64314 + + +""" Penalizer(λ, α) + +Returns a callable object `penalizer` for evaluating regularization +penalties associated with some numerical array. Specifically, +`penalizer(A)` returns + + λ*(α*L1 + (1 - α)*L2), + +where `L1` is the sum of absolute values of the elments of `A` and +`L2` is the sum of squares of those elements. +""" +struct Penalizer{T} + lambda::T + alpha::T + function Penalizer(lambda, alpha) + lambda == 0 && return new{Nothing}(nothing, nothing) + T = promote_type(typeof.((lambda, alpha))...) + return new{T}(lambda, alpha) + end +end + +(::Penalizer{Nothing})(::Any) = 0 +function (p::Penalizer)(A) + λ = p.lambda + α = p.alpha + # avoiding broadcasting; see Note (1) above + L2 = sum(abs2, A) + L1 = sum(abs, A) + return λ*(α*L1 + (1 - α)*L2) +end + +""" + PenalizedLoss(model, chain) + +Returns a callable object `p`, for returning the penalized loss on +some batch of data `(x, y)`. Specifically, `p(x, y)` returns + + loss(chain(x), y) + sum(Penalizer(λ, α).(params(chain))) + +where `loss = model.loss`, `α = model.alpha`, `λ = model.lambda`. + +See also [`Penalizer`](@ref) + +""" +struct PenalizedLoss{P} + loss + penalizer::P + chain + params + function PenalizedLoss(model, chain) + loss = model.loss + penalizer = Penalizer(model.lambda, model.alpha) + params = Flux.params(chain) + return new{typeof(penalizer)}(loss, penalizer, chain, params) + end +end +(p::PenalizedLoss{Penalizer{Nothing}})(x, y) = p.loss(p.chain(x), y) +(p::PenalizedLoss)(x, y) = p.loss(p.chain(x), y) + + sum(p.penalizer(θ) for θ in p.params) diff --git a/src/types.jl b/src/types.jl index 86bef752..362262ff 100644 --- a/src/types.jl +++ b/src/types.jl @@ -3,6 +3,51 @@ abstract type MLJFluxDeterministic <: MLJModelInterface.Deterministic end const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic} +const doc_regressor(model_name) = """ + + $model_name(; hyparameters...) + +Instantiate an MLJFlux model. Available hyperparameters: + +- `builder`: Default = `MLJFlux.Linear(σ=Flux.relu)` (regressors) or + `MLJFlux.Short(n_hidden=0, dropout=0.5, σ=Flux.σ)` (classifiers) + +- `optimiser`: The optimiser to use for training. Default = + `Flux.ADAM()` + +- `loss`: The loss function used for training. Default = `Flux.mse` + (regressors) and `Flux.crossentropy` (classifiers) + +- `n_epochs`: Number of epochs to train for. Default = `10` + +- `batch_size`: The batch_size for the data. Default = 1 + +- `lambda`: The regularization strength. Default = 0. Range = [0, ∞) + +- `alpha`: The L2/L1 mix of regularization. Default = 0. Range = [0, 1] + +- `rng`: The random number generator (RNG) passed to builders, for + weight intitialization, for example. Can be any `AbstractRNG` or + the seed (integer) for a `MersenneTwister` that is reset on every + cold restart of model (machine) training. Default = + `GLOBAL_RNG`. + +- `acceleration`: Use `CUDALibs()` for training on GPU; default is `CPU1()`. + +- `optimiser_changes_trigger_retraining`: True if fitting an + associated machine should trigger retraining from scratch whenever + the optimiser changes. Default = `false` + +""" + +doc_classifier(model_name) = doc_regressor(model_name)*""" +- `finaliser`: Operation applied to the unnormalized output of the + final layer to obtain probabilities (outputs summing to + one). The shape of the inputs and outputs + of this operator must match. Default = `Flux.softmax`. + +""" + for Model in [:NeuralNetworkClassifier, :ImageClassifier] ex = quote @@ -51,6 +96,9 @@ for Model in [:NeuralNetworkClassifier, :ImageClassifier] return model end + + @doc doc_classifier($Model) $Model + end eval(ex) @@ -100,6 +148,9 @@ for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] return model end + + @doc $doc_regressor($Model) $Model + end eval(ex) diff --git a/test/core.jl b/test/core.jl index d62e9707..cf5b2ab6 100644 --- a/test/core.jl +++ b/test/core.jl @@ -90,6 +90,8 @@ chain_yes_drop = Flux.Chain(Flux.Dense(5, 15), Flux.Dense(15, 8), Flux.Dense(8, 1)) +model = MLJFlux.NeuralNetworkRegressor() # any model will do here + chain_no_drop = deepcopy(chain_yes_drop) chain_no_drop.layers[2].p = 1.0 @@ -105,35 +107,29 @@ epochs = 10 move = MLJFlux.Mover(accel) Random.seed!(123) - - _chain_yes_drop, history = MLJFlux.fit!(chain_yes_drop, + penalized_loss = MLJFlux.PenalizedLoss(model, chain_yes_drop) + _chain_yes_drop, history = MLJFlux.fit!(penalized_loss, + chain_yes_drop, Flux.Optimise.ADAM(0.001), - Flux.mse, epochs, 0, - 0, - 0, - accel, data[1], data[2]) - println() Random.seed!(123) - - _chain_no_drop, history = MLJFlux.fit!(chain_no_drop, - Flux.Optimise.ADAM(0.001), - Flux.mse, - epochs, - 0, - 0, - 0, - accel, - data[1], - data[2]) + penalized_loss = MLJFlux.PenalizedLoss(model, chain_no_drop) + _chain_no_drop, history = MLJFlux.fit!(penalized_loss, + chain_no_drop, + Flux.Optimise.ADAM(0.001), + epochs, + 0, + data[1], + data[2]) # check chains have different behaviour after training: - @test !(_chain_yes_drop(test_input) ≈ _chain_no_drop(test_input)) + @test !(_chain_yes_drop(test_input) ≈ + _chain_no_drop(test_input)) # check chain with dropout is deterministic outside of training # (if we do not differentiate): @@ -143,5 +139,3 @@ epochs = 10 @test length(history) == epochs + 1 end - - diff --git a/test/integration.jl b/test/integration.jl new file mode 100644 index 00000000..70b45fd1 --- /dev/null +++ b/test/integration.jl @@ -0,0 +1,26 @@ +rng = StableRNGs.StableRNG(123) + +table = load_iris() +y, X = unpack(table, ==(:target), _->true, rng=rng) + +@testset_accelerated "regularization has an effect" accel begin + + model = MLJFlux.NeuralNetworkClassifier(acceleration=accel, + builder=MLJFlux.Linear(), + rng=rng) + model2 = deepcopy(model) + model3 = deepcopy(model) + model3.lambda = 0.1 + + e = evaluate(model, X, y, resampling=Holdout(), measure=LogLoss()) + loss1 = e.measurement[1] + + e = evaluate(model2, X, y, resampling=Holdout(), measure=LogLoss()) + loss2 = e.measurement[1] + + e = evaluate(model3, X, y, resampling=Holdout(), measure=LogLoss()) + loss3 = e.measurement[1] + + @test loss1 ≈ loss2 + @test !(loss2 ≈ loss3) +end diff --git a/test/penalized_losses.jl b/test/penalized_losses.jl new file mode 100644 index 00000000..bbc67720 --- /dev/null +++ b/test/penalized_losses.jl @@ -0,0 +1,39 @@ +using Statistics + +@testset "penalties" begin + A = [-1 2; -3 4] + lambda = 1 + + # 100% L2: + alpha = 0 + penalty = MLJFlux.Penalizer(lambda, alpha) + @test penalty(A) ≈ 1 + 4 + 9 + 16 + + # 100% L1: + alpha = 1 + penalty = MLJFlux.Penalizer(lambda, alpha) + @test penalty(A) ≈ 1 + 2 + 3 + 4 + + # no strength: + lambda = 0 + alpha = 42.324 + penalty = MLJFlux.Penalizer(lambda, alpha) + @test penalty(A) == 0 +end + +@testset "penalized_losses" begin + # construct a penalized loss function: + model = MLJFlux.NeuralNetworkRegressor(lambda=1, alpha=1, loss=Flux.mae) + chain = Flux.Dense(3, 1, identity) + p = MLJFlux.PenalizedLoss(model, chain) + + # construct a batch: + b = 5 + x = rand(Float32, 3, b) + y = rand(Float32, 1, b) + + # compare loss by hand and with penalized loss function: + penalty = (sum(abs.(chain.weight)) + abs(chain.bias[1])) + yhat = chain(x) + @test p(x, y) ≈ Flux.mae(yhat, y) + penalty +end diff --git a/test/runtests.jl b/test/runtests.jl index e3963daa..d72b13e8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,6 +44,10 @@ seed!(123) include("test_utils.jl") +@testset "penalized_losses" begin + include("penalized_losses.jl") +end + @testset "core" begin include("core.jl") end @@ -68,3 +72,6 @@ end include("image.jl") end +@testset "integration" begin + include("integration.jl") +end