Skip to content

Commit

Permalink
Merge pull request #50 from hao-ai-lab/dev
Browse files Browse the repository at this point in the history
flash attention and sampling
  • Loading branch information
Viol2000 authored Feb 14, 2024
2 parents 1292680 + 089f862 commit 9d50de4
Show file tree
Hide file tree
Showing 13 changed files with 4,903 additions and 575 deletions.
53 changes: 51 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
| <a href="https://arxiv.org/abs/2402.02057"><b>Paper</b></a> | <a href="https://lmsys.org/blog/2023-11-21-lookahead-decoding/"><b>Blog</b></a> | <a href="https://github.com/hao-ai-lab/LookaheadDecoding/issues/13"><b>Roadmap</b></a> |
</p>

---
*News* 🔥
- [2024/2] Lookahead Decoding Paper now available on [arXiv](https://arxiv.org/abs/2402.02057). [Sampling](#use-lookahead-decoding-in-your-own-code) and [FlashAttention](#flashAttention-support) are supported. Advanced features for better token prediction are updated.

---
## Introduction
We introduce lookahead decoding:
- A parallel decoding algorithm to accelerate LLM inference.
Expand Down Expand Up @@ -138,7 +143,7 @@ USE_LADE=0 python applications/chatbot.py --model_path meta-llama/Llama-2-7b-ch
```

### Use Lookahead decoding in your own code
You can import and use Lookahead decoding in your own code in three LoCs. You also need to set ```USE_LADE=1``` in command line or set ```os.environ["USE_LADE"]="1"``` in Python script. Note that Lookahead decoding only support LLaMA and Greedy Search yet.
You can import and use Lookahead decoding in your own code in three LoCs. You also need to set ```USE_LADE=1``` in command line or set ```os.environ["USE_LADE"]="1"``` in Python script. Note that Lookahead decoding only support LLaMA yet.

```python
import lade
Expand All @@ -148,14 +153,58 @@ lade.config_lade(LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7, DEBUG=0)
#You can obtain a better performance by tuning LEVEL/WINDOW_SIZE/GUESS_SET_SIZE on your own device.
```

Then you can speedup the decoding process.
Then you can speedup the decoding process. Here is an example using greedy search:
```
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
greedy_output = model.generate(**model_inputs, max_new_tokens=1024) #speedup obtained
```

Here is an example using sampling:
```
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
sample_output = model.generate(**model_inputs, max_new_tokens=1024, temperature=0.7) #speedup obtained
```

### FlashAttention Support
Install the original FlashAttention
```bash
pip install flash-attn==2.3.3 #original FlashAttention
```
Two ways to install FlashAttention specialized for Lookahead Decoding
1) Download a pre-built package on https://github.com/Viol2000/flash-attention-lookahead/releases/tag/v2.3.3 and install (fast, recommended).
For example, I have cuda==11.8, python==3.9 and torch==2.1, I should do the following:
```bash
wget https://github.com/Viol2000/flash-attention-lookahead/releases/download/v2.3.3/flash_attn_lade-2.3.3+cu118torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
pip install flash_attn_lade-2.3.3+cu118torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
```
2) Install from the source (slow, not recommended)
```bash
git clone https://github.com/Viol2000/flash-attention-lookahead.git
cd flash-attention-lookahead && python setup.py install
```

Here is an example script to run the models with FlashAttention:
```bash
python minimal-flash.py #no Lookahead decoding, w/ FlashAttention
USE_LADE=1 LOAD_LADE=1 python minimal-flash.py #use Lookahead decoding, w/ FlashAttention, 20% speedup than w/o FlashAttention
```

In your own code, you need to set ```USE_FLASH=True``` when calling ```config_lade```, and set ```attn_implementation="flash_attention_2"``` when calling ```AutoModelForCausalLM.from_pretrained```.
```python
import lade
lade.augment_all()
lade.config_lade(LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7, USE_FLASH=True, DEBUG=0)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device, attn_implementation="flash_attention_2")
model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
greedy_output = model.generate(**model_inputs, max_new_tokens=1024) #speedup obtained
```
We will integrate FlashAttention directly into this repo for simple installation and usage.

## Citation
```bibtex
@misc{fu2024break,
Expand Down
Loading

0 comments on commit 9d50de4

Please sign in to comment.