Skip to content

Commit f5703b0

Browse files
authored
Add Auto-Round support (#581)
* initial flow for autoround Signed-off-by: yiliu30 <[email protected]> * update flow Signed-off-by: yiliu30 <[email protected]> * use int4 kernel Signed-off-by: yiliu30 <[email protected]> * remove debug code Signed-off-by: yiliu30 <[email protected]> * update the forward Signed-off-by: yiliu30 <[email protected]> * clean code Signed-off-by: yiliu30 <[email protected]> * e2e example Signed-off-by: yiliu30 <[email protected]> * refine code Signed-off-by: yiliu30 <[email protected]> * add requirements for test Signed-off-by: yiliu30 <[email protected]> * update test Signed-off-by: yiliu30 <[email protected]> * update the readme Signed-off-by: yiliu30 <[email protected]> * add readme Signed-off-by: yiliu30 <[email protected]> * update the filenames Signed-off-by: yiliu30 <[email protected]> * update the np version Signed-off-by: yiliu30 <[email protected]> * add demo Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * add more docs Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * add doc Signed-off-by: yiliu30 <[email protected]> * use `AffineQuantizedTensor` Signed-off-by: yiliu30 <[email protected]> * impl ar using multensors Signed-off-by: yiliu30 <[email protected]> * clean code Signed-off-by: yiliu30 <[email protected]> * use hook + multensors Signed-off-by: yiliu30 <[email protected]> * separate mul_tensors into a new file Signed-off-by: yiliu30 <[email protected]> * fix typos Signed-off-by: yiliu30 <[email protected]> * rename mul_tensor to multi_tensor Signed-off-by: yiliu30 <[email protected]> * enable amp Signed-off-by: yiliu30 <[email protected]> * eval model Signed-off-by: yiliu30 <[email protected]> * add gen examples Signed-off-by: yiliu30 <[email protected]> * add warmup to benchmark Signed-off-by: yiliu30 <[email protected]> * add benchmark Signed-off-by: yiliu30 <[email protected]> * clean code Signed-off-by: yiliu30 <[email protected]> * format code Signed-off-by: yiliu30 <[email protected]> * use tiny kernel Signed-off-by: yiliu30 <[email protected]> * add more note Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * correct typos Signed-off-by: yiliu30 <[email protected]> * remove hard code Signed-off-by: yiliu30 <[email protected]> * use intx Signed-off-by: yiliu30 <[email protected]> * enable offload for multitensor Signed-off-by: yiliu30 <[email protected]> * update the default config Signed-off-by: yiliu30 <[email protected]> * refine note Signed-off-by: yiliu30 <[email protected]> * update the version check Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * update Signed-off-by: yiliu30 <[email protected]> * add ut Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * add scripts Signed-off-by: yiliu30 <[email protected]> * format code Signed-off-by: yiliu30 <[email protected]> * format Signed-off-by: yiliu30 <[email protected]> * update Signed-off-by: yiliu30 <[email protected]> * fix typo Signed-off-by: yiliu30 <[email protected]> * refine bench code Signed-off-by: yiliu30 <[email protected]> * Enable `use_optimized_layer_output` and AO' llama (#12) Signed-off-by: yiliu30 <[email protected]> * Refine the Doc (#14) --------- Signed-off-by: yiliu30 <[email protected]> * add more docstring Signed-off-by: yiliu30 <[email protected]> * add paper link Signed-off-by: yiliu30 <[email protected]> * correct some note Signed-off-by: yiliu30 <[email protected]> * add cmd Signed-off-by: yiliu30 <[email protected]> * udpdate the scripts Signed-off-by: yiliu30 <[email protected]> * revert some change Signed-off-by: yiliu30 <[email protected]> * Add a lightweight configuration for quick benchmarking (#15) Signed-off-by: yiliu30 <[email protected]> * update quant method name Signed-off-by: yiliu30 <[email protected]> * Wrap model's buffers and params to `MultiTensor` & update the results (#16) * wrap model's buffers and params to `MultiTensor` and update the results Signed-off-by: yiliu30 <[email protected]> --------- Signed-off-by: yiliu30 <[email protected]>
1 parent 0987dd6 commit f5703b0

File tree

13 files changed

+1477
-2
lines changed

13 files changed

+1477
-2
lines changed

test/prototype/test_autoround.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import pytest
2+
from torchao.prototype.autoround.utils import is_auto_round_available
3+
4+
if not is_auto_round_available():
5+
pytest.skip("AutoRound is not available", allow_module_level=True)
6+
7+
import torch
8+
from torch.testing._internal.common_utils import (
9+
instantiate_parametrized_tests,
10+
parametrize,
11+
run_tests,
12+
TestCase,
13+
)
14+
from torchao import quantize_
15+
16+
from torchao.dtypes import AffineQuantizedTensor
17+
from torchao.prototype.autoround.core import (
18+
apply_auto_round,
19+
prepare_model_for_applying_auto_round_,
20+
)
21+
from torchao.prototype.autoround.multi_tensor import MultiTensor
22+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
23+
24+
_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
25+
26+
27+
# Copied from https://github.com/pytorch/ao/pull/721
28+
class TwoLinear(torch.nn.Module):
29+
def __init__(self, in_features=64, out_features=128):
30+
super().__init__()
31+
self.linear1 = torch.nn.Linear(in_features, out_features)
32+
self.linear2 = torch.nn.Linear(in_features, out_features)
33+
34+
def forward(self, x, y):
35+
x = self.linear1(x)
36+
y = self.linear2(y)
37+
return x + y
38+
39+
40+
class M(torch.nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
self.two_linear1 = TwoLinear()
44+
self.two_linear2 = TwoLinear(128, 256)
45+
46+
def forward(self, x, y):
47+
x1 = self.two_linear1(x, y)
48+
x2 = self.two_linear2(x1, x1)
49+
return x2
50+
51+
52+
def _is_two_linear(mod, fqn):
53+
return isinstance(mod, TwoLinear)
54+
55+
56+
class ModelWithInplaceOp(torch.nn.Module):
57+
def __init__(self, DIM=128):
58+
super().__init__()
59+
self.lin = torch.nn.Linear(DIM, DIM)
60+
self.register_buffer("other", torch.zeros(DIM, DIM))
61+
62+
def forward(self, x, idx):
63+
x = x + self.lin(x)
64+
# update buffer
65+
self.other[idx] = x
66+
return x
67+
68+
69+
class M2(torch.nn.Module):
70+
def __init__(self, DIM=128):
71+
super().__init__()
72+
self.m1 = ModelWithInplaceOp(DIM)
73+
self.m2 = ModelWithInplaceOp(DIM)
74+
75+
def forward(self, x, idx):
76+
x = self.m1(x, idx)
77+
x = self.m2(x, idx)
78+
return x
79+
80+
81+
def _check_params_and_buffers_type(module, check_fun):
82+
return [check_fun(p) for p in module.parameters()] + [
83+
check_fun(b) for b in module.buffers()
84+
]
85+
86+
87+
class TestAutoRound(TestCase):
88+
89+
@pytest.mark.skip(not TORCH_VERSION_AT_LEAST_2_5, "Requires torch 2.5 or later")
90+
@parametrize("device", _AVAILABLE_DEVICES)
91+
@torch.no_grad()
92+
def test_auto_round(self, device: str):
93+
example_inputs = (
94+
torch.randn(32, 64).to(device),
95+
torch.randn(32, 64).to(device),
96+
)
97+
m = M().eval().to(device)
98+
before_quant = m(*example_inputs)
99+
prepare_model_for_applying_auto_round_(
100+
m,
101+
is_target_module=_is_two_linear,
102+
bits=7,
103+
group_size=32,
104+
iters=20,
105+
device=device,
106+
)
107+
assert all(
108+
_check_params_and_buffers_type(m, lambda x: isinstance(x, MultiTensor))
109+
), "Expected all parameters and buffers to be `MultiTensor`."
110+
input1 = []
111+
input2 = []
112+
for _ in range(10):
113+
input1.append(torch.randn(32, 64).to(device))
114+
input2.append(torch.randn(32, 64).to(device))
115+
116+
mt_input1 = MultiTensor(input1)
117+
mt_input2 = MultiTensor(input2)
118+
out = m(mt_input1, mt_input2)
119+
assert isinstance(out, MultiTensor), f"Expected MultiTensor, got {type(out)}"
120+
assert all(
121+
_check_params_and_buffers_type(m, lambda x: not isinstance(x, MultiTensor))
122+
), "Expected all parameters and buffers have been converted back to tensor."
123+
quantize_(m, apply_auto_round(), _is_two_linear, device=device)
124+
for l in m.modules():
125+
if isinstance(l, torch.nn.Linear):
126+
assert isinstance(l.weight, AffineQuantizedTensor)
127+
after_quant = m(*example_inputs)
128+
assert after_quant is not None, "Quantized model forward pass failed"
129+
130+
@pytest.mark.skip(not TORCH_VERSION_AT_LEAST_2_5, "Requires torch 2.5 or later")
131+
@parametrize("device", _AVAILABLE_DEVICES)
132+
@torch.no_grad()
133+
def test_wrap_model_with_multi_tensor(self, device: str):
134+
135+
_is_model_with_inplace_op = lambda mod, fqn: isinstance(mod, ModelWithInplaceOp)
136+
137+
DIM = 128
138+
m = M2(DIM).eval().to(device)
139+
prepare_model_for_applying_auto_round_(
140+
m,
141+
is_target_module=_is_model_with_inplace_op,
142+
bits=7,
143+
group_size=32,
144+
iters=20,
145+
device=device,
146+
)
147+
assert all(
148+
_check_params_and_buffers_type(m, lambda x: isinstance(x, MultiTensor))
149+
), "Expected all parameters and buffers to be `MultiTensor`."
150+
input1 = []
151+
input2 = []
152+
for _ in range(2):
153+
input1.append(torch.randn(DIM, DIM).to(device))
154+
input2.append(torch.randint(0, DIM, (DIM,), dtype=torch.long).to(device))
155+
156+
mt_input1 = MultiTensor(input1)
157+
mt_input2 = MultiTensor(input2)
158+
out = m(mt_input1, mt_input2)
159+
assert isinstance(out, MultiTensor), f"Expected MultiTensor, got {type(out)}"
160+
assert all(
161+
_check_params_and_buffers_type(m, lambda x: not isinstance(x, MultiTensor))
162+
), "Expected all parameters and buffers have been converted back to tensor."
163+
quantize_(m, apply_auto_round(), _is_model_with_inplace_op, device=device)
164+
for l in m.modules():
165+
if isinstance(l, torch.nn.Linear):
166+
assert isinstance(l.weight, AffineQuantizedTensor)
167+
after_quant = m(input1[0], input2[0])
168+
assert after_quant is not None, "Quantized model forward pass failed"
169+
170+
171+
instantiate_parametrized_tests(TestAutoRound)
172+
173+
if __name__ == "__main__":
174+
run_tests()

torchao/_models/llama/benchmarks.sh

+12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
1212
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
1313
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
1414

15+
# auto-round w/ quant_lm_head
16+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround
17+
# auto-round w/o quant_lm_head
18+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0
19+
20+
1521
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
1622
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
1723
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
@@ -23,6 +29,12 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
2329
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
2430
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
2531

32+
# auto-round w/ quant_lm_head
33+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround
34+
# auto-round w/o quant_lm_head
35+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0
36+
37+
2638
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
2739
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192
2840
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization

torchao/_models/llama/generate.py

+49-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def device_sync(device):
3030
wd = Path(__file__).parent.parent.resolve()
3131
sys.path.append(str(wd))
3232

33-
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
33+
from torchao._models.llama.model import Transformer, prepare_inputs_for_model, TransformerBlock
3434
from torchao._models.llama.tokenizer import get_tokenizer
3535

3636
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
@@ -219,6 +219,53 @@ def main(
219219
groupsize=int(quantization.split("-")[-1])
220220
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
221221
quantize_(model, int4_weight_only(group_size=groupsize))
222+
223+
if "autoround" in quantization:
224+
from torchao.prototype.autoround.autoround_llm import quantize_model_with_autoround_
225+
from transformers import AutoTokenizer
226+
227+
_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent)
228+
# parse args from quantization string:
229+
# autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>
230+
# A lightweight configuration for generation benchmarking.
231+
_quant_args = quantization.split("-")
232+
_default_quant_args = [True, 1, 128, 1, 512, 32]
233+
_model_devie = _quant_args[1] if len(_quant_args) > 1 else device
234+
_quant_args = _quant_args[2:]
235+
quant_lm_head, iters, groupsize, batch_size, seqlen, nsamples = [
236+
int(x) for x in _quant_args
237+
] + _default_quant_args[len(_quant_args) :]
238+
model = model.to(_model_devie)
239+
print(
240+
(
241+
f"Quantizing model with autoround(iters={iters}, groupsize={groupsize}, "
242+
f"quant_lm_head={quant_lm_head}, batch_size={batch_size}, seqlen={seqlen}, nsamples={nsamples})"
243+
)
244+
)
245+
with torch.device(_model_devie):
246+
model.setup_caches(
247+
max_batch_size=batch_size, max_seq_length=seqlen, training=True
248+
)
249+
250+
if quant_lm_head:
251+
is_target_module = (
252+
lambda mod, fqn: isinstance(mod, TransformerBlock) or "output" in fqn
253+
)
254+
else:
255+
is_target_module = lambda mod, fqn: isinstance(mod, TransformerBlock)
256+
quantize_model_with_autoround_(
257+
model=model,
258+
tokenizer=_tokenizer,
259+
is_target_module=is_target_module,
260+
bits=4,
261+
seqlen=seqlen,
262+
bs=batch_size,
263+
iters=iters,
264+
nsamples=nsamples,
265+
)
266+
model.to(device)
267+
model.reset_caches()
268+
222269
if "fp6" in quantization:
223270
quantize_(model, fpx_weight_only(3, 2))
224271
if "autoquant" == quantization:
@@ -387,7 +434,7 @@ def callback(x):
387434
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
388435
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
389436
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
390-
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
437+
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>')
391438
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
392439
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
393440
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')

torchao/_models/llama/model.py

+9
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,16 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
190190
dtype,
191191
use_scaled=self.config.use_scaled_rope
192192
)
193+
194+
def reset_caches(self):
195+
"""Reset caches.
193196
197+
The caches used by training stage and inference stage may be different, reset them before switching.
198+
"""
199+
self.max_batch_size = -1
200+
self.max_seq_length = -1
201+
self.freqs_cis: Optional[Tensor] = None
202+
self.mask_cache: Optional[Tensor] = None
194203

195204
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
196205
"""Forward pass of the model.

torchao/prototype/autoround/README.md

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Auto-Round
2+
3+
Auto-Round is an advanced quantization algorithm designed for low-bit LLM inference. It leverages [sign gradient descent](https://arxiv.org/abs/1905.12938) to fine-tune rounding values and minmax values of weights. This approach competes impressively with recent methods without introducing any additional inference overhead while using low tuning costs. This module provides the end-to-end examples to quantize floating-point models to low-bit and integration with torchao's `quantize_` API and low-bit kernels.
4+
5+
## Usage
6+
7+
### Quick Start
8+
9+
```python
10+
python autoround_llm.py -m /model/name/or/path
11+
```
12+
13+
14+
> [!NOTE]
15+
> Before running, ensure you have installed the `auto-round` with `pip install -r requirements.txt`.
16+
17+
18+
### Detailed Usage
19+
20+
`Auto-Round` is a calibration-based quantization algorithm. The flow involves three main steps: 1) insert hooks to the modules you want to quantize, 2) Wrap the calibration data with `MultiTensor` and run the model, 3) Replace the optimized weight with `AffineQuantizedTensor` to select the appropriate low-bit kernel.
21+
22+
> [!NOTE]
23+
> To learn more about the flow and `MultiTensor`, please refer to [this example](https://github.com/pytorch/ao/blob/main/tutorials/calibration_flow/gptq_like.py).
24+
25+
#### Step 1: Prepare the Model
26+
```python
27+
model = ... # Load your model
28+
model_device = next(model.parameters()).device
29+
device = "cuda" if torch.cuda.is_available() else "cpu"
30+
31+
# Define a function to identify target modules for quantization.
32+
# For example, to apply Auto-Round to all decoder layers and the `lm-head` in a Llama model:
33+
decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
34+
is_target_module = lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn
35+
# Prepare the model for Auto-Round
36+
from torchao.prototype.autoround.core import prepare_model_for_applying_auto_round_
37+
38+
prepare_model_for_applying_auto_round_(
39+
model,
40+
is_target_module=is_target_module,
41+
bits=4,
42+
group_size=128,
43+
iters=200,
44+
device=device,
45+
)
46+
```
47+
> [!NOTE]
48+
> To avoid OOM issues, load the model on CPU, and set `device` to `'cuda'`.
49+
50+
#### Step 2: Apply Optimization
51+
Wrap all inputs as a `MultiTensor` to track all calibration data for optimized modules:
52+
53+
```python
54+
input_ids_lst = []
55+
for data in dataloader:
56+
input_ids_lst.append(data["input_ids"].to(model_device))
57+
58+
multi_t_input_ids = MultiTensor(input_ids_lst)
59+
# The optimization is applied during the forward pass
60+
out = model(multi_t_input_ids)
61+
```
62+
#### Step 3: Finalize Quantization
63+
After obtaining optimized `zero_point` and `scale` values, create the `AffineQuantizedTensor`
64+
for each target weight to select the right low-bits kernel.
65+
66+
```python
67+
from torchao.prototype.autoround.core import apply_auto_round
68+
69+
quantize_(model, apply_auto_round(), is_target_module)
70+
```
71+
72+
## End-to-End Results
73+
### [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
74+
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
75+
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
76+
| bf16 | 0.7080 | 0.6783 | 0.8003 | 0.7403 | 0.5910 | 0.7303 |
77+
| auto-round-4bit | 0.6988 | 0.6533 | 0.7949 | 0.7372 | 0.5837 | 0.7250 |
78+
| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 |
79+
80+
### [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
81+
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
82+
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
83+
| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 |
84+
| auto-round-4bit | 0.6818 | 0.6232 | 0.7862 | 0.7230 | 0.5661 | 0.7105 |
85+
| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 |
86+
87+
88+
### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
89+
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
90+
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
91+
| bf16 | 0.6347 | 0.4647 | 0.7644 | 0.6606 | 0.577 | 0.7070 |
92+
| auto-round-4bit | 0.6327 | 0.4534 | 0.7590 | 0.6661 | 0.5706 | 0.7143 |
93+
| torchao-int4wo | 0.6252 | 0.4427 | 0.7617 | 0.6654 | 0.5674 | 0.6889 |
94+
95+
> [!NOTE]
96+
> - `auto-round-4bit` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=8`, `group_size=128`, and `quant_lm_head=False`. <br>
97+
> - `torchao-int4wo` represents `int4_weight_only(group_size=128)` and `quant_lm_head=False`.
98+
> - If the model includes operations without a deterministic implementation (such as Flash Attention), the results may differ slightly.
99+
100+
101+
## Credits
102+
103+
- Paper: https://arxiv.org/abs/2309.05516
104+
- Authors: [Intel® Neural Compressor Team](https://github.com/intel/neural-compressor)

0 commit comments

Comments
 (0)