diff --git a/examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml b/examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml new file mode 100644 index 00000000..6f49a510 --- /dev/null +++ b/examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml @@ -0,0 +1,15 @@ +BACKEND: virtual +W: + QSCHEME: per-channel-symmetric + QUANTIZER: + TYPE: lsq + BIT: 4 +A: + QSCHEME: per-tensor-affine + QUANTIZER: + TYPE: lsq + BIT: 4 +REGULARIZER: + ENABLE: True + TYPE: dampen + LAMBDA: 0.01 diff --git a/sparsebit/quantization/regularizers/__init__.py b/sparsebit/quantization/regularizers/__init__.py new file mode 100644 index 00000000..6da60002 --- /dev/null +++ b/sparsebit/quantization/regularizers/__init__.py @@ -0,0 +1,15 @@ +REGULARIZERS_MAP = {} + + +def register_regularizer(regularizer): + REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer + return regularizer + + +from .base import Regularizer +from . import dampen + + +def build_regularizer(config): + regularizer = REGULARIZERS_MAP[config.REGULARIZER.TYPE.lower()](config) + return regularizer diff --git a/sparsebit/quantization/regularizers/base.py b/sparsebit/quantization/regularizers/base.py new file mode 100644 index 00000000..45e3cb96 --- /dev/null +++ b/sparsebit/quantization/regularizers/base.py @@ -0,0 +1,6 @@ +class Regularizer(object): + def __init__(self, config): + self.config = config + + def __call__(self): + pass diff --git a/sparsebit/quantization/regularizers/dampen.py b/sparsebit/quantization/regularizers/dampen.py new file mode 100644 index 00000000..6408738a --- /dev/null +++ b/sparsebit/quantization/regularizers/dampen.py @@ -0,0 +1,49 @@ +import torch + +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "Dampen" + + def __init__(self, config): + super(Regularizer, self).__init__(config) + self.config = config + self._lambda = config.REGULARIZER.LAMBDA + + def _get_loss(self, x, quantizer): + + x_q = quantizer(x) + + qmin, qmax = quantizer.qdesc.qrange + + scale, zero_point = quantizer._qparams_preprocess(x) + + scale = scale.detach() + zero_point = zero_point.detach() + + min_val = (qmin - zero_point) * scale + + max_val = (qmax - zero_point) * scale + + x_c = torch.min(torch.max(x, min_val), max_val) + + loss = (x_q - x_c) ** 2 + + loss = loss.sum() + + return loss + + def __call__(self, model): + loss = 0.0 + for n, m in model.named_modules(): + if ( + hasattr(m, "weight") + and hasattr(m, "weight_quantizer") + and m.weight_quantizer + and m.weight_quantizer.is_enable + ): + loss += self._get_loss(m.weight, m.weight_quantizer) + return loss * self._lambda diff --git a/sparsebit/quantization/regularizers/pact.py b/sparsebit/quantization/regularizers/pact.py new file mode 100644 index 00000000..feb49ed1 --- /dev/null +++ b/sparsebit/quantization/regularizers/pact.py @@ -0,0 +1,21 @@ +import torch + +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "Pact" + + def __init__(self, config): + super(Regularizer, self).__init__(config) + self.config = config + self._lambda = config.REGULARIZER.LAMBDA + + def __call__(self, model): + loss = 0.0 + for n, p in model.named_parameters(): + if "alpha" in n: + loss += (p ** 2).sum() + return loss * self._lambda