Skip to content

Commit

Permalink
add all-keyword constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 12, 2023
1 parent 1cd1e87 commit 2f333c9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
12 changes: 9 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ like this:
struct Rule
eta::Float64
beta::Tuple{Float64, Float64}
Rule(eta = 0.1, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
Rule(eta, beta = (0.7, 0.8)) = eta < 0 ? error() : new(eta, beta)
Rule(; eta = 0.1, beta = (0.7, 0.8)) = Rule(eta, beta)
end
```
Any field called `eta` is assumed to be a learning rate, and cannot be negative.
Expand All @@ -259,12 +260,17 @@ macro def(expr)
lines[i] = :($name::$typeof($val))
end
rule = Meta.isexpr(expr.args[2], :<:) ? expr.args[2].args[1] : expr.args[2]
params = [Expr(:kw, nv...) for nv in zip(names,vals)]
check = :eta in names ? :(eta < 0 && throw(DomainError(eta, "the learning rate cannot be negative"))) : nothing
inner = :(function $rule($([Expr(:kw, nv...) for nv in zip(names,vals)]...))
# Positional-argument method, has defaults for all but the first arg:
inner = :(function $rule($(names[1]), $(params[2:end]...))
$check
new($(names...))
end)
push!(lines, inner)
# Keyword-argument method. (Made an inner constructor only to allow
# resulting structs to be @doc... cannot if macro returns a block.)
kwmethod = :($rule(; $(params...)) = $rule($(names...)))
push!(lines, inner, kwmethod)
esc(expr)
end

10 changes: 7 additions & 3 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ For each parameter `p` and its gradient `dp`, this runs `p -= η*dp`.
struct Descent{T} <: AbstractRule
eta::T
end
Descent() = Descent(1f-1)
Descent(; eta = 1f-1) = Descent(eta)

init(o::Descent, x::AbstractArray) = nothing

Expand Down Expand Up @@ -115,10 +115,11 @@ struct RMSProp <: AbstractRule
centred::Bool
end

function RMSProp = 0.001, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
function RMSProp(η, ρ = 0.9, ϵ = 1e-8; centred::Bool = false, centered::Bool = false)
η < 0 && throw(DomainError(η, "the learning rate cannot be negative"))
RMSProp(η, ρ, ϵ, centred | centered)
end
RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, epsilon; kw...)

init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false)

Expand Down Expand Up @@ -501,9 +502,12 @@ weight decay regularization.
- Machine epsilon (`ϵ`): Constant to prevent division by zero
(no need to change default)
"""
AdamW = 0.001, β = (0.9, 0.999), γ = 0, ϵ = 1e-8) =
AdamW(η, β = (0.9, 0.999), γ = 0, ϵ = 1e-8) =
OptimiserChain(Adam(η, β, ϵ), WeightDecay(γ))

AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) =
OptimiserChain(Adam(eta, beta, epsilon), WeightDecay(lambda))

"""
AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16)
Expand Down

0 comments on commit 2f333c9

Please sign in to comment.