Skip to content

Commit

Permalink
Move Keras qat activation quantizers to trainable infrastructure (#1240)
Browse files Browse the repository at this point in the history
Move STE and LSQ activation quantizers in Keras from QAT to the new trainable infrastructure module.
Add flag 'freeze_quantization_params' to align them with pytorch quantizers (even though this flag is meaningless in Keras).
Rename Trainable QAT quantizer to be Weight Trainable quantizer.

---------

Co-authored-by: reuvenp <[email protected]>
  • Loading branch information
reuvenperetz and reuvenp authored Oct 14, 2024
1 parent dc678aa commit fa2b3b7
Show file tree
Hide file tree
Showing 18 changed files with 717 additions and 504 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,14 @@

if FOUND_TF:

class BaseKerasQATTrainableQuantizer(BaseKerasTrainableQuantizer):
class BaseKerasQATWeightTrainableQuantizer(BaseKerasTrainableQuantizer):
"""
A base class for trainable Keras quantizer for QAT.
"""

def __init__(self,
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
"""
Initializes BaseKerasQATTrainableQuantizer object.
Args:
quantization_config: quantizer config class contains all the information about a quantizer configuration.
"""

super().__init__(quantization_config)
pass

else: # pragma: no cover
class BaseKerasQATTrainableQuantizer(BaseKerasTrainableQuantizer):
class BaseKerasQATWeightTrainableQuantizer(BaseKerasTrainableQuantizer):
def __init__(self,
quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):

Expand Down
131 changes: 5 additions & 126 deletions model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,47 +28,18 @@
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
from model_compression_toolkit import constants as C

from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
TrainableQuantizerActivationConfig
from mct_quantizers.keras.quantizers import WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, \
ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_weight_quantizer import BaseKerasQATWeightTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
from mct_quantizers.keras.quantizers import WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
from model_compression_toolkit.qat.keras.quantizer.quant_utils import ste_round, grad_scale


def symmetric_lsq_quantizer(x: tf.Tensor,
thresholds: tf.Tensor,
num_bits: int,
sign: bool,
min_int: int,
max_int:int,
scale_factor: float) -> tf.Tensor:
"""
Symmetric quantizer according to LSQ algorithm: https://arxiv.org/pdf/1902.08153.pdf
Args:
x: input to quantize
thresholds: thresholds of quantization levels
num_bits: number of bits for quantization
sign: whether x is signed or not
min_int: min clipping integer value
max_int: max clipping integer value
scale_factor: grad scale of LSQ algorithm
Returns:
A quantized tensor
"""
delta = thresholds / (2 ** (num_bits - int(sign)))
delta_scaled = grad_scale(delta, scale_factor)
rounded = ste_round(x / delta_scaled)
clipped = tf.math.minimum(tf.math.maximum(rounded, min_int), max_int)
quantized = delta_scaled * clipped
return quantized
from model_compression_toolkit.trainable_infrastructure.keras.quantizer_utils import symmetric_lsq_quantizer


@mark_quantizer(quantization_target=QuantizationTarget.Weights,
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
identifier=TrainingMethod.LSQ)
class LSQWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
class LSQWeightQATQuantizer(BaseKerasQATWeightTrainableQuantizer):
"""
Trainable constrained quantizer to quantize layer's weights.
"""
Expand Down Expand Up @@ -159,95 +130,3 @@ def convert2inferable(self) -> Union[WeightsPOTInferableQuantizer, WeightsSymmet
input_rank=len(self.threshold_shape))


@mark_quantizer(quantization_target=QuantizationTarget.Activation,
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
identifier=TrainingMethod.LSQ)
class LSQActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
"""
Trainable constrained quantizer to quantize layer activations.
"""

def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
"""
Initialize a LSQActivationQATQuantizer object with parameters to use
for the quantization.
Args:
quantization_config: trainable quantizer config class
"""
super().__init__(quantization_config)
self.power_of_two = quantization_config.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
self.threshold_values = float(quantization_config.activation_quantization_params[C.THRESHOLD])
self.threshold_shape = np.asarray(self.threshold_values).shape
self.sign = quantization_config.activation_quantization_params[SIGNED]
self.num_bits = quantization_config.activation_n_bits
n_pos_bits = self.num_bits - int(self.sign)
self.min_int = -int(self.sign) * (2 ** n_pos_bits)
self.max_int = (2 ** n_pos_bits) - 1
if self.power_of_two:
self.threshold_values = np.power(2.0, np.ceil(np.log2(np.maximum(self.threshold_values, C.MIN_THRESHOLD))))


def initialize_quantization(self,
tensor_shape: TensorShape,
name: str,
layer: KerasTrainableQuantizationWrapper):
"""
Add quantizer parameters to the quantizer parameters dictionary
Args:
tensor_shape: tensor shape of the quantized tensor.
name: Tensor name.
layer: Layer to quantize.
"""
ptq_threshold_tensor = layer.add_weight(
name + THRESHOLD_TENSOR,
shape=(),
initializer=tf.keras.initializers.Constant(1.0),
trainable=True)
ptq_threshold_tensor.assign(self.threshold_values)

# save the quantizer added parameters for later calculations
self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS)

def __call__(self,
inputs: tf.Tensor,
training: bool):
"""
Quantize a tensor.
Args:
inputs: Input tensor to quantize.
training: Whether the graph is in training mode.
Returns:
The quantized tensor.
"""

thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR)
n_channels = inputs.shape[-1]
scale_factor = 1.0 / np.sqrt(self.max_int * n_channels)
q_tensor = symmetric_lsq_quantizer(inputs, thresholds, self.num_bits, self.sign, self.min_int, self.max_int, scale_factor)
return q_tensor

def convert2inferable(self) -> Union[ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer]:
"""
Convert quantizer to inferable quantizer.
Returns:
BaseKerasInferableQuantizer object.
"""

if self.power_of_two:
thresholds = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()))
return ActivationPOTInferableQuantizer(num_bits=self.num_bits,
# In activation quantization is per-tensor only - thus we pass
# the threshold as a list with a len of 1
threshold=[thresholds],
signed=self.sign)
else:
thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()
return ActivationSymmetricInferableQuantizer(num_bits=self.num_bits,
# In activation quantization is per-tensor only - thus we
# pass the threshold as a list with a len of 1
threshold=[thresholds],
signed=self.sign)
125 changes: 4 additions & 121 deletions model_compression_toolkit/qat/keras/quantizer/lsq/uniform_lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import tensorflow as tf
from tensorflow.python.framework.tensor_shape import TensorShape
from model_compression_toolkit.constants import RANGE_MIN, RANGE_MAX
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_weight_quantizer import \
BaseKerasQATWeightTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure.common.constants import FQ_MIN, FQ_MAX
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
Expand All @@ -26,47 +28,18 @@

from model_compression_toolkit import constants as C

from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
TrainableQuantizerActivationConfig
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import fix_range_to_include_zero
from model_compression_toolkit.qat.keras.quantizer.quant_utils import ste_round, grad_scale, adjust_range_to_include_zero


def uniform_lsq_quantizer(x: tf.Tensor,
min_range: tf.Tensor,
max_range: tf.Tensor,
num_bits: int,
min_int: int,
max_int:int,
scale_factor: float) -> tf.Tensor:
"""
Uniform quantizer according to LSQ algorithm: https://arxiv.org/pdf/1902.08153.pdf
Args:
x: input to quantize
min_range: min range of quantization values
max_range: min range of quantization values
num_bits: number of bits for quantization
min_int: min clipping integer value
max_int: max clipping integer value
scale_factor: grad scale of LSQ algorithm
Returns:
A quantized tensor
"""
min_range, max_range = adjust_range_to_include_zero(min_range, max_range, num_bits)
delta = (max_range - min_range) / (2 ** num_bits - 1)
delta_scaled = grad_scale(delta, scale_factor)
rounded = ste_round((x-min_range) / delta_scaled)
clipped = tf.math.minimum(tf.math.maximum(rounded, min_int), max_int)
quantized = delta_scaled * clipped + min_range
return quantized
from model_compression_toolkit.trainable_infrastructure.keras.quantizer_utils import uniform_lsq_quantizer


@mark_quantizer(quantization_target=QuantizationTarget.Weights,
quantization_method=[QuantizationMethod.UNIFORM],
identifier=TrainingMethod.LSQ)
class LSQUniformWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
class LSQUniformWeightQATQuantizer(BaseKerasQATWeightTrainableQuantizer):
"""
Trainable constrained quantizer to quantize layer's weights.
"""
Expand Down Expand Up @@ -158,93 +131,3 @@ def convert2inferable(self) -> BaseKerasInferableQuantizer:
channel_axis=self.channel_axis,
input_rank=len(self.min_max_shape))


@mark_quantizer(quantization_target=QuantizationTarget.Activation,
quantization_method=[QuantizationMethod.UNIFORM],
identifier=TrainingMethod.LSQ)
class LSQUniformActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
"""
Trainable constrained quantizer to quantize layer activations.
"""

def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
"""
Initialize a LSQUniformActivationQATQuantizer object with parameters to use
for the quantization.
Args:
quantization_config: trainable quantizer config class
"""
super().__init__(quantization_config)

self.num_bits = quantization_config.activation_n_bits
self.min_range = np.array(quantization_config.activation_quantization_params[C.RANGE_MIN])
self.max_range = np.array(quantization_config.activation_quantization_params[C.RANGE_MAX])
self.min_int = 0
self.max_int = 2**self.num_bits - 1

def initialize_quantization(self,
tensor_shape: TensorShape,
name: str,
layer: KerasTrainableQuantizationWrapper):
"""
Add quantizer parameters to the quantizer parameters dictionary
Args:
tensor_shape: tensor shape of the quantized tensor.
name: Tensor name.
layer: Layer to quantize.
"""
fq_min = layer.add_weight(
name + FQ_MIN,
shape=(),
initializer=tf.keras.initializers.Constant(-1.0),
trainable=True)
fq_min.assign(self.min_range)

fq_max = layer.add_weight(
name + FQ_MAX,
shape=(),
initializer=tf.keras.initializers.Constant(1.0),
trainable=True)
fq_max.assign(self.max_range)

# save the quantizer added parameters for later calculations
self.add_quantizer_variable(FQ_MIN, fq_min, VariableGroup.QPARAMS)
self.add_quantizer_variable(FQ_MAX, fq_max, VariableGroup.QPARAMS)

def __call__(self,
inputs: tf.Tensor,
training: bool):
"""
Quantize a tensor.
Args:
inputs: Input tensor to quantize.
training: Whether the graph is in training mode.
Returns:
The quantized tensor.
"""

min_range = self.get_quantizer_variable(FQ_MIN)
max_range = self.get_quantizer_variable(FQ_MAX)
n_channels = inputs.shape[-1]
scale_factor = 1.0 / np.sqrt(self.max_int * n_channels)
q_tensor = uniform_lsq_quantizer(inputs, min_range, max_range, self.num_bits, self.min_int, self.max_int, scale_factor)
return q_tensor

def convert2inferable(self) -> BaseKerasInferableQuantizer:
"""
Convert quantizer to inferable quantizer.
Returns:
BaseKerasInferableQuantizer object.
"""
min_range, max_range = fix_range_to_include_zero(self.get_quantizer_variable(FQ_MIN).numpy(),
self.get_quantizer_variable(FQ_MAX).numpy(),
self.num_bits)
return ActivationUniformInferableQuantizer(num_bits=self.num_bits,
# In activation quantization is per-tensor only - thus we pass
# the min/max as lists with a len of 1
min_range=[min_range],
max_range=[max_range])
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
from typing import Tuple, Dict, List, Callable

from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.qat.common.qat_config import QATConfig
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
from mct_quantizers import QuantizationTarget, KerasActivationQuantizationHolder
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_weight_quantizer import \
BaseKerasQATWeightTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
get_trainable_quantizer_weights_config, get_trainable_quantizer_activation_config, \
get_trainable_quantizer_quantization_candidates
from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
get_trainable_quantizer_class
from model_compression_toolkit.trainable_infrastructure.keras.activation_quantizers import \
BaseKerasActivationTrainableQuantizer


def get_activation_quantizer_holder(n: common.BaseNode,
Expand Down Expand Up @@ -55,7 +56,7 @@ def get_activation_quantizer_holder(n: common.BaseNode,
def quantization_builder(n: common.BaseNode,
qat_config: QATConfig,
kernel_attr: str = None,
) -> Tuple[Dict[str, BaseKerasQATTrainableQuantizer], List[BaseKerasQATTrainableQuantizer]]:
) -> Tuple[Dict[str, BaseKerasQATWeightTrainableQuantizer], List[BaseKerasActivationTrainableQuantizer]]:
"""
Build quantizers for a node according to its quantization configuration.
Expand All @@ -82,7 +83,7 @@ def quantization_builder(n: common.BaseNode,
quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Weights,
qat_config.weight_training_method,
quant_method,
BaseKerasQATTrainableQuantizer)
BaseKerasQATWeightTrainableQuantizer)

weight_quantizers.update({kernel_attr: quantizer_class(get_trainable_quantizer_weights_config(n,
attr_name=kernel_attr,
Expand All @@ -98,7 +99,7 @@ def quantization_builder(n: common.BaseNode,
quantizer_class = get_trainable_quantizer_class(QuantizationTarget.Activation,
qat_config.activation_training_method,
quant_method,
BaseKerasQATTrainableQuantizer)
BaseKerasActivationTrainableQuantizer)

activation_quantizers = [quantizer_class(get_trainable_quantizer_activation_config(n, aq_cand),
**qat_config.activation_quantizer_params_override)] * len(output_shapes)
Expand Down
Loading

0 comments on commit fa2b3b7

Please sign in to comment.