Skip to content

Commit

Permalink
revert more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
lljbash committed Jan 19, 2024
1 parent 290ee88 commit ff3baf3
Showing 1 changed file with 3 additions and 43 deletions.
46 changes: 3 additions & 43 deletions internlm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import torch
import torch.nn.functional as F
# from flash_attn.ops.fused_dense import FusedDenseFunc
from flash_attn.ops.fused_dense import FusedDenseFunc
from flash_attn.utils.distributed import (
all_gather_raw,
all_reduce_raw,
reduce_scatter_raw,
)
from torch import Tensor
from torch.cuda.amp import custom_fwd, custom_bwd
from torch.cuda.amp import custom_bwd
from torch.distributed import ProcessGroup

from internlm.core.context import global_context as gpc
Expand Down Expand Up @@ -96,49 +96,9 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):


# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(torch.autograd.Function):
class FusedDenseFuncTorch(FusedDenseFunc):
"""A custom PyTorch module extending FusedDenseFunc."""

@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
sequence_parallel=True):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
"""
ctx.compute_weight_gradient = weight.requires_grad
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel

if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None and sequence_parallel:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x

if torch.is_autocast_enabled():
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
weight = weight.contiguous()
if process_group is not None and sequence_parallel:
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
# https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
if min(batch_dim, n, *weight.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
else:
ctx.save_for_backward(weight)
return output if not return_residual else (output, x)

@staticmethod
@custom_bwd
def backward(ctx, grad_output, *args):
Expand Down

0 comments on commit ff3baf3

Please sign in to comment.