diff --git a/src/rules.jl b/src/rules.jl index 4ba550d..08216df 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -603,10 +603,10 @@ end """ GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true) -Gradient norm growth limiter from Chen et al. (https://arxiv.org/pdf/2410.01623) and used with Apollo in Zhu et al. (https://arxiv.org/pdf/2412.05270). -With Optimisers.jl this will apply per-tensor, which may not be the same as the implementations in these papers. It still seems to help, but the ideal settings may vary. -This also introduces `m` a hard minimum on the gradient norm, and never rescales grads below this, preventing a tensor from getting "trapped" near zero. -This can be a fixed min, or scaled by the number of parameters in the tensor (with `paramscale_min = true`). +Gradient norm growth limiter. Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but +with Optimisers.jl this will apply per-tensor instead of per-model, and as a result the defaults are different. `γ` controls the maximum that the gradient norm can grow +from one step to the next. This implementation also introduces `m` a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor +from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with `paramscale_min = true`). """ struct GradNormGrowthLimiter <: AbstractRule γ::Float64 @@ -630,7 +630,7 @@ function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where else #If you're below the hard min, then don't scale if o.paramscale_min - minthresh = o.m * length(dx) + minthresh = o.m * sqrt(length(dx)) else minthresh = o.m end @@ -659,19 +659,20 @@ Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks mome First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor). """ -struct Apollo{T1} <: AbstractRule +struct Apollo{T1, T2, T3, T4, T5} <: AbstractRule opt::T1 - r::Function #Maps non-first dims to rank - u::Int #Subspace update frequency (T in paper) - sort_dims::Bool #Whether to swap the dims of x and dx when the second dim is smaller than the first + eta::T2 + r::T3 #Maps non-first dims to rank + u::T4 #Subspace update frequency (T in paper) + sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first end -Apollo() = Apollo(AdamW(0.001), dim -> ceil(Int, sqrt(dim)), 100, true) -Apollo(η::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims) -Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), rank_function, u, sort_dims) -Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims) -Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, rank_function, u, sort_dims) +Apollo() = Apollo(AdamW(0.001), 0.001, dim -> ceil(Int, sqrt(dim)), 100, true) +Apollo(η::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), η, dim -> max(dim, rank), u, sort_dims) +Apollo(η::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), η, rank_function, u, sort_dims) +Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(opt, opt.eta, dim -> max(dim, rank), u, sort_dims) +Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, opt.eta, rank_function, u, sort_dims) #Use the base init and apply for 1D arrays init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x) @@ -706,7 +707,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T swapped = true end (mt, vt, βt), t, P = state - η = T(o.opt.eta) + η = T(o.eta) #This is what will get modified by adjust λ = T(o.opt.lambda) β = T.(o.opt.beta) ϵ = T(o.opt.epsilon) @@ -728,6 +729,9 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size) end +#Notes: chuck the AdamW from the struct, so that adjust will just work. + + """ WeightDecay(λ = 5e-4)