Skip to content

Commit

Permalink
remove contiguous copy for flash-attn opbuilder (microsoft#372)
Browse files Browse the repository at this point in the history
* remove unnecessary codes for latest flash-attn opbuilder

* add use-flash-attn-builder to make flash_attn usage clear and compatible

* use hasattr
  • Loading branch information
YizhouZ authored Apr 2, 2024
1 parent 888a63a commit 3c5f475
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 22 deletions.
4 changes: 3 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def validate_args(args, defaults={}):
args.compression_training = False

# FlashAttention
args.use_flash_attn = args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2
args.use_flash_attn = args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or args.use_flash_attn_builder

# AML
if args.aml_data_download_path is not None:
Expand Down Expand Up @@ -910,6 +910,8 @@ def _add_training_args(parser):
'https://arxiv.org/abs/2307.08691')
group.add_argument('--use-flash-attn-triton', action='store_true',
help='use FlashAttention implementation of attention using Triton.')
group.add_argument('--use-flash-attn-builder', action='store_true',
help='use FlashAttention op builder.')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
Expand Down
63 changes: 42 additions & 21 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,19 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,

# Use FlashAttention-2 when args.use_flash_attn_v2 is True
args = get_args()
self.flash_attn_func = flash_attn_varlen_func if args.use_flash_attn_v2 else flash_attn_unpadded_func
self.use_flash_attn_builder_v1 = False
self.use_flash_attn_builder_v2 = False
self.use_flash_attn = False
if args.use_flash_attn_builder:
if hasattr(flash_attn_builder, 'flash_attn_func'):
self.flash_attn_func = flash_attn_builder.flash_attn_func
self.use_flash_attn_builder_v1 = True
else:
self.flash_attn_func = flash_attn_builder.flash_attn_func_v2
self.use_flash_attn_builder_v2 = True
else:
self.flash_attn_func = flash_attn_varlen_func if args.use_flash_attn_v2 else flash_attn_unpadded_func
self.use_flash_attn = True

def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Expand All @@ -392,22 +404,19 @@ def forward(self, q, k, v):

assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((get_accelerator().on_accelerator(i) for i in (q, k, v)))
# if get_accelerator().device_name() == 'cuda':
# assert all((i.is_cuda for i in (q,k,v)))
# else:
# assert all((i.is_xpu for i in (q,k,v)))

batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]

if get_accelerator().device_name() == 'cuda':
# goes for cuda device
if self.use_flash_attn:
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
else:
# goes for other device
elif self.use_flash_attn_builder_v1:
q, k, v = [rearrange(x, 'b s h d -> b h s d').contiguous() for x in [q, k, v]]
else:
# use_flash_attn_builder_v2
q, k, v = [rearrange(x, 'b s h d -> b h s d') for x in [q, k, v]]

if self.training:
# during training q,k,v always have same seqlen
Expand All @@ -424,16 +433,26 @@ def forward(self, q, k, v):
device=q.device) if get_accelerator().device_name() == 'cuda' else None
dropout_p = 0

output = self.flash_attn_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal
) if get_accelerator().device_name() == 'cuda' else flash_attn_builder.flash_attn_func(
q, k, v, self.dropout_p, self.softmax_scale, is_causal
)
if self.use_flash_attn:
output = self.flash_attn_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal
)
else:
# use_flash_attn_builder
output = self.flash_attn_func(
q, k, v, self.dropout_p, self.softmax_scale, is_causal
)

if self.use_flash_attn:
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
elif self.use_flash_attn_builder_v1:
output = rearrange(output, 'b h s d -> b s h d').contiguous()
else:
# use_flash_attn_builder_v2:
output = rearrange(output, 'b h s d -> b s h d')

output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) if get_accelerator().device_name() == 'cuda' else rearrange(
output, 'b h s d -> b s h d').contiguous()
return output

class FlashSelfAttentionTriton(torch.nn.Module):
Expand Down Expand Up @@ -492,7 +511,8 @@ def __init__(self, config, layer_number,
self.num_key_value_heads = config.num_key_value_heads
self.use_gqa = (self.num_attention_heads != self.num_key_value_heads)

self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2) \
self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or \
args.use_flash_attn_builder) \
and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal
self.use_flash_attn_triton = args.use_flash_attn_triton
Expand All @@ -504,12 +524,13 @@ def __init__(self, config, layer_number,
flash_attn_builder = None

if args.use_flash_attn_v1:
assert flash_attn_unpadded_func != None or flash_attn_builder != None, ("Cannot import FlashAttention v1 "
"and Cannot find FlashAttention Builder")
assert flash_attn_unpadded_func != None, "Cannot import FlashAttention v1 "
if args.use_flash_attn_v2:
assert flash_attn_varlen_func != None, "Cannot import FlashAttention v2 "
if args.use_flash_attn_triton:
assert flash_attn_func != None, "Cannot import FlashAttention triton "
if args.use_flash_attn_builder:
assert flash_attn_builder != None, "Cannot find FlashAttention op builder "

assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
Expand Down

0 comments on commit 3c5f475

Please sign in to comment.