Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add all-keyword constructors, much like @kwdef #160

Merged
merged 5 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ like this:
struct Rule
eta::Float64
beta::Tuple{Float64, Float64}
Rule(eta = 0.1, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
Rule(eta, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
Rule(; eta = 0.1, beta = (0.7, 0.8)) = Rule(eta, beta)
end
```
Any field called `eta` is assumed to be a learning rate, and cannot be negative.
Expand All @@ -259,12 +260,17 @@ macro def(expr)
lines[i] = :($name::$typeof($val))
end
rule = Meta.isexpr(expr.args[2], :<:) ? expr.args[2].args[1] : expr.args[2]
params = [Expr(:kw, nv...) for nv in zip(names,vals)]
check = :eta in names ? :(eta < 0 && throw(DomainError(eta, "the learning rate cannot be negative"))) : nothing
inner = :(function $rule($([Expr(:kw, nv...) for nv in zip(names,vals)]...))
# Positional-argument method, has defaults for all but the first arg:
inner = :(function $rule($(names[1]), $(params[2:end]...))
$check
new($(names...))
end)
push!(lines, inner)
# Keyword-argument method. (Made an inner constructor only to allow
# resulting structs to be @doc... cannot if macro returns a block.)
kwmethod = :($rule(; $(params...)) = $rule($(names...)))
push!(lines, inner, kwmethod)
esc(expr)
end

35 changes: 22 additions & 13 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@

"""
Descent(η = 1f-1)
Descent(; eta)

Classic gradient descent optimiser with learning rate `η`.
For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.

# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
the weights.
"""
struct Descent{T} <: AbstractRule
eta::T
end
Descent() = Descent(1f-1)
Descent(; eta = 1f-1) = Descent(eta)

init(o::Descent, x::AbstractArray) = nothing

Expand All @@ -37,13 +38,14 @@ end

"""
Momentum(η = 0.01, ρ = 0.9)
Momentum(; [eta, rho])

Gradient descent optimizer with learning rate `η` and momentum `ρ`.

# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
the weights.
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the
prominent direction, in effect dampening oscillations.
"""
@def struct Momentum <: AbstractRule
Expand Down Expand Up @@ -89,6 +91,7 @@ end

"""
RMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred = false)
RMSProp(; [eta, rho, epsilon, centred])

Optimizer using the
[RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
Expand All @@ -99,11 +102,11 @@ generally don't need tuning.
gradients by an estimate their variance, instead of their second moment.

# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
the weights.
- Momentum (`ρ`): Controls the acceleration of gradient descent in the
- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the
prominent direction, in effect dampening oscillations.
- Machine epsilon (`ϵ`): Constant to prevent division by zero
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
(no need to change default)
- Keyword `centred` (or `centered`): Indicates whether to use centred variant
of the algorithm.
Copy link
Member Author

@mcabbott mcabbott Oct 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should these be documented?7706ffd is one suggestion.

We could also remove all greek glyphs from the docstrings, in favour of ascii field names. Some formulae would get a bit longer, or perhaps those few could show both.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe keep the ascii version in the part of the docstring that lists the actual parameter name, then use a latex-ified greek characters in the description of what the parameter does if applicable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow.

Goal of the current style is partly to be usable even if you don't know the names of greek letters. This was true before, as long as you can identify where the same scribble shows up.

Copy link
Member

@ToucheSir ToucheSir Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of how other libraries use embedded latex in their extended docstrings to describe what optimization rules are doing. e.g. https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html#torch.optim.AdamW. Not something required for every field of every rule, but we don't have to purge greek characters from the docstrings completely.

Expand All @@ -115,10 +118,11 @@ struct RMSProp <: AbstractRule
centred::Bool
end

function RMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
function RMSProp(η, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
η < 0 && throw(DomainError(η, "the learning rate cannot be negative"))
RMSProp(η, ρ, ϵ, centred | centered)
end
RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, epsilon; kw...)

init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false)

Expand Down Expand Up @@ -488,22 +492,27 @@ end

"""
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8)
AdamW(; [eta, beta, lambda, epsilon])

[AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its
weight decay regularization.
Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`](@ref)`.

# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
- Learning rate (`η == eta`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
- Weight decay (`λ`): Controls the strength of ``L_2`` regularisation.
- Machine epsilon (`ϵ`): Constant to prevent division by zero
- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation.
- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero
(no need to change default)
"""
AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) =
AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) =
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ))

AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) =
OptimiserChain(Adam(eta, beta, epsilon), WeightDecay(lambda))

"""
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)

Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ end
@test_throws ArgumentError Optimisers.thaw!(m)
end

@testset "keyword arguments" begin
@test Nesterov(rho=0.8, eta=0.1) === Nesterov(0.1, 0.8)
@test AdamW(lambda=0.3).opts[1] == Adam()
@test AdamW(lambda=0.3).opts[2] == WeightDecay(0.3)
end

@testset "forgotten gradient" begin
x = [1.0, 2.0]
sx = Optimisers.setup(Descent(), x)
Expand Down
Loading