Skip to content

Commit

Permalink
address review: add logic to regularized_optimiser()
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed May 29, 2024
1 parent b3b41ac commit 4af84e5
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
18 changes: 15 additions & 3 deletions src/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,24 @@ function regularized_optimiser(model, nbatches)
λ_L2 = (1 - model.alpha)*model.lambda
λ_sign = λ_L1/nbatches
λ_weight = 2*λ_L2/nbatches
# components in an optimiser chain are executed from left to right:
return Optimisers.OptimiserChain(

# recall components in an optimiser chain are executed from left to right:
if model.alpha == 0
return Optimisers.OptimiserChain(
Optimisers.WeightDecay(λ_weight),
model.optimiser,
)
elseif model.alpha == 1
return Optimisers.OptimiserChain(
Optimisers.SignDecay(λ_sign),
model.optimiser,
)
else return Optimisers.OptimiserChain(
Optimisers.SignDecay(λ_sign),
Optimisers.WeightDecay(λ_weight),
model.optimiser,
)
)
end
end

function MLJModelInterface.fit(model::MLJFluxModel,
Expand Down
32 changes: 31 additions & 1 deletion test/mlj_model_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,37 @@ end
end
end

@testset "regularization" begin
@testset "regularization: logic" begin
optimiser = Optimisers.Momentum()

# lambda = 0:
model = MLJFlux.NeuralNetworkRegressor(; alpha=0.3, lambda=0, optimiser)
chain = MLJFlux.regularized_optimiser(model, 1)
@test chain == optimiser

# alpha = 0:
model = MLJFlux.NeuralNetworkRegressor(; alpha=0, lambda=0.3, optimiser)
chain = MLJFlux.regularized_optimiser(model, 1)
@test chain isa Optimisers.OptimiserChain{
Tuple{Optimisers.WeightDecay, Optimisers.Momentum}
}

# alpha = 1:
model = MLJFlux.NeuralNetworkRegressor(; alpha=1, lambda=0.3, optimiser)
chain = MLJFlux.regularized_optimiser(model, 1)
@test chain isa Optimisers.OptimiserChain{
Tuple{Optimisers.SignDecay, Optimisers.Momentum}
}

# general case:
model = MLJFlux.NeuralNetworkRegressor(; alpha=0.4, lambda=0.3, optimiser)
chain = MLJFlux.regularized_optimiser(model, 1)
@test chain isa Optimisers.OptimiserChain{
Tuple{Optimisers.SignDecay, Optimisers.WeightDecay, Optimisers.Momentum}
}
end

@testset "regularization: integration" begin
rng = StableRNG(123)
nobservations = 12
Xuser = rand(Float32, nobservations, 3)
Expand Down

0 comments on commit 4af84e5

Please sign in to comment.