From 47ea05910d6665cfb117c59804de53580739f8ae Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 9 Feb 2024 18:42:30 -0500 Subject: [PATCH 1/3] add SignDecay --- Project.toml | 2 +- src/Flux.jl | 2 +- src/deprecations.jl | 8 +++++++- src/optimise/Optimise.jl | 2 +- src/optimise/optimisers.jl | 23 +++++++++++++++++++++++ test/optimise.jl | 2 +- 6 files changed, 34 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 1ec2036e56..660bd0b865 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ MacroTools = "0.5" Metal = "0.5, 1" NNlib = "0.9.1" OneHotArrays = "0.2.4" -Optimisers = "0.2.12, 0.3.0" +Optimisers = "0.3.2" Preferences = "1" ProgressLogging = "0.1" Reexport = "1.0" diff --git a/src/Flux.jl b/src/Flux.jl index d158771ef8..d3ca611dbd 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -45,7 +45,7 @@ using .Optimise export Descent, Adam, Momentum, Nesterov, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, OAdam, AdamW, RAdam, AdaBelief, InvDecay, ExpDecay, - WeightDecay, ClipValue, ClipNorm + WeightDecay, SignDecay, ClipValue, ClipNorm export ClipGrad, OptimiserChain # these are const defined in deprecations, for ClipValue, Optimiser diff --git a/src/deprecations.jl b/src/deprecations.jl index 3b65f00da0..f775deb23b 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -118,6 +118,7 @@ Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(_old_to_new(ru for T in [:Descent, :Adam, :Momentum, :Nesterov, :AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief, # :InvDecay, :ExpDecay, + :SignDecay, ] @eval function _old_to_new(rule::$T) args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T)) @@ -126,7 +127,7 @@ for T in [:Descent, :Adam, :Momentum, :Nesterov, end _old_to_new(rule::Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...) const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too. -_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called gamma now +_old_to_new(rule::WeightDecay) = Optimisers.WeightDecay(rule.wd) # called lambda now _old_to_new(rule::ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields _old_to_new(rule::ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs const ClipGrad = Optimise.ClipValue @@ -134,6 +135,11 @@ _old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon _old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule") +# This allows you to mix and match, like Flux.setup(OptimiserChain(SignDecay(), Flux.Descent()), [1,2,3.]) +Optimisers.OptimiserChain(rules::Union{Optimisers.AbstractRule, Optimise.AbstractOptimiser}...) = + Optimisers.OptimiserChain(map(_old_to_new, rules)) +_old_to_new(rule::Optimisers.AbstractRule) = rule + # Since `update!` should be called in a loop, it makes less sense to call `setup` for you if you forgot. # But let's make sure that such uses give a helpful error: import .Optimise: update! diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 3ca01e93fa..f637d83242 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -6,7 +6,7 @@ export train!, update!, Descent, Adam, Momentum, Nesterov, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW,RAdam, OAdam, AdaBelief, InvDecay, ExpDecay, WeightDecay, Optimiser, - ClipValue, ClipNorm + ClipValue, ClipNorm, SignDecay include("optimisers.jl") include("train.jl") diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 9da9b1472f..d2f69da884 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -698,6 +698,29 @@ function apply!(o::WeightDecay, x, Δ) @. Δ += wd * x end +""" + SignDecay(λ = 1e-3) + +Version of `WeightDecay` which implements ``L_1`` regularisation, +when composed with other optimisers as the first transformation to the gradient. + +# Examples + +```julia +opt = Optimiser(SignDecay(1e-4), Adam()) +``` +""" +mutable struct SignDecay <: AbstractOptimiser + lambda::Float32 +end + +WeightDecay() = SignDecay(1f-3) + +function apply!(o::SignDecay, x, Δ) + λ = o.lambda + @. Δ += λ * sign(x) +end + """ ClipValue(thresh) diff --git a/test/optimise.jl b/test/optimise.jl index c79ce7f5e8..475bfba6c1 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -30,7 +30,7 @@ end @testset "Optimiser" begin Random.seed!(84) w = randn(10, 10) - @testset for Opt in [InvDecay, WeightDecay, ExpDecay] + @testset for Opt in [InvDecay, WeightDecay, ExpDecay, SignDecay] Random.seed!(42) w′ = randn(10, 10) loss(x) = Flux.Losses.mse(w*x, w′*x) From d65b76ff5bec5e4cc83dddad5ad21f4f15c97669 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 9 Feb 2024 21:02:09 -0500 Subject: [PATCH 2/3] typo --- src/optimise/optimisers.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index d2f69da884..18f9d3ddae 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -714,7 +714,7 @@ mutable struct SignDecay <: AbstractOptimiser lambda::Float32 end -WeightDecay() = SignDecay(1f-3) +SignDecay() = SignDecay(1f-3) function apply!(o::SignDecay, x, Δ) λ = o.lambda From adafb71126d0719e98e07adfa822c5b0533b4c48 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 12 Feb 2024 18:18:28 -0500 Subject: [PATCH 3/3] Update src/deprecations.jl --- src/deprecations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deprecations.jl b/src/deprecations.jl index f775deb23b..4b094e5eb5 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -135,7 +135,7 @@ _old_to_new(rule::RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon _old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule") -# This allows you to mix and match, like Flux.setup(OptimiserChain(SignDecay(), Flux.Descent()), [1,2,3.]) +# This allows you to mix and match, like Flux.setup(OptimiserChain(Optimisers.SignDecay(), Flux.Descent()), [1,2,3.]) Optimisers.OptimiserChain(rules::Union{Optimisers.AbstractRule, Optimise.AbstractOptimiser}...) = Optimisers.OptimiserChain(map(_old_to_new, rules)) _old_to_new(rule::Optimisers.AbstractRule) = rule