-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix tp mem cache #203
Fix tp mem cache #203
Conversation
…rentiable distributed operations
dtype=tensor.dtype, | ||
requires_grad=tensor.requires_grad, | ||
) | ||
unsharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pass device=tensor.device
and requires_grad=tensor.requires_grad
as well ?
dtype=tensor.dtype, | ||
requires_grad=tensor.requires_grad, | ||
sharded_tensor = MemoryBuffer().get( | ||
"dist", (unsharded_batch_size // group.size(), *rest_size), dtype=tensor.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
nice catch ! lgtm |
Thanks for the comments! I have one question regarding the requires_grad. In principle we shouldn't require gradient on the gathered tensors, right? The custom backward handles the gradient computation for the parameters anyway. At least training runs seamlessly without setting gradient to those tensors. |
Yeah I think we can set it to False here (seems like Megatron does the same as well here) |
Hi, Thanks for the PR, it's really nice! I tested your PR by training 100 steps on the Tiny Story dataset and compared the loss with our code. I found an abnormal difference. Could you observe the same thing on your side? This is my config file, you may have to change it a little bit, but the idea is to compare the loss before and after the change with the same hyperparameters. Thanks a lot for the work.
|
That's interesting. Thanks for letting me know, I will investigate further and come back with some results soon :) |
I was able to reproduce the error. To fix the issue, I followed megatron's design and fused the all gather and linear operation in a single module. The loss progression now matches with the main branch. I added two configurations of the optimization. Given that the all-gather and linear are in the same module, we can control whether to recompute or not (and cache it instead) the all-gather during the backward. As expected, recomputing yields larger memory savings at the cost of throughput (but still, both methods are more memory-efficient than the current implementation, and both provide at least comparable tok/sec than the current main). The configuration is I attach wandb logs of four runs on two different configurations that validate the claims. On both, blue is the baseline (main branch) implementation, red is the wrong first version of this PR, green is the no-recompute mode (moderate memory savings and slightly faster than baseline) and purple is the recompute mode (large memory savings and on average as fast as the baseline). The first plots correspond to the tiny llama configuration you shared before. The second plot corresponds to a llama8b run. Except for the wrong plot, all lines are pretty much identical in the Let me know if you have any suggestions. |
sub_grad_input = torch.empty( | ||
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False | ||
) | ||
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like dist.reduce_scatter needs grad_input
to be contiguous
(cf https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305)
I am not sure if grad_input = grad_output @ weight
is contiguous (although you have grad_output = grad_output.contiguous()
). Maybe to be sure, we should grad_input = grad_input.contiguous()
before running the reduce_scatter ? what do you think ?
Updated the PR with the suggestions mentioned! Let me know if I'm missing something. |
all points were addressed, LGTM ! |
Nanotron seems to consume disproportionately more memory on its activations compared to megatron. This is due to at least the following factors:
Attached: Memory traces of the default nanotron implementation (which OOMs), the current PR implementation and megatron. The memory traces represent the first rank of a tp8 pp4 dp1 llama70b 5 iteration run (sequence length 8k, microbatch size of 1, accumulation=4, synchronous tp and reduc_scatter mode).
I think these changes are important, as it allows training larger models with significantly less memory requirements.
Let me know if you have any suggestions, and I'd be happy to make adjustments to upstream this feature! :)