diff --git a/Project.toml b/Project.toml index 02e13b7..7422f28 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" authors = ["Mike J Innes "] -version = "0.3.1" +version = "0.3.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 1451bc8..9254bb7 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -14,7 +14,7 @@ export destructure include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, - WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, + WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, AccumGrad ### diff --git a/src/rules.jl b/src/rules.jl index ecc5860..9fa5011 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -487,7 +487,7 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T end """ - AdamW(η = 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8) + AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) [AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its weight decay regularization. @@ -497,12 +497,12 @@ weight decay regularization. the weights. - Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the second (β2) momentum estimate. -- Weight decay (`γ`): Decay applied to weights during optimisation. +- Weight decay (`λ`): Controls the strength of ``L_2`` regularisation. - Machine epsilon (`ϵ`): Constant to prevent division by zero (no need to change default) """ -AdamW(η = 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8) = - OptimiserChain(Adam(η, β, ϵ), WeightDecay(γ)) +AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) = + OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ)) """ AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16) @@ -538,35 +538,79 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T end """ - WeightDecay(γ = 5e-4) + WeightDecay(λ = 5e-4) -Decay weights by ``γ``, that is, add `γ .* x` to the gradient `x̄` which will be -subtracted from `x`. +Implements ``L_2`` regularisation, also known as ridge regression, +when composed with other rules as the first transformation in an [`OptimiserChain`](@ref). -Typically composed with other optimisers as the first transformation in an [`OptimiserChain`](@ref). -This is equivalent to adding ``L_2`` regularization with coefficient ``γ`` to the loss. +It does this by adding `λ .* x` to the gradient. This is equivalent to adding +`λ/2 * sum(abs2, x) == λ/2 * norm(x)^2` to the loss. + +See also [`SignDecay`] for ``L_1`` normalisation. # Parameters -- Weight decay (`γ`): Decay applied to weights during optimisation. +- Penalty (`λ ≥ 0`): Controls the strength of the regularisation. """ @def struct WeightDecay <: AbstractRule - gamma = 5e-4 + lambda = 5e-4 end init(o::WeightDecay, x::AbstractArray) = nothing function apply!(o::WeightDecay, state, x::AbstractArray{T}, dx) where T - γ = T(o.gamma) - dx′ = @lazy dx + γ * x + λ = T(o.lambda) + dx′ = @lazy dx + λ * x return state, dx′ end +function adjust(r::WeightDecay; gamma = nothing, kw...) + if isnothing(gamma) + return _adjust(r, NamedTuple(kw)) + else + Base.depwarn("The strength of WeightDecay is now field :lambda, not :gamma", :adjust, force=true) + nt = (; lambda = gamma, NamedTuple(kw)...) + return _adjust(r, nt) + end + end + +""" + SignDecay(λ = 1e-3) + +Implements ``L_1`` regularisation, also known as LASSO regression, +when composed with other rules as the first transformation in an [`OptimiserChain`](@ref). + +It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding +`λ * sum(abs, x) == λ * norm(x, 1)` to the loss. + +See also [`WeightDecay`] for ``L_2`` normalisation. +They can be used together: `OptimiserChain(SignDecay(0.012), WeightDecay(0.034), Adam())` +is equivalent to adding `0.012 * norm(x, 1) + 0.017 * norm(x, 2)^2` to the loss function. + +# Parameters +- Penalty (`λ ≥ 0`): Controls the strength of the regularisation. +""" +@def struct SignDecay <: AbstractRule + lambda = 1e-3 +end + +init(o::SignDecay, x::AbstractArray) = nothing + +function apply!(o::SignDecay, state, x::AbstractArray{T}, dx) where T + λ = T(o.lambda) + dx′ = @lazy dx + λ * sign(x) + + return state, dx′ +end + + """ ClipGrad(δ = 10) Restricts every gradient component to obey `-δ ≤ dx[i] ≤ δ`. +Typically composed with other rules using [`OptimiserChain`](@ref). + See also [`ClipNorm`](@ref). """ @def struct ClipGrad <: AbstractRule @@ -591,6 +635,8 @@ to stay at this threshold (unless `p==0`). Throws an error if the norm is infinite or `NaN`, which you can turn off with `throw = false`. +Typically composed with other rules using [`OptimiserChain`](@ref). + See also [`ClipGrad`](@ref). """ struct ClipNorm <: AbstractRule diff --git a/test/rules.jl b/test/rules.jl index 0ac2648..52a3580 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -10,7 +10,7 @@ RULES = [ AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), # A few chained combinations: - OptimiserChain(WeightDecay(), Adam(0.001)), + OptimiserChain(SignDecay(0.001), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)), OptimiserChain(ClipGrad(0.5), Momentum()), OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)), diff --git a/test/runtests.jl b/test/runtests.jl index 49b595a..82d43bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -137,6 +137,7 @@ end @testset "OptimiserChain" begin x = [1, 10, 100.0]; dx = [1, 2, 3.0]; @test Optimisers.update(Optimisers.setup(WeightDecay(0.1), x), x, dx)[2] ≈ [1-0.1-1, 10-1-2, 100-10-3] + @test Optimisers.update(Optimisers.setup(SignDecay(0.1), x), x, dx)[2] ≈ [1-0.1-1, 10-0.1-2, 100-0.1-3] @test Optimisers.update(Optimisers.setup(ClipGrad(2), x), x, dx)[2] ≈ [1-1, 10-2, 100-2] o2 = OptimiserChain(ClipGrad(2), WeightDecay(0.1)) @@ -154,6 +155,10 @@ end o0 = OptimiserChain() @test Optimisers.update(Optimisers.setup(o0, x), x, dx)[2] ≈ [1-1,10-2,100-3] + + # L1 norm via sign + xm = [1, -10, 100.0]; dxm = [3, 2, -1]; + @test Optimisers.update(Optimisers.setup(SignDecay(0.1), xm), xm, dxm)[2] ≈ [1-0.1-3, -10+0.1-2, 100-0.1+1] end @testset "trainable subset" begin