Skip to content

Commit

Permalink
Readme.md - HQQ v1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham authored Nov 20, 2023
1 parent 0e7b6ca commit bfac35c
Showing 1 changed file with 80 additions and 9 deletions.
89 changes: 80 additions & 9 deletions code/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,101 @@ This folder contains the code to perform Half-Quadratic Quantization (<b>HQQ</b>
### Installation
Clone the repo and run ```pip install .``` from this current folder.

### Usage
To perform quantization via HQQ, you simply replace the ```torch.nn.Linear``` layers as follows:
### Basic Usage
To perform quantization via HQQ, you simply need to replace the ```torch.nn.Linear``` layers as follows:
```Python
from hqq.core import *
from hqq.quantize.core import *
#Quantization settings
quant_config = hqq_base_quant_config(nbits=4, group_size=64)
#Replace linear layer
hqq_layer = HQQLinear(your_linear_layer, quant_config, del_orig=True)
#del_orig=True will remove the original linear layer from memory
```

### LLama2/Mistral Quantization 🦙
We provide examples to quantize LLama2/Mistral models that you can find in the ```llama2_benchmark``` folder. By default, it quantizes the LLama2-7B model with 4-bit precision and reports the perplexity on wikitext-2.
### LLama2 Quantization 🦙
First, make sure to install the following dependencies:
```pip install transformers[torch] datasets xformers accelerate```

Before you run the scripts, make sure to install the following libraries:
```pip install transformers[torch] bitsandbytes datasets xformers accelerate tqdm```
You can quantize a LLama2 HuggingFace model as follows:

```Python
import torch, transformers
model_id = "meta-llama/Llama-2-7b-hf"

#Load model on the CPU
######################
model = transformers.AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

#Quantize the model
######################
from hqq.quantize.core import hqq_base_quant_config
from hqq.models.llama import LlamaHQQ

quant_config = hqq_base_quant_config(nbits=4, group_size=64)
LlamaHQQ.quantize_model(model, quant_config=quant_config)
```

You can save/load the quantized models as follows:
```Python
#Save
LlamaHQQ.save_quantized(model, save_dir=save_dir)
#Load from local directory or Hugging Face Hub
model = LlamaHQQ.from_quantized(save_dir)
```
We provide a complete example to quantize LLama2 models that you can find in the ```llama2_benchmark``` folder. By default, it quantizes the LLama2-7B model with 4-bit precision and reports the perplexity on wikitext-2.

Additionally, to run the GPTQ and AWQ demos you need the following:
```pip install auto-gptq[triton]==0.4.2 autoawq triton==2.0.0```

Then set your HuggingFace 🤗 token via cli or inside the demo files, and you're all set!

### ViT Quantization 🖼️
Make sure to install _timm_ via ```pip install timm``` first.

You can quantize a ViT model as follows:
```Python
import timm, torch
model_id = 'vit_large_patch14_clip_224.laion2b'

#Load model on CPU
model = timm.create_model(model_id, pretrained=True)

#Quantize
from hqq.quantize.core import hqq_base_quant_config
from hqq.models.vit import ViTHQQ
quant_config = hqq_base_quant_config(nbits=4, group_size=64)
ViTHQQ.quantize_model(model, quant_config=quant_config)
```

You can also save/load the quantized ViT models as follows:
```Python
#Save
ViTHQQ.save_quantized(model, save_dir=save_dir)
#Load from local directory or Hugging Face Hub
model = ViTHQQ.from_quantized(save_dir)
```

We provide a comple example to quantize ViT models that you can find in the ```vit_example``` folder. The script shows how to quantize a _timm_ ViT model and compares the dot score between the quantized and the original model predictions.


### Quantize Custom Models
If you want to quantize your own model, you need to write a patching function that goes through all the linear layers and replaces them with ```HQQLinear```. Simply follow the LLama2 example in ```hqq/llama2.py```.
### Quantize Custom Models 🗜️
If you want to quantize your own model architecture, you need to write a patching function that goes through all the linear layers and replaces them with ```HQQLinear```. You can follow the examples provided in ```hqq/models```.

### Models from Hugging Face Hub 🤗
We provide pre-quantized LLama2 models that you can directly use from [Hugging Face Hub](https://huggingface.co/mobiuslabsgmbh):

Here's an example:
```Python
from hqq.models.llama import LlamaHQQ
import transformers

model_id = 'mobiuslabsgmbh/Llama-2-7b-hf-4bit_g64-HQQ'
#Load the tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
#Load the model
model = LlamaHQQ.from_quantized(model_id)
```

Have fun 🚀!

Expand Down

0 comments on commit bfac35c

Please sign in to comment.