Skip to content

Commit

Permalink
fix_lsq_input_scale_untrained_bug (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiang-Stan authored Nov 17, 2022
1 parent 49cd8b5 commit b908f29
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
2 changes: 1 addition & 1 deletion sparsebit/quantization/quantizers/lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def calc_qparams(self):
scale = 2 * x_oc.abs().mean(axis=1) / math.sqrt(self.qdesc.qmax)
else:
scale = 2 * x_oc.abs().mean() / math.sqrt(self.qdesc.qmax)
self.scale = nn.Parameter(self._broadcast_qparams(scale)).to(self.device)
self.scale = nn.Parameter(self._broadcast_qparams(scale.to(self.device)))
self.zero_point = self._broadcast_qparams(torch.zeros_like(self.scale))
self.init_params = True
return self.scale, self.zero_point
Expand Down
14 changes: 4 additions & 10 deletions sparsebit/quantization/quantizers/lsq_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,23 @@ def calc_qparams(self):
x_oc = self.observer.get_calibration_data(c_first=True)
assert (
self.is_symmetric
), "LSQ+ only support per-channel-sysmetric quant for weight"
), "LSQ+ only support per-channel-symmetric quant for weight"
mean, std = x_oc.mean(axis=1), x_oc.std(axis=1)
scale = (
2
* torch.maximum((mean - 3 * std).abs(), (mean + 3 * std).abs())
/ (self.qdesc.qmax - self.qdesc.qmin)
)
self.scale = nn.Parameter(self._broadcast_qparams(scale)).to(
self.device
)
self.scale = nn.Parameter(self._broadcast_qparams(scale.to(self.device)))
self.zero_point = self._broadcast_qparams(torch.zeros_like(self.scale))
else:
assert (
not self.is_symmetric
), "LSQ+ only support per-tensor-affine quant for activation"
scale, zero_point = self.observer.calc_qparams()
self.scale = nn.Parameter(self._broadcast_qparams(scale)).to(
self.device
)
self.scale = nn.Parameter(self._broadcast_qparams(scale.to(self.device)))
zero_point = zero_point.clamp(self.qdesc.qmin, self.qdesc.qmax)
self.zero_point = nn.Parameter(self._broadcast_qparams(zero_point)).to(
self.device
)
self.zero_point = nn.Parameter(self._broadcast_qparams(zero_point.to(self.device)))
self.init_params = True
return self.scale, self.zero_point

Expand Down

0 comments on commit b908f29

Please sign in to comment.