Skip to content

Commit

Permalink
Fix (examples/export): correct groupwise export (#832)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Feb 8, 2024
1 parent 6dd41d5 commit 4bd515a
Showing 1 changed file with 53 additions and 22 deletions.
75 changes: 53 additions & 22 deletions src/brevitas_examples/llm/llm_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,41 @@
from abc import ABC
from abc import abstractmethod
from contextlib import contextmanager
import math
import warnings

import numpy as np
import torch
from torch.nn import Module

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.manager import _set_layer_export_handler
from brevitas.export.manager import _set_layer_export_mode
from brevitas.export.manager import _set_proxy_export_handler
from brevitas.export.manager import _set_proxy_export_mode
from brevitas.export.manager import BaseManager
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.nn import QuantLinear
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector


# TODO: Improve Groupwise export
def clip_kwargs(narrow, signed, bit_width):
if narrow or bit_width != 8. and bit_width != 32.:
if signed and (bit_width < 8. or narrow and bit_width <= 8.):
dtype = torch.int8
elif not signed and (bit_width < 8. or narrow and bit_width <= 8.):
dtype = torch.uint8
elif signed and (bit_width < 32. or narrow and bit_width <= 32.):
dtype = torch.int32
else:
raise RuntimeError(f"Sign {signed} and bit width {bit_width} not supported for export.")
return {
'min_val': min_int(signed, narrow, bit_width).to(dtype),
'max_val': max_int(signed, narrow, bit_width).to(dtype)}
else:
return None


class WeightBlockQuantHandlerBase(BaseHandler, ABC):
handled_layer = WeightQuantProxyFromInjector

Expand Down Expand Up @@ -79,6 +97,8 @@ def prepare_for_export(self, module):
assert self.bit_width <= 8., "Only 8b or lower is supported."
quant_layer = module.tracked_module_list[0]
quant_weight = quant_layer.quant_weight()
signed = module.is_signed
self.int_dtype = torch.int8 if signed else torch.uint8
self.dtype = quant_weight.value.dtype
self.scale = self.export_scale(module, self.bit_width).detach()
self.expanded_scaling_shape = self.scaling_impl(module).expanded_scaling_shape
Expand All @@ -87,30 +107,41 @@ def prepare_for_export(self, module):
self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach()
self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape
self.reshaped_zero_point_shape = self.zero_point_impl(module).reshaped_zero_point_shape

def forward(self, x):
scale = self.scale.expand(self.expanded_scaling_shape).contiguous()
# contiguous above is to avoid the reshape below being mapped to a unsafe view
scale = scale.view(self.reshaped_scaling_shape)

# Explicitly export custom Q/DQ to avoid aggressive constant folding
x = x / scale
if self.zero_point is not None:
zero_point = self.zero_point.expand(self.expanded_zero_point_shape).contiguous()
# contiguous above is to avoid the reshape below being mapped to a unsafe view
zero_point = zero_point.view(self.reshaped_zero_point_shape)
# avoid unsigned subtraction
x = x.to(self.dtype) + zero_point.to(self.dtype)
else:
zero_point = torch.zeros_like(scale)
self.zero_point = None

int_weight = torch.round(x)
if self.zero_point is not None:
int_weight = int_weight.to(self.dtype) - zero_point.to(self.dtype)
self.clip_kwargs = clip_kwargs(
module.is_narrow_range, module.is_signed, quant_weight.bit_width)

quant_weight = int_weight * scale
def forward(self, x):
scale = self.scale
zero_point = self.zero_point
bit_width = self.bit_width
# If zero point is not defined, it's all zeros
if self.zero_point is None:
zero_point = torch.zeros_like(scale)
else:
zero_point = self.zero_point

# QCDQ
x = x.view(self.expanded_scaling_shape)
x = torch.round((x / scale) + zero_point).type(self.int_dtype)
if self.clip_kwargs is not None:
x = torch.clip(x, min=self.clip_kwargs['min_val'], max=self.clip_kwargs['max_val'])
x = (x.type(self.dtype) - zero_point) * scale

# Fix shape post quantization
scale = scale.expand(self.expanded_scaling_shape).contiguous().view(
self.reshaped_scaling_shape)
# If zero_point is not defined, propagate same shape as scale
if self.zero_point is None:
zero_point = torch.zeros_like(scale).type(self.int_dtype)
else:
zero_point = zero_point.expand(self.expanded_zero_point_shape).contiguous().view(
self.reshaped_zero_point_shape).type(self.int_dtype)
x = x.view(self.reshaped_scaling_shape)

return quant_weight, scale, zero_point, self.bit_width
return x, scale, zero_point, bit_width


class LinearWeightBlockQuantHandler(WeightBlockQuantHandlerBase, ABC):
Expand Down

0 comments on commit 4bd515a

Please sign in to comment.