Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lapid92 committed Nov 6, 2024
1 parent 2967f17 commit 26a850d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ def compute_activation_bias_correction(graph: Graph,
Graph with activation bias correction term for each node.
"""

# Retrieve the 'kernel_size' value if it exists and ensure it is None, 1, or (1, 1).
# This feature supports only Dense/Linear layers and convolution layers with kernel size of 1 or (1, 1).
# Retrieve the 'kernel_size' value if it exists and ensure it is None, 1, or (1, 1). This feature supports only
# Dense/Linear layers and convolution layers with kernel size of 1 or (1, 1).
# For Dense/Linear layers, which lack a 'kernel_size' attribute, the result will be None, and no restriction
# applies in that case.
if linear_node.framework_attr.get(kernel_size) not in [None, 1, (1, 1)]:
# If the kernel size is not 1 or (1, 1), return the current graph unmodified
return graph
Expand Down Expand Up @@ -136,17 +138,18 @@ def compute_activation_bias_correction(graph: Graph,
# size matching the number of output channels.
if kernel is not None:

# Get the axes that are not the output channel
# Get the axes that are not the output channel.
output_channel_index, input_channel_index = fw_info.kernel_channels_mapping.get(linear_node.type)
axis_not_output_channel = list(range(len(kernel.shape)))
axis_not_output_channel.remove(output_channel_index)

# special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters
# Special case of depthwise_conv2d in tensorflow, where we have a depth multiplier for the filters.
if output_channel_index == input_channel_index:
axis_not_output_channel.remove(3) # 3 is the depth multiplier index
axis_not_output_channel.remove(3) # 3 is the depth multiplier index.

activation_bias_correction_term = mean_diff * np.sum(kernel, axis=tuple(axis_not_output_channel))
linear_node.final_activation_quantization_cfg.activation_bias_correction_term = activation_bias_correction_term.flatten()
linear_node.final_activation_quantization_cfg.activation_bias_correction_term = (
activation_bias_correction_term.flatten())
return graph


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,23 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
bias = float_linear_layers[-1].bias
bias_after_activation_bias_correction = quantized_linear_layers[-1].layer.bias

y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)

self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')

if getattr(float_linear_layers[-1], KERNEL_SIZE, None) in [None, 1, (1, 1)]:
if self.activation_bias_correction_threshold > 1e8:
self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction),
msg=f"Error in activation bias correction: expected no change in the bias "
f"value in case of activation_bias_correction_threshold "
f"{self.activation_bias_correction_threshold}.")

else:
self.unit_test.assertFalse(np.array_equal(bias, bias_after_activation_bias_correction),
msg=f"Error in activation bias correction: expected a change in the bias "
f"value.")
else:
self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction),
msg=f"Error in activation bias correction: expected no change in the bias value "
f"in case of conv with kernel 2.")
f"in case of conv with kernel different than 1 or (1, 1).")
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import model_compression_toolkit as mct
from model_compression_toolkit.core.pytorch.constants import KERNEL_SIZE
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, set_model
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest

"""
Expand Down Expand Up @@ -109,6 +110,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
bias = float_model.linear_layer.bias.cpu().detach().numpy()
bias_after_activation_bias_correction = quantized_model.linear_layer.layer.bias.cpu().detach().numpy()

set_model(float_model)
y = float_model(to_torch_tensor(input_x[0]))
y_hat = quantized_model(to_torch_tensor(input_x[0]))

self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')

if getattr(float_model.linear_layer, KERNEL_SIZE, None) in [None, 1, (1, 1)]:
if self.activation_bias_correction_threshold > 1e8:
self.unit_test.assertTrue(np.array_equal(bias, bias_after_activation_bias_correction),
Expand Down

0 comments on commit 26a850d

Please sign in to comment.