From 1908a1cd599f656b15304a9722328bf9b2eed360 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 7 Feb 2024 23:08:43 -0500 Subject: [PATCH] Add all-keyword constructors, much like `@kwdef` (#160) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add all-keyword constructors * update a few docstrings * docstrings * add tests * one lost γ should be λ --- src/interface.jl | 12 +++++++++--- src/rules.jl | 35 ++++++++++++++++++++++------------- test/runtests.jl | 6 ++++++ 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 7583958..29c1db6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -241,7 +241,8 @@ like this: struct Rule eta::Float64 beta::Tuple{Float64, Float64} - Rule(eta = 0.1, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta) + Rule(eta, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta) + Rule(; eta = 0.1, beta = (0.7, 0.8)) = Rule(eta, beta) end ``` Any field called `eta` is assumed to be a learning rate, and cannot be negative. @@ -259,12 +260,17 @@ macro def(expr) lines[i] = :($name::$typeof($val)) end rule = Meta.isexpr(expr.args[2], :<:) ? expr.args[2].args[1] : expr.args[2] + params = [Expr(:kw, nv...) for nv in zip(names,vals)] check = :eta in names ? :(eta < 0 && throw(DomainError(eta, "the learning rate cannot be negative"))) : nothing - inner = :(function $rule($([Expr(:kw, nv...) for nv in zip(names,vals)]...)) + # Positional-argument method, has defaults for all but the first arg: + inner = :(function $rule($(names[1]), $(params[2:end]...)) $check new($(names...)) end) - push!(lines, inner) + # Keyword-argument method. (Made an inner constructor only to allow + # resulting structs to be @doc... cannot if macro returns a block.) + kwmethod = :($rule(; $(params...)) = $rule($(names...))) + push!(lines, inner, kwmethod) esc(expr) end diff --git a/src/rules.jl b/src/rules.jl index 9fa5011..f3df9d6 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -8,18 +8,19 @@ """ Descent(η = 1f-1) + Descent(; eta) Classic gradient descent optimiser with learning rate `η`. For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`. # Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating +- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights. """ struct Descent{T} <: AbstractRule eta::T end -Descent() = Descent(1f-1) +Descent(; eta = 1f-1) = Descent(eta) init(o::Descent, x::AbstractArray) = nothing @@ -37,13 +38,14 @@ end """ Momentum(η = 0.01, ρ = 0.9) + Momentum(; [eta, rho]) Gradient descent optimizer with learning rate `η` and momentum `ρ`. # Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating +- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the +- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the prominent direction, in effect dampening oscillations. """ @def struct Momentum <: AbstractRule @@ -89,6 +91,7 @@ end """ RMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred = false) + RMSProp(; [eta, rho, epsilon, centred]) Optimizer using the [RMSProp](https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) @@ -99,11 +102,11 @@ generally don't need tuning. gradients by an estimate their variance, instead of their second moment. # Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating +- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights. -- Momentum (`ρ`): Controls the acceleration of gradient descent in the +- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the prominent direction, in effect dampening oscillations. -- Machine epsilon (`ϵ`): Constant to prevent division by zero +- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero (no need to change default) - Keyword `centred` (or `centered`): Indicates whether to use centred variant of the algorithm. @@ -115,10 +118,11 @@ struct RMSProp <: AbstractRule centred::Bool end -function RMSProp(η = 0.001, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false) +function RMSProp(η, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false) η < 0 && throw(DomainError(η, "the learning rate cannot be negative")) RMSProp(η, ρ, ϵ, centred | centered) end +RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, epsilon; kw...) init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false) @@ -488,22 +492,27 @@ end """ AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) + AdamW(; [eta, beta, lambda, epsilon]) [AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its weight decay regularization. +Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`](@ref)`. # Parameters -- Learning rate (`η`): Amount by which gradients are discounted before updating +- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights. -- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the +- Decay of momentums (`β::Tuple == beta`): Exponential decay for the first (β1) and the second (β2) momentum estimate. -- Weight decay (`λ`): Controls the strength of ``L_2`` regularisation. -- Machine epsilon (`ϵ`): Constant to prevent division by zero +- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation. +- Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero (no need to change default) """ -AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) = +AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) = OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ)) +AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) = + OptimiserChain(Adam(eta, beta, epsilon), WeightDecay(lambda)) + """ AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16) diff --git a/test/runtests.jl b/test/runtests.jl index 82d43bd..e4a1401 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -330,6 +330,12 @@ end @test_throws ArgumentError Optimisers.thaw!(m) end + @testset "keyword arguments" begin + @test Nesterov(rho=0.8, eta=0.1) === Nesterov(0.1, 0.8) + @test AdamW(lambda=0.3).opts[1] == Adam() + @test AdamW(lambda=0.3).opts[2] == WeightDecay(0.3) + end + @testset "forgotten gradient" begin x = [1.0, 2.0] sx = Optimisers.setup(Descent(), x)