From 4a78a55f55e098a71fc96b2c2d91bb75b7a926cb Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 7 Nov 2024 08:13:54 +0100 Subject: [PATCH] fix epsilon for Float16 (#190) --- .gitignore | 1 + src/Optimisers.jl | 2 ++ src/rules.jl | 21 ++++++++++----------- src/utils.jl | 6 ++++++ test/rules.jl | 10 ++++++++++ 5 files changed, 29 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 952f7ce..763dd6f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ Manifest.toml .vscode/ docs/build/ .DS_Store +/test.jl \ No newline at end of file diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 2e115c4..99fc162 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -25,6 +25,8 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, AccumGrad +VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!")) + ### ### one-array functions ### diff --git a/src/rules.jl b/src/rules.jl index bc9c099..b4fbd2a 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -130,7 +130,7 @@ RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, eps init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false) function apply!(o::RMSProp, state, x::AbstractArray{T}, dx) where T - η, ρ, ϵ = T(o.eta), T(o.rho), T(o.epsilon) + η, ρ, ϵ = T(o.eta), T(o.rho), _eps(T, o.epsilon) quad, lin = state @.. quad = ρ * quad + (1 - ρ) * abs2(dx) @@ -216,7 +216,7 @@ end init(o::Adam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta)) function apply!(o::Adam, state, x::AbstractArray{T}, dx) where T - η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) + η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon) mt, vt, βt = state @.. mt = β[1] * mt + (1 - β[1]) * dx @@ -279,7 +279,7 @@ end init(o::RAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), 1) function apply!(o::RAdam, state, x::AbstractArray{T}, dx) where T - η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) + η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon) ρ∞ = 2/(1-β[2]) - 1 |> real mt, vt, βt, t = state @@ -320,7 +320,7 @@ end init(o::AdaMax, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta)) function apply!(o::AdaMax, state, x::AbstractArray{T}, dx) where T - η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) + η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon) mt, ut, βt = state @.. mt = β[1] * mt + (1 - β[1]) * dx @@ -354,7 +354,7 @@ end init(o::OAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), zero(x)) function apply!(o::OAdam, state, x::AbstractArray{T}, dx) where T - η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) + η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon) mt, vt, βt, term = state @.. mt = β[1] * mt + (1 - β[1]) * dx @@ -388,7 +388,7 @@ end init(o::AdaGrad, x::AbstractArray) = onevalue(o.epsilon, x) function apply!(o::AdaGrad, state, x::AbstractArray{T}, dx) where T - η, ϵ = T(o.eta), T(o.epsilon) + η, ϵ = T(o.eta), _eps(T, o.epsilon) acc = state @.. acc = acc + abs2(dx) @@ -418,7 +418,7 @@ end init(o::AdaDelta, x::AbstractArray) = (zero(x), zero(x)) function apply!(o::AdaDelta, state, x::AbstractArray{T}, dx) where T - ρ, ϵ = T(o.rho), T(o.epsilon) + ρ, ϵ = T(o.rho), _eps(T, o.epsilon) acc, Δacc = state @.. acc = ρ * acc + (1 - ρ) * abs2(dx) @@ -454,7 +454,7 @@ init(o::AMSGrad, x::AbstractArray) = (onevalue(o.epsilon, x), onevalue(o.epsilon, x), onevalue(o.epsilon, x)) function apply!(o::AMSGrad, state, x::AbstractArray{T}, dx) where T - η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) + η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon) mt, vt, v̂t = state @.. mt = β[1] * mt + (1 - β[1]) * dx @@ -489,8 +489,7 @@ end init(o::NAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta)) function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T - η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) - + η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon) mt, vt, βt = state @.. mt = β[1] * mt + (1 - β[1]) * dx @@ -548,7 +547,7 @@ end init(o::AdaBelief, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta)) function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T - η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon) + η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon) mt, st, βt = state @.. mt = β[1] * mt + (1 - β[1]) * dx diff --git a/src/utils.jl b/src/utils.jl index 12a19dd..8f66746 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,3 +14,9 @@ foreachvalue(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v) end ofeltype(x, y) = convert(float(eltype(x)), y) + +_eps(T::Type{<:AbstractFloat}, e) = T(e) +# catch complex and integers +_eps(T::Type{<:Number}, e) = _eps(real(float(T)), e) +# avoid small e being rounded to zero +_eps(T::Type{Float16}, e) = e == 0 ? T(0) : max(T(1e-7), T(e)) diff --git a/test/rules.jl b/test/rules.jl index 52a3580..9068fa1 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -267,3 +267,13 @@ end tree, x4 = Optimisers.update(tree, x3, g4) @test x4 ≈ x3 end + +@testset "Float16 epsilon" begin + # issue https://github.com/FluxML/Optimisers.jl/issues/167 + x = Float16[0.579, -0.729, 0.5493] + δx = Float16[-0.001497, 0.0001875, -0.013176] + + os = Optimisers.setup(Adam(1e-4), x); + os, x = Optimisers.update(os, x, δx) + @test x ≈ Float16[1.835, -0.886, 0.5493] rtol=1e-3 +end