From 3c5f47563f697702c1e305fa01b7563f54b747fc Mon Sep 17 00:00:00 2001 From: Yizhou Wang Date: Wed, 3 Apr 2024 01:50:47 +0800 Subject: [PATCH] remove contiguous copy for flash-attn opbuilder (#372) * remove unnecessary codes for latest flash-attn opbuilder * add use-flash-attn-builder to make flash_attn usage clear and compatible * use hasattr --- megatron/arguments.py | 4 ++- megatron/model/transformer.py | 63 +++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 461a2441c3..631d4b12e8 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -421,7 +421,7 @@ def validate_args(args, defaults={}): args.compression_training = False # FlashAttention - args.use_flash_attn = args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 + args.use_flash_attn = args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or args.use_flash_attn_builder # AML if args.aml_data_download_path is not None: @@ -910,6 +910,8 @@ def _add_training_args(parser): 'https://arxiv.org/abs/2307.08691') group.add_argument('--use-flash-attn-triton', action='store_true', help='use FlashAttention implementation of attention using Triton.') + group.add_argument('--use-flash-attn-builder', action='store_true', + help='use FlashAttention op builder.') group.add_argument('--disable-bias-linear', action='store_false', help='Disable bias in the linear layers', dest='add_bias_linear') diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 924f46c688..3449cf26f3 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -381,7 +381,19 @@ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, # Use FlashAttention-2 when args.use_flash_attn_v2 is True args = get_args() - self.flash_attn_func = flash_attn_varlen_func if args.use_flash_attn_v2 else flash_attn_unpadded_func + self.use_flash_attn_builder_v1 = False + self.use_flash_attn_builder_v2 = False + self.use_flash_attn = False + if args.use_flash_attn_builder: + if hasattr(flash_attn_builder, 'flash_attn_func'): + self.flash_attn_func = flash_attn_builder.flash_attn_func + self.use_flash_attn_builder_v1 = True + else: + self.flash_attn_func = flash_attn_builder.flash_attn_func_v2 + self.use_flash_attn_builder_v2 = True + else: + self.flash_attn_func = flash_attn_varlen_func if args.use_flash_attn_v2 else flash_attn_unpadded_func + self.use_flash_attn = True def forward(self, q, k, v): """Implements the multihead softmax attention. @@ -392,22 +404,19 @@ def forward(self, q, k, v): assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v))) assert all((get_accelerator().on_accelerator(i) for i in (q, k, v))) - # if get_accelerator().device_name() == 'cuda': - # assert all((i.is_cuda for i in (q,k,v))) - # else: - # assert all((i.is_xpu for i in (q,k,v))) batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = k.shape[1] - if get_accelerator().device_name() == 'cuda': - # goes for cuda device + if self.use_flash_attn: q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]] cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device) - else: - # goes for other device + elif self.use_flash_attn_builder_v1: q, k, v = [rearrange(x, 'b s h d -> b h s d').contiguous() for x in [q, k, v]] + else: + # use_flash_attn_builder_v2 + q, k, v = [rearrange(x, 'b s h d -> b h s d') for x in [q, k, v]] if self.training: # during training q,k,v always have same seqlen @@ -424,16 +433,26 @@ def forward(self, q, k, v): device=q.device) if get_accelerator().device_name() == 'cuda' else None dropout_p = 0 - output = self.flash_attn_func( - q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, - dropout_p, - softmax_scale=self.softmax_scale, causal=is_causal - ) if get_accelerator().device_name() == 'cuda' else flash_attn_builder.flash_attn_func( - q, k, v, self.dropout_p, self.softmax_scale, is_causal - ) + if self.use_flash_attn: + output = self.flash_attn_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, + dropout_p, + softmax_scale=self.softmax_scale, causal=is_causal + ) + else: + # use_flash_attn_builder + output = self.flash_attn_func( + q, k, v, self.dropout_p, self.softmax_scale, is_causal + ) + + if self.use_flash_attn: + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + elif self.use_flash_attn_builder_v1: + output = rearrange(output, 'b h s d -> b s h d').contiguous() + else: + # use_flash_attn_builder_v2: + output = rearrange(output, 'b h s d -> b s h d') - output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) if get_accelerator().device_name() == 'cuda' else rearrange( - output, 'b h s d -> b s h d').contiguous() return output class FlashSelfAttentionTriton(torch.nn.Module): @@ -492,7 +511,8 @@ def __init__(self, config, layer_number, self.num_key_value_heads = config.num_key_value_heads self.use_gqa = (self.num_attention_heads != self.num_key_value_heads) - self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2) \ + self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or \ + args.use_flash_attn_builder) \ and attention_type == AttnType.self_attn \ and self.attn_mask_type == AttnMaskType.causal self.use_flash_attn_triton = args.use_flash_attn_triton @@ -504,12 +524,13 @@ def __init__(self, config, layer_number, flash_attn_builder = None if args.use_flash_attn_v1: - assert flash_attn_unpadded_func != None or flash_attn_builder != None, ("Cannot import FlashAttention v1 " - "and Cannot find FlashAttention Builder") + assert flash_attn_unpadded_func != None, "Cannot import FlashAttention v1 " if args.use_flash_attn_v2: assert flash_attn_varlen_func != None, "Cannot import FlashAttention v2 " if args.use_flash_attn_triton: assert flash_attn_func != None, "Cannot import FlashAttention triton " + if args.use_flash_attn_builder: + assert flash_attn_builder != None, "Cannot find FlashAttention op builder " assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' 'self-attention for now')