Skip to content

Commit

Permalink
add regularizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lz02k committed Nov 17, 2022
1 parent b908f29 commit 126c30d
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
BACKEND: virtual
W:
QSCHEME: per-channel-symmetric
QUANTIZER:
TYPE: lsq
BIT: 4
A:
QSCHEME: per-tensor-affine
QUANTIZER:
TYPE: lsq
BIT: 4
REGULARIZER:
ENABLE: True
TYPE: dampen
LAMBDA: 0.01
15 changes: 15 additions & 0 deletions sparsebit/quantization/regularizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
REGULARIZERS_MAP = {}


def register_regularizer(regularizer):
REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer
return regularizer


from .base import Regularizer
from . import dampen


def build_regularizer(config):
regularizer = REGULARIZERS_MAP[config.REGULARIZER.TYPE.lower()](config)
return regularizer
6 changes: 6 additions & 0 deletions sparsebit/quantization/regularizers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class Regularizer(object):
def __init__(self, config):
self.config = config

def __call__(self):
pass
49 changes: 49 additions & 0 deletions sparsebit/quantization/regularizers/dampen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch

from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "Dampen"

def __init__(self, config):
super(Regularizer, self).__init__(config)
self.config = config
self._lambda = config.REGULARIZER.LAMBDA

def _get_loss(self, x, quantizer):

x_q = quantizer(x)

qmin, qmax = quantizer.qdesc.qrange

scale, zero_point = quantizer._qparams_preprocess(x)

scale = scale.detach()
zero_point = zero_point.detach()

min_val = (qmin - zero_point) * scale

max_val = (qmax - zero_point) * scale

x_c = torch.min(torch.max(x, min_val), max_val)

loss = (x_q - x_c) ** 2

loss = loss.sum()

return loss

def __call__(self, model):
loss = 0.0
for n, m in model.named_modules():
if (
hasattr(m, "weight")
and hasattr(m, "weight_quantizer")
and m.weight_quantizer
and m.weight_quantizer.is_enable
):
loss += self._get_loss(m.weight, m.weight_quantizer)
return loss * self._lambda
21 changes: 21 additions & 0 deletions sparsebit/quantization/regularizers/pact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch

from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "Pact"

def __init__(self, config):
super(Regularizer, self).__init__(config)
self.config = config
self._lambda = config.REGULARIZER.LAMBDA

def __call__(self, model):
loss = 0.0
for n, p in model.named_parameters():
if "alpha" in n:
loss += (p ** 2).sum()
return loss * self._lambda

0 comments on commit 126c30d

Please sign in to comment.