Skip to content

Commit

Permalink
fix epsilon for Float16 (#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Nov 7, 2024
1 parent 2da6d7f commit 4a78a55
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ Manifest.toml
.vscode/
docs/build/
.DS_Store
/test.jl
2 changes: 2 additions & 0 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion,
AccumGrad

VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!"))

###
### one-array functions
###
Expand Down
21 changes: 10 additions & 11 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ RMSProp(; eta = 0.001, rho = 0.9, epsilon = 1e-8, kw...) = RMSProp(eta, rho, eps
init(o::RMSProp, x::AbstractArray) = (zero(x), o.centred ? zero(x) : false)

function apply!(o::RMSProp, state, x::AbstractArray{T}, dx) where T
η, ρ, ϵ = T(o.eta), T(o.rho), T(o.epsilon)
η, ρ, ϵ = T(o.eta), T(o.rho), _eps(T, o.epsilon)
quad, lin = state

@.. quad = ρ * quad + (1 - ρ) * abs2(dx)
Expand Down Expand Up @@ -216,7 +216,7 @@ end
init(o::Adam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))

function apply!(o::Adam, state, x::AbstractArray{T}, dx) where T
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
mt, vt, βt = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
Expand Down Expand Up @@ -279,7 +279,7 @@ end
init(o::RAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), 1)

function apply!(o::RAdam, state, x::AbstractArray{T}, dx) where T
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
ρ∞ = 2/(1-β[2]) - 1 |> real

mt, vt, βt, t = state
Expand Down Expand Up @@ -320,7 +320,7 @@ end
init(o::AdaMax, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))

function apply!(o::AdaMax, state, x::AbstractArray{T}, dx) where T
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
mt, ut, βt = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
Expand Down Expand Up @@ -354,7 +354,7 @@ end
init(o::OAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta), zero(x))

function apply!(o::OAdam, state, x::AbstractArray{T}, dx) where T
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
mt, vt, βt, term = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
Expand Down Expand Up @@ -388,7 +388,7 @@ end
init(o::AdaGrad, x::AbstractArray) = onevalue(o.epsilon, x)

function apply!(o::AdaGrad, state, x::AbstractArray{T}, dx) where T
η, ϵ = T(o.eta), T(o.epsilon)
η, ϵ = T(o.eta), _eps(T, o.epsilon)
acc = state

@.. acc = acc + abs2(dx)
Expand Down Expand Up @@ -418,7 +418,7 @@ end
init(o::AdaDelta, x::AbstractArray) = (zero(x), zero(x))

function apply!(o::AdaDelta, state, x::AbstractArray{T}, dx) where T
ρ, ϵ = T(o.rho), T(o.epsilon)
ρ, ϵ = T(o.rho), _eps(T, o.epsilon)
acc, Δacc = state

@.. acc = ρ * acc + (1 - ρ) * abs2(dx)
Expand Down Expand Up @@ -454,7 +454,7 @@ init(o::AMSGrad, x::AbstractArray) =
(onevalue(o.epsilon, x), onevalue(o.epsilon, x), onevalue(o.epsilon, x))

function apply!(o::AMSGrad, state, x::AbstractArray{T}, dx) where T
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
mt, vt, v̂t = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
Expand Down Expand Up @@ -489,8 +489,7 @@ end
init(o::NAdam, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))

function apply!(o::NAdam, state, x::AbstractArray{T}, dx) where T
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)

η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
mt, vt, βt = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
Expand Down Expand Up @@ -548,7 +547,7 @@ end
init(o::AdaBelief, x::AbstractArray{T}) where T = (zero(x), zero(x), T.(o.beta))

function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T
η, β, ϵ = T(o.eta), T.(o.beta), T(o.epsilon)
η, β, ϵ = T(o.eta), T.(o.beta), _eps(T, o.epsilon)
mt, st, βt = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ foreachvalue(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
end

ofeltype(x, y) = convert(float(eltype(x)), y)

_eps(T::Type{<:AbstractFloat}, e) = T(e)
# catch complex and integers
_eps(T::Type{<:Number}, e) = _eps(real(float(T)), e)
# avoid small e being rounded to zero
_eps(T::Type{Float16}, e) = e == 0 ? T(0) : max(T(1e-7), T(e))
10 changes: 10 additions & 0 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,13 @@ end
tree, x4 = Optimisers.update(tree, x3, g4)
@test x4 x3
end

@testset "Float16 epsilon" begin
# issue https://github.com/FluxML/Optimisers.jl/issues/167
x = Float16[0.579, -0.729, 0.5493]
δx = Float16[-0.001497, 0.0001875, -0.013176]

os = Optimisers.setup(Adam(1e-4), x);
os, x = Optimisers.update(os, x, δx)
@test x Float16[1.835, -0.886, 0.5493] rtol=1e-3
end

0 comments on commit 4a78a55

Please sign in to comment.