diff --git a/Project.toml b/Project.toml index 528d5a25ff..fa0eafb613 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ MacroTools = "0.5" Metal = "0.5" NNlib = "0.9.1" OneHotArrays = "0.2.4" -Optimisers = "0.2.12, 0.3.0" +Optimisers = "0.3.2" Preferences = "1" ProgressLogging = "0.1" Reexport = "1.0" diff --git a/src/deprecations.jl b/src/deprecations.jl index 3b65f00da0..ee432f2425 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -126,9 +126,9 @@ for T in [:Descent, :Adam, :Momentum, :Nesterov, end _old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...) const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too. -_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now _old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields _old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs +_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd, rule.alpha) # wd is called lambda now const ClipGrad = Optimise.ClipValue _old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 9da9b1472f..763852c1e3 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -681,6 +681,12 @@ Typically composed with other optimisers as the first transformation to the gra making it equivalent to adding ``L_2`` regularization with coefficient ``λ`` to the loss. + WeightDecay(λ, α) + +Second argument turns on ``L_1`` regularization, adding `λ * (α * sign(x) + (1 - α) * x)` +to the gradient. The case `α = 1` is equivalent to adding `sum(abs, x) == norm(x, 1)` to the +loss function, while `0 < α < 1` mixes L1 and L2. + # Examples ```julia @@ -689,13 +695,16 @@ opt = Optimiser(WeightDecay(1f-4), Adam()) """ mutable struct WeightDecay <: AbstractOptimiser wd::Real + alpha::Real end -WeightDecay() = WeightDecay(0) +WeightDecay(λ = 0f0) = WeightDecay(λ, 0f0) function apply!(o::WeightDecay, x, Δ) - wd = o.wd - @. Δ += wd * x + wd, α = o.wd, o.alpha + l2 = (1-α) * wd + l1 = α * wd + @. Δ += l2 * x + l1 * sign(x) end """