Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WeightDecay for L1 norm #159

Merged
merged 10 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
authors = ["Mike J Innes <[email protected]>"]
version = "0.3.1"
version = "0.3.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export destructure
include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
AccumGrad

###
Expand Down
72 changes: 59 additions & 13 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
end

"""
AdamW(η = 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8)
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)

[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
weight decay regularization.
Expand All @@ -497,12 +497,12 @@ weight decay regularization.
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
- Weight decay (`γ`): Decay applied to weights during optimisation.
- Weight decay (`λ`): Controls the strength of ``L_2`` regularisation.
- Machine epsilon (`ϵ`): Constant to prevent division by zero
(no need to change default)
"""
AdamW(η = 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8) =
OptimiserChain(Adam(η, β, ϵ), WeightDecay(γ))
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) =
OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ))

"""
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
Expand Down Expand Up @@ -538,35 +538,79 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
end

"""
WeightDecay(γ = 5e-4)
WeightDecay(λ = 5e-4)

Decay weights by ``γ``, that is, add `γ .* x` to the gradient `x̄` which will be
subtracted from `x`.
Implements ``L_2`` regularisation, also known as ridge regression,
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).

Typically composed with other optimisers as the first transformation in an [`OptimiserChain`](@ref).
This is equivalent to adding ``L_2`` regularization with coefficient ``γ`` to the loss.
It does this by adding `λ .* x` to the gradient. This is equivalent to adding
`λ/2 * sum(abs2, x) == λ/2 * norm(x)^2` to the loss.

See also [`SignDecay`] for ``L_1`` normalisation.

# Parameters
- Weight decay (`γ`): Decay applied to weights during optimisation.
- Penalty (`λ ≥ 0`): Controls the strength of the regularisation.
"""
@def struct WeightDecay <: AbstractRule
gamma = 5e-4
lambda = 5e-4
end

init(o::WeightDecay, x::AbstractArray) = nothing

function apply!(o::WeightDecay, state, x::AbstractArray{T}, dx) where T
γ = T(o.gamma)
dx′ = @lazy dx + γ * x
λ = T(o.lambda)
dx′ = @lazy dx + λ * x

return state, dx′
end

function adjust(r::WeightDecay; gamma = nothing, kw...)
if isnothing(gamma)
return _adjust(r, NamedTuple(kw))
else
Base.depwarn("The strength of WeightDecay is now field :lambda, not :gamma", :adjust, force=true)
nt = (; lambda = gamma, NamedTuple(kw)...)
return _adjust(r, nt)
end
end

"""
SignDecay(λ = 1e-3)

Implements ``L_1`` regularisation, also known as LASSO regression,
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).

It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding
`λ * sum(abs, x) == λ * norm(x, 1)` to the loss.

See also [`WeightDecay`] for ``L_2`` normalisation.
They can be used together: `OptimiserChain(SignDecay(0.012), WeightDecay(0.034), Adam())`
is equivalent to adding `0.012 * norm(x, 1) + 0.017 * norm(x, 2)^2` to the loss function.

# Parameters
- Penalty (`λ ≥ 0`): Controls the strength of the regularisation.
"""
@def struct SignDecay <: AbstractRule
lambda = 1e-3
end

init(o::SignDecay, x::AbstractArray) = nothing

function apply!(o::SignDecay, state, x::AbstractArray{T}, dx) where T
λ = T(o.lambda)
dx′ = @lazy dx + λ * sign(x)

return state, dx′
end


"""
ClipGrad(δ = 10)

Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`.

Typically composed with other rules using [`OptimiserChain`](@ref).

See also [`ClipNorm`](@ref).
"""
@def struct ClipGrad <: AbstractRule
Expand All @@ -591,6 +635,8 @@ to stay at this threshold (unless `p==0`).
Throws an error if the norm is infinite or `NaN`,
which you can turn off with `throw = false`.

Typically composed with other rules using [`OptimiserChain`](@ref).

See also [`ClipGrad`](@ref).
"""
struct ClipNorm <: AbstractRule
Expand Down
2 changes: 1 addition & 1 deletion test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ RULES = [
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
# A few chained combinations:
OptimiserChain(WeightDecay(), Adam(0.001)),
OptimiserChain(SignDecay(0.001), Adam(0.001)),
OptimiserChain(ClipNorm(), Adam(0.001)),
OptimiserChain(ClipGrad(0.5), Momentum()),
OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)),
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ end
@testset "OptimiserChain" begin
x = [1, 10, 100.0]; dx = [1, 2, 3.0];
@test Optimisers.update(Optimisers.setup(WeightDecay(0.1), x), x, dx)[2] ≈ [1-0.1-1, 10-1-2, 100-10-3]
@test Optimisers.update(Optimisers.setup(SignDecay(0.1), x), x, dx)[2] ≈ [1-0.1-1, 10-0.1-2, 100-0.1-3]
@test Optimisers.update(Optimisers.setup(ClipGrad(2), x), x, dx)[2] ≈ [1-1, 10-2, 100-2]

o2 = OptimiserChain(ClipGrad(2), WeightDecay(0.1))
Expand All @@ -154,6 +155,10 @@ end

o0 = OptimiserChain()
@test Optimisers.update(Optimisers.setup(o0, x), x, dx)[2] ≈ [1-1,10-2,100-3]

# L1 norm via sign
xm = [1, -10, 100.0]; dxm = [3, 2, -1];
@test Optimisers.update(Optimisers.setup(SignDecay(0.1), xm), xm, dxm)[2] ≈ [1-0.1-3, -10+0.1-2, 100-0.1+1]
end

@testset "trainable subset" begin
Expand Down
Loading