Skip to content

Commit

Permalink
Move DelayWrapper logic to Proxy
Browse files Browse the repository at this point in the history
Related to Xilinx#1023

Move `DelayWrapper` logic to Proxy classes.

* Add `DelayWrapper` instantiation in the `WeightQuantProxyFromInjectorBase` class in `src/brevitas/proxy/parameter_quant.py`.
* Modify the `forward` method in `WeightQuantProxyFromInjectorBase` to use `DelayWrapper` to decide the return value.
* Remove `DelayWrapper` instantiation and usage from the `IntQuant` and `DecoupledIntQuant` classes in `src/brevitas/core/quant/int_base.py`.
* Add tests in `tests/brevitas/proxy/test_proxy.py` to ensure the new behavior of `DelayWrapper` in the proxy classes.
  • Loading branch information
vishwamartur committed Nov 3, 2024
1 parent 4617f7b commit 8626a51
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
11 changes: 2 additions & 9 deletions src/brevitas/core/quant/int_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import brevitas
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.delay import DelayWrapper
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int

Expand Down Expand Up @@ -53,14 +52,12 @@ def __init__(
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
tensor_clamp_impl: Module = TensorClamp()):
super(IntQuant, self).__init__()
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
Expand All @@ -87,7 +84,6 @@ def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tenso
y_int = self.to_int(scale, zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y


Expand Down Expand Up @@ -129,14 +125,12 @@ def __init__(
signed: bool,
input_view_impl: Module,
float_to_int_impl: Module = RoundSte(),
tensor_clamp_impl: Module = TensorClamp(),
quant_delay_steps: int = 0):
tensor_clamp_impl: Module = TensorClamp()):
super(DecoupledIntQuant, self).__init__()
self.float_to_int_impl = float_to_int_impl
self.tensor_clamp_impl = tensor_clamp_impl
self.signed = signed
self.narrow_range = narrow_range
self.delay_wrapper = DelayWrapper(quant_delay_steps)
self.input_view_impl = input_view_impl

@brevitas.jit.script_method
Expand Down Expand Up @@ -172,5 +166,4 @@ def forward(
y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x)
y = y_int - zero_point
y = y * scale
y = self.delay_wrapper(x, y)
return y
5 changes: 4 additions & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from brevitas import config
from brevitas import is_dynamo_compiling
from brevitas.core.function_wrapper.misc import Identity
from brevitas.core.quant.delay import DelayWrapper
from brevitas.function import max_int
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import _unpack_quant_tensor
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
self.cache_inference_quant_weight_metadata_only = False
self.cache_class = None # To be redefined by each class
self.quant_tensor_class = None # To be redefined by each class
self.delay_wrapper = DelayWrapper(quant_injector.quant_delay_steps)

@property
def input_view_impl(self):
Expand Down Expand Up @@ -136,7 +138,8 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
else:
out = self.create_quant_tensor(out)
else:
out = self.tensor_quant(x)
quantized_out = self.tensor_quant(x)
out = self.delay_wrapper(x, quantized_out)
if is_dynamo_compiling():
out = out[0]
else:
Expand Down
24 changes: 24 additions & 0 deletions tests/brevitas/proxy/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,27 @@ def test_dynamic_act_proxy(self):

model.act_quant.disable_quant = True
assert model.act_quant.bit_width() is None

def test_delay_wrapper_in_weight_proxy(self):
model = QuantLinear(10, 5, weight_quant=Int8WeightPerTensorFloat)
assert model.weight_quant.delay_wrapper is not None

model.weight_quant.delay_wrapper.quant_delay_steps = 5
for _ in range(5):
quantized_out = model.weight_quant(model.weight)
assert torch.equal(quantized_out, model.weight)

quantized_out = model.weight_quant(model.weight)
assert not torch.equal(quantized_out, model.weight)

def test_delay_wrapper_in_bias_proxy(self):
model = QuantLinear(10, 5, bias_quant=Int8BiasPerTensorFloatInternalScaling)
assert model.bias_quant.delay_wrapper is not None

model.bias_quant.delay_wrapper.quant_delay_steps = 5
for _ in range(5):
quantized_out = model.bias_quant(model.bias)
assert torch.equal(quantized_out, model.bias)

quantized_out = model.bias_quant(model.bias)
assert not torch.equal(quantized_out, model.bias)

0 comments on commit 8626a51

Please sign in to comment.