diff --git a/_drafts/2023-11-27-pytorch-quantization.md b/_drafts/2023-11-27-pytorch-quantization.md index deb98df..80c5d8b 100644 --- a/_drafts/2023-11-27-pytorch-quantization.md +++ b/_drafts/2023-11-27-pytorch-quantization.md @@ -150,6 +150,102 @@ Pytorch 原生支持的量化算法因为只支持 CPU, 所以应该暂时没啥 本节主要梳理 pytorch 关于量化的源码目录及 API 接口的层次关系, 尤其关注上层接口, 主要是梳理 [A4](https://pytorch.org/docs/2.1/quantization-support.html), 但绝非完整介绍 +源码目录节选 + +``` +torch/ao/ + - nn/ + - intrinsic/ + - modules/ + - qat/modules/ + - quantized/ + - modules/ + - dynamic/modules/ + - qat/ + - modules/ + - dynamic/modules/ + - quantizable/modules/ # 似乎不在下面的模块缩写中?? + - quantized/ + - modules/ + - dynamic/modules/ + - reference/modules/ + - functional.py + - quantization/ + - fx/ + - quantize.py + - quantize_fx.py + - qconfig.py + - qconfig_mapping.py + - observer.py + - ... +``` + +关于 Module 的缩写及位置, 总结如下, 具体看下表 + +- 新位置为 torch.ao.nn[.xxx], 对应原位置为 torch.nn[.xxx] +- torch.ao.nn.intrinsic 目录底下都是 fused layer 相关的东西, 而 torch.ao.nn.[qat,quantized] 目录底下都是对应 `nn.Linear`, `nn.Conv2d` 的 layer +- torch.ao.nn.qat.dynamic 不支持 Conv2d +- 缩写规则如下: + - `nn`: torch.ao.nn + - `i`: intrinsic + - `q`: quantized + - `qat`: qat + - `r`: reference + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
模块名缩写 (torch/ao/quantization/quantization_mapping.py) 模块名 (迁移后: torch.ao) 模块名 (迁移前: torch.quantization)
nni torch.ao.nn.intrinsic[.modules.fused.LinearReLU] torch.nn.intrinsic[.modules.fused.LinearReLU]
nniq torch.ao.nn.intrinsic.quantized[.modules.linear_relu.LinearReLU] torch.nn.intrinsic.quantized[.modules.linear_relu.LinearReLU]
nniq torch.ao.nn.intrinsic.quantized.dynamic[.modules.linear_relu.LinearReLU] torch.nn.intrinsic.quantized.dynamic[.modules.linear_relu.LinearReLU]
nniqat torch.ao.nn.intrinsic.qat[.modules.linear_relu.LinearReLU] torch.nn.intrinsic.qat[.modules.linear_relu.LinearReLU]
nnq torch.ao.nn.quantized[.modules.conv.Conv2d] torch.nn.quantized[.modules.conv.Conv2d]
nnqr torch.ao.nn.quantized.reference[.modules.conv.Conv2d] torch.nn.quantized.reference[.modules.conv.Conv2d]
nnqd torch.ao.nn.quantized.dynamic[.modules.conv.Conv2d] torch.nn.quantized.dynamic[.modules.conv.Conv2d]
nnqat torch.ao.nn.qat[.modules.conv.Conv2d] torch.nn.qat[.modules.conv.Conv2d]
nnqatd torch.ao.nn.qat.dynamic[.modules.linear.Linear] torch.nn.qat.dynamic[.modules.linear.Linear]
+ + 最上层的 API: - torch.ao.quantization.quantize: static quantization @@ -840,6 +936,146 @@ if __name__ == "__main__": ## Static Quantization(TODO) +```python +import torch + +# define a floating point model where some layers could be statically quantized +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.quant = torch.ao.quantization.QuantStub() # QuantStub converts tensors from floating point to quantized + self.conv = torch.nn.Conv2d(1, 1, 1) + self.relu = torch.nn.ReLU() + self.dequant = torch.ao.quantization.DeQuantStub() # DeQuantStub converts tensors from quantized to floating point + + def forward(self, x): + x = self.quant(x) # manually specify where tensors will be converted from floating point to quantized in the quantized model + x = self.conv(x) + x = self.relu(x) + x = self.dequant(x) # manually specify where tensors will be converted from quantized to floating point in the quantized model + return x + +model_fp32 = M() +model_fp32.eval() # model must be set to eval mode for static quantization logic to work + +# attach a global qconfig, which contains information about what kind +# of observers to attach. Use 'x86' for server inference and 'qnnpack' +# for mobile inference. Other quantization configurations such as selecting +# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques +# can be specified here. +# Note: the old 'fbgemm' is still available but 'x86' is the recommended default +# for server inference. +# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') +model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86') + +# Fuse the activations to preceding layers, where applicable. +# This needs to be done manually depending on the model architecture. +# Common fusions include `conv + relu` and `conv + batchnorm + relu` +model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']]) + +# Prepare the model for static quantization. This inserts observers in +# the model that will observe activation tensors during calibration. +model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused) + +# calibrate the prepared model to determine quantization parameters for activations +# in a real world setting, the calibration would be done with a representative dataset +input_fp32 = torch.randn(4, 1, 4, 4) +model_fp32_prepared(input_fp32) + +# Convert the observed model to a quantized model. This does several things: +# quantizes the weights, computes and stores the scale and bias value to be +# used with each activation tensor, and replaces key operators with quantized +# implementations. +model_int8 = torch.ao.quantization.convert(model_fp32_prepared) + +# run the model, relevant calculations will happen in int8 +res = model_int8(input_fp32) +``` + +#### `torch.ao.quantization.quantize.prepare` 浅析 + +现在先分析一下上层接口 `prepare`, [源码](https://github.com/pytorch/pytorch/blob/v2.0.0/torch/ao/quantization/quantize.py#L263): + +```python +def prepare(model, inplace=False, allow_list=None, + observer_non_leaf_module_list=None, + prepare_custom_config_dict=None): + r"""Prepares a copy of the model for quantization calibration or quantization-aware training. + + Quantization configuration should be assigned preemptively + to individual submodules in `.qconfig` attribute. + + The model will be attached with observer or fake quant modules, and qconfig + will be propagated. + + Args: + `model`: input model to be modified in-place + `inplace`: carry out model transformations in-place, the original module is mutated + `allow_list`: list of quantizable modules + `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer + `prepare_custom_config_dict`: customization configuration dictionary for prepare function + + .. code-block:: python + + # Example of prepare_custom_config_dict: + prepare_custom_config_dict = { + # user will manually define the corresponding observed + # module class which has a from_float class method that converts + # float custom module to observed custom module + "float_to_observed_custom_module_class": { + CustomModule: ObservedCustomModule + } + } + + """ + torch._C._log_api_usage_once("quantization_api.quantize.prepare") + if prepare_custom_config_dict is None: + # 即返回下面的 _DEFAULT_CUSTOM_CONFIG_DICT, 是一个字典的字典, 包含两个 key: + # "float_to_observed_custom_module_class", "observed_to_quantized_custom_module_class" + # 内层字典包含 nn.Module 的映射关系 + prepare_custom_config_dict = get_default_custom_config_dict() + custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) + + if not inplace: + model = copy.deepcopy(model) + + # TODO: remove allow_list + qconfig_propagation_list = allow_list + if allow_list is None: + qconfig_propagation_list = get_default_qconfig_propagation_list() # ??? 不确定含义, 是一个集合, 涵盖了 nn.Linear + propagate_qconfig_(model, qconfig_dict=None) # 注意在上面的用例中, 在调用 prepare 之前, 手动对 model.qconfig 进行了赋值 + # propagate_qconfig_ 的作用是为 model 的 submodule 递归设置好 qconfig 属性: + # 注意如果在调用 prepare 函数之前就手动给 submodule 赋了不同于 model.qconfig 的值, 那么这些手动赋值将被保留, 不受全局的 model.qconfig 的影响 + + # sanity check common API misusage + if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()): + warnings.warn("None of the submodule got qconfig applied. Make sure you " + "passed correct configuration through `qconfig_dict` or " + "by assigning the `.qconfig` attribute directly on submodules") + + _add_observer_( + model, qconfig_propagation_list, observer_non_leaf_module_list, + custom_module_class_mapping=custom_module_class_mapping) + return model + + +import torch.nn as nn +_DEFAULT_CUSTOM_CONFIG_DICT = { + 'float_to_observed_custom_module_class': { + nn.LSTM: nn.quantizable.LSTM, # torch.ao.nn.quantizable.modules.rnn.LSTM + nn.MultiheadAttention: nn.quantizable.MultiheadAttention, # torch.ao.nn.quantizable.modules.activation.MultiheadAttention + }, + 'observed_to_quantized_custom_module_class': { + nn.quantizable.LSTM: nn.quantized.LSTM, # torch.ao.nn.quantized.modules.rnn.LSTM + nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention, # torch.ao.nn.quantized.modules.activation.MultiheadAttention + } +} +# +``` + +**`propagate_qconfig_`** + + ## QAT (tensorflow)