Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/Feat (trunc avg pool): Update truncation and average pool behaviour #1042

Open
wants to merge 35 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0d7637b
Fix (quant_tensor): Produce valid IntQuantTensor after AvgPool functi…
nickfraser Oct 4, 2024
200cb4e
Fix (core/trunc): Fix output scaling after truncation
nickfraser Oct 4, 2024
69f0463
Fix (nn/TruncAvgPool): Remove any quant tensor manual manipulation.
nickfraser Oct 4, 2024
0f4df89
fix/trunc_avg_pool: Clamp output.
nickfraser Jan 22, 2025
f184727
style: fix
nickfraser Jan 22, 2025
5aee892
fix (trunc_avg_pool): Set default arguments for backward compatibility
nickfraser Jan 22, 2025
6fb7ced
test (trunc_int_quant): Added initial sanity-check test
nickfraser Jan 22, 2025
bc2aaa9
fix (export/torch/qcdq): Fixed output scale, and `signed` setting
nickfraser Jan 22, 2025
a61a63a
Fix (core/proxy/trunc): Moved setting of signed to the proxy
nickfraser Jan 22, 2025
e10583b
fix (qonnx/trunc): Fixed Trunc Quant QONNX export
nickfraser Jan 23, 2025
dc7b9b1
fix (trunc): Factored out scaling calculation to standalone class.
nickfraser Jan 24, 2025
5e9e106
fix typo: Updated comment in TruncAvgPool export
nickfraser Jan 24, 2025
7971700
feat (trunc/scaling): Factored out the scaling implementation.
nickfraser Jan 24, 2025
5f8d95b
test (trunc): Added signed overflow test
nickfraser Jan 24, 2025
26ad5c5
test (trunc): Added more unti tests.
nickfraser Jan 24, 2025
b1db190
fix (test/trunc): Bugfixes and tests.
nickfraser Jan 24, 2025
e2878e9
Fix: precommit
nickfraser Jan 27, 2025
2f8152f
Fix (solver/trunc): Added a ShiftRoundSaturate quantizer and update t…
nickfraser Jan 27, 2025
a1fddac
Fix (export/trunc): Updated export to generate Quant node.
nickfraser Jan 27, 2025
e48b88b
Fix (test/qonnx/trunc): Allow off-by-1 errors in test
nickfraser Jan 27, 2025
7d505e1
tests (brv_finn/avgpool): Add "lossless" tests
nickfraser Jan 28, 2025
4efa4d0
Fix (brevitas/scaling): TruncPowerOfTwoIntScaling -> PowerOfTwoIntSca…
nickfraser Jan 28, 2025
9c99d7d
Fix (scaling): Made signed an optional argument at init time.
nickfraser Jan 28, 2025
8395176
test (trunc_quant): Switched to pytest_cases.parametrize
nickfraser Jan 28, 2025
a80d7da
Fix (trunc): Fixed output zero-point calculation
nickfraser Jan 28, 2025
8e41fad
Fix (export/qonnx/trunc): Added check that zero-point is zero.
nickfraser Jan 28, 2025
6416388
Fix (export/qcdq/trunc): Pick up output scale from proxy
nickfraser Jan 28, 2025
a78408f
Fix (export/trunc): Retrieve bit_width from cache
nickfraser Jan 28, 2025
ca71825
precommit
nickfraser Jan 28, 2025
97a451d
docs (imagenet/qat): Updated accuracy with new TruncAvgPool implement…
nickfraser Jan 28, 2025
eb25fd5
test (finn/mobilenet): Allow tolerance of up-to 7 in output.
nickfraser Jan 28, 2025
c305d49
Fix (test/export/trunc): Revert export to produce a Trunc node.
nickfraser Feb 11, 2025
c6967de
fix (export/qonnx): Set QONNX OpSet.
nickfraser Mar 11, 2025
5b0ef65
Fix (export/qonnx): Set QONNX version during export
nickfraser Mar 11, 2025
14150a0
Fix style
nickfraser Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module

import brevitas
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.core.scaling import TruncMsbScaling
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.function.ops_ste import round_ste


Expand Down Expand Up @@ -201,28 +205,56 @@ class TruncIntQuant(brevitas.jit.ScriptModule):
"""
"""

__constants__ = ['narrow_range']

def __init__(
self, float_to_int_impl: Module, bit_width_impl: Module, quant_delay_steps: int = 0):
self,
float_to_int_impl: Module,
bit_width_impl: Module,
trunc_scaling_impl: Module = TruncMsbScaling(),
narrow_range: bool = False,
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
super(TruncIntQuant, self).__init__()
self.narrow_range = narrow_range
self.msb_clamp_bit_width_impl = bit_width_impl
self.trunc_scaling_impl = trunc_scaling_impl
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.delay_wrapper = DelayWrapper(quant_delay_steps)

@brevitas.jit.script_method
def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor,
input_bit_width: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
def min_int(self, bit_width: Tensor, signed: Union[bool, Tensor]):
return min_int(signed, self.narrow_range, bit_width)

@brevitas.jit.script_method
def max_int(self, bit_width: Tensor, signed: Union[bool, Tensor]):
return max_int(signed, self.narrow_range, bit_width)

@brevitas.jit.script_method
def forward(
self,
x: Tensor,
scale: Tensor,
zero_point: Tensor,
input_bit_width: Tensor,
signed: Union[bool, Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
y = x / scale
y = y + zero_point
y = round_ste(y) # clean up floating point error
output_bit_width = self.msb_clamp_bit_width_impl()
trunc_bit_width = input_bit_width - output_bit_width
trunc_scale = 2.0 ** trunc_bit_width
trunc_scale = self.trunc_scaling_impl(y, input_bit_width, output_bit_width, signed)
y = y / trunc_scale
min_int_val = self.min_int(output_bit_width, signed)
max_int_val = self.max_int(output_bit_width, signed)
y = self.float_to_int_impl(y)
y = y - zero_point
y = y * scale
y = self.tensor_clamp_impl(y, min_val=min_int_val, max_val=max_int_val)
output_scale = scale * trunc_scale
output_zero_point = zero_point / trunc_scale
y = y - output_zero_point
y = y * output_scale
y = self.delay_wrapper(x, y)
return y, scale, zero_point, output_bit_width
return y, output_scale, output_zero_point, output_bit_width


class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant):
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas/core/scaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@
from .standalone import ParameterFromRuntimeStatsScaling
from .standalone import ParameterFromStatsFromParameterScaling
from .standalone import ParameterScaling
from .standalone import TruncMsbScaling
from .standalone import TruncScalingWrapper

SCALING_STATS_REDUCE_DIM = 1
22 changes: 14 additions & 8 deletions src/brevitas/core/scaling/int_scaling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional, Union

from torch import Tensor

import brevitas
Expand All @@ -11,26 +13,30 @@
class IntScaling(brevitas.jit.ScriptModule):
__constants__ = ['signed', 'narrow_range']

def __init__(self, signed: bool, narrow_range: bool):
def __init__(self, narrow_range: bool, signed: Optional[bool] = None):
super(IntScaling, self).__init__()
self.signed = signed
self.narrow_range = narrow_range

@brevitas.jit.script_method
def forward(self, bit_width: Tensor) -> Tensor:
if self.signed:
return -min_int(self.signed, self.narrow_range, bit_width)
def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = None) -> Tensor:
is_signed = signed if signed is not None else self.signed
assert is_signed is not None, f"signed is not defined, signed={is_signed}"
if is_signed:
return -min_int(is_signed, self.narrow_range, bit_width)
else:
return max_int(self.signed, self.narrow_range, bit_width)
return max_int(is_signed, self.narrow_range, bit_width)


class PowerOfTwoIntScaling(brevitas.jit.ScriptModule):
__constants__ = ['signed']

def __init__(self, signed: bool):
def __init__(self, signed: Optional[bool] = None):
super(PowerOfTwoIntScaling, self).__init__()
self.signed = signed

@brevitas.jit.script_method
def forward(self, bit_width: Tensor) -> Tensor:
return max_int(self.signed, False, bit_width) + 1
def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = None) -> Tensor:
is_signed = signed if signed is not None else self.signed
assert is_signed is not None, f"signed is not defined, signed={is_signed}"
return max_int(is_signed, False, bit_width) + 1
49 changes: 49 additions & 0 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import brevitas.config as config
from brevitas.core.function_wrapper import Identity
from brevitas.core.function_wrapper import OverBatchOverTensorView
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.restrict_val import _ClampValue
from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.restrict_val import _RestrictValue
Expand Down Expand Up @@ -469,3 +470,51 @@ def _load_from_state_dict(
self.counter = self.collect_stats_steps + 1
if config.IGNORE_MISSING_KEYS and value_key in missing_keys:
missing_keys.remove(value_key)


class TruncMsbScaling(brevitas.jit.ScriptModule):
"""
"""

def __init__(self) -> None:
super(TruncMsbScaling, self).__init__()

@brevitas.jit.script_method
def forward(
self,
scaling_input: Tensor,
input_bit_width: Tensor,
output_bit_width: Tensor,
signed: Union[bool, Tensor]) -> Tensor:
return 2 ** (input_bit_width - output_bit_width)


class TruncScalingWrapper(brevitas.jit.ScriptModule):
"""
"""

def __init__(
self,
trunc_int_scaling_impl: Module,
scaling_impl: Module,
tensor_clamp_impl: Module = TensorClamp()) -> None:
super(TruncScalingWrapper, self).__init__()
self.trunc_int_scaling_impl = trunc_int_scaling_impl
self.scaling_impl = scaling_impl
self.tensor_clamp_impl = tensor_clamp_impl

@brevitas.jit.script_method
def forward(
self,
scaling_input: Tensor,
input_bit_width: Tensor,
output_bit_width: Tensor,
signed: Union[bool, Tensor]) -> Tensor:
threshold = self.trunc_int_scaling_impl(output_bit_width, signed)
scale = self.scaling_impl(scaling_input, threshold)
msb_scale = 2 ** (input_bit_width - output_bit_width)
unit_scale = torch.ones_like(msb_scale)
max_scale = torch.where(msb_scale > unit_scale, msb_scale, unit_scale)
min_scale = torch.where(msb_scale < unit_scale, msb_scale, unit_scale)
trunc_scale = self.tensor_clamp_impl(scale, min_scale, max_scale)
return trunc_scale
22 changes: 14 additions & 8 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,38 +773,44 @@ class QCDQCastTruncQuantProxyHandlerMixin(QuantAxisMixin,
ABC):
handled_layer = TruncQuantProxyFromInjector

def validate(self, module):
assert module.zero_point() == 0, "Zero-point export not supported for TruncQuant."
super(QCDQCastTruncQuantProxyHandlerMixin, self).validate(module)

def prepare_for_export(self, module: TruncQuantProxyFromInjector):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs = {'output_bit_width': module.bit_width()}
self.symbolic_kwargs = {
'narrow_range': module.is_narrow_range,
'output_scale': module.scale(),
'output_bit_width': module.bit_width()}

def symbolic_execution(
self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor,
signed: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'
output_bit_width = self.symbolic_kwargs['output_bit_width']
narrow_range = self.symbolic_kwargs['narrow_range']
dtype = self.int8_dtype() if signed else self.uint8_dtype()
trunc_scale = 2.0 ** (input_bit_width - output_bit_width)
scale = self.symbolic_kwargs['output_scale'] # Input scale is ignored now
# If original dtype of scale is (b)float16, store the original scale dtype
# and cast the scale and the input to float32
scale_dtype = scale.dtype
if scale_dtype == torch.bfloat16 or scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
x = self.cast_fn(x, torch.float32)
pre_scale = scale * trunc_scale
flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten())
flat_scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten()).expand_as(flat_scale)
zp = self.zero_point_with_dtype(signed, output_bit_width, zp)
x = self.quantize_fn(x, flat_pre_scale, zp, dtype, self.quant_axis(pre_scale))
x = self.quantize_fn(x, flat_scale, zp, dtype, self.quant_axis(scale))
clip_symbolic_kwargs = self.int_clip_symbolic_kwargs(
signed=signed, narrow=False, bit_width=output_bit_width)
signed=signed, narrow=self.symbolic_kwargs['narrow_range'], bit_width=output_bit_width)
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale))
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
return x, scale, zero_point, output_bit_width
flat_scale = self.cast_fn(flat_scale, scale_dtype)
return x, flat_scale, zero_point, output_bit_width
45 changes: 35 additions & 10 deletions src/brevitas/export/onnx/qonnx/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from brevitas.function import binary_sign
from brevitas.quant.solver.common import solve_float_to_int_impl_from_enum

DOMAIN_STRING = "onnx.brevitas"
DOMAIN_STRING = "qonnx.custom_op.general"
DOMAIN_VERSION = 2


class BrevitasBinaryQuantFn(Function):
Expand Down Expand Up @@ -111,26 +112,50 @@ def forward(
class BrevitasTruncFn(Function):

@staticmethod
def symbolic(g, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode):
def symbolic(
g,
x,
scale,
zero_point,
input_bit_width,
signed,
narrow_range,
output_scale,
output_bit_width,
rounding_mode):
ret = g.op(
f'{DOMAIN_STRING}::Trunc',
x,
scale,
zero_point,
input_bit_width,
output_bit_width,
rounding_mode_s=rounding_mode)
rounding_mode_s=rounding_mode,
signed_i=int(signed),
narrow_i=int(narrow_range),
output_scale_f=output_scale)
ret.setType(x.type())
return ret

@staticmethod
def forward(ctx, x, scale, zero_point, input_bit_width, output_bit_width, rounding_mode):
float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode)
trunc = TruncIntQuant(
float_to_int_impl=float_to_int_impl(),
bit_width_impl=BitWidthConst(int(output_bit_width)))
y_tuple = trunc(x, scale, zero_point, input_bit_width)
return y_tuple[0]
def forward(
ctx,
x,
scale,
zero_point,
input_bit_width,
signed,
narrow_range,
output_scale,
output_bit_width,
rounding_mode):
# TODO: Restore this (fails when `signed` arg added)
#float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode)
#trunc = TruncIntQuant(
# float_to_int_impl=float_to_int_impl(),
# bit_width_impl=BitWidthConst(int(output_bit_width)))
#y_tuple = trunc(x, scale, zero_point, input_bit_width, signed)
return x


class BrevitasQuantLSTMCellFn(Function):
Expand Down
13 changes: 10 additions & 3 deletions src/brevitas/export/onnx/qonnx/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,23 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
class BrevitasTruncQuantProxyHandler(ONNXBaseHandler):
handled_layer = TruncQuantProxyFromInjector

def validate(self, module):
assert module.zero_point() == 0, "Zero-point export not supported for TruncQuant."

def prepare_for_export(self, module: TruncQuantProxyFromInjector):
self.validate(module)
self.symbolic_kwargs = {
'output_bit_width': module.bit_width(), 'rounding_mode': module.rounding_mode}
'narrow_range': module.is_narrow_range,
'output_scale': float(module.scale().detach().cpu().numpy()),
'output_bit_width': module.bit_width(),
'rounding_mode': module.rounding_mode}

def symbolic_execution(
self, x: Tensor, scale: Tensor, zero_point: Tensor, input_bit_width: Tensor,
signed: Tensor):
y = BrevitasTruncFn.apply(
x, scale, zero_point, input_bit_width, *self.symbolic_kwargs.values())
return y, scale, zero_point, self.symbolic_kwargs['output_bit_width']
x, scale, zero_point, input_bit_width, signed, *self.symbolic_kwargs.values())
return y, self.symbolic_kwargs['output_scale'], zero_point, self.symbolic_kwargs['output_bit_width']


class BrevitasQuantLSTMLayerHandler(QuantLSTMLayerHandler):
Expand Down
Loading
Loading