diff --git a/examples/quantization_aware_training/cifar10/basecase/main.py b/examples/quantization_aware_training/cifar10/basecase/main.py index a8eababb..3b3e6926 100644 --- a/examples/quantization_aware_training/cifar10/basecase/main.py +++ b/examples/quantization_aware_training/cifar10/basecase/main.py @@ -5,6 +5,7 @@ import time import warnings from enum import Enum +import math import torch import torch.nn as nn @@ -27,7 +28,7 @@ raise NotImplementedError("This example should run on a GPU device.") # 确定在GPU上运行 -config = "qconfig_lsq.yaml" # QAT配置文件——包括量化方式(dorefa/lsq),权重和激活值的量化bit数等 +config = "qconfig_lsq_dampen.yaml" # QAT配置文件——包括量化方式(dorefa/lsq),权重和激活值的量化bit数等 workers = 4 epochs = 200 start_epoch = 0 @@ -38,8 +39,7 @@ print_freq = 100 pretrained = "" qconfig = parse_qconfig(config) -is_pact = qconfig.A.QUANTIZER.TYPE == "pact" -regularizer_lambda = 1e-4 +regularizer_schedule = "cosine" if qconfig.REGULARIZER.TYPE == "dampen" else "keep" model = resnet20(num_classes=10) # 以resnet20作为基础模型 if pretrained: # 可以采用pretrained中保存的模型参数 @@ -109,21 +109,8 @@ optimizer, milestones=[100, 150], last_epoch=start_epoch - 1 ) -# PACT算法中对 alpha 增加 L2-regularization -def get_pact_regularizer_loss(model): - loss = 0 - for n, p in model.named_parameters(): - if "alpha" in n: - loss += (p ** 2).sum() - return loss - -def get_regularizer_loss(model, scale=0): - if is_pact: - return get_pact_regularizer_loss(model) * scale - else: - return torch.tensor(0.).cuda() -def train(train_loader, model, criterion, optimizer, epoch): +def train(train_loader, model, criterion, optimizer, epoch, schedule_value=1.0): batch_time = AverageMeter("Time", ":6.3f") data_time = AverageMeter("Data", ":6.3f") losses = AverageMeter("Loss", ":.4e") @@ -151,7 +138,7 @@ def train(train_loader, model, criterion, optimizer, epoch): # compute output output = model(images) ce_loss = criterion(output, target) - regular_loss = get_regularizer_loss(model, scale=regularizer_lambda) + regular_loss = model.get_regularizer_loss() * schedule_value loss = ce_loss + regular_loss # measure accuracy and record loss @@ -311,12 +298,18 @@ def accuracy(output, target, topk=(1,)): best_acc1 = 0 for epoch in range(start_epoch, epochs): # train for one epoch + if regularizer_schedule == "cosine": + coeff = (epoch - start_epoch + 1) / (epochs - start_epoch) + schedule_value = 1 - 0.5 * (1.0 + math.cos(math.pi * coeff)) + else: + schedule_value = 1.0 train( trainloader, model, criterion, optimizer, epoch, + schedule_value=schedule_value, ) # evaluate on validation set diff --git a/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml b/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml index a191746c..11d9a983 100644 --- a/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml +++ b/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml @@ -9,3 +9,7 @@ A: QUANTIZER: TYPE: pact BIT: 4 +REGULARIZER: + ENABLE: True + TYPE: pact + LAMBDA: 0.0001 diff --git a/sparsebit/quantization/quant_config.py b/sparsebit/quantization/quant_config.py index c5abff97..0e7f9703 100644 --- a/sparsebit/quantization/quant_config.py +++ b/sparsebit/quantization/quant_config.py @@ -38,6 +38,11 @@ _C.A.OBSERVER.LAYOUT = "NCHW" # NCHW / NLC _C.A.SPECIFIC = [] +_C.REGULARIZER = CN() +_C.REGULARIZER.ENABLE = False +_C.REGULARIZER.TYPE = "" +_C.REGULARIZER.LAMBDA = 0.0 + def parse_qconfig(cfg_file): qconfig = _parse_config(cfg_file, default_cfg=_C) diff --git a/sparsebit/quantization/quant_model.py b/sparsebit/quantization/quant_model.py index 51675d57..468af277 100644 --- a/sparsebit/quantization/quant_model.py +++ b/sparsebit/quantization/quant_model.py @@ -20,6 +20,7 @@ from sparsebit.quantization.quantizers import Quantizer from sparsebit.quantization.tools import QuantizationErrorProfiler from sparsebit.quantization.converters import simplify, fuse_operations +from sparsebit.quantization.regularizers import build_regularizer __all__ = ["QuantModel"] @@ -34,6 +35,7 @@ def __init__(self, model: nn.Module, config): self._run_simplifiers() self._convert2quantmodule() self._build_quantizer() + self._build_regularizer() self._run_fuse_operations() def _convert2quantmodule(self): @@ -119,11 +121,17 @@ def _sub_build(src, module_name): update_config(_config, "A", _sub_build(self.cfg.A, node.target)) identity_module.build_quantizer(_config) + def _build_regularizer(self): + if self.cfg.REGULARIZER.ENABLE: + self.regularizer = build_regularizer(self.cfg) + else: + self.regularizer = None + def _run_simplifiers(self): self.model = simplify(self.model) def _run_fuse_operations(self): - if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn + if self.cfg.SCHEDULE.BN_TUNING: # first disable fuse bn update_config(self.cfg.SCHEDULE, "FUSE_BN", False) self.model = fuse_operations(self.model, self.cfg.SCHEDULE) self.model.graph.print_tabular() @@ -144,7 +152,9 @@ def batchnorm_tuning(self): yield self.model.eval() update_config(self.cfg.SCHEDULE, "FUSE_BN", True) - self.model = fuse_operations(self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"]) + self.model = fuse_operations( + self.model, self.cfg.SCHEDULE, custom_fuse_list=["fuse_bn"] + ) self.set_quant(w_quant=False, a_quant=False) def prepare_calibration(self): @@ -210,6 +220,12 @@ def set_quant(self, w_quant=False, a_quant=False): if isinstance(m, QuantOpr): m.set_quant(w_quant, a_quant) + def get_regularizer_loss(self): + if self.regularizer is None: + return torch.tensor(0.).to(self.device) + else: + return self.regularizer(self.model) + def export_onnx( self, dummy_data, diff --git a/sparsebit/quantization/regularizers/__init__.py b/sparsebit/quantization/regularizers/__init__.py index 6da60002..2fda5946 100644 --- a/sparsebit/quantization/regularizers/__init__.py +++ b/sparsebit/quantization/regularizers/__init__.py @@ -7,7 +7,7 @@ def register_regularizer(regularizer): from .base import Regularizer -from . import dampen +from . import dampen, pact def build_regularizer(config):