diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index e175e4445..e8b42312a 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -195,8 +195,9 @@ def forward(self, x): tensor_shape = x.shape tensor_shape_list = list(tensor_shape) - tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size) - block_dim = self.group_dim + 1 if self.group_dim != -1 else -1 + tensor_shape_list[self.group_dim] = ( + tensor_shape_list[self.group_dim] + self.group_size - 1) // self.group_size + block_dim = self.group_dim + 1 if self.group_dim != -1 else len(tensor_shape_list) tensor_shape_list.insert(block_dim, self.group_size) x = x.view(tensor_shape_list) return x diff --git a/src/brevitas/core/restrict_val.py b/src/brevitas/core/restrict_val.py index 59b3fe8ec..7d6d83231 100644 --- a/src/brevitas/core/restrict_val.py +++ b/src/brevitas/core/restrict_val.py @@ -24,7 +24,10 @@ class _RestrictClampValue(brevitas.jit.ScriptModule): - def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Optional[Module]): + def __init__( + self, + scaling_min_val: Optional[float] = None, + restrict_value_impl: Optional[Module] = None): super(_RestrictClampValue, self).__init__() if scaling_min_val is not None and scaling_min_val != 0: self.clamp_min_ste = ScalarClampMinSte(scaling_min_val) @@ -90,9 +93,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor) -> Tensor: return x @@ -116,9 +116,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.power_of_two(x) @@ -143,9 +140,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return Identity() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x / threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) @@ -171,9 +165,6 @@ def restrict_init_module(self): def restrict_init_inplace_module(self): return InplaceLogTwo() - def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor: - return x - threshold - @brevitas.jit.script_method def forward(self, x: Tensor): x = self.float_to_int_impl(x) diff --git a/src/brevitas/core/scaling/runtime.py b/src/brevitas/core/scaling/runtime.py index f11eb1f2a..09f891ed7 100644 --- a/src/brevitas/core/scaling/runtime.py +++ b/src/brevitas/core/scaling/runtime.py @@ -30,12 +30,18 @@ def __init__( tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, affine_rescaling: bool = False, affine_shift_scale: bool = False, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(StatsFromParameterScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.parameter_list_stats = _ParameterListStats( scaling_stats_impl, scaling_shape, @@ -44,6 +50,7 @@ def __init__( tracked_parameter_list) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, + restrict_threshold_impl, scaling_shape, scaling_min_val, affine_rescaling, @@ -65,6 +72,7 @@ class _StatsScaling(brevitas.jit.ScriptModule): def __init__( self, restrict_scaling_impl: Module, + restrict_threshold_impl: Module, scaling_shape: Tuple[int, ...], scaling_min_val: Optional[float], affine_rescaling: bool, @@ -81,19 +89,22 @@ def __init__( else: self.affine_rescaling = Identity() self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() - self.restrict_scaling_impl = restrict_scaling_impl + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward( self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor: if threshold is None: threshold = torch.ones(1).type_as(stats) - threshold = self.restrict_scaling_pre(threshold) + threshold = self.restrict_threshold_pre(threshold) + threshold = self.restrict_clamp_threshold(threshold) stats = self.restrict_scaling_pre(stats) - stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold) stats = self.affine_rescaling(stats) stats = self.restrict_clamp_scaling(stats) + stats = stats / threshold return stats @@ -107,12 +118,17 @@ def __init__( affine_rescaling: bool = False, affine_shift_scale: bool = False, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_stats_momentum: float = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(RuntimeStatsScaling, self).__init__() + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.runtime_stats = _RuntimeStats( scaling_stats_impl, scaling_shape, @@ -122,6 +138,7 @@ def __init__( device) self.stats_scaling_impl = _StatsScaling( restrict_scaling_impl, + restrict_threshold_impl, scaling_shape, scaling_min_val, affine_rescaling, @@ -173,20 +190,32 @@ def _load_from_state_dict( class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule): def __init__( - self, - group_size: int, - group_dim: int, - input_view_impl: Module, - scaling_stats_impl: Module, - scaling_min_val: Optional[float], - restrict_scaling_impl: Module = FloatRestrictValue()) -> None: + self, + group_size: int, + group_dim: int, + input_view_impl: Module, + scaling_stats_impl: Module, + scaling_min_val: Optional[float], + restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None) -> None: super(RuntimeDynamicGroupStatsScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.group_size = group_size self.group_dim = group_dim self.scaling_stats_impl = scaling_stats_impl self.scaling_min_val = scaling_min_val self.input_view_impl = input_view_impl self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) + self.restrict_scaling_pre = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module( + ) + self.restrict_threshold_pre = self.restrict_clamp_threshold.restrict_value_impl.restrict_init_module( + ) @brevitas.jit.script_method def forward( @@ -196,7 +225,10 @@ def forward( if threshold is None: threshold = torch.ones(1).type_as(stats_input) stats_input_reshaped = self.input_view_impl(stats_input) - out = self.scaling_stats_impl(stats_input_reshaped) / threshold + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) + out = self.scaling_stats_impl(stats_input_reshaped) + # Apply log scaling + out = self.restrict_scaling_pre(out) # Scaling min val - out = self.restrict_clamp_scaling(out) + out = self.restrict_clamp_scaling(out) / threshold return out diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 4917b859a..13ead5afc 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -62,20 +62,27 @@ def __init__( self, scaling_init: Union[float, Tensor], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ConstScaling, self).__init__() + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) if isinstance(scaling_init, Tensor): scaling_init = scaling_init.to(device=device, dtype=dtype) scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(scaling_init.detach()) else: scaling_init = restrict_scaling_impl.restrict_init_float(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device)) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -83,7 +90,7 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling # For IntQuant, this is no-op, retrocompatible. - threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) restricted_value = self.restrict_clamp_scaling(self.value()) restricted_value = restricted_value / threshold return restricted_value @@ -133,11 +140,16 @@ def __init__( scaling_init: Union[float, Tensor], scaling_shape: Optional[Tuple[int, ...]] = None, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ParameterScaling, self).__init__() + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + if (isinstance(scaling_init, Tensor) and scaling_shape is not None and scaling_init.shape != SCALAR_SHAPE and scaling_init.shape != scaling_shape): raise RuntimeError("scaling_init.shape is non-scalar and != from scaling_shape.") @@ -149,12 +161,14 @@ def __init__( scaling_init = torch.tensor(scaling_init, dtype=dtype, device=device) scaling_init = restrict_scaling_impl.restrict_init_tensor(scaling_init) - self.restrict_init_module = restrict_scaling_impl.restrict_init_module() if scaling_init.shape == SCALAR_SHAPE and scaling_shape is not None: scaling_init = torch.full(scaling_shape, scaling_init, dtype=dtype, device=device) self.value = Parameter(scaling_init) self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl) + self.restrict_clamp_threshold = _RestrictClampValue( + restrict_value_impl=restrict_threshold_impl) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -162,7 +176,7 @@ def forward(self, placeholder: Tensor, threshold: Optional[Tensor] = None) -> Te threshold = torch.ones(1).type_as(placeholder) # We first apply any restriction to scaling # For IntQuant, this is no-op, retrocompatible. - threshold = self.restrict_clamp_scaling(self.restrict_init_module(threshold)) + threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold)) value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) return value / threshold @@ -193,6 +207,7 @@ def __init__( tracked_parameter_list: List[torch.nn.Parameter], scaling_shape: Tuple[int, ...], restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: @@ -203,26 +218,37 @@ def __init__( scaling_stats_input_view_shape_impl, scaling_stats_input_concat_dim, tracked_parameter_list) - self.restrict_scaling_impl = restrict_scaling_impl + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.stats_scaling_impl = _StatsScaling( - restrict_scaling_impl, scaling_shape, scaling_min_val, False, False, dtype, device) + restrict_scaling_impl, + restrict_threshold_impl, + scaling_shape, + scaling_min_val, + False, + False, + dtype, + device) + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() + self.restrict_inplace_scaling_pre = restrict_scaling_impl.restrict_init_inplace_module() + self.init_done: bool = brevitas.jit.Attribute(False, bool) self.local_loss_mode: bool = brevitas.jit.Attribute(False, bool) - self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() + self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) @brevitas.jit.script_method def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor: if threshold is None: threshold = torch.ones(1).type_as(ignored) - # Threshold division must happen after we update self.value, but before we apply restrict_preproces - # This is because we don't want to store a parameter dependant on a runtime value (threshold) - # And because restrict needs to happen after we divide by threshold if self.init_done: - threshold = self.restrict_inplace_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + threshold = self.stats_scaling_impl.restrict_clamp_threshold( + self.restrict_threshold_pre(threshold)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold return value else: stats = self.parameter_list_stats() @@ -230,11 +256,12 @@ def forward(self, ignored: Tensor, threshold: Optional[Tensor] = None) -> Tensor stats = stats + 0. * self.value if self.local_loss_mode: return self.stats_scaling_impl(stats, threshold) - stats = self.restrict_inplace_preprocess(stats) - threshold = self.restrict_inplace_preprocess(threshold) + stats = self.restrict_inplace_scaling_pre(stats) + threshold = self.stats_scaling_impl.restrict_clamp_threshold( + self.restrict_threshold_pre(threshold)) inplace_tensor_mul(self.value.detach(), stats) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(value)) + value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value)) + value = value / threshold self.init_done = True return value @@ -312,12 +339,18 @@ def __init__( scaling_stats_input_view_shape_impl: Module = OverBatchOverTensorView(), scaling_shape: Tuple[int, ...] = SCALAR_SHAPE, restrict_scaling_impl: Module = FloatRestrictValue(), + restrict_threshold_impl: Optional[Module] = None, scaling_stats_momentum: Optional[float] = DEFAULT_MOMENTUM, scaling_min_val: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super(ParameterFromRuntimeStatsScaling, self).__init__() assert collect_stats_steps > 0, 'Steps should be more than 0' + + # Ensure retro-compatibility with shared threshold/scaling restrict + if restrict_threshold_impl is None: + restrict_threshold_impl = restrict_scaling_impl + self.collect_stats_steps: int = brevitas.jit.Attribute(collect_stats_steps, int) self.counter: int = brevitas.jit.Attribute(0, int) self.stats_input_view_shape_impl = scaling_stats_input_view_shape_impl @@ -326,19 +359,17 @@ def __init__( scaling_stats_momentum, Optional[float]) self.register_buffer('buffer', torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device)) - self.restrict_scaling_impl = restrict_scaling_impl self.restrict_scaling = _RestrictValue(restrict_scaling_impl) + self.restrict_threshold = _RestrictValue(restrict_threshold_impl) self.clamp_scaling = _ClampValue(scaling_min_val) self.local_loss_mode: bool = brevitas.jit.Attribute( False, bool) # required to support MSE eval or variants self.restrict_inplace_preprocess = restrict_scaling_impl.restrict_init_inplace_module() - self.restrict_preprocess = restrict_scaling_impl.restrict_init_module() + self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() + self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() @brevitas.jit.script_method def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: - # Threshold division must happen after we update self.value, but before we apply restrict_preproces - # This is because we don't want to store a parameter dependent on a runtime value (threshold) - # And because restrict needs to happen after we divide by threshold if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) @@ -360,14 +391,16 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold self.counter = self.counter + 1 - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + return abs_binary_sign_grad(value) else: - threshold = self.restrict_preprocess(threshold) - value = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(value))) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) + value = self.clamp_scaling(self.restrict_scaling(self.value)) + value = value / threshold + return abs_binary_sign_grad(value) @brevitas.jit.script_method def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Tensor: @@ -378,12 +411,14 @@ def forward(self, stats_input: Tensor, threshold: Optional[Tensor] = None) -> Te return self.training_forward(stats_input, threshold) else: if self.counter <= self.collect_stats_steps: - out = self.buffer / threshold - out = self.restrict_preprocess(out) + out = self.buffer + out = self.restrict_scaling_pre(out) else: - threshold = self.restrict_preprocess(threshold) - out = self.restrict_scaling_impl.combine_scale_threshold(self.value, threshold) - out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out))) + out = self.value + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) + out = self.clamp_scaling(self.restrict_scaling(out)) + out = out / threshold + out = abs_binary_sign_grad(self.clamp_scaling(out)) return out def state_dict(self, destination=None, prefix='', keep_vars=False): @@ -396,7 +431,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): del output_dict[prefix + 'value'] # Save buffer into value for any non-zero number of collection steps elif self.counter <= self.collect_stats_steps: - output_dict[prefix + 'value'] = self.restrict_preprocess(self.buffer) + output_dict[prefix + 'value'] = self.restrict_scaling_pre(self.buffer) return output_dict def _load_from_state_dict( diff --git a/src/brevitas/quant/experimental/mx_quant_ocp.py b/src/brevitas/quant/experimental/mx_quant_ocp.py index 2299c1783..5900fe663 100644 --- a/src/brevitas/quant/experimental/mx_quant_ocp.py +++ b/src/brevitas/quant/experimental/mx_quant_ocp.py @@ -1,9 +1,13 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from dependencies import this from dependencies import value from brevitas.core.function_wrapper.ops_ste import CeilSte +from brevitas.core.function_wrapper.ops_ste import FloorSte +from brevitas.core.restrict_val import PowerOfTwo +from brevitas.core.restrict_val import PowerOfTwoRestrictValue from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling from brevitas.inject import ExtendedInjector from brevitas.inject.enum import RestrictValueType @@ -43,17 +47,28 @@ class GroupwiseActProxyMixin(ExtendedInjector): proxy_class = GroupwiseActQuantProxyFromInjector +class RestrictThresholdMixin(ExtendedInjector): + restrict_value_float_to_int_impl = FloorSte + restrict_scaling_impl = PowerOfTwoRestrictValue + + class MXWeightMixin(ExtendedInjector): + threshold_mixin = RestrictThresholdMixin group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO - restrict_value_float_to_int_impl = CeilSte + restrict_value_float_to_int_impl = FloorSte scaling_per_output_type = ScalingPerOutputType.GROUP + @value + def restrict_threshold_impl(): + return this.threshold_mixin.restrict_scaling_impl + class MXActMixin(ExtendedInjector): + threshold_mixin = RestrictThresholdMixin group_size = 32 restrict_scaling_type = RestrictValueType.POWER_OF_TWO - restrict_value_float_to_int_impl = CeilSte + restrict_value_float_to_int_impl = FloorSte scaling_impl = RuntimeDynamicGroupStatsScaling scaling_per_output_type = ScalingPerOutputType.GROUP @@ -65,6 +80,10 @@ def stats_reduce_dim(group_dim): else: return group_dim + 1 + @value + def restrict_threshold_impl(): + return this.threshold_mixin.restrict_scaling_impl + class MXFloat8e4m3Weight(MXWeightMixin, GroupwiseWeightFloatProxyMixin, diff --git a/src/brevitas/quant/solver/common.py b/src/brevitas/quant/solver/common.py index 4d46cc704..a4930e43d 100644 --- a/src/brevitas/quant/solver/common.py +++ b/src/brevitas/quant/solver/common.py @@ -178,7 +178,8 @@ def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None): elif scaling_per_output == ScalingPerOutputType.TENSOR: return None elif scaling_per_output == ScalingPerOutputType.GROUP: - return group_dim + 1 + reduce_dim = group_dim + 1 if group_dim != -1 else -1 + return reduce_dim @value def keepdim(scaling_per_output): diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 9460fadf1..eaabf4d81 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -8,6 +8,9 @@ from torch import nn from brevitas import nn as qnn +from brevitas.core.function_wrapper import CeilSte +from brevitas.core.function_wrapper import FloorSte +from brevitas.core.restrict_val import RoundSte from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.quantize import layerwise_quantize from brevitas.quant.experimental.float import Fp8e4m3Act @@ -220,6 +223,7 @@ def generate_quantizers( input_quant_granularity=None, input_group_size=None, quantize_input_zero_point=False, + scale_rounding_func_type=None, device=None, weight_kwargs=None, input_kwargs=None): @@ -278,6 +282,19 @@ def generate_quantizers( 'quantize_zero_point': quantize_weight_zero_point}, **weight_float_format) + if scale_rounding_func_type is not None: + scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte} + scale_type = scale_rounding_func_dict[scale_rounding_func_type] + weight_quant = weight_quant.let(**{'restrict_value_float_to_int_impl': scale_type}) + if input_quant is not None: + input_quant = input_quant.let(**{'restrict_value_float_to_int_impl': scale_type}) + if sym_input_quant is not None: + sym_input_quant = sym_input_quant.let( + **{'restrict_value_float_to_int_impl': scale_type}) + if linear_input_quant is not None: + linear_input_quant = linear_input_quant.let( + **{'restrict_value_float_to_int_impl': scale_type}) + if weight_group_dim is not None: weight_quant = weight_quant.let(**{'group_dim': weight_group_dim}) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index bef47b24f..f40a367e1 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -253,6 +253,7 @@ def main(args): input_quant_granularity=args.input_quant_granularity, input_group_size=args.input_group_size, quantize_input_zero_point=args.quantize_input_zero_point, + scale_rounding_func_type=args.scale_rounding_func_type, device=device) layer_map = generate_quant_maps( linear_input_quant=linear_input_quant, @@ -400,6 +401,12 @@ def parse_args(args): default='per_group', choices=['per_channel', 'per_tensor', 'per_group'], help='Granularity for scales/zero-point of weights. Default: per_group.') + parser.add_argument( + '--scale-rounding-func-type', + type=str, + default=None, + choices=['round', 'ceil', 'floor'], + help='Rounding function to use with Po2 scale. Default: None.') parser.add_argument( '--weight-group-dim', type=int, diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index fbfc76842..16f944e97 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -60,7 +60,7 @@ def reference_implementation_scale_factors_po2( return scale -@given(inp=float_tensor_random_size_st()) +@given(inp=float_tensor_random_size_st(max_val=1e10, min_val=-1e10)) def test_scale_factors_ptq_calibration_po2(inp): class TestModel(nn.Module): @@ -80,7 +80,6 @@ def forward(self, x): expected_scale = reference_implementation_scale_factors_po2(inp) scale = model.act.act_quant.scale() - assert torch.allclose(expected_scale, scale)