Skip to content

Commit

Permalink
update torch.quantization blog (update src directory)
Browse files Browse the repository at this point in the history
  • Loading branch information
BuxianChen committed Jan 3, 2024
1 parent 8576c0d commit 065a5d1
Showing 1 changed file with 236 additions and 0 deletions.
236 changes: 236 additions & 0 deletions _drafts/2023-11-27-pytorch-quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,102 @@ Pytorch 原生支持的量化算法因为只支持 CPU, 所以应该暂时没啥

本节主要梳理 pytorch 关于量化的源码目录及 API 接口的层次关系, 尤其关注上层接口, 主要是梳理 [A4](https://pytorch.org/docs/2.1/quantization-support.html), 但绝非完整介绍

源码目录节选

```
torch/ao/
- nn/
- intrinsic/
- modules/
- qat/modules/
- quantized/
- modules/
- dynamic/modules/
- qat/
- modules/
- dynamic/modules/
- quantizable/modules/ # 似乎不在下面的模块缩写中??
- quantized/
- modules/
- dynamic/modules/
- reference/modules/
- functional.py
- quantization/
- fx/
- quantize.py
- quantize_fx.py
- qconfig.py
- qconfig_mapping.py
- observer.py
- ...
```

关于 Module 的缩写及位置, 总结如下, 具体看下表

- 新位置为 torch.ao.nn[.xxx], 对应原位置为 torch.nn[.xxx]
- torch.ao.nn.intrinsic 目录底下都是 fused layer 相关的东西, 而 torch.ao.nn.[qat,quantized] 目录底下都是对应 `nn.Linear`, `nn.Conv2d` 的 layer
- torch.ao.nn.qat.dynamic 不支持 Conv2d
- 缩写规则如下:
- `nn`: torch.ao.nn
- `i`: intrinsic
- `q`: quantized
- `qat`: qat
- `r`: reference

<table>
<tr>
<th> 模块名缩写 (torch/ao/quantization/quantization_mapping.py) </th>
<th> 模块名 (迁移后: torch.ao) </th>
<th> 模块名 (迁移前: torch.quantization) </th>
</tr>
<tr>
<td> nni </td>
<td> torch.ao.nn.intrinsic[.modules.fused.LinearReLU] </td>
<td> torch.nn.intrinsic[.modules.fused.LinearReLU] </td>
</tr>
<tr>
<td> nniq </td>
<td> torch.ao.nn.intrinsic.quantized[.modules.linear_relu.LinearReLU] </td>
<td> torch.nn.intrinsic.quantized[.modules.linear_relu.LinearReLU] </td>
</tr>
<tr>
<td> nniq </td>
<td> torch.ao.nn.intrinsic.quantized.dynamic[.modules.linear_relu.LinearReLU] </td>
<td> torch.nn.intrinsic.quantized.dynamic[.modules.linear_relu.LinearReLU] </td>
</tr>
<tr>
<td> nniqat </td>
<td> torch.ao.nn.intrinsic.qat[.modules.linear_relu.LinearReLU] </td>
<td> torch.nn.intrinsic.qat[.modules.linear_relu.LinearReLU] </td>
</tr>
<tr>
<td> nnq </td>
<td> torch.ao.nn.quantized[.modules.conv.Conv2d] </td>
<td> torch.nn.quantized[.modules.conv.Conv2d] </td>
</tr>
<tr>
<td> nnqr </td>
<td> torch.ao.nn.quantized.reference[.modules.conv.Conv2d] </td>
<td> torch.nn.quantized.reference[.modules.conv.Conv2d] </td>
</tr>
<tr>
<td> nnqd </td>
<td> torch.ao.nn.quantized.dynamic[.modules.conv.Conv2d] </td>
<td> torch.nn.quantized.dynamic[.modules.conv.Conv2d] </td>
</tr>
<tr>
<td> nnqat </td>
<td> torch.ao.nn.qat[.modules.conv.Conv2d] </td>
<td> torch.nn.qat[.modules.conv.Conv2d] </td>
</tr>
<tr>
<td> nnqatd </td>
<td> torch.ao.nn.qat.dynamic[.modules.linear.Linear] </td>
<td> torch.nn.qat.dynamic[.modules.linear.Linear] </td>
</tr>
</table>


最上层的 API:

- torch.ao.quantization.quantize: static quantization
Expand Down Expand Up @@ -840,6 +936,146 @@ if __name__ == "__main__":

## Static Quantization(TODO)

```python
import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub() # QuantStub converts tensors from floating point to quantized
self.conv = torch.nn.Conv2d(1, 1, 1)
self.relu = torch.nn.ReLU()
self.dequant = torch.ao.quantization.DeQuantStub() # DeQuantStub converts tensors from quantized to floating point

def forward(self, x):
x = self.quant(x) # manually specify where tensors will be converted from floating point to quantized in the quantized model
x = self.conv(x)
x = self.relu(x)
x = self.dequant(x) # manually specify where tensors will be converted from quantized to floating point in the quantized model
return x

model_fp32 = M()
model_fp32.eval() # model must be set to eval mode for static quantization logic to work

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')

# Fuse the activations to preceding layers, where applicable.
# This needs to be done manually depending on the model architecture.
# Common fusions include `conv + relu` and `conv + batchnorm + relu`
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])

# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)

# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
res = model_int8(input_fp32)
```

#### `torch.ao.quantization.quantize.prepare` 浅析

现在先分析一下上层接口 `prepare`, [源码](https://github.com/pytorch/pytorch/blob/v2.0.0/torch/ao/quantization/quantize.py#L263):

```python
def prepare(model, inplace=False, allow_list=None,
observer_non_leaf_module_list=None,
prepare_custom_config_dict=None):
r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
Quantization configuration should be assigned preemptively
to individual submodules in `.qconfig` attribute.
The model will be attached with observer or fake quant modules, and qconfig
will be propagated.
Args:
`model`: input model to be modified in-place
`inplace`: carry out model transformations in-place, the original module is mutated
`allow_list`: list of quantizable modules
`observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
`prepare_custom_config_dict`: customization configuration dictionary for prepare function
.. code-block:: python
# Example of prepare_custom_config_dict:
prepare_custom_config_dict = {
# user will manually define the corresponding observed
# module class which has a from_float class method that converts
# float custom module to observed custom module
"float_to_observed_custom_module_class": {
CustomModule: ObservedCustomModule
}
}
"""
torch._C._log_api_usage_once("quantization_api.quantize.prepare")
if prepare_custom_config_dict is None:
# 即返回下面的 _DEFAULT_CUSTOM_CONFIG_DICT, 是一个字典的字典, 包含两个 key:
# "float_to_observed_custom_module_class", "observed_to_quantized_custom_module_class"
# 内层字典包含 nn.Module 的映射关系
prepare_custom_config_dict = get_default_custom_config_dict()
custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})

if not inplace:
model = copy.deepcopy(model)

# TODO: remove allow_list
qconfig_propagation_list = allow_list
if allow_list is None:
qconfig_propagation_list = get_default_qconfig_propagation_list() # ??? 不确定含义, 是一个集合, 涵盖了 nn.Linear
propagate_qconfig_(model, qconfig_dict=None) # 注意在上面的用例中, 在调用 prepare 之前, 手动对 model.qconfig 进行了赋值
# propagate_qconfig_ 的作用是为 model 的 submodule 递归设置好 qconfig 属性:
# 注意如果在调用 prepare 函数之前就手动给 submodule 赋了不同于 model.qconfig 的值, 那么这些手动赋值将被保留, 不受全局的 model.qconfig 的影响

# sanity check common API misusage
if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
warnings.warn("None of the submodule got qconfig applied. Make sure you "
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")

_add_observer_(
model, qconfig_propagation_list, observer_non_leaf_module_list,
custom_module_class_mapping=custom_module_class_mapping)
return model


import torch.nn as nn
_DEFAULT_CUSTOM_CONFIG_DICT = {
'float_to_observed_custom_module_class': {
nn.LSTM: nn.quantizable.LSTM, # torch.ao.nn.quantizable.modules.rnn.LSTM
nn.MultiheadAttention: nn.quantizable.MultiheadAttention, # torch.ao.nn.quantizable.modules.activation.MultiheadAttention
},
'observed_to_quantized_custom_module_class': {
nn.quantizable.LSTM: nn.quantized.LSTM, # torch.ao.nn.quantized.modules.rnn.LSTM
nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, # torch.ao.nn.quantized.modules.activation.MultiheadAttention
}
}
#
```

**`propagate_qconfig_`**



## QAT (tensorflow)

Expand Down

0 comments on commit 065a5d1

Please sign in to comment.