Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stage3: Use new torch grad accumulation hooks API #6773

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

deepcharm
Copy link
Contributor

  • This commit addresses a Deepspeed issue #6718
  • The existing code has been using the grad_acc node hook to reduce params grads.
    The constructs such as param.data = replicated_tensor.data used in allgather_params(..)
    are compiled into param.set() causing the hook assigned to the grad_acc node not being called.
  • The above caused accuracy issues and could be temporarily solved by simply disabling the torch compile when activation checkpointing is used.
  • This commit provides a clean solution by replacing the hook on a grad_acc node to a hook using a new and robust hook API on a param itself: param.register_post_accumulate_grad_hook(..)

* This commit addresses an issue reported in:
  microsoft#6718
* The existing code has been using the grad_acc node hook to reduce params grads.
  The constructs such as param.data = replicated_tensor.data used in
  allgather_params(..) are compiled into param.set() causing the hook assigned
  to the grad_acc node not being called.
* This is a known torch issue pytorch/pytorch#139742.
* The above caused accuracy issues and could be temporarily solved by simply
  disabling the torch compile when activation checkpointing is used.
* This commit provides a clean solution by replacing the hook on a grad_acc node
  to a hook using a new and robust hook API on a param itself:
  param.register_post_accumulate_grad_hook(..)
self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads))
self.grad_accs.append(grad_acc)
self._grad_acc_hooks.append(
param.register_post_accumulate_grad_hook(reduce_partition_and_remove_grads))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which pytorch version introduced this API? How should we handle older versions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants