-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WeightDecay for L1 norm #159
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. It might be worth it to add examples to the docstring now that the rule is sufficiently complex.
src/rules.jl
Outdated
|
||
# Parameters | ||
- Weight decay (`γ`): Decay applied to weights during optimisation. | ||
- Sign decay (`ζ`): umm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Sign decay (`ζ`): umm | |
- Sign decay (`ζ`): Signed decay applied to weights during optimization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I meant to write some words! I think we can do better than "Weight decay (γ
): Decay applied to weights" too, as this is pretty circular.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though not 100% accurate, I think even "L1/L2 regularization coefficient" would be more informative.
An alternative API is to add |
I thought about that too, but this seems more straightforward if one wants to combine L1 and L2. We don't currently have a |
Yes I wondered about an independent rule, but then thought precisely that you may want a bit of L1 and a bit of L2. And also, perhaps, that if you know about this trick for L2, then this proximity may help you discover the similar trick for L1. I gave it the next unused greek letter. It's sort-of neat that each different rule you may wish to chain uses a different field name, as |
For what it's worth, I'm okay with a single rule. But just to push the other side bit more, you don't need a |
Ah you're right, I got my wires crossed there. FWIW, the AdamW paper uses |
Another option is to have |
Thanks for considering this contribution @mcabbott. Another convention, adopted in elastic net and elsewhere in statistics is to have an overall But I don't have a strong opinion. |
Ah that is a nice idea. It sounds like lambda is more standard. I don't know where we got gamma, possibly I just invented something other than Flux's It only matters because of |
Yes, but I have also seen the roles of lambda and alpha reversed :-( |
I wish I was surprised... Now changed to lambda alpha. This seems fairly natural to have as one struct not two. Not easily accessible from Flux, but shouldn't break anything: julia> Flux.setup(Flux.WeightDecay(0.1), [1,2.0]) |> dump
Optimisers.Leaf{WeightDecay, Nothing}
rule: WeightDecay
lambda: Float64 0.1
alpha: Float64 0.0
state: Nothing nothing
frozen: Bool false |
src/rules.jl
Outdated
λ, α = T(o.lambda), T(o.alpha) | ||
ℓ1 = λ * α | ||
ℓ2 = λ * (1 - α) | ||
dx′ = @lazy dx + ℓ2 * x + ℓ1 * sign(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there is a factor of two missing here. Consider ordinary scalar case: The derivative of x^2
(l2 penalty) is 2x
while the derivative of |x|
(l1 penalty) is sign(x)
. So either
dx′ = @lazy dx + ℓ2 * 2x + ℓ1 * sign(x)
or
dx′ = @lazy dx + ℓ2 * x + ℓ1 * sign(x) / 2
would be more correct. I think the first is better, but it is also breaking, I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not aware of any other implementations which add the factor of two. It's likely considered that it will be folded into λ
. Of course, none of them also try to use l1/l2 together!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x^2
(l2 penalty)
I'm sure all conventions exist, but the most common one seems to be to take norm(x)^2/2
as the L2 penalty. I think the present code & docs should agree on this choice.
For L1 it surely has to be just norm(x,1)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, all conventions exist, but I'm going to push back one last time on what it is "most common". The first place I looked, just now, is the Wiki page on regularisation and there is no 1/2 in front of the L2 penalty, when mixed in with L1. In fact, the formula and notation correspond exactly to my first case and what we implement in MLJFlux presently.
For me "LP penalty" is the LP norm to the power of p, which is always sum of |x_i|^p over i.
I understand how the 1/2 started to appear in isolation (ie. when ignoring L1 regularisation), because it simplifies the derivative. But in the context of using both, we should compare apples to apples, so the 1/2 makes no sense, unless you say the LP penalty is "\frac{1}{p} |x|^p" which I have not seen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can have anything other than lambda * x
as the L2 penalty. That's so commonly accepted as the standard weight decay implementation that anything else would be unnecessarily surprising. So it's either /2
for the L1 portion or no change. I'm inclined to consider the /2
only because we are using the elastic net coefficient convention. If we had separate coefficients for each term, I would prefer to avoid any extra factors and lump it all into the coefficients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Add 1/2 to the L1 part or separate into two optimisers.
@mcabbott Do you have some time to push this along? The project to update MLJFlux to use explicit parameters is waiting on this. |
I had a go locally & will try to find the branch |
Ok dbcea29 pushes what I had locally, way back when. It leaves Is this a good design? We could instead have a new struct which does only L1. And then (if we want to support a mixture) have some function which returns a chain of L1 and L2, using existing structs. Maybe that would be better. Edit: And f70aa9c changes to a separate Maybe they should not have the same field name lambda, ideas for what might be better? |
PyTorch wasn't very helpful as inspiration, but optax uses the term "decay rate" in their implementation of weight decay. A little verbose but pretty clear. Alternatively, the sklearn regression models call this L1/L2 coefficient |
The separate |
Another argument against having a mixture parameter: In most practical use, knowing that Last commits change the name of the L1 penalty coefficient to "kappa", because it's next door and not used in this package yet (hence |
If you'll allow me to bikeshed names one more time: I feel like we should not be pulling out greek characters that have not been used in the literature before, even if they are represented as English words instead of the original symbols. Are there no descriptive terms we can use instead of Otherwise LGTM. |
I think they could have the same name. They're both regularization parameters in separate structs. I think In the same vein that |
Good point about Maybe this is done? CI on julia > 1.6 might be fixed later by #166 |
If either of you clicks approve I can merge this, and then rebase #160 |
* WeightDecay for L1 norm * better words * change to lambda alpha, add tests * change to lambda, add tests * tweaks * shashed in October - makes two structs instead * version with simple SignDecay instead * change SignDecay penalty to be called kappa * restore depwarn for WeightDecay, was called gamma * change kappa back to lambda
As I learned here FluxML/MLJFlux.jl#221 (comment) , since the gradient of L1 norm is even simpler than the gradient of L2 norm it can, obviously, be implemented as an optimisation rule.
This quick PR adds it to the same WeightDecay struct. Below is a check that this does what you expect.
PR Checklist