Skip to content

Commit

Permalink
update flash_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
ZubinGou committed Dec 5, 2023
1 parent 2ca2955 commit 10fe71a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 45 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ We recommend using [Conda](https://docs.conda.io/projects/miniconda) to manage y
git clone https://github.com/microsoft/ToRA.git && cd ToRA/src
conda create -n tora python=3.10
conda activate tora
pip install packaging==22.0
pip install torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118 # CUDA 11.8 for example
pip install -r requirements.txt
```
Expand Down
12 changes: 4 additions & 8 deletions src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,36 +1,32 @@
packaging
xformers==0.0.21
tqdm
sentencepiece
datasets
transformers==4.31.0
datasets==2.14.5
deepspeed==0.11.0
accelerate==0.21.0
flash_attn==2.3.6
tensorboard
peft
bitsandbytes
evaluate
tokenizers
protobuf
openai
tiktoken
rouge_score
wandb
gradio
markupsafe
termcolor
jsonlines
mecab-python3
unidic-lite
einops
scipy
flash_attn==2.0.1
fire
flask
gpustat
rich
bitarray
ray==2.6.3
vllm==0.1.4 # ! Ensure version <= 0.1.4, later versions have unsolved bugs.
xformers==0.0.21
cvxpy
func-timeout
timeout-decorator
Expand Down
53 changes: 16 additions & 37 deletions src/train/llama2_flash_attn_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Tuple

import torch
from flash_attn import __version__ as flash_attn_version
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down Expand Up @@ -39,6 +40,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
Expand All @@ -61,50 +63,31 @@ def forward(
kv_seq_len = k.shape[1]
past_kv_len = 0
if past_key_value is not None:
past_kv_len = past_key_value[0].shape[1]
past_kv_len = past_key_value[0].shape[2]
kv_seq_len += past_kv_len

cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)

if past_key_value is not None:
assert (
flash_attn_version >= "2.1.0"
), "past_key_value support requires flash-attn >= 2.1.0"
# reuse k, v
k = torch.cat([past_key_value[0], k], dim=1)
v = torch.cat([past_key_value[1], v], dim=1)

past_key_value = (k, v) if use_cache else None

key_padding_mask = attention_mask
# Ideally we could just do this:
# q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask[:, -q_len:])
# but this does not work as Flash attention treats the q seq and kv seq as starting at index 0
# which then breaks the causality logic. Probably if q_len >> past_kv_len we should
# just skip flash attention. Leaving this in for now to demonstrate correctness of
# flash attention information even when q needs padding.
# TODO(siddartha): delegate back to original implementation on this condition.
if past_kv_len > 0:
q = torch.cat(
(
torch.full(
(bsz, past_kv_len, self.num_heads, self.head_dim),
0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
)
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)

past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None

if key_padding_mask is None:
if attention_mask is None:
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
bsz, q_len + past_kv_len, -1
bsz, q_len, -1
)
else:
q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask)
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
# We can skip concat and call unpad twice but seems better to call unpad only once.
kv, _, cu_k_lens, max_k = unpad_input(
torch.stack((k, v), dim=2), key_padding_mask
torch.stack((k, v), dim=2), attention_mask
)
output_unpad = flash_attn_varlen_kvpacked_func(
q,
Expand All @@ -118,11 +101,7 @@ def forward(
causal=True,
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len + past_kv_len)

# Need to strip off the zero query outputs.
if past_kv_len > 0:
output = output[:, past_kv_len:, ...]
output = pad_input(output_unpad, indices, bsz, q_len)

return self.o_proj(output), None, past_key_value

Expand Down Expand Up @@ -248,7 +227,7 @@ def test():
use_cache=True,
)
parts.append(part)
past_kv_len = past_kv[0].shape[1]
past_kv_len = past_kv[0].shape[2]

print(
f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
Expand Down

0 comments on commit 10fe71a

Please sign in to comment.