From ffe7970a89bc9d6d3d6e0dbe72e3e7daa2dc2c87 Mon Sep 17 00:00:00 2001 From: POI-WX Date: Mon, 29 Jan 2024 21:22:52 +0800 Subject: [PATCH] draft for supporting flash attention --- csrc/extensions.cpp | 70 ++++++++++++++++ .../internlm_ops/mha/fa_kvpacked_func.py | 49 +++++++++++ .../internlm_ops/mha/fa_qkvpacked_func.py | 48 +++++++++++ deeplink_ext/internlm_ops/mha/mha.py | 83 +++++++++++-------- 4 files changed, 217 insertions(+), 33 deletions(-) create mode 100644 deeplink_ext/internlm_ops/mha/fa_kvpacked_func.py create mode 100644 deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py diff --git a/csrc/extensions.cpp b/csrc/extensions.cpp index 09c6e3e2..654ef942 100644 --- a/csrc/extensions.cpp +++ b/csrc/extensions.cpp @@ -125,6 +125,70 @@ auto extMultiHeadAttentionBackward(const at::Tensor& grad_out, std::move(grad_v)); } +auto extFlashAttention(const at::Tensor& q, const at::Tensor& k, + const at::Tensor& v, double p_dropout, + double softmax_scale, bool is_causal) { + const auto batch_size = q.sizes()[0]; + const auto q_seq_len = q.sizes()[1]; + const auto head_num = q.sizes()[2]; + const auto k_seq_len = k.sizes()[1]; + + auto out = at::empty_like(q); + // auto softmax_max = at::empty_like(q); + // auto softmax_sum = at::empty_like(q); + // auto softmax_out = at::empty_like(q); + + // const IntArray softmax_max_size{batch_size, head_num, q_seq_len, 8}; + // const auto softmax_lse_option = q.options().dtype(at::kFloat); + // auto softmax_lse = at::empty(softmax_lse_size, softmax_lse_option); + + // const IntArray softmax_sum_size{batch_size, head_num, q_seq_len, 8}; + // const auto softmax_lse_option = q.options().dtype(at::kFloat); + // auto softmax_lse = at::empty(softmax_lse_size, softmax_lse_option); + + // const IntArray softmax_out_size{batch_size, head_num, q_seq_len, 8}; + // const auto softmax_lse_option = q.options().dtype(at::kFloat); + // auto softmax_lse = at::empty(softmax_lse_size, softmax_lse_option); + + // softmax_max = OpPreparation::apply_tensor_without_format( + // {B, head_num, S0, 8}, + // query.options().dtype(at::kFloat)); // [B, N, S0, 8] + // softmax_sum = OpPreparation::apply_tensor_without_format( + // {B, head_num, S0, 8}, + // query.options().dtype(at::kFloat)); // [B, N, S0, 8] + // softmax_out = at::empty({0}, query.options()); + + auto gen = createDIPUGenerator(); + + callDiopi(diopiFlashAttention, out, gen, softmax_max, softmax_sum, + softmax_out, q, k, v, p_dropout, softmax_scale, is_causal); + return std::make_tuple(std::move(out), std::move(softmax_max), + std::move(softmax_sum), std::move(softmax_out), + std::move(gen)); +} + +// grad_q, grad_k, grad_v are output args, and should be pre-allocated. +auto extFlashAttentionBackward(c10::optional& grad_q_opt, + c10::optional& grad_k_opt, + c10::optional& grad_v_opt, + const at::Tensor& grad_out, const at::Tensor& q, + const at::Tensor& k, const at::Tensor& v, + const at::Tensor& out, + const at::Tensor& softmax_max, + const at::Tensor& softmax_sum, + const at::Tensor& softmax_out, + at::Generator& gen, double p_dropout, + double softmax_scale, bool is_causal) { + auto grad_q = grad_q_opt.has_value() ? grad_q_opt.value() : at::empty_like(q); + auto grad_k = grad_k_opt.has_value() ? grad_k_opt.value() : at::empty_like(k); + auto grad_v = grad_v_opt.has_value() ? grad_v_opt.value() : at::empty_like(v); + callDiopi(diopiFlashAttentionBackward, grad_q, grad_k, grad_v, grad_out, q, k, + v, out, softmax_max, softmax_sum, softmax_out, gen, p_dropout, + softmax_scale, is_casual); + return std::make_tuple(std::move(grad_q), std::move(grad_k), + std::move(grad_v)); +} + auto extMultiHeadAttentionVarLen(at::Tensor& q, at::Tensor& k, at::Tensor& v, const at::Tensor& cum_seq_q, const at::Tensor& cum_seq_k, @@ -275,6 +339,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("mha_varlen_bwd", &extMultiHeadAttentionVarLenBackward, "deeplink ext_mha_varlen_bwd"); } + if (&diopiFlashAttention != nullptr) { + m.def("fa_fwd", &extFlashAttention, "deeplink ext_fa_fwd"); + } + if (&diopiFlashAttentionBackward != nullptr) { + m.def("fa_bwd", &extFlashAttentionBackward, "deeplink ext_fa_bwd"); + } if (&diopiDestIndexCopyKV != nullptr) { m.def("dest_index_copy_kv", &extDestIndexCopyKV, "deeplink ext_dest_index_copy_kv"); diff --git a/deeplink_ext/internlm_ops/mha/fa_kvpacked_func.py b/deeplink_ext/internlm_ops/mha/fa_kvpacked_func.py new file mode 100644 index 00000000..6c6d2509 --- /dev/null +++ b/deeplink_ext/internlm_ops/mha/fa_kvpacked_func.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024, DeepLink. + +import torch +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd") + + +class DeepLinkFlashAttentionKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, kv, dropout_p, softmax_scale, causal): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_max, softmax_sum, softmax_out, rng = ext.fa_fwd( + q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal + ) + ctx.save_for_backward( + q, kv, out, softmax_max, softmax_sum, softmax_out, rng.get_state() + ) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out + + @staticmethod + def backward(ctx, dout): + q, kv, out, softmax_max, softmax_sum, softmax_out, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + rng = torch.Generator(device=q.device) + rng.set_state(rng_state) + ext.fa_bwd( + dq, + dkv[:, :, 0], + dkv[:, :, 1], + dout, + q, + kv[:, :, 0], + kv[:, :, 1], + out, + softmax_max, + softmax_sum, + softmax_out, + rng, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ) + return dq, dkv, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py b/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py new file mode 100644 index 00000000..87ef312f --- /dev/null +++ b/deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, DeepLink. + +import torch +import deeplink_ext.cpp_extensions as ext + +assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd") + + +class DeepLinkFlashAttentionQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, dropout_p, softmax_scale, causal): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, softmax_max, softmax_sum, softmax_out, rng = ext.fa_fwd( + qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, causal + ) + ctx.save_for_backward( + qkv, out, softmax_max, softmax_sum, softmax_out, rng.get_state() + ) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out + + @staticmethod + def backward(ctx, dout): + qkv, out, softmax_max, softmax_sum, softmax_out, rng_state = ctx.saved_tensors + dqkv = torch.empty_like(qkv) + rng = torch.Generator(device=qkv.device) + rng.set_state(rng_state) + ext.fa_bwd( + dqkv[:, :, 0], + dqkv[:, :, 1], + dqkv[:, :, 2], + dout, + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + out, + softmax_max, + softmax_sum, + softmax_out, + rng, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ) + return dqkv, None, None, None, None diff --git a/deeplink_ext/internlm_ops/mha/mha.py b/deeplink_ext/internlm_ops/mha/mha.py index c718ecd2..be650336 100644 --- a/deeplink_ext/internlm_ops/mha/mha.py +++ b/deeplink_ext/internlm_ops/mha/mha.py @@ -1,10 +1,12 @@ # Copyright (c) 2023, DeepLink. import torch.nn as nn -from .mha_qkvpacked_func import DeepLinkMultiHeadAttentionQKVPackedFunc -from .mha_varlen_qkvpacked_func import DeepLinkMultiHeadAttentionVarLenQKVPackedFunc -from .mha_kvpacked_func import DeepLinkMultiHeadAttentionKVPackedFunc -from .mha_varlen_kvpacked_func import DeepLinkMultiHeadAttentionVarLenKVPackedFunc +# from .mha_qkvpacked_func import DeepLinkMultiHeadAttentionQKVPackedFunc +# from .mha_varlen_qkvpacked_func import DeepLinkMultiHeadAttentionVarLenQKVPackedFunc +# from .mha_kvpacked_func import DeepLinkMultiHeadAttentionKVPackedFunc +# from .mha_varlen_kvpacked_func import DeepLinkMultiHeadAttentionVarLenKVPackedFunc +from .fa_qkvpacked_func import DeepLinkFlashAttentionQKVPackedFunc +from .fa_kvpacked_func import DeepLinkFlashAttentionKVPackedFunc class DeepLinkSelfAttention(nn.Module): @@ -46,24 +48,31 @@ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): """ if cu_seqlens is None: # padded - return DeepLinkMultiHeadAttentionQKVPackedFunc.apply( + # return DeepLinkMultiHeadAttentionQKVPackedFunc.apply( + # qkv, + # self.dropout_p if self.training else 0.0, + # self.softmax_scale, + # causal if causal is not None else self.causal, + # False, + # ) + # for ascend + return DeepLinkFlashAttentionQKVPackedFunc.apply( qkv, self.dropout_p if self.training else 0.0, self.softmax_scale, causal if causal is not None else self.causal, - False, - ) - else: - # unpadded - return DeepLinkMultiHeadAttentionVarLenQKVPackedFunc.apply( - qkv, - cu_seqlens, - max_seqlen, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, ) + # else: + # # unpadded + # return DeepLinkMultiHeadAttentionVarLenQKVPackedFunc.apply( + # qkv, + # cu_seqlens, + # max_seqlen, + # self.dropout_p if self.training else 0.0, + # self.softmax_scale, + # causal if causal is not None else self.causal, + # False, + # ) class DeepLinkCrossAttention(nn.Module): @@ -85,25 +94,33 @@ def forward( ): if cu_seqlens_q is None: # padded - return DeepLinkMultiHeadAttentionKVPackedFunc.apply( - q, - kv, - self.dropout_p if self.training else 0.0, - self.softmax_scale, - causal if causal is not None else self.causal, - False, - ) - else: - # unpadded - return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply( + # return DeepLinkMultiHeadAttentionKVPackedFunc.apply( + # q, + # kv, + # self.dropout_p if self.training else 0.0, + # self.softmax_scale, + # causal if causal is not None else self.causal, + # False, + # ) + # for ascend + return DeepLinkFlashAttentionKVPackedFunc.apply( q, kv, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, self.dropout_p if self.training else 0.0, self.softmax_scale, causal if causal is not None else self.causal, - False, ) + # else: + # # unpadded + # return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply( + # q, + # kv, + # cu_seqlens_q, + # cu_seqlens_k, + # max_seqlen_q, + # max_seqlen_k, + # self.dropout_p if self.training else 0.0, + # self.softmax_scale, + # causal if causal is not None else self.causal, + # False, + # )