Skip to content

Commit

Permalink
Add feature Activation Bias Correction (#1256)
Browse files Browse the repository at this point in the history
* Add feature Activation Bias Correction
  • Loading branch information
lapid92 authored Nov 6, 2024
1 parent ca780b3 commit 104445e
Show file tree
Hide file tree
Showing 15 changed files with 828 additions and 58 deletions.
73 changes: 46 additions & 27 deletions model_compression_toolkit/core/common/framework_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_hessian_scores_calculator(self,
Returns: HessianScoresCalculator to use for the hessian approximation scores computation for this request.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_hessian_scores_calculator method.') # pragma: no cover

@abstractmethod
Expand All @@ -77,7 +77,7 @@ def to_numpy(self, tensor: Any) -> np.ndarray:
Returns:
Numpy array converted from the input tensor.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s to_numpy method.') # pragma: no cover

@abstractmethod
Expand All @@ -90,7 +90,7 @@ def to_tensor(self, tensor: np.ndarray) -> Any:
Returns:
Framework's tensor converted from the input Numpy array.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s to_tensor method.') # pragma: no cover

@abstractmethod
Expand All @@ -106,7 +106,7 @@ def model_reader(self,
Returns:
Graph representing the input model.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s model_reader method.') # pragma: no cover

@abstractmethod
Expand All @@ -131,7 +131,7 @@ def model_builder(self,
Returns:
A tuple with the model and additional relevant supporting objects.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s model_builder method.') # pragma: no cover

@abstractmethod
Expand All @@ -148,7 +148,7 @@ def run_model_inference(self,
Returns:
The frameworks model's output.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s run_model_inference method.') # pragma: no cover

@abstractmethod
Expand All @@ -167,9 +167,28 @@ def shift_negative_correction(self,
Returns:
Graph after SNC.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s apply_shift_negative_correction method.') # pragma: no cover

@abstractmethod
def compute_activation_bias_correction(self,
graph: Graph,
quant_config: QuantizationConfig,
fw_info: FrameworkInfo) -> Graph:
"""
Compute activation bias correction on a graph.
Args:
graph: Graph to apply activation bias correction on.
quant_config: QuantizationConfig of how the model should be quantized.
fw_info: FrameworkInfo object with information about the specific framework's model.
Returns:
Graph after activation bias correction computing.
"""
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s compute_activation_bias_correction method.') # pragma: no cover

@abstractmethod
def get_substitutions_channel_equalization(self,
quant_config: QuantizationConfig,
Expand All @@ -184,7 +203,7 @@ def get_substitutions_channel_equalization(self,
Returns:
A list of the framework substitutions used after we collect statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover

@abstractmethod
Expand All @@ -194,7 +213,7 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
Returns: A list of the framework substitutions used to prepare the graph.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover

@abstractmethod
Expand All @@ -208,23 +227,23 @@ def get_substitutions_pre_statistics_collection(self, quant_config: Quantization
Returns: A list of the framework substitutions used before we collect statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover

@abstractmethod
def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: linear collapsing substitution
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: conv2d add const collapsing substitution
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover

@abstractmethod
Expand All @@ -239,15 +258,15 @@ def get_substitutions_statistics_correction(self, quant_config: QuantizationConf
Returns:
A list of the framework substitutions used for statistics correction.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover

@abstractmethod
def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]:
"""
Returns: A list of the framework substitutions used for residual collapsing
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover


Expand All @@ -263,7 +282,7 @@ def get_substitutions_post_statistics_collection(self, quant_config: Quantizatio
Returns:
A list of the framework substitutions used after we collect statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_post_statistics_collection method.') # pragma: no cover

@abstractmethod
Expand All @@ -272,7 +291,7 @@ def get_substitutions_virtual_weights_activation_coupling(self) -> List[common.B
Returns: A list of Keras substitutions used to build a virtual graph with composed activation-weights pairs.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_virtual_weights_activation_coupling '
f'method.') # pragma: no cover

Expand All @@ -288,7 +307,7 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
Returns:
A list of the framework substitutions used after we apply second moment statistics.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_after_second_moment_correction '
f'method.') # pragma: no cover

Expand Down Expand Up @@ -316,7 +335,7 @@ def get_sensitivity_evaluator(self,
A function that computes the metric.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_sensitivity_evaluator method.') # pragma: no cover

def get_node_prior_info(self, node: BaseNode,
Expand All @@ -334,7 +353,7 @@ def get_node_prior_info(self, node: BaseNode,
NodePriorInfo with information about the node.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_node_prior_info method.') # pragma: no cover

def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
Expand All @@ -345,7 +364,7 @@ def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool
Returns: True if the node should be considered an interest point, False otherwise.
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover

def get_mp_node_distance_fn(self, n: BaseNode,
Expand All @@ -364,7 +383,7 @@ def get_mp_node_distance_fn(self, n: BaseNode,
Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover


Expand All @@ -381,7 +400,7 @@ def is_output_node_compatible_for_hessian_score_computation(self,
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s is_output_node_compatible_for_hessian_score_computation method.') # pragma: no cover

@abstractmethod
Expand All @@ -398,7 +417,7 @@ def get_node_mac_operations(self,
Returns: The MAC count of the operation
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_node_mac_operations method.') # pragma: no cover

@abstractmethod
Expand All @@ -419,7 +438,7 @@ def apply_second_moment_correction(self,
Returns:
A Graph after second moment correction.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s apply_second_moment_correction method.') # pragma: no cover

@abstractmethod
Expand All @@ -436,7 +455,7 @@ def sensitivity_eval_inference(self,
Returns:
The output of the model inference on the given input.
"""
raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s sensitivity_eval_inference method.') # pragma: no cover

def get_inferable_quantizers(self, node: BaseNode):
Expand All @@ -452,9 +471,9 @@ def get_inferable_quantizers(self, node: BaseNode):
"""

raise NotImplementedError(f'{self.__class__.__name__} have to implement the '
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_inferable_quantizers method.') # pragma: no cover

@staticmethod
def convert_data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(self,
self.activation_error_method = qc.activation_error_method
self.activation_n_bits = op_cfg.activation_n_bits
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
self.activation_bias_correction_term = None
self.enable_activation_quantization = op_cfg.enable_activation_quantization
self.quantization_preserving = op_cfg.quantization_preserving
self.signedness = op_cfg.signedness
self.activation_channel_equalization = qc.activation_channel_equalization
self.input_scaling = qc.input_scaling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class QuantizationConfig:
shift_negative_threshold_recalculation: bool = False
shift_negative_params_search: bool = False
concat_threshold_update: bool = False
activation_bias_correction: bool = False
activation_bias_correction_threshold: float = 0.0


# Default quantization configuration the library use.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.core import CoreConfig, QuantizationConfig
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import AttributeQuantizationConfig


def apply_activation_bias_correction_to_graph(graph: Graph,
core_config: CoreConfig,
fw_impl: FrameworkImplementation) -> Graph:
"""
Get a graph, where each node has a final activation quantization configuration (with an activation bias
correction term in it), and apply the activation bias correction for each node in the graph.
Args:
graph: Graph to apply activation bias correction to.
core_config: CoreConfig containing parameters of how the model should be quantized.
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
Returns:
Graph with activation bias correction apply to it's nodes.
"""

for n in graph.nodes:
# Activation bias correction is only relevant for nodes with kernel op
kernel_attr = graph.fw_info.get_kernel_op_attributes(n.type)[0]
if core_config.quantization_config.activation_bias_correction and kernel_attr is not None and \
n.final_activation_quantization_cfg.activation_bias_correction_term is not None:
# If activation bias correction is enabled in n.quantization_cfg, an activation bias correction term was
# calculated during model preparation, and is used now in the node's bias term.
_apply_activation_bias_correction_to_node(n, fw_impl, core_config.quantization_config)
return graph


def _apply_activation_bias_correction_to_node(node: BaseNode,
fw_impl: FrameworkImplementation,
qc: QuantizationConfig):
"""
Set new bias to node using the activation bias correction term that is stored in the
final activation quantization configuration.
Args:
node: Node to set its corrected bias after activation bias correction.
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
qc: QuantizationConfig containing parameters of how the model should be quantized.
"""
correction = node.final_activation_quantization_cfg.activation_bias_correction_term
bias = node.get_weights_by_keys(fw_impl.constants.BIAS) # get original bias from node's weights

if bias is None:
# If the layer has no bias, we set the bias as -correction.
node.set_weights_by_keys(fw_impl.constants.BIAS, - correction)

# Mark the use_bias attribute of the node.
node.framework_attr[fw_impl.constants.USE_BIAS] = True

# Configure the quantization of the bias as disabled.
node.final_weights_quantization_cfg.set_attr_config(fw_impl.constants.BIAS,
WeightsAttrQuantizationConfig(
qc,
AttributeQuantizationConfig(
enable_weights_quantization=False)))
else:
# If the layer has bias, we subtract the correction from original bias
node.set_weights_by_keys(fw_impl.constants.BIAS, bias - correction)
Loading

0 comments on commit 104445e

Please sign in to comment.