Skip to content

Commit

Permalink
WeightDecay for L1 norm (#159)
Browse files Browse the repository at this point in the history
* WeightDecay for L1 norm

* better words

* change to lambda alpha, add tests

* change to lambda, add tests

* tweaks

* shashed in October - makes two structs instead

* version with simple SignDecay instead

* change SignDecay penalty to be called kappa

* restore depwarn for WeightDecay, was called gamma

* change kappa back to lambda
  • Loading branch information
mcabbott authored Feb 7, 2024
1 parent 6473c45 commit e60b71e
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <[email protected]>"]
version = "0.3.1"
version = "0.3.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export destructure
include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
AccumGrad

###
Expand Down
72 changes: 59 additions & 13 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
end

"""
AdamW(η = 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8)
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
weight decay regularization.
Expand All @@ -497,12 +497,12 @@ weight decay regularization.
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
- Weight decay (`γ`): Decay applied to weights during optimisation.
- Weight decay (`λ`): Controls the strength of ``L_2`` regularisation.
- Machine epsilon (`ϵ`): Constant to prevent division by zero
(no need to change default)
"""
AdamW= 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8) =
OptimiserChain(Adam(η, β, ϵ), WeightDecay(γ))
AdamW= 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) =
OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ))

"""
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
Expand Down Expand Up @@ -538,35 +538,79 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
end

"""
WeightDecay(γ = 5e-4)
WeightDecay(λ = 5e-4)
Decay weights by ``γ``, that is, add `γ .* x` to the gradient `x̄` which will be
subtracted from `x`.
Implements ``L_2`` regularisation, also known as ridge regression,
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
Typically composed with other optimisers as the first transformation in an [`OptimiserChain`](@ref).
This is equivalent to adding ``L_2`` regularization with coefficient ``γ`` to the loss.
It does this by adding `λ .* x` to the gradient. This is equivalent to adding
`λ/2 * sum(abs2, x) == λ/2 * norm(x)^2` to the loss.
See also [`SignDecay`] for ``L_1`` normalisation.
# Parameters
- Weight decay (`γ`): Decay applied to weights during optimisation.
- Penalty (`λ ≥ 0`): Controls the strength of the regularisation.
"""
@def struct WeightDecay <: AbstractRule
gamma = 5e-4
lambda = 5e-4
end

init(o::WeightDecay, x::AbstractArray) = nothing

function apply!(o::WeightDecay, state, x::AbstractArray{T}, dx) where T
γ = T(o.gamma)
dx′ = @lazy dx + γ * x
λ = T(o.lambda)
dx′ = @lazy dx + λ * x

return state, dx′
end

function adjust(r::WeightDecay; gamma = nothing, kw...)
if isnothing(gamma)
return _adjust(r, NamedTuple(kw))
else
Base.depwarn("The strength of WeightDecay is now field :lambda, not :gamma", :adjust, force=true)
nt = (; lambda = gamma, NamedTuple(kw)...)
return _adjust(r, nt)
end
end

"""
SignDecay(λ = 1e-3)
Implements ``L_1`` regularisation, also known as LASSO regression,
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding
`λ * sum(abs, x) == λ * norm(x, 1)` to the loss.
See also [`WeightDecay`] for ``L_2`` normalisation.
They can be used together: `OptimiserChain(SignDecay(0.012), WeightDecay(0.034), Adam())`
is equivalent to adding `0.012 * norm(x, 1) + 0.017 * norm(x, 2)^2` to the loss function.
# Parameters
- Penalty (`λ ≥ 0`): Controls the strength of the regularisation.
"""
@def struct SignDecay <: AbstractRule
lambda = 1e-3
end

init(o::SignDecay, x::AbstractArray) = nothing

function apply!(o::SignDecay, state, x::AbstractArray{T}, dx) where T
λ = T(o.lambda)
dx′ = @lazy dx + λ * sign(x)

return state, dx′
end


"""
ClipGrad(δ = 10)
Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`.
Typically composed with other rules using [`OptimiserChain`](@ref).
See also [`ClipNorm`](@ref).
"""
@def struct ClipGrad <: AbstractRule
Expand All @@ -591,6 +635,8 @@ to stay at this threshold (unless `p==0`).
Throws an error if the norm is infinite or `NaN`,
which you can turn off with `throw = false`.
Typically composed with other rules using [`OptimiserChain`](@ref).
See also [`ClipGrad`](@ref).
"""
struct ClipNorm <: AbstractRule
Expand Down
2 changes: 1 addition & 1 deletion test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ RULES = [
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
# A few chained combinations:
OptimiserChain(WeightDecay(), Adam(0.001)),
OptimiserChain(SignDecay(0.001), Adam(0.001)),
OptimiserChain(ClipNorm(), Adam(0.001)),
OptimiserChain(ClipGrad(0.5), Momentum()),
OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)),
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ end
@testset "OptimiserChain" begin
x = [1, 10, 100.0]; dx = [1, 2, 3.0];
@test Optimisers.update(Optimisers.setup(WeightDecay(0.1), x), x, dx)[2] [1-0.1-1, 10-1-2, 100-10-3]
@test Optimisers.update(Optimisers.setup(SignDecay(0.1), x), x, dx)[2] [1-0.1-1, 10-0.1-2, 100-0.1-3]
@test Optimisers.update(Optimisers.setup(ClipGrad(2), x), x, dx)[2] [1-1, 10-2, 100-2]

o2 = OptimiserChain(ClipGrad(2), WeightDecay(0.1))
Expand All @@ -154,6 +155,10 @@ end

o0 = OptimiserChain()
@test Optimisers.update(Optimisers.setup(o0, x), x, dx)[2] [1-1,10-2,100-3]

# L1 norm via sign
xm = [1, -10, 100.0]; dxm = [3, 2, -1];
@test Optimisers.update(Optimisers.setup(SignDecay(0.1), xm), xm, dxm)[2] [1-0.1-3, -10+0.1-2, 100-0.1+1]
end

@testset "trainable subset" begin
Expand Down

0 comments on commit e60b71e

Please sign in to comment.