Skip to content

Commit

Permalink
add L1 regularisation to WeightDecay
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 8, 2023
1 parent 95737ff commit 4c98c6b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ MacroTools = "0.5"
Metal = "0.5"
NNlib = "0.9.1"
OneHotArrays = "0.2.4"
Optimisers = "0.2.12, 0.3.0"
Optimisers = "0.3.2"
Preferences = "1"
ProgressLogging = "0.1"
Reexport = "1.0"
Expand Down
2 changes: 1 addition & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ for T in [:Descent, :Adam, :Momentum, :Nesterov,
end
_old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now
_old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields
_old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs
_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd, rule.alpha) # wd is called lambda now
const ClipGrad = Optimise.ClipValue
_old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred

Expand Down
15 changes: 12 additions & 3 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,12 @@ Typically composed with other optimisers as the first transformation to the gra
making it equivalent to adding ``L_2`` regularization
with coefficient ``λ`` to the loss.
WeightDecay(λ, α)
Second argument turns on ``L_1`` regularization, adding `λ * (α * sign(x) + (1 - α) * x)`
to the gradient. The case `α = 1` is equivalent to adding `sum(abs, x) == norm(x, 1)` to the
loss function, while `0 < α < 1` mixes L1 and L2.
# Examples
```julia
Expand All @@ -689,13 +695,16 @@ opt = Optimiser(WeightDecay(1f-4), Adam())
"""
mutable struct WeightDecay <: AbstractOptimiser
wd::Real
alpha::Real
end

WeightDecay() = WeightDecay(0)
WeightDecay(λ = 0f0) = WeightDecay(λ, 0f0)

function apply!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
wd, α = o.wd, o.alpha
l2 = (1-α) * wd
l1 = α * wd
@. Δ += l2 * x + l1 * sign(x)
end

"""
Expand Down

0 comments on commit 4c98c6b

Please sign in to comment.