Skip to content

Commit

Permalink
Fix activation gradient backprop in GPTQ (sony#1197)
Browse files Browse the repository at this point in the history
Add freeze_quant_params flag to base trainable quantizer with False as default.
Implement quant params freezing for STE activation quantizers.
Use activation trainable quantizers in GPTQ instead of inferable quantizers, with frozen quant params.
  • Loading branch information
irenaby authored Sep 8, 2024
1 parent a514d5f commit 75e9f83
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.trainable_infrastructure import TrainingMethod, BasePytorchActivationTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure.common.get_quantizer_config import \
get_trainable_quantizer_weights_config
get_trainable_quantizer_weights_config, get_trainable_quantizer_activation_config
from model_compression_toolkit.trainable_infrastructure.common.get_quantizers import \
get_trainable_quantizer_class

Expand Down Expand Up @@ -68,12 +69,11 @@ def quantization_builder(n: common.BaseNode,

quant_method = n.final_activation_quantization_cfg.activation_quantization_method

quantizer_class = get_inferable_quantizer_class(quant_target=QuantizationTarget.Activation,
quantizer_class = get_trainable_quantizer_class(quant_target=QuantizationTarget.Activation,
quantizer_id=TrainingMethod.STE,
quant_method=quant_method,
quantizer_base_class=BasePyTorchInferableQuantizer)

kwargs = get_activation_inferable_quantizer_kwargs(n.final_activation_quantization_cfg)

activation_quantizers.append(quantizer_class(**kwargs))
quantizer_base_class=BasePytorchActivationTrainableQuantizer)
cfg = get_trainable_quantizer_activation_config(n, None)
activation_quantizers.append(quantizer_class(cfg, freeze_quant_params=True))

return weights_quantizers, activation_quantizers
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@
# ==============================================================================
from abc import ABC, abstractmethod
from enum import Enum
from typing import Union, List, Any
from inspect import signature

from model_compression_toolkit.logger import Logger
from typing import Union, List, Any

from mct_quantizers.common.base_inferable_quantizer import BaseInferableQuantizer, \
QuantizationTarget
from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig
from mct_quantizers.common.constants import QUANTIZATION_METHOD, \
QUANTIZATION_TARGET

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.trainable_infrastructure.common.trainable_quantizer_config import \
TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig

VAR = 'var'
GROUP = 'group'
Expand All @@ -43,12 +41,14 @@ class VariableGroup(Enum):

class BaseTrainableQuantizer(BaseInferableQuantizer, ABC):
def __init__(self,
quantization_config: Union[TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig]):
quantization_config: Union[TrainableQuantizerActivationConfig, TrainableQuantizerWeightsConfig],
freeze_quant_params: bool = False):
"""
This class is a base quantizer which validates the provided quantization config and defines an abstract function which any quantizer needs to implment.
Args:
quantization_config: quantizer config class contains all the information about the quantizer configuration.
freeze_quant_params: whether to freeze all learnable quantization parameters during training.
"""

# verify the quantizer class that inherits this class only has a config argument and key-word arguments
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(self,
f"Unrecognized 'QuantizationTarget': {static_quantization_target}.") # pragma: no cover

self.quantizer_parameters = {}
self.freeze_quant_params = freeze_quant_params

@classmethod
def get_sig(cls):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import torch
from torch import nn

from mct_quantizers import mark_quantizer, QuantizationTarget, QuantizationMethod, PytorchQuantizationWrapper
from mct_quantizers import mark_quantizer, QuantizationTarget, QuantizationMethod, PytorchQuantizationWrapper, \
PytorchActivationQuantizationHolder
from mct_quantizers.pytorch.quantizers import ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer
from model_compression_toolkit import constants as C
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
Expand All @@ -39,14 +40,15 @@ class STESymmetricActivationTrainableQuantizer(BasePytorchActivationTrainableQua
Trainable constrained quantizer to quantize a layer activations.
"""

def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
def __init__(self, quantization_config: TrainableQuantizerActivationConfig, freeze_quant_params: bool = False):
"""
Initialize a STESymmetricActivationTrainableQuantizer object with parameters to use for symmetric or power of two quantization.
Args:
quantization_config: trainable quantizer config class
freeze_quant_params: whether to freeze learnable quantization parameters
"""
super().__init__(quantization_config)
super().__init__(quantization_config, freeze_quant_params)
self.power_of_two = quantization_config.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
self.sign = quantization_config.activation_quantization_params['is_signed']
np_threshold_values = quantization_config.activation_quantization_params[C.THRESHOLD]
Expand All @@ -56,7 +58,7 @@ def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
def initialize_quantization(self,
tensor_shape: torch.Size,
name: str,
layer: PytorchQuantizationWrapper):
layer: PytorchActivationQuantizationHolder):
"""
Add quantizer parameters to the quantizer parameters dictionary
Expand All @@ -66,7 +68,7 @@ def initialize_quantization(self,
layer: Layer to quantize.
"""
layer.register_parameter(name, nn.Parameter(to_torch_tensor(self.threshold_tensor),
requires_grad=True))
requires_grad=not self.freeze_quant_params))

# save the quantizer added parameters for later calculations
self.add_quantizer_variable(THRESHOLD_TENSOR, layer.get_parameter(name), VariableGroup.QPARAMS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ class STEUniformActivationTrainableQuantizer(BasePytorchActivationTrainableQuant
Trainable constrained quantizer to quantize a layer activations.
"""

def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
def __init__(self, quantization_config: TrainableQuantizerActivationConfig, freeze_quant_params: bool = False):
"""
Initialize a STEUniformActivationTrainableQuantizer object with parameters to use for uniform quantization.
Args:
quantization_config: trainable quantizer config class
quantization_config: trainable quantizer config class.
freeze_quant_params: whether to freeze learnable quantization parameters.
"""
super().__init__(quantization_config)
super().__init__(quantization_config, freeze_quant_params)

np_min_range = quantization_config.activation_quantization_params[C.RANGE_MIN]
np_max_range = quantization_config.activation_quantization_params[C.RANGE_MAX]
Expand All @@ -56,17 +57,17 @@ def initialize_quantization(self,
name: str,
layer: PytorchQuantizationWrapper):
"""
Add quantizer parameters to the quantizer parameters dictionary
Add quantizer parameters to the quantizer parameters dictionary.
Args:
tensor_shape: tensor shape of the quantized tensor.
name: Tensor name.
layer: Layer to quantize.
"""
layer.register_parameter(name+"_"+FQ_MIN, nn.Parameter(to_torch_tensor(self.min_range_tensor),
requires_grad=True))
requires_grad=not self.freeze_quant_params))
layer.register_parameter(name+"_"+FQ_MAX, nn.Parameter(to_torch_tensor(self.max_range_tensor),
requires_grad=True))
requires_grad=not self.freeze_quant_params))

# Save the quantizer parameters for later calculations
self.add_quantizer_variable(FQ_MIN, layer.get_parameter(name+"_"+FQ_MIN), VariableGroup.QPARAMS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def get_trainable_variables(self, group: VariableGroup) -> List[torch.Tensor]:
quantizer_parameter, parameter_group = parameter_dict[VAR], parameter_dict[GROUP]
if quantizer_parameter.requires_grad and parameter_group == group:
quantizer_trainable.append(quantizer_parameter)

# sanity check to catch inconsistent initialization
if self.freeze_quant_params and group == VariableGroup.QPARAMS and quantizer_trainable:
Logger.critical(
'Found trainable quantization params despite self.freeze_quant_params=True. '
'Quantization parameters were probably not initialized correctly in the Quantizer.'
) # pragma: no cover

return quantizer_trainable

else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import copy

import unittest

import numpy as np
import torch
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
from mct_quantizers.pytorch.quantizers import ActivationPOTInferableQuantizer
from torch.nn import Conv2d
import numpy as np

import model_compression_toolkit as mct
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
from model_compression_toolkit.core.common.mixed_precision.bit_width_setter import set_bit_widths
from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
from model_compression_toolkit.gptq.pytorch.gptq_training import PytorchGPTQTrainer
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc
from torch.fx import symbolic_trace
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers import \
STESymmetricActivationTrainableQuantizer
from tests.common_tests.helpers.prep_graph_for_func_test import prepare_graph_with_quantization_parameters


INPUT_SHAPE = [3, 8, 8]


Expand Down Expand Up @@ -73,7 +74,12 @@ def test_adding_holder_instead_quantize_wrapper(self):
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer))
self.assertEquals(a.activation_holder_quantizer.identifier, TrainingMethod.STE)
# activation quantization params for gptq should be frozen (non-learnable)
self.assertTrue(a.activation_holder_quantizer.freeze_quant_params is True)
self.assertEquals(a.activation_holder_quantizer.get_trainable_variables(VariableGroup.QPARAMS), [])

for name, module in gptq_model.named_modules():
if isinstance(module, PytorchQuantizationWrapper):
self.assertTrue(len(module.weights_quantizers) > 0)
Expand All @@ -87,7 +93,7 @@ def test_adding_holder_after_relu(self):
# check that 3 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer))
for name, module in gptq_model.named_modules():
if isinstance(module, PytorchQuantizationWrapper):
self.assertTrue(len(module.weights_quantizers) > 0)
Expand All @@ -102,14 +108,18 @@ def test_adding_holders_after_reuse(self):
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer))
for name, module in gptq_model.named_modules():
if isinstance(module, PytorchQuantizationWrapper):
self.assertTrue(len(module.weights_quantizers) > 0)
# Test that two holders are getting inputs from reused conv2d (the layer that is wrapped)
fx_model = symbolic_trace(gptq_model)
self.assertTrue(list(fx_model.graph.nodes)[3].all_input_nodes[0] == list(fx_model.graph.nodes)[2])
self.assertTrue(list(fx_model.graph.nodes)[6].all_input_nodes[0] == list(fx_model.graph.nodes)[5])

# FIXME there is no reuse support and the test doesn't test what it says it tests. It doesn't even look
# at correct layers. After moving to trainable quantizer the test makes even less sense since now fx traces
# all quantization operations instead of fake_quant layer.
# fx_model = symbolic_trace(gptq_model)
# self.assertTrue(list(fx_model.graph.nodes)[3].all_input_nodes[0] == list(fx_model.graph.nodes)[2])
# self.assertTrue(list(fx_model.graph.nodes)[6].all_input_nodes[0] == list(fx_model.graph.nodes)[5])

def _get_gptq_model(self, input_shape, in_model):
pytorch_impl = GPTQPytorchImplemantation()
Expand Down
18 changes: 11 additions & 7 deletions tests/pytorch_tests/model_tests/feature_models/qat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,26 @@
from torch import Tensor

import model_compression_toolkit as mct
import model_compression_toolkit.trainable_infrastructure.common.training_method
from mct_quantizers import PytorchActivationQuantizationHolder, QuantizationTarget, PytorchQuantizationWrapper
from mct_quantizers.common.base_inferable_quantizer import QuantizerID
from mct_quantizers.common.get_all_subclasses import get_all_subclasses
from mct_quantizers.pytorch.quantizers import BasePyTorchInferableQuantizer
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import BasePytorchQATWeightTrainableQuantizer
from model_compression_toolkit.qat.pytorch.quantizer.base_pytorch_qat_weight_quantizer import \
BasePytorchQATWeightTrainableQuantizer
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc, \
get_op_quantization_configs
from model_compression_toolkit.trainable_infrastructure import TrainingMethod
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers.base_activation_quantizer import \
BasePytorchActivationTrainableQuantizer
from model_compression_toolkit.trainable_infrastructure.pytorch.activation_quantizers.ste.symmetric_ste import \
STESymmetricActivationTrainableQuantizer
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import generate_pytorch_tpc
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.latest import \
get_op_quantization_configs
from tests.common_tests.helpers.generate_test_tp_model import generate_test_tp_model, \
generate_tp_model_with_activation_mp
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest
from tests.pytorch_tests.tpc_pytorch import get_mp_activation_pytorch_tpc_dict
from mct_quantizers.common.base_inferable_quantizer import QuantizerID
from model_compression_toolkit.trainable_infrastructure import TrainingMethod


def dummy_train(qat_ready_model, x, y):
Expand Down Expand Up @@ -179,6 +179,10 @@ def compare(self, ptq_model, qat_ready_model, qat_finalized_model, input_x=None,
and self.activation_quantization_method in _q.quantization_method]
self.unit_test.assertTrue(len(q) == 1)
self.unit_test.assertTrue(isinstance(layer.activation_holder_quantizer, q[0]))
# quantization params in qat should be trainable (not frozen)
self.unit_test.assertFalse(layer.activation_holder_quantizer.freeze_quant_params)
trainable_params = layer.activation_holder_quantizer.get_trainable_variables(VariableGroup.QPARAMS)
self.unit_test.assertTrue(len(trainable_params) > 0)
elif isinstance(layer, PytorchQuantizationWrapper) and isinstance(layer.layer, nn.Conv2d):
q = [_q for _q in all_qat_weight_quantizers if _q.identifier == self.training_method
and _q.quantization_target == QuantizationTarget.Weights
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,10 @@ def get_weights_quantization_config(self):
weights_per_channel_threshold=True,
min_threshold=0)

def get_activation_quantization_config(self):
return TrainableQuantizerActivationConfig(activation_quantization_method=QuantizationMethod.POWER_OF_TWO,
def get_activation_quantization_config(self, quant_method=QuantizationMethod.POWER_OF_TWO,
activation_quant_params=None):
return TrainableQuantizerActivationConfig(activation_quantization_method=quant_method,
activation_n_bits=8,
activation_quantization_params={},
activation_quantization_params=activation_quant_params or {},
enable_activation_quantization=True,
min_threshold=0)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from model_compression_toolkit.trainable_infrastructure.pytorch.base_pytorch_quantizer import \
BasePytorchTrainableQuantizer
from tests.pytorch_tests.trainable_infrastructure_tests.trainable_pytorch.test_pytorch_base_quantizer import \
TestPytorchBaseWeightsQuantizer, TestPytorchBaseActivationQuantizer, TestPytorchQuantizerWithoutMarkDecorator
TestPytorchBaseWeightsQuantizer, TestPytorchBaseActivationQuantizer, TestPytorchQuantizerWithoutMarkDecorator, \
TestPytorchSTEActivationQuantizerQParamFreeze
from tests.pytorch_tests.trainable_infrastructure_tests.trainable_pytorch.test_pytorch_get_quantizers import \
TestGetTrainableQuantizer

Expand All @@ -46,6 +47,9 @@ def test_pytorch_base_quantizer(self):
TestPytorchBaseActivationQuantizer(self).run_test()
TestPytorchQuantizerWithoutMarkDecorator(self).run_test()

def test_pytorch_ste_activation_quantizers_qparams_freeze(self):
TestPytorchSTEActivationQuantizerQParamFreeze(self).run_test()

def test_pytorch_get_quantizers(self):
TestGetTrainableQuantizer(self, quant_target=QuantizationTarget.Weights,
quant_method=QuantizationMethod.POWER_OF_TWO,
Expand Down
Loading

0 comments on commit 75e9f83

Please sign in to comment.