Skip to content

Commit

Permalink
Add eager updating
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb authored Dec 13, 2024
1 parent 9b6dd08 commit 2d423c9
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/lib/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,42 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs...
return y, pullback_checkpointed
end


"""
eager_update(f, update, state, xs...)
Allows training large models when the gradients cannot all fit in memory simultaneously.
A combination of gradient checkpointing and eagerly updating the model parameters, discarding the updated gradients.
Assumes that `f` is a callable struct, `state` is the optimization state (eg. from Optimisers.jl) matching `f`, and
`update` is the function that updates the parameters of `f` from the state and the gradients, called as `update(state, f, grads)`.
If eg. `model.layers[i]` is layer in a transformer, then:
```
for i in 1:length(model.layers)
h = eager_updater(model.layers[i], Optimisers.update!, opt_state.layers[i], h, other_args)
end
```
!!! warning
If different layers share trainable parameters, then `eager_update` will likely give wrong results.
"""
eager_update(f, update, state, xs...) = f(state, xs...)

function Zygote._pullback(ctx::Zygote.AContext, ::typeof(eager_update), f, update, state, xs...)
y = f(xs...)
function pullback_eager_update(Δy)
y, pb = Zygote._pullback(ctx, f, xs...)
ret = pb(Δy)
update(state, f, ret[1])
return (nothing, nothing, nothing, nothing, ret[2:end]...)
end
return y, pullback_eager_update
end


"""
hessian(f, x)
Expand Down

0 comments on commit 2d423c9

Please sign in to comment.