Skip to content

Commit

Permalink
change kappa back to lambda
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 6, 2024
1 parent 6a2b326 commit 9df7b35
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -575,30 +575,30 @@ function adjust(r::WeightDecay; gamma = nothing, kw...)
end

"""
SignDecay(κ = 1e-3)
SignDecay(λ = 1e-3)
Implements ``L_1`` regularisation, also known as LASSO regression,
when composed with other rules as the first transformation in an [`OptimiserChain`](@ref).
It does this by adding `κ .* sign(x)` to the gradient. This is equivalent to adding
`κ * sum(abs, x) == κ * norm(x, 1)` to the loss.
It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding
`λ * sum(abs, x) == λ * norm(x, 1)` to the loss.
See also [`WeightDecay`] for ``L_2`` normalisation.
They can be used together: `OptimiserChain(SignDecay(0.012), WeightDecay(0.034), Adam())`
is equivalent to adding `0.012 * norm(x, 1) + 0.017 * norm(x, 2)^2` to the loss function.
# Parameters
- Penalty (`κ ≥ 0`): Controls the strength of the regularisation.
- Penalty (`λ ≥ 0`): Controls the strength of the regularisation.
"""
@def struct SignDecay <: AbstractRule
kappa = 1e-3
lambda = 1e-3
end

init(o::SignDecay, x::AbstractArray) = nothing

function apply!(o::SignDecay, state, x::AbstractArray{T}, dx) where T
κ = T(o.kappa)
dx′ = @lazy dx + κ * sign(x)
λ = T(o.lambda)
dx′ = @lazy dx + λ * sign(x)

return state, dx′
end
Expand Down

0 comments on commit 9df7b35

Please sign in to comment.