Skip to content

Commit

Permalink
draft for supporting flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
POI-WX committed Jan 29, 2024
1 parent 9cd4dcc commit ffe7970
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 33 deletions.
70 changes: 70 additions & 0 deletions csrc/extensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,70 @@ auto extMultiHeadAttentionBackward(const at::Tensor& grad_out,
std::move(grad_v));
}

auto extFlashAttention(const at::Tensor& q, const at::Tensor& k,
const at::Tensor& v, double p_dropout,
double softmax_scale, bool is_causal) {
const auto batch_size = q.sizes()[0];
const auto q_seq_len = q.sizes()[1];
const auto head_num = q.sizes()[2];
const auto k_seq_len = k.sizes()[1];

auto out = at::empty_like(q);
// auto softmax_max = at::empty_like(q);
// auto softmax_sum = at::empty_like(q);
// auto softmax_out = at::empty_like(q);

// const IntArray softmax_max_size{batch_size, head_num, q_seq_len, 8};
// const auto softmax_lse_option = q.options().dtype(at::kFloat);
// auto softmax_lse = at::empty(softmax_lse_size, softmax_lse_option);

// const IntArray softmax_sum_size{batch_size, head_num, q_seq_len, 8};
// const auto softmax_lse_option = q.options().dtype(at::kFloat);
// auto softmax_lse = at::empty(softmax_lse_size, softmax_lse_option);

// const IntArray softmax_out_size{batch_size, head_num, q_seq_len, 8};
// const auto softmax_lse_option = q.options().dtype(at::kFloat);
// auto softmax_lse = at::empty(softmax_lse_size, softmax_lse_option);

// softmax_max = OpPreparation::apply_tensor_without_format(
// {B, head_num, S0, 8},
// query.options().dtype(at::kFloat)); // [B, N, S0, 8]
// softmax_sum = OpPreparation::apply_tensor_without_format(
// {B, head_num, S0, 8},
// query.options().dtype(at::kFloat)); // [B, N, S0, 8]
// softmax_out = at::empty({0}, query.options());

auto gen = createDIPUGenerator();

callDiopi(diopiFlashAttention, out, gen, softmax_max, softmax_sum,
softmax_out, q, k, v, p_dropout, softmax_scale, is_causal);
return std::make_tuple(std::move(out), std::move(softmax_max),
std::move(softmax_sum), std::move(softmax_out),
std::move(gen));
}

// grad_q, grad_k, grad_v are output args, and should be pre-allocated.
auto extFlashAttentionBackward(c10::optional<at::Tensor>& grad_q_opt,
c10::optional<at::Tensor>& grad_k_opt,
c10::optional<at::Tensor>& grad_v_opt,
const at::Tensor& grad_out, const at::Tensor& q,
const at::Tensor& k, const at::Tensor& v,
const at::Tensor& out,
const at::Tensor& softmax_max,
const at::Tensor& softmax_sum,
const at::Tensor& softmax_out,
at::Generator& gen, double p_dropout,
double softmax_scale, bool is_causal) {
auto grad_q = grad_q_opt.has_value() ? grad_q_opt.value() : at::empty_like(q);
auto grad_k = grad_k_opt.has_value() ? grad_k_opt.value() : at::empty_like(k);
auto grad_v = grad_v_opt.has_value() ? grad_v_opt.value() : at::empty_like(v);
callDiopi(diopiFlashAttentionBackward, grad_q, grad_k, grad_v, grad_out, q, k,
v, out, softmax_max, softmax_sum, softmax_out, gen, p_dropout,
softmax_scale, is_casual);
return std::make_tuple(std::move(grad_q), std::move(grad_k),
std::move(grad_v));
}

auto extMultiHeadAttentionVarLen(at::Tensor& q, at::Tensor& k, at::Tensor& v,
const at::Tensor& cum_seq_q,
const at::Tensor& cum_seq_k,
Expand Down Expand Up @@ -275,6 +339,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mha_varlen_bwd", &extMultiHeadAttentionVarLenBackward,
"deeplink ext_mha_varlen_bwd");
}
if (&diopiFlashAttention != nullptr) {
m.def("fa_fwd", &extFlashAttention, "deeplink ext_fa_fwd");
}
if (&diopiFlashAttentionBackward != nullptr) {
m.def("fa_bwd", &extFlashAttentionBackward, "deeplink ext_fa_bwd");
}
if (&diopiDestIndexCopyKV != nullptr) {
m.def("dest_index_copy_kv", &extDestIndexCopyKV,
"deeplink ext_dest_index_copy_kv");
Expand Down
49 changes: 49 additions & 0 deletions deeplink_ext/internlm_ops/mha/fa_kvpacked_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd")


class DeepLinkFlashAttentionKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, kv, dropout_p, softmax_scale, causal):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_max, softmax_sum, softmax_out, rng = ext.fa_fwd(
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal
)
ctx.save_for_backward(
q, kv, out, softmax_max, softmax_sum, softmax_out, rng.get_state()
)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out

@staticmethod
def backward(ctx, dout):
q, kv, out, softmax_max, softmax_sum, softmax_out, rng_state = ctx.saved_tensors
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
rng = torch.Generator(device=q.device)
rng.set_state(rng_state)
ext.fa_bwd(
dq,
dkv[:, :, 0],
dkv[:, :, 1],
dout,
q,
kv[:, :, 0],
kv[:, :, 1],
out,
softmax_max,
softmax_sum,
softmax_out,
rng,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
)
return dq, dkv, None, None, None, None
48 changes: 48 additions & 0 deletions deeplink_ext/internlm_ops/mha/fa_qkvpacked_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2024, DeepLink.

import torch
import deeplink_ext.cpp_extensions as ext

assert hasattr(ext, "fa_fwd") and hasattr(ext, "fa_bwd")


class DeepLinkFlashAttentionQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, dropout_p, softmax_scale, causal):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_max, softmax_sum, softmax_out, rng = ext.fa_fwd(
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, causal
)
ctx.save_for_backward(
qkv, out, softmax_max, softmax_sum, softmax_out, rng.get_state()
)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out

@staticmethod
def backward(ctx, dout):
qkv, out, softmax_max, softmax_sum, softmax_out, rng_state = ctx.saved_tensors
dqkv = torch.empty_like(qkv)
rng = torch.Generator(device=qkv.device)
rng.set_state(rng_state)
ext.fa_bwd(
dqkv[:, :, 0],
dqkv[:, :, 1],
dqkv[:, :, 2],
dout,
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
out,
softmax_max,
softmax_sum,
softmax_out,
rng,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
)
return dqkv, None, None, None, None
83 changes: 50 additions & 33 deletions deeplink_ext/internlm_ops/mha/mha.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) 2023, DeepLink.

import torch.nn as nn
from .mha_qkvpacked_func import DeepLinkMultiHeadAttentionQKVPackedFunc
from .mha_varlen_qkvpacked_func import DeepLinkMultiHeadAttentionVarLenQKVPackedFunc
from .mha_kvpacked_func import DeepLinkMultiHeadAttentionKVPackedFunc
from .mha_varlen_kvpacked_func import DeepLinkMultiHeadAttentionVarLenKVPackedFunc
# from .mha_qkvpacked_func import DeepLinkMultiHeadAttentionQKVPackedFunc
# from .mha_varlen_qkvpacked_func import DeepLinkMultiHeadAttentionVarLenQKVPackedFunc
# from .mha_kvpacked_func import DeepLinkMultiHeadAttentionKVPackedFunc
# from .mha_varlen_kvpacked_func import DeepLinkMultiHeadAttentionVarLenKVPackedFunc
from .fa_qkvpacked_func import DeepLinkFlashAttentionQKVPackedFunc
from .fa_kvpacked_func import DeepLinkFlashAttentionKVPackedFunc


class DeepLinkSelfAttention(nn.Module):
Expand Down Expand Up @@ -46,24 +48,31 @@ def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
"""
if cu_seqlens is None:
# padded
return DeepLinkMultiHeadAttentionQKVPackedFunc.apply(
# return DeepLinkMultiHeadAttentionQKVPackedFunc.apply(
# qkv,
# self.dropout_p if self.training else 0.0,
# self.softmax_scale,
# causal if causal is not None else self.causal,
# False,
# )
# for ascend
return DeepLinkFlashAttentionQKVPackedFunc.apply(
qkv,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal if causal is not None else self.causal,
False,
)
else:
# unpadded
return DeepLinkMultiHeadAttentionVarLenQKVPackedFunc.apply(
qkv,
cu_seqlens,
max_seqlen,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal if causal is not None else self.causal,
False,
)
# else:
# # unpadded
# return DeepLinkMultiHeadAttentionVarLenQKVPackedFunc.apply(
# qkv,
# cu_seqlens,
# max_seqlen,
# self.dropout_p if self.training else 0.0,
# self.softmax_scale,
# causal if causal is not None else self.causal,
# False,
# )


class DeepLinkCrossAttention(nn.Module):
Expand All @@ -85,25 +94,33 @@ def forward(
):
if cu_seqlens_q is None:
# padded
return DeepLinkMultiHeadAttentionKVPackedFunc.apply(
q,
kv,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal if causal is not None else self.causal,
False,
)
else:
# unpadded
return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply(
# return DeepLinkMultiHeadAttentionKVPackedFunc.apply(
# q,
# kv,
# self.dropout_p if self.training else 0.0,
# self.softmax_scale,
# causal if causal is not None else self.causal,
# False,
# )
# for ascend
return DeepLinkFlashAttentionKVPackedFunc.apply(
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
self.dropout_p if self.training else 0.0,
self.softmax_scale,
causal if causal is not None else self.causal,
False,
)
# else:
# # unpadded
# return DeepLinkMultiHeadAttentionVarLenKVPackedFunc.apply(
# q,
# kv,
# cu_seqlens_q,
# cu_seqlens_k,
# max_seqlen_q,
# max_seqlen_k,
# self.dropout_p if self.training else 0.0,
# self.softmax_scale,
# causal if causal is not None else self.causal,
# False,
# )

0 comments on commit ffe7970

Please sign in to comment.