Skip to content

Commit

Permalink
Feat (brevitas_examples): per-row po2 int/float_ocp act quant (#1102)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Nov 27, 2024
1 parent ac0b5f5 commit 5e473c4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import Tensor
import torch.nn as nn

from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad

Expand All @@ -19,16 +20,31 @@ def __init__(
self,
scaling_stats_impl: nn.Module,
dynamic_scaling_broadcastable_fn: Callable,
scaling_stats_input_view_shape_impl: nn.Module) -> None:
scaling_stats_input_view_shape_impl: nn.Module,
restrict_scaling_impl: nn.Module,
restrict_threshold_impl: nn.Module = None,
scaling_min_val=None) -> None:
super(RuntimeDynamicStatsScaling, self).__init__()
# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl
self.scaling_stats_input_view_shape_impl = scaling_stats_input_view_shape_impl
self.stats_impl = scaling_stats_impl
self.dynamic_scaling_broadcastable_fn = dynamic_scaling_broadcastable_fn
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_clamp_scaling = _RestrictClampValue(
scaling_min_val=scaling_min_val, restrict_value_impl=restrict_scaling_impl)
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)

def forward(self, x, threshold) -> Tensor:
shape = x.shape
threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold))
x = self.scaling_stats_input_view_shape_impl(x)
x = self.stats_impl(x) / threshold
x = self.stats_impl(x)
x = self.restrict_clamp_scaling(self.restrict_scaling_pre(x))
x = x / threshold

x = self.dynamic_scaling_broadcastable_fn(x, shape)
return x
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint
from brevitas_examples.common.generative.quantizers import Fp8e4m3WeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFixedPoint
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
Expand Down Expand Up @@ -170,6 +172,8 @@
'sym': Int8DynamicActPerGroupFloat}}},
'po2_scale': {
'stats': {
'per_row': {
'sym': Int8DynamicActPerRowFixedPoint,},
'per_group': {
'sym': MXInt8Act}}}}},
'float': {
Expand All @@ -194,6 +198,8 @@
'dynamic': {
'po2_scale': {
'stats': {
'per_row': {
'sym': FP8e4m3OCPDynamicActPerRowFixedPoint},
'per_group': {
'sym': MXFloat8e4m3Act}}}}},
'float_fnuz': {
Expand Down
21 changes: 21 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from torch import nn

from brevitas.core.function_wrapper.ops_ste import FloorSte
from brevitas.core.function_wrapper.shape import OverOutputFeaturesView
from brevitas.core.function_wrapper.shape import OverTensorView
from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling
Expand All @@ -16,7 +17,9 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import this
from brevitas.inject import value
from brevitas.inject.enum import RestrictValueType
from brevitas.inject.enum import ScalingPerOutputType
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.groupwise_float_parameter_quant import \
GroupwiseWeightFloatQuantProxyFromInjector
from brevitas.proxy.groupwise_float_runtime_quant import GroupwiseActFloatQuantProxyFromInjector
Expand Down Expand Up @@ -78,6 +81,11 @@ class Int8DynamicActPerRowFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
scaling_per_output_channel = True


class Int8DynamicActPerRowFixedPoint(Int8DynamicActPerRowFloat):
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = FloorSte


class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
"""
Symmetric quantizer with per group scale.
Expand Down Expand Up @@ -120,3 +128,16 @@ class Fp8e4m3DynamicActPerGroupFloat(DynamicActProxyMixin, Fp8e4m3ActPerTensorFl
scaling_impl = RuntimeDynamicGroupStatsScaling
scaling_per_output_type = ScalingPerOutputType.GROUP
scaling_stats_op = 'min_max'


class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3ActPerTensorFloat):
"""
Symmetric quantizer with per row dynamic scale.
"""
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverOutputFeaturesView
scaling_stats_op = 'min_max'
scaling_per_output_channel = True
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = FloorSte
proxy_class = ActFloatQuantProxyFromInjector

0 comments on commit 5e473c4

Please sign in to comment.