Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Stage3: Use new torch grad accumulation hooks API
* This commit addresses an issue reported in: microsoft#6718 * The existing code has been using the grad_acc node hook to reduce params grads. The constructs such as param.data = replicated_tensor.data used in allgather_params(..) are compiled into param.set() causing the hook assigned to the grad_acc node not being called. * This is a known torch issue pytorch/pytorch#139742. * The above caused accuracy issues and could be temporarily solved by simply disabling the torch compile when activation checkpointing is used. * This commit provides a clean solution by replacing the hook on a grad_acc node to a hook using a new and robust hook API on a param itself: param.register_post_accumulate_grad_hook(..)
- Loading branch information