Skip to content

Commit 2615e8d

Browse files
committed
update torchao READMEs with new configuration APIs
Summary: This updates the README files with the names of the new workflow configuration APIs. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3f1727d221c26726c9747b149bfb05f459881cb1 ghstack-comment-id: 2655664399 Pull Request resolved: #1711
1 parent 69a5a53 commit 2615e8d

File tree

3 files changed

+44
-44
lines changed

3 files changed

+44
-44
lines changed

README.md

+13-13
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ For inference, we have the option of
2929
```python
3030
from torchao.quantization.quant_api import (
3131
quantize_,
32-
int8_dynamic_activation_int8_weight,
33-
int4_weight_only,
34-
int8_weight_only
32+
Int8DynamicActivationInt8WeightConfig,
33+
Int4WeightOnlyConfig,
34+
Int8WeightOnlyConfig
3535
)
36-
quantize_(m, int4_weight_only())
36+
quantize_(m, Int4WeightOnlyConfig())
3737
```
3838

39-
For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline.
39+
For gpt-fast `Int4WeightOnlyConfig()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline.
4040

41-
If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, int8_weight_only(), device="cuda")` which will send and quantize each layer individually to your GPU.
41+
If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, Int8WeightOnlyConfig(), device="cuda")` which will send and quantize each layer individually to your GPU.
4242

4343
If you see slowdowns with any of these techniques or you're unsure which option to use, consider using [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers and pick the best way to quantize each layer.
4444

@@ -63,27 +63,27 @@ Post-training quantization can result in a fast and compact model, but may also
6363
```python
6464
from torchao.quantization import (
6565
quantize_,
66-
int8_dynamic_activation_int4_weight,
66+
Int8DynamicActivationInt4WeightConfig,
6767
)
6868
from torchao.quantization.qat import (
6969
FakeQuantizeConfig,
70-
from_intx_quantization_aware_training,
71-
intx_quantization_aware_training,
70+
FromIntXQuantizationAwareTrainingConfig,
71+
IntXQuantizationAwareTrainingConfig,
7272
)
7373

7474
# Insert fake quantization
7575
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
7676
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
7777
quantize_(
7878
my_model,
79-
intx_quantization_aware_training(activation_config, weight_config),
79+
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
8080
)
8181

8282
# Run training... (not shown)
8383

8484
# Convert fake quantization to actual quantized operations
85-
quantize_(my_model, from_intx_quantization_aware_training())
86-
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
85+
quantize_(my_model, FromIntXQuantizationAwareTrainingConfig())
86+
quantize_(my_model, Int8DynamicActivationInt4WeightConfig(group_size=32))
8787
```
8888

8989
### Float8
@@ -139,7 +139,7 @@ The best example we have combining the composability of lower bit dtype with com
139139

140140
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow
141141

142-
1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
142+
1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
143143
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
144144
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
145145

torchao/quantization/README.md

+22-22
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ model(input)
8282

8383
When used as in the example above, when the `autoquant` api is called alongside torch.compile, autoquant sets up the model so that when its run on the next input, the autoquantization and torch.compile processes leave you with a heavily optimized model.
8484

85-
When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `int4_weight_only()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow.
85+
When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `Int4WeightOnlyConfig()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow.
8686

8787
Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods.
8888

@@ -109,13 +109,13 @@ be applied individually. While there are a large variety of quantization apis, t
109109

110110
```python
111111
# for torch 2.4+
112-
from torchao.quantization import quantize_, int4_weight_only
112+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
113113
group_size = 32
114114

115115
# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through
116-
# use_hqq flag for `int4_weight_only` quantization
116+
# use_hqq flag for `Int4WeightOnlyConfig` quantization
117117
use_hqq = False
118-
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
118+
quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq))
119119

120120
# for torch 2.2.2 and 2.3
121121
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
@@ -128,8 +128,8 @@ Note: The quantization error incurred by applying int4 quantization to your mode
128128

129129
```python
130130
# for torch 2.4+
131-
from torchao.quantization import quantize_, int8_weight_only
132-
quantize_(model, int8_weight_only())
131+
from torchao.quantization import quantize_, Int8WeightOnlyConfig
132+
quantize_(model, Int8WeightOnlyConfig())
133133

134134
# for torch 2.2.2 and 2.3
135135
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
@@ -140,8 +140,8 @@ change_linear_weights_to_int8_woqtensors(model)
140140

141141
```python
142142
# for torch 2.4+
143-
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
144-
quantize_(model, int8_dynamic_activation_int8_weight())
143+
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
144+
quantize_(model, Int8DynamicActivationInt8WeightConfig())
145145

146146
# for torch 2.2.2 and 2.3
147147
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
@@ -152,8 +152,8 @@ change_linear_weights_to_int8_dqtensors(model)
152152

153153
```python
154154
# for torch 2.5+
155-
from torchao.quantization import quantize_, float8_weight_only
156-
quantize_(model, float8_weight_only())
155+
from torchao.quantization import quantize_, Float8WeightOnlyConfig
156+
quantize_(model, Float8WeightOnlyConfig())
157157
```
158158

159159
Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required.
@@ -162,8 +162,8 @@ Supports all dtypes for original weight and activation. This API is only tested
162162

163163
```python
164164
# for torch 2.4+
165-
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor
166-
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
165+
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
166+
quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
167167
```
168168

169169
Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required.
@@ -172,8 +172,8 @@ Supports all dtypes for original weight and activation. This API is only tested
172172

173173
```python
174174
# for torch 2.5+
175-
from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight
176-
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow()))
175+
from torchao.quantization import quantize_, PerRow, Float8DynamicActivationFloat8WeightConfig
176+
quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
177177
```
178178

179179
Per-row scaling is only supported for bfloat16 weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required.
@@ -182,14 +182,14 @@ Per-row scaling is only supported for bfloat16 weight and activation. This API i
182182

183183
```python
184184
# for torch 2.4+
185-
from torchao.quantization import quantize_, fpx_weight_only
186-
quantize_(model, fpx_weight_only(3, 2))
185+
from torchao.quantization import quantize_, FPXWeightOnlyConfig
186+
quantize_(model, FPXWeightOnlyConfig(3, 2))
187187
```
188188

189189
You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype.
190190

191191
## Affine Quantization Details
192-
Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_preicsion_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization.
192+
Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_precision_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization.
193193

194194
### Quantization Primitives
195195
We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass.
@@ -200,7 +200,7 @@ Note: these primitive ops supports two "types" of quantization, distinguished by
200200
We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel)
201201

202202
#### Layouts
203-
We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels.
203+
We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for workflows backing `Int8WeightOnlyConfig` and `Int8DynamicActivationInt8WeightConfig` and also as a default layout. `tensor_core_tiled` layout is used for workflows backing `Int4WeightOnlyConfig` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels.
204204

205205
### Zero Point Domains
206206
```ZeroPointDomain``` is used to control the data types of zero points. ```ZeroPointDomain.None``` means zero_point is None, ```ZeroPointDomain.FLOAT``` means zero_point is in the floating point domain and ```ZeroPointDomain.INT``` means integer domain. For detailed implementation of different zero point data types, refer to [the reference implementation](../../test/quantization/test_quant_primitives.py).
@@ -223,7 +223,7 @@ from torchao.dtypes import to_affine_quantized_intx
223223
import copy
224224
from torchao.quantization.quant_api import (
225225
quantize_,
226-
int4_weight_only,
226+
Int4WeightOnlyConfig,
227227
)
228228

229229
class ToyLinearModel(torch.nn.Module):
@@ -249,9 +249,9 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
249249
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
250250
group_size = 32
251251
# only works for torch 2.4+
252-
quantize_(m, int4_weight_only(group_size=group_size))
252+
quantize_(m, Int4WeightOnlyConfig(group_size=group_size))
253253
## If different zero_point_domain needed
254-
# quantize_(m, int4_weight_only(group_size=group_size), zero_point_domain=ZeroPointDomain.FLOAT)
254+
# quantize_(m, Int4WeightOnlyConfig(group_size=group_size, zero_point_domain=ZeroPointDomain.FLOAT))
255255

256256
# temporary workaround for tensor subclass + torch.compile
257257
# NOTE: this is only need for torch version < 2.5+
@@ -360,7 +360,7 @@ We're trying to develop kernels for low bit quantization for intx quantization f
360360
| | uintx-4-64-hqq | 8.124 | 47.85 | 213.24 | 11.85 | 4.46 |
361361
| | uintx-2-8-hqq | 39.605 | 34.83 | 261.42 | 14.99 | 7.51 |
362362

363-
You try can out these apis with the `quantize_` api as above alongside the constructor `uintx_weight_only` an example can be found in in `torchao/_models/llama/generate.py`.
363+
You try can out these apis with the `quantize_` api as above alongside the config `UIntXWeightOnlyConfig`. An example can be found in in `torchao/_models/llama/generate.py`.
364364

365365
### int8_dynamic_activation_intx_weight Quantization
366366
We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used.

torchao/quantization/qat/README.md

+9-9
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,22 @@ def train_loop(m: torch.nn.Module):
7171

7272
The recommended way to run QAT in torchao is through the `quantize_` API:
7373
1. **Prepare:** specify how weights and/or activations are to be quantized through
74-
[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`intx_quantization_aware_training`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242)
74+
[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242)
7575
2. **Convert:** quantize the model using the standard post-training quantization (PTQ)
76-
functions such as [`int8_dynamic_activation_int4_weight`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606)
76+
functions such as [`Int8DynamicActivationInt4WeightConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606)
7777

7878
For example:
7979

8080

8181
```python
8282
from torchao.quantization import (
8383
quantize_,
84-
int8_dynamic_activation_int4_weight,
84+
Int8DynamicActivationInt4WeightConfig,
8585
)
8686
from torchao.quantization.qat import (
8787
FakeQuantizeConfig,
88-
from_intx_quantization_aware_training,
89-
intx_quantization_aware_training,
88+
FromIntXQuantizationAwareTrainingConfig,
89+
IntXQuantizationAwareTrainingConfig,
9090
)
9191
model = get_model()
9292

@@ -96,7 +96,7 @@ activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=Fal
9696
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
9797
quantize_(
9898
model,
99-
intx_quantization_aware_training(activation_config, weight_config),
99+
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
100100
)
101101

102102
# train
@@ -105,8 +105,8 @@ train_loop(model)
105105
# convert: transform fake quantization ops into actual quantized ops
106106
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
107107
# quantized activation and weight tensor subclasses
108-
quantize_(model, from_intx_quantization_aware_training())
109-
quantize_(model, int8_dynamic_activation_int4_weight(group_size=32))
108+
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
109+
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
110110

111111
# inference or generate
112112
```
@@ -117,7 +117,7 @@ the following with a filter function during the prepare step:
117117
```
118118
quantize_(
119119
m,
120-
intx_quantization_aware_training(weight_config=weight_config),
120+
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
121121
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
122122
)
123123
```

0 commit comments

Comments
 (0)