Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WeightDecay for L1 norm #159

Merged
merged 10 commits into from
Feb 7, 2024
Merged

WeightDecay for L1 norm #159

merged 10 commits into from
Feb 7, 2024

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Sep 6, 2023

As I learned here FluxML/MLJFlux.jl#221 (comment) , since the gradient of L1 norm is even simpler than the gradient of L2 norm it can, obviously, be implemented as an optimisation rule.

This quick PR adds it to the same WeightDecay struct. Below is a check that this does what you expect.

using Flux: Flux, Dense, gradient, state
using Optimisers
using Optimisers: setup, update

input = [1,2]
model = Dense([1 -2; 3 -4.0])

grads = Flux.gradient(model) do m
  result = m(input)
  sum(result)
end

# Check L2 norm via WeightDecay (nothing new!)

pen_l2(x::AbstractArray) = sum(abs2, x)/2

grads_L2 = Flux.gradient(model) do m
  result = m(input)
  penalty = sum(pen_l2, Flux.params(m))
  sum(result) + 0.42 * penalty
end

update(
  setup(Descent(0.1), model),
  model, grads_L2[1])[2] |> Flux.state

update(
  setup(OptimiserChain(WeightDecay(0.42), Descent(0.1)), model),
  model, grads[1])[2] |> Flux.state

# Do exactly the same thing for L1 (needs this PR)

pen_l1(x::AbstractArray) = sum(abs, x)

grads_L1 = Flux.gradient(model) do m
  result = m(input)
  penalty = sum(pen_l1, Flux.params(m))
  sum(result) + 0.42 * penalty
end

update(
  setup(Descent(0.1), model),
  model, grads_L1[1])[2] |> Flux.state

update(
  setup(OptimiserChain(WeightDecay(0.0, 0.42), Descent(0.1)), model),
  model, grads[1])[2] |> Flux.state
  
# Both give (weight = [0.858 -2.158; 2.858 -4.158], bias = [-0.1, -0.1], σ = ())

PR Checklist

  • Tests are added
  • Documentation, if applicable

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. It might be worth it to add examples to the docstring now that the rule is sufficiently complex.

src/rules.jl Outdated

# Parameters
- Weight decay (`γ`): Decay applied to weights during optimisation.
- Sign decay (`ζ`): umm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Sign decay (`ζ`): umm
- Sign decay (`ζ`): Signed decay applied to weights during optimization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I meant to write some words! I think we can do better than "Weight decay (γ): Decay applied to weights" too, as this is pretty circular.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though not 100% accurate, I think even "L1/L2 regularization coefficient" would be more informative.

@darsnack
Copy link
Member

darsnack commented Sep 6, 2023

An alternative API is to add SignedDecay (or something) if we find WeightDecay(0.0, 0.004) too weird.

@ToucheSir
Copy link
Member

I thought about that too, but this seems more straightforward if one wants to combine L1 and L2. We don't currently have a Parallel-esque rule which feeds the same gradient into two different rules, though now that I say it such a composite rule could be a nice addition.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 6, 2023

Yes I wondered about an independent rule, but then thought precisely that you may want a bit of L1 and a bit of L2. And also, perhaps, that if you know about this trick for L2, then this proximity may help you discover the similar trick for L1.

I gave it the next unused greek letter. It's sort-of neat that each different rule you may wish to chain uses a different field name, as adjust!(..., zeta=0.1) etc. never modifies two unrelated things.

@darsnack
Copy link
Member

darsnack commented Sep 6, 2023

For what it's worth, I'm okay with a single rule. But just to push the other side bit more, you don't need a Parallel-esque construct for these rules to compose. OptimiserChain(WeightDecay(0.004), SignedDecay(0.004), Descent(0.1)) works just fine (since it depends on x not dx).

@ToucheSir
Copy link
Member

Ah you're right, I got my wires crossed there.

FWIW, the AdamW paper uses λ for the weight decay term, which PyTorch borrows for its optimizer documentation but does not use in any API.

@darsnack
Copy link
Member

darsnack commented Sep 6, 2023

Another option is to have SignedDecay(zeta) = WeightDecay(0, zeta). I'm okay with all options, just throwing things out for consideration.

@ablaom
Copy link

ablaom commented Sep 7, 2023

Thanks for considering this contribution @mcabbott.

Another convention, adopted in elastic net and elsewhere in statistics is to have an overall lambda parameter and an L1/L2 mixture parameter alpha. This is what we do in MLJFlux.

https://github.com/FluxML/MLJFlux.jl/blob/b449d80d1d5606298bae0ded1992ee35c5c099c0/src/penalizers.jl#L11

But I don't have a strong opinion.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 7, 2023

Ah that is a nice idea.

It sounds like lambda is more standard. I don't know where we got gamma, possibly I just invented something other than Flux's .wd:

https://github.com/FluxML/Flux.jl/blob/95737ffc9aa989f31d5fecd9a887a9c25f4fd865/src/optimise/optimisers.jl#L690-L692

It only matters because of adjust!, but I guess we can add a deprecation.

@ablaom
Copy link

ablaom commented Sep 7, 2023

Yes, but I have also seen the roles of lambda and alpha reversed :-(

@mcabbott
Copy link
Member Author

mcabbott commented Sep 8, 2023

I wish I was surprised...

Now changed to lambda alpha. This seems fairly natural to have as one struct not two.

Not easily accessible from Flux, but shouldn't break anything:

julia> Flux.setup(Flux.WeightDecay(0.1), [1,2.0]) |> dump
Optimisers.Leaf{WeightDecay, Nothing}
  rule: WeightDecay
    lambda: Float64 0.1
    alpha: Float64 0.0
  state: Nothing nothing
  frozen: Bool false

@mcabbott mcabbott marked this pull request as ready for review September 8, 2023 03:30
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated Show resolved Hide resolved
src/rules.jl Outdated
λ, α = T(o.lambda), T(o.alpha)
ℓ1 = λ * α
ℓ2 = λ * (1 - α)
dx′ = @lazy dx + ℓ2 * x + ℓ1 * sign(x)
Copy link

@ablaom ablaom Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a factor of two missing here. Consider ordinary scalar case: The derivative of x^2 (l2 penalty) is 2x while the derivative of |x| (l1 penalty) is sign(x). So either

dx′ = @lazy dx + ℓ2 * 2x + ℓ1 * sign(x)

or

dx′ = @lazy dx + ℓ2 * x + ℓ1 * sign(x) / 2

would be more correct. I think the first is better, but it is also breaking, I guess.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not aware of any other implementations which add the factor of two. It's likely considered that it will be folded into λ. Of course, none of them also try to use l1/l2 together!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x^2 (l2 penalty)

I'm sure all conventions exist, but the most common one seems to be to take norm(x)^2/2 as the L2 penalty. I think the present code & docs should agree on this choice.

For L1 it surely has to be just norm(x,1).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, all conventions exist, but I'm going to push back one last time on what it is "most common". The first place I looked, just now, is the Wiki page on regularisation and there is no 1/2 in front of the L2 penalty, when mixed in with L1. In fact, the formula and notation correspond exactly to my first case and what we implement in MLJFlux presently.

For me "LP penalty" is the LP norm to the power of p, which is always sum of |x_i|^p over i.

I understand how the 1/2 started to appear in isolation (ie. when ignoring L1 regularisation), because it simplifies the derivative. But in the context of using both, we should compare apples to apples, so the 1/2 makes no sense, unless you say the LP penalty is "\frac{1}{p} |x|^p" which I have not seen.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can have anything other than lambda * x as the L2 penalty. That's so commonly accepted as the standard weight decay implementation that anything else would be unnecessarily surprising. So it's either /2 for the L1 portion or no change. I'm inclined to consider the /2 only because we are using the elastic net coefficient convention. If we had separate coefficients for each term, I would prefer to avoid any extra factors and lump it all into the coefficients.

Copy link

@ablaom ablaom Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Add 1/2 to the L1 part or separate into two optimisers.

@ablaom
Copy link

ablaom commented Jan 16, 2024

@mcabbott Do you have some time to push this along? The project to update MLJFlux to use explicit parameters is waiting on this.

@mcabbott
Copy link
Member Author

I had a go locally & will try to find the branch

@mcabbott
Copy link
Member Author

mcabbott commented Feb 2, 2024

Ok dbcea29 pushes what I had locally, way back when. It leaves WeightDecay alone, and makes a new struct for the combined L1 and L2 story. I called this NormReg although perhaps there's a better name.

Is this a good design? We could instead have a new struct which does only L1. And then (if we want to support a mixture) have some function which returns a chain of L1 and L2, using existing structs. Maybe that would be better.

Edit: And f70aa9c changes to a separate SignDecay struct for L1 alone. No function for a combination. Maybe that's the minimal thing.

Maybe they should not have the same field name lambda, ideas for what might be better?

@ToucheSir
Copy link
Member

ToucheSir commented Feb 3, 2024

PyTorch wasn't very helpful as inspiration, but optax uses the term "decay rate" in their implementation of weight decay. A little verbose but pretty clear.

Alternatively, the sklearn regression models call this L1/L2 coefficient alpha. The ElasticNet page specifically refers to it as a "penalty (term)", which is another idea for a plain English word.

@ablaom
Copy link

ablaom commented Feb 4, 2024

The separate SignDecay option, as currently implemented here, would suit me fine. In this way, I can confidently use the two decays without looking up documentation to sort out the notation and convention about 1, or 1/2. (In Elastic net I have seen the roles of alpha and lambda reversed in some implementations.)

@mcabbott
Copy link
Member Author

mcabbott commented Feb 5, 2024

Another argument against having a mixture parameter: In most practical use, knowing that λ = 1e-3 is a useful amount of L2 for your problem does not imply that this is the right amount of L1... you are going to have to search. In which case just changing the mixture / angle parameter isn't really better than changing some other κ instead.

Last commits change the name of the L1 penalty coefficient to "kappa", because it's next door and not used in this package yet (hence adjust(st, kappa=0.1) will hit exactly one thing).

@ToucheSir
Copy link
Member

If you'll allow me to bikeshed names one more time: I feel like we should not be pulling out greek characters that have not been used in the literature before, even if they are represented as English words instead of the original symbols. Are there no descriptive terms we can use instead of kappa (and maybe lambda too)?

Otherwise LGTM.

@ablaom
Copy link

ablaom commented Feb 6, 2024

I think they could have the same name. They're both regularization parameters in separate structs. I think lambda (or its unicode equivalent) is pretty standard for a generic reg. param.

In the same vein that eta is used for learning rate in all the variations of Optimiser's grad descent.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 6, 2024

Good point about eta being used everywhere, maybe just re-using lambda is best.

Maybe this is done? CI on julia > 1.6 might be fixed later by #166

@mcabbott
Copy link
Member Author

mcabbott commented Feb 7, 2024

If either of you clicks approve I can merge this, and then rebase #160

@mcabbott mcabbott merged commit e60b71e into FluxML:master Feb 7, 2024
3 of 4 checks passed
@mcabbott mcabbott deleted the l1norm branch February 7, 2024 18:07
mashu pushed a commit to mashu/Optimisers.jl that referenced this pull request Nov 14, 2024
* WeightDecay for L1 norm

* better words

* change to lambda alpha, add tests

* change to lambda, add tests

* tweaks

* shashed in October - makes two structs instead

* version with simple SignDecay instead

* change SignDecay penalty to be called kappa

* restore depwarn for WeightDecay, was called gamma

* change kappa back to lambda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants