-
-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Apollo optimizer (https://arxiv.org/pdf/2412.05270) #196
base: master
Are you sure you want to change the base?
Changes from 11 commits
bb94d68
acbe8e3
d358026
8a05289
b46c0ef
e39add7
43d30c6
b282c35
ca2ae0a
6aa32c1
d9637c6
c75142f
b95fd3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -599,6 +599,140 @@ | |||||||||||||
return (mt, st, βt .* β), dx′ | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
""" | ||||||||||||||
GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) | ||||||||||||||
|
||||||||||||||
Gradient norm growth limiter. Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but | ||||||||||||||
with Optimisers.jl this will apply per-tensor instead of per-model, and as a result the defaults are different. `γ` controls the maximum that the gradient norm can grow | ||||||||||||||
from one step to the next. This implementation also introduces `m` a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor | ||||||||||||||
from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with `paramscale_min = true`). | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this explain what it does do, mathematically, before explaining that it's different to some paper?
I don't know what this means without reading the code. Can you write like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||||||||||||||
""" | ||||||||||||||
struct GradNormGrowthLimiter <: AbstractRule | ||||||||||||||
γ::Float64 | ||||||||||||||
m::Float64 #Min grad norm, to stop a tensor getting stuck near zero | ||||||||||||||
ϵ::Float64 | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't allow unicode field names, suggest:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, but changed the variable names to avoid eg. gamma. |
||||||||||||||
throw::Bool | ||||||||||||||
paramscale_min::Bool | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) = GradNormGrowthLimiter(γ, m, ϵ, throw, paramscale_min) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have greek-letter keyword options, nor field names -- the API should never ask the user to type these. They are used only in documentation / as local variables. Probably the first 3 should be positional. Bikeshedding names bit, to avoid overly long things, the constructor could be:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I went with NormGrowthCap here. |
||||||||||||||
|
||||||||||||||
init(o::GradNormGrowthLimiter, x::AbstractArray{T}) where T = T(0) | ||||||||||||||
|
||||||||||||||
function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where T | ||||||||||||||
current_norm = _norm(dx, 2) | ||||||||||||||
if o.throw && !isfinite(current_norm) | ||||||||||||||
throw(DomainError("gradient has L2-norm $current_norm, for array $(summary(x))")) | ||||||||||||||
end | ||||||||||||||
if state == 0 | ||||||||||||||
return (current_norm), dx | ||||||||||||||
else | ||||||||||||||
#If you're below the hard min, then don't scale | ||||||||||||||
if o.paramscale_min | ||||||||||||||
minthresh = o.m * sqrt(length(dx)) | ||||||||||||||
else | ||||||||||||||
minthresh = o.m | ||||||||||||||
end | ||||||||||||||
if current_norm < minthresh | ||||||||||||||
return current_norm, dx | ||||||||||||||
end | ||||||||||||||
ratio = current_norm / (state + o.ϵ) | ||||||||||||||
if ratio > o.γ | ||||||||||||||
λ = T((o.γ * state) / (current_norm + o.ϵ)) | ||||||||||||||
return current_norm * λ, dx * λ | ||||||||||||||
else | ||||||||||||||
return current_norm, dx | ||||||||||||||
end | ||||||||||||||
end | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
nonfirstdims(x) = prod(size(x)[2:end]) | ||||||||||||||
|
||||||||||||||
""" | ||||||||||||||
Apollo(η::Real, rank::Int; u = 100, sort_dims = false) | ||||||||||||||
Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false) | ||||||||||||||
Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = false) | ||||||||||||||
Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = false) | ||||||||||||||
|
||||||||||||||
Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks moments in a low-rank subspace, aiming for Adam-like behavior with minimal additional memory usage. | ||||||||||||||
First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function | ||||||||||||||
to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor). | ||||||||||||||
""" | ||||||||||||||
struct Apollo{T1, T2, T3, T4, T5} <: AbstractRule | ||||||||||||||
opt::T1 | ||||||||||||||
eta::T2 | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why store opt and eta? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally just stored Edit: another reason for storing an AdamW is that the AdamW is used instead of Apollo on regular arrays. But I just realized that now "adjust" won't work for regular arrays. I'll try figuring this out... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Storing an AdamW seems fine, surely we can make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've made adjust work on the inner Adam now, so have dropped the additional eta. |
||||||||||||||
r::T3 #Maps non-first dims to rank | ||||||||||||||
u::T4 #Subspace update frequency (T in paper) | ||||||||||||||
sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These have fixed types, right?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup. |
||||||||||||||
end | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
Apollo() = Apollo(AdamW(0.001), 0.001, dim -> ceil(Int, sqrt(dim)), 100, true) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't this method just be created by giving a default to eta in the next one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fixed via a different route. |
||||||||||||||
Apollo(η::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), η, dim -> max(dim, rank), u, sort_dims) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure you want There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching this. |
||||||||||||||
Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), η, rank_function, u, sort_dims) | ||||||||||||||
Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(opt, opt.eta, dim -> max(dim, rank), u, sort_dims) | ||||||||||||||
Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, opt.eta, rank_function, u, sort_dims) | ||||||||||||||
|
||||||||||||||
#Use the base init and apply for 1D arrays | ||||||||||||||
init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x) | ||||||||||||||
apply!(o::Apollo, state, x::AbstractArray{T,1}, dx) where T = apply!(o.opt, state, x, dx) | ||||||||||||||
|
||||||||||||||
function init(o::Apollo, x::AbstractArray{T}) where T | ||||||||||||||
first_dim, second_dim = size(x,1), nonfirstdims(x) | ||||||||||||||
if o.sort_dims && second_dim < first_dim | ||||||||||||||
first_dim, second_dim = second_dim, first_dim | ||||||||||||||
end | ||||||||||||||
rank = o.r(second_dim) | ||||||||||||||
P = similar(x, rank, first_dim) | ||||||||||||||
randn!(P) | ||||||||||||||
P .*= T(sqrt(1/rank)) | ||||||||||||||
((similar(x, rank, second_dim) .= 0, similar(x, rank, second_dim) .= 0, o.opt.beta), 1, P) | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T | ||||||||||||||
swapped = false | ||||||||||||||
original_size = size(x) | ||||||||||||||
x = reshape(x, size(x,1), nonfirstdims(x)) | ||||||||||||||
|
||||||||||||||
dx = Broadcast.materialize(dx) #This is to stop the "gradient type" @lazy test from failing due to reshape. | ||||||||||||||
dx = reshape(dx, size(x,1), nonfirstdims(x)) | ||||||||||||||
Comment on lines
+707
to
+708
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need to materialize in matrix case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For everything except the whatever comes in during the "gradient type" test you don't need materialize. I wasn't 100% sure exactly what is coming in during those tests, so wasn't sure how to separate them from regular matrix/tensors. What do you suggest here? |
||||||||||||||
|
||||||||||||||
first_dim, second_dim = size(x,1), size(x,2) | ||||||||||||||
if o.sort_dims && second_dim < first_dim | ||||||||||||||
first_dim, second_dim = second_dim, first_dim | ||||||||||||||
x = x' | ||||||||||||||
dx = dx' | ||||||||||||||
swapped = true | ||||||||||||||
end | ||||||||||||||
(mt, vt, βt), t, P = state | ||||||||||||||
η = T(o.eta) #This is what will get modified by adjust | ||||||||||||||
λ = T(o.opt.lambda) | ||||||||||||||
β = T.(o.opt.beta) | ||||||||||||||
ϵ = T(o.opt.epsilon) | ||||||||||||||
βt = T.(βt) | ||||||||||||||
if mod(t, o.u) == 0 | ||||||||||||||
rank = o.r(second_dim) | ||||||||||||||
randn!(P) | ||||||||||||||
P .*= T(sqrt(1/rank)) | ||||||||||||||
end | ||||||||||||||
R = P * dx | ||||||||||||||
@.. mt = β[1] * mt + (1 - β[1]) * R | ||||||||||||||
@.. vt = β[2] * vt + (1 - β[2]) * abs2(R) | ||||||||||||||
Rhat = @. mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) | ||||||||||||||
s = sqrt.(sum(abs2.(Rhat), dims=1))[:] ./ (sqrt.(sum(abs2.(R), dims=1))[:] .+ ϵ) | ||||||||||||||
dx′′ = η * (dx .* reshape(s, 1, :)) + λ * x | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These lines allocate a lot. Rhat isn't used? For the rest maybe it can be something like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got something like this working, but the @lazy breaks things, so omitted for now. |
||||||||||||||
if swapped | ||||||||||||||
dx′′ = dx′′' | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, this sort of branching introduces type instability. IDK if we care but perhaps worth some thought. Maybe there's a nicer way to just store everything transposed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe an optimization we can figure out later if it becomes an issue? |
||||||||||||||
end | ||||||||||||||
return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size) | ||||||||||||||
end | ||||||||||||||
|
||||||||||||||
#Notes: chuck the AdamW from the struct, so that adjust will just work. | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
|
||||||||||||||
""" | ||||||||||||||
WeightDecay(λ = 5e-4) | ||||||||||||||
WeightDecay(; [lambda]) | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should the default value for
m
correspond to the original paper (i.e. m=0 i suppose)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m=0
makes sense when this is applied to the entire model, but could be fatal when applied tensor-wise. I think it is better to have non-footgun defaults, and make it clearer that this isn't a faithful reproduction?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've kept a non-zero default, but I've tweaked the docs to clarify that this method isn't quite the same as in those papers. (I also switched the "scaling m by the number of parameters" to using
sqrt
).