From 38c9d622c4a9979190b9c4c000604267aac39239 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Nov 2024 08:14:53 +0100 Subject: [PATCH] Add the option couple to AdamW and set the default to match pytorch (#188) --- Project.toml | 2 +- README.md | 6 ++++++ src/rules.jl | 56 ++++++++++++++++++++++++++++++++++++++++++------ test/rules.jl | 2 +- test/runtests.jl | 3 +-- 5 files changed, 58 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 41c9709..0a19f49 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" authors = ["Mike J Innes "] -version = "0.3.4" +version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/README.md b/README.md index fa318b2..e15155a 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,12 @@ This was written as the new training system for [Flux.jl](https://github.com/Flu and also used by [Lux.jl](https://github.com/avik-pal/Lux.jl). But it can be used separately on any array, or anything else understood by [Functors.jl](https://github.com/FluxML/Functors.jl). + +> [!WARNING] +> With version 0.4 the default update rule for AdamW has changed to match the pytorch implementation. +> The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`. +> See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details. + ## Installation ```julia diff --git a/src/rules.jl b/src/rules.jl index b4fbd2a..0063d70 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -501,8 +501,8 @@ function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T end """ - AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8) - AdamW(; [eta, beta, lambda, epsilon]) + AdamW(η = 0.001, β = (0.9, 0.999), λ = 0, ϵ = 1e-8; couple = true) + AdamW(; [eta, beta, lambda, epsilon, couple]) [AdamW](https://arxiv.org/abs/1711.05101) is a variant of Adam fixing (as in repairing) its weight decay regularization. @@ -516,12 +516,54 @@ Implemented as an [`OptimiserChain`](@ref) of [`Adam`](@ref) and [`WeightDecay`] - Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation. - Machine epsilon (`ϵ == epsilon`): Constant to prevent division by zero (no need to change default) -""" -AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8) = - OptimiserChain(Adam(η, β, ϵ), WeightDecay(λ)) +- Keyword `couple`: If `true`, the weight decay is coupled with the learning rate, as in pytorch's AdamW. + This corresponds to an update of the form `x = x - η * (dx + λ * x)`, where `dx` is the + update from Adam with learning rate 1. + If `false`, the weight decay is decoupled from the learning rate, in the spirit of the original paper. + This corresponds to an update of the form `x = x - η * dx - λ * x`. + Default is `true`. + +!!! warning "Breaking change in v0.4" + With version 0.4 the default update rule for AdamW has changed to match the pytorch implementation. + The previous rule, which is closer to the original paper, can be obtained by setting `AdamW(..., couple=false)`. + See [this issue](https://github.com/FluxML/Flux.jl/issues/2433) for more details. +""" +struct AdamW{T1,T2,T3,T4} <: AbstractRule + eta::T1 + beta::T2 + epsilon::T3 + lambda::T4 + couple::Bool +end + +function AdamW(η, β = (0.9, 0.999), λ = 0.0, ϵ = 1e-8; couple::Bool = true) + η < 0 && throw(DomainError(η, "the learning rate cannot be negative")) + AdamW(η, β, λ, ϵ, couple) +end + +AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda= 0.0, epsilon = 1e-8, kw...) = + AdamW(eta, beta, lambda, epsilon; kw...) + +init(o::AdamW, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta)) + +function apply!(o::AdamW, state, x::AbstractArray{T}, dx) where T + η, β, ϵ, λ = T(o.eta), T.(o.beta), T(o.epsilon), T(o.lambda) + mt, vt, βt = state -AdamW(; eta = 0.001, beta = (0.9, 0.999), lambda = 0, epsilon = 1e-8) = - OptimiserChain(Adam(eta, beta, epsilon), WeightDecay(lambda)) + # standard Adam update with learning rate eta=1 + @.. mt = β[1] * mt + (1 - β[1]) * dx + @.. vt = β[2] * vt + (1 - β[2]) * abs2(dx) + dx′ = @lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) + + # apply learning rate and weight decay + if o.couple + dx′′ = @lazy η * (dx′ + λ * x) + else + dx′′ = @lazy η * dx′ + λ * x + end + + return (mt, vt, βt .* β), dx′′ +end """ AdaBelief(η = 0.001, β = (0.9, 0.999), ϵ = 1e-16) diff --git a/test/rules.jl b/test/rules.jl index 9068fa1..499902c 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -15,7 +15,7 @@ RULES = [ OptimiserChain(ClipGrad(0.5), Momentum()), OptimiserChain(WeightDecay(), OAdam(), ClipGrad(1)), # Not the default: - RMSProp(centred = true), + RMSProp(centred = true), AdamW(couple=false), ] name(o) = typeof(o).name.name # just for printing testset headings diff --git a/test/runtests.jl b/test/runtests.jl index fc0fe57..ae2d9d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -332,8 +332,7 @@ end @testset "keyword arguments" begin @test Nesterov(rho=0.8, eta=0.1) === Nesterov(0.1, 0.8) - @test AdamW(lambda=0.3).opts[1] == Adam() - @test AdamW(lambda=0.3).opts[2] == WeightDecay(0.3) + @test AdamW(lambda=0.3, eta=0.1) == AdamW(0.1, (0.9, 0.999), 0.3, 1.0e-8) end @testset "forgotten gradient" begin