From 4af84e57e8f0f7213693da96c0c27a3ecf890680 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 30 May 2024 10:29:23 +1200 Subject: [PATCH] address review: add logic to regularized_optimiser() --- src/mlj_model_interface.jl | 18 +++++++++++++++--- test/mlj_model_interface.jl | 32 +++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index 70dadf9c..f25c137e 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -62,12 +62,24 @@ function regularized_optimiser(model, nbatches) λ_L2 = (1 - model.alpha)*model.lambda λ_sign = λ_L1/nbatches λ_weight = 2*λ_L2/nbatches - # components in an optimiser chain are executed from left to right: - return Optimisers.OptimiserChain( + + # recall components in an optimiser chain are executed from left to right: + if model.alpha == 0 + return Optimisers.OptimiserChain( + Optimisers.WeightDecay(λ_weight), + model.optimiser, + ) + elseif model.alpha == 1 + return Optimisers.OptimiserChain( + Optimisers.SignDecay(λ_sign), + model.optimiser, + ) + else return Optimisers.OptimiserChain( Optimisers.SignDecay(λ_sign), Optimisers.WeightDecay(λ_weight), model.optimiser, - ) + ) + end end function MLJModelInterface.fit(model::MLJFluxModel, diff --git a/test/mlj_model_interface.jl b/test/mlj_model_interface.jl index d09c0a62..7b929050 100644 --- a/test/mlj_model_interface.jl +++ b/test/mlj_model_interface.jl @@ -37,7 +37,37 @@ end end end -@testset "regularization" begin +@testset "regularization: logic" begin + optimiser = Optimisers.Momentum() + + # lambda = 0: + model = MLJFlux.NeuralNetworkRegressor(; alpha=0.3, lambda=0, optimiser) + chain = MLJFlux.regularized_optimiser(model, 1) + @test chain == optimiser + + # alpha = 0: + model = MLJFlux.NeuralNetworkRegressor(; alpha=0, lambda=0.3, optimiser) + chain = MLJFlux.regularized_optimiser(model, 1) + @test chain isa Optimisers.OptimiserChain{ + Tuple{Optimisers.WeightDecay, Optimisers.Momentum} + } + + # alpha = 1: + model = MLJFlux.NeuralNetworkRegressor(; alpha=1, lambda=0.3, optimiser) + chain = MLJFlux.regularized_optimiser(model, 1) + @test chain isa Optimisers.OptimiserChain{ + Tuple{Optimisers.SignDecay, Optimisers.Momentum} + } + + # general case: + model = MLJFlux.NeuralNetworkRegressor(; alpha=0.4, lambda=0.3, optimiser) + chain = MLJFlux.regularized_optimiser(model, 1) + @test chain isa Optimisers.OptimiserChain{ + Tuple{Optimisers.SignDecay, Optimisers.WeightDecay, Optimisers.Momentum} + } +end + +@testset "regularization: integration" begin rng = StableRNG(123) nobservations = 12 Xuser = rand(Float32, nobservations, 3)