Skip to content

Commit

Permalink
Add all-keyword constructors, much like @kwdef (#160)
Browse files Browse the repository at this point in the history
* add all-keyword constructors

* update a few docstrings

* docstrings

* add tests

* one lost γ should be λ
  • Loading branch information
mcabbott authored Feb 8, 2024
1 parent e60b71e commit 1908a1c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
12 changes: 9 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ like this:
struct Rule
eta::Float64
beta::Tuple{Float64, Float64}
Rule(eta = 0.1, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
Rule(eta, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
Rule(; eta = 0.1, beta = (0.7, 0.8)) = Rule(eta, beta)
end
```
Any field called `eta` is assumed to be a learning rate, and cannot be negative.
Expand All @@ -259,12 +260,17 @@ macro def(expr)
lines[i] = :($name::$typeof($val))
end
rule = Meta.isexpr(expr.args[2], :<:) ? expr.args[2].args[1] : expr.args[2]
params = [Expr(:kw, nv...) for nv in zip(names,vals)]
check = :eta in names ? :(eta < 0 && throw(DomainError(eta, "the learning rate cannot be negative"))) : nothing
inner = :(function $rule($([Expr(:kw, nv...) for nv in zip(names,vals)]...))
# Positional-argument method, has defaults for all but the first arg:
inner = :(function $rule($(names[1]), $(params[2:end]...))
$check
new($(names...))
end)
push!(lines, inner)
# Keyword-argument method. (Made an inner constructor only to allow
# resulting structs to be @doc... cannot if macro returns a block.)
kwmethod = :($rule(; $(params...)) = $rule($(names...)))
push!(lines, inner, kwmethod)
esc(expr)
end

35 changes: 22 additions & 13 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@

"""
Descent(η = 1f-1)
Descent(; eta)
Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
the weights.
"""
struct Descent{T} <: AbstractRule
eta::T
end
Descent() = Descent(1f-1)
Descent(; eta = 1f-1) = Descent(eta)

init(o::Descent, x::AbstractArray) = nothing

Expand All @@ -37,13 +38,14 @@ end

"""
Momentum(η = 0.01, ρ = 0.9)
Momentum(; [eta, rho])
Gradient descent optimizer with learning rate `η` and momentum `ρ`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
the weights.
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the
prominent direction, in effect dampening oscillations.
"""
@def struct Momentum <: AbstractRule
Expand Down Expand Up @@ -89,6 +91,7 @@ end

"""
RMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred = false)
RMSProp(; [eta, rho, epsilon, centred])
Optimizer using the
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Expand All @@ -99,11 +102,11 @@ generally don't need tuning.
gradients by an estimate their variance, instead of their second moment.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
the weights.
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the
prominent direction, in effect dampening oscillations.
- Machine epsilon (`ϵ`): Constant to prevent division by zero
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
(no need to change default)
- Keyword `centred` (or `centered`): Indicates whether to use centred variant
of the algorithm.
Expand All @@ -115,10 +118,11 @@ struct RMSProp <: AbstractRule
centred::Bool
end

function RMSProp = 0.001, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
function RMSProp(η, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
η < 0 && throw(DomainError(η, "the learning rate cannot be negative"))
RMSProp(η, ρ, ϵ, centred | centered)
end
RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, epsilon; kw...)

init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false)

Expand Down Expand Up @@ -488,22 +492,27 @@ end

"""
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
AdamW(; [eta, beta, lambda, epsilon])
[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
weight decay regularization.
Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`](@ref)`.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
- Weight decay (`λ`): Controls the strength of ``L_2`` regularisation.
- Machine epsilon (`ϵ`): Constant to prevent division by zero
- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation.
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
(no need to change default)
"""
AdamW = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) =
AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) =
OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ))

AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) =
OptimiserChain(Adam(eta, beta, epsilon), WeightDecay(lambda))

"""
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ end
@test_throws ArgumentError Optimisers.thaw!(m)
end

@testset "keyword arguments" begin
@test Nesterov(rho=0.8, eta=0.1) === Nesterov(0.1, 0.8)
@test AdamW(lambda=0.3).opts[1] == Adam()
@test AdamW(lambda=0.3).opts[2] == WeightDecay(0.3)
end

@testset "forgotten gradient" begin
x = [1.0, 2.0]
sx = Optimisers.setup(Descent(), x)
Expand Down

2 comments on commit 1908a1c

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/100446

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.2 -m "<description of version>" 1908a1cd599f656b15304a9722328bf9b2eed360
git push origin v0.3.2

Please sign in to comment.