Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager parameter updating #1541

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Conversation

murrellb
Copy link

This adds something like Zygote's checkpointed, but additionally accepts an optimizer state and an update function. The model parameters are updated during the backward pass and then the gradients are discarded, allowing you to train models when you can't fit both the model weights and the full gradients in memory together.

I wasn't quite sure if this should be PR'd to Flux instead?

PR Checklist

  • Tests are added
  • Documentation, if applicable

src/lib/grad.jl Outdated Show resolved Hide resolved
src/lib/grad.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

Is there precedence for this in other libraries?

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 13, 2024

There is a pytorch tutorial describing a per-parameter version here
https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html

# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
# for every parameter so we could reference them in our hook.
optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}

# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
def optimizer_hook(parameter) -> None:
  optimizer_dict[parameter].step()
  optimizer_dict[parameter].zero_grad()

# Register the hook onto every parameter
for p in model.parameters():
   p.register_post_accumulate_grad_hook(optimizer_hook)

It would be great if we could have a function like this

grad_and_update!(m -> loss(m, x, y), model, opt_state)

in an Optimisers extension, also interacting nicely with checkpointed.

@murrellb
Copy link
Author

Is there precedence for this in other libraries?

I've never used any other libraries :)

There is a pytorch tutorial describing a per-parameter version here ...

Nice! I figure it must be a normal trick, but I was quite happy with how easy it was to get such a big gain in this ecosystem. And this version luckily doesn't require you to split your optimizer up across the different layers, because you can just pass the right part of the larger opt_state into this. It is now ~one line to halve your model's mem requirements.

One thing I might want to tweak before merging: this works when f is a callable struct that stores its parameters, but if the function and the model weights are different things then I think this won't work as-is. Let me see if I can generalize it slightly?

src/lib/grad.jl Outdated Show resolved Hide resolved
src/lib/grad.jl Outdated Show resolved Hide resolved
src/lib/grad.jl Outdated Show resolved Hide resolved
@murrellb
Copy link
Author

It would be great if we could have a function like this: grad_and_update!(m -> loss(m, x, y), model, opt_state) in an Optimisers extension, also interacting nicely with checkpointed.

Do you mean automatically tracking which bit of the optimizer state would go into eager_update!? Yes, that would be neat, but a little beyond my abilities.

@murrellb murrellb marked this pull request as draft December 14, 2024 01:18
@ToucheSir
Copy link
Member

ToucheSir commented Dec 14, 2024

The most appropriate place for this may be Optimisers.jl, as the technique could be applicable to ADs beyond Zygote. That said, I'm not quite sure I understand how it's meant to work. The abridged example in the docstring does not look like the PyTorch one Carlo shared.

Is there a complete, working minimal example that demonstrates this functionality? The main thing I'd like to understand is how it would pick up on the final accumulated gradients for a parameter.

@CarloLucibello
Copy link
Member

CarloLucibello commented Dec 14, 2024

I think Flux is the appropriate place for this. And I would simplify the interface to

eager_update!(opt_state, model, xs...) = model(xs...)
eager_update!(f::Function, opt_state, model, xs...) = f(model, xs...)

since we can assume that Optimisers.update! will be the update function and opt_state is never a Function.

@murrellb
Copy link
Author

I think Flux is the appropriate place for this. And I would simplify the interface to

Perhaps, but it is possible that this might be useful for non-Optimisers.jl opts, like if someone rolls their own (not using Optimisers), or is outside of the flux ecosystem entirely. Maybe Flux should have something with your simplified interface that calls this version? Note: rolling your own optimizer is very likely when working with very large models.

The most appropriate place for this may be Optimisers.jl, as the technique could be applicable to ADs beyond Zygote.

I was originally considering an rrule for this but then there was some discussion on checkpointed that argued for using Zygote directly, and I didn't follow the reasoning there, so I decided to just follow the form of checkpointed.

Overall, for discussions of where this should go: it should probably be wherever checkpointed is, and that is here.

Is there a complete, working minimal example that demonstrates this functionality?

I was going to push one to Jjama3.jl when we've finalized the interface, but this is the core of it:

#Model function def, where model.layers[i] is a Jjama3 TransformerBlock (just an RMSnorm -> attention -> RMSnorm -> feed_forward)
function forward(model::Transformer, tokens::AbstractArray{Int}, opt_state) #<- Note opt_state passed in here
    h = model.tok_embeddings(tokens)
    rope = model.rope[model.pos+1:model.pos+size(tokens, 1)]
    mask = Jjama3.create_mask(h)
    for i in 1:length(model.layers)
        #Insted of:
        #h = model.layers[i](h, model.pos, rope, mask)
        #We do:
        h = Zygote.eager_update!((model.layers[i], h, model.pos, rope, mask), (Optimisers.update!, opt_state.layers[i]))    
    end
    h = model.norm(h)
    output = model.output(h)
    return output
end

#Normal opt_state setup:
opt_state = Flux.setup(Apollo(lr), model)

#Training loop:
for i in 1:1000000000
    train_toks = ...
    l, grads = Flux.withgradient(model) do m
        logit_loss(forward(m, train_toks[1:end-1,:], opt_state, eager = true), train_toks[2:end,:])
    end
    #Then a catch-all for any parameters not updated eagerly:
    Flux.update!(opt_state, model, grads[1])
end

I have verified that this is training a model where you can't fit both the weights and the gradients on the GPU, in combination with the new Optimiser PR I opened which removes the need to also store the moments of the gradients (using a low-rank trick): FluxML/Optimisers.jl#196 (comment)

Together, these bring the memory footprint from a min of 4x the weights (weights + grads+ moment1 + moment2) down to ~1.3x the weights (with some overhead for the low-rank projections and the activations themselves). I've tested up to a 7.5 billion parameter model.

@CarloLucibello
Copy link
Member

Perhaps, but it is possible that this might be useful for non-Optimisers.jl opts, like if someone rolls their own (not using Optimisers), or is outside of the flux ecosystem entirely. Maybe Flux should have something with your simplified interface that calls this version? Note: rolling your own optimizer is very likely when working with very large models.

ok, we can have a wrapper in flux

@ToucheSir
Copy link
Member

ToucheSir commented Dec 14, 2024

I was going to push one to Jjama3.jl when we've finalized the interface, but this is the core of it:

Thanks, this really helps. It seems like the main difference here is that eager_update! does not behave like PyTorch's register_post_accumulate_grad_hook. It's more of a register_after_any_grad_hook. That means we either:

  1. Have this mechanism hold onto intermediate (i.e. non-accumulated) gradients somehow and accumulate them in-place until some signal that everything has been accumulated.
  2. Add guardrails (at the bare, bare minimum docs, but ideally code) that prevents people from using eager_update! with non-accumulated gradients. e.g. error on forward_incorrect here:
shared_layer = Dense(...)

model = (;
  branch_1 = Chain(shared_layer, more_layers...),
  branch_2 = Chain(other_layer, shared_layer, another_layer)
)
opt_state = opt_state = Flux.setup(Apollo(lr), model)

function forward_incorrect(model, x, opt_state) #<- Note opt_state passed in here
  # using the `eager_update!(opt_state, model, xs...) = model(xs...)` method proposed above for brevity

  # optimizer step run for shared_layer using branch_1 gradients only!
  y_1 = eager_update!(opt_state.branch_1, model.branch_1, x)
  # optimizer step run for shared_layer using branch_2 gradients only!
  y_2 = eager_update!(opt_state.branch_2, model.branch_2, x)

  return y_1 + y_2
end

function forward_correct(model, x, opt_state) #<- Note opt_state passed in here
  # using the `eager_update!(f::Function, opt_state, model, xs...) = f(model, xs...)` method proposed above for brevity

  # optimizer step run for shared_layer using accumulated gradients from both branches
  y = eager_update!(opt_state, model, x) do model, x
    y_1 = model.branch_1(x)
    y_2 = model.branch_2(x)
    return y_1 + y_2
  end

  return y
end

The problem I see is that one has to wrap every code path shared_layer could get gradients from in order to make the gradient update correct (i.e. equal to not using checkpointing), but then eager_update! becomes just as inefficient as not using eager_update!.

I was originally considering an rrule for this but then there was some discussion on checkpointed that argued for using Zygote directly, and I didn't follow the reasoning there, so I decided to just follow the form of checkpointed.
Overall, for discussions of where this should go: it should probably be wherever checkpointed is, and that is here.

I think that may be a misunderstanding based on out-of-date historical discussion? Zygote is basically on life support at this point, and Flux wants to be rid of it as soon as possible. As such, nice new functionality should ideally find a different home.

@murrellb
Copy link
Author

Add guardrails (at the bare, bare minimum docs, but ideally code) that prevents people from using eager_update! with non-accumulated gradients. e.g. error on forward_incorrect here:

Agreed re: docs. I had this warning in my first commit:
image
...but it looks like I accidentally dropped it in an update. To me this is like the restriction that checkpointed should be a pure function otherwise you will get the wrong answer.

The problem I see is that one has to wrap every code path shared_layer could get gradients from in order to make the gradient update correct (i.e. equal to not using checkpointing), but then eager_update! becomes just as inefficient as not using eager_update!.

Models with a repeating core that don't have any shared parameters across the repeating layers are a large and critical class, so having a simple trick that helps with these but doesn't help when there are are shared parameters seems fine to me? If a layer shares all its parameters then you don't do this sort of thing, and if it shares some of its parameters then often you can rewrite the layers themselves to separate out the components that share parameters and those that don't, and then use this for the components that don't.

I think that may be a misunderstanding based on #884?

Yes, that was the discussion I saw. And as I said I couldn't follow the argument to know whether or not this would fit as an rrule. Is there an rrule for the equivalent of checkpointed? Or is that grandfathered into Zygote?

Zygote is basically on life support at this point, and Flux wants to be rid of it as soon as possible. As such, nice new functionality should ideally find a different home.

Well Enzyme errors whenever we look at it, so I hope this shift isn't too precipitous. But then this becomes a question of whether the Flux wrapper for this should use eg. an rrule version of this trick (instead of the Zygote one) and not a question of whether or not this should be in Zygote.

@ToucheSir
Copy link
Member

ToucheSir commented Dec 15, 2024

The docstring warning LGTM and should probably be sufficient for now. If this feature becomes widely used, we can think about more guardrails.

Yes, that was the discussion I saw. And as I said I couldn't follow the argument to know whether or not this would fit as an rrule. Is there an rrule for the equivalent of checkpointed? Or is that grandfathered into Zygote?

The feature that would allow someone to write a checkpointed function using ChainRules(Core) was added after that the Zygote PR that added checkpointed. Although it's been used for similar things, nobody got around to adding it. I suspect this is because few people needed checkpointing, and Zygote was/is the only widely-used AD that understands ChainRules.

On that note, this discussion reminded me of another Zygote utility function: hook. We could save some logic by writing eager_update! in terms of hook and checkpointed:

function eager_update!(f, (model, xs...), (update!, state))
    function update_hook(dmodel)
        update!(state, model, dmodel)
        return nothing
    end
    return Zygote.checkpointed(f, Zygote.hook(update_hook, model), xs...)
end

While we're at it, perhaps the interface could be simplified as well. I think the key here is realizing that checkpointing and eager updates can be decoupled:

function eager_update!(model, opt_state, update! = Optimisers.update!)
    function update_hook(dmodel)
        update!(opt_state, model, dmodel)
        return nothing
    end
    return Zygote.hook(update_hook, model)
end

# So instead of
eager_update!(f, (model, xs...), (opt_state, Optimisers.update!))

# You could write
Zygote.checkpointed(f, eager_update!(model, opt_state), xs...)

# Or even
Zygote.checkpointed(f, eager_update!(model1, opt_state1), eager_update!(model2, opt_state2), xs...)

# Or not checkpoint, which would be equivalent to https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
f(eager_update!(model, opt_state), xs...)

Either way, I think this exercise suggests that eager_update! should live in Flux because we can then have a method which defaults update! to Optimisers.update!.

@mcabbott
Copy link
Member

Maybe Fluxperimental.jl is the right place? I see roughly what this is, but perhaps the very best way to wrap this up isn't entirely clear yet.

@ToucheSir
Copy link
Member

Fluxperimental could work too. Another motivation for having this in a higher-level package than Zygote is that we could define an overload of this function (and checkpointed) for Enzyme too.

@CarloLucibello
Copy link
Member

I think the key here is realizing that checkpointing and eager updates can be decoupled:

wow this looks very nice. For discoverability I would prefer to have eager_update!(model, opt_state, update! = Optimisers.update!) in Flux rather than Fluxperimental, maybe saying in the docstring that function is still experimental and works only with Zygote.
It is clear that we want to have this functionality as a part of the stable api at some point, and I think that the elegance of Brian's proposal is hard to beat, so it could be the case that we won't have to revise the interface.

@murrellb
Copy link
Author

murrellb commented Dec 15, 2024

We could save some logic by writing eager_update! in terms of hook and checkpointed:

Yes this is the way. Two points:

  • The arguments should should probably be eager_update!(state, model) to match the ordering that Optimisers.jl's uses with update!(state,model,grads).
  • This should live in Zygote as eager_update!(state, model, update!), and then Flux.jl can have a wrapper where the 3rd argument is, by default, Optimisers.update!. This allows one to use it in a non-Flux setup that still uses Zygote, and allows Flux to easily switch what the wrapper calls when they switch over their backend?

I've updated my PR accordingly.

@murrellb murrellb marked this pull request as ready for review December 16, 2024 00:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants