From 3ebae9500ac1c4533e4f7cf246efe464c99b5427 Mon Sep 17 00:00:00 2001 From: irenab Date: Wed, 25 Sep 2024 23:44:11 +0300 Subject: [PATCH 1/5] initial sample layer attention implementation for torch --- .../core/common/hessian/__init__.py | 4 +- .../common/hessian/hessian_info_service.py | 80 +++++- .../common/hessian/hessian_scores_request.py | 17 +- ...ation_hessian_scores_calculator_pytorch.py | 254 ++++++++++++------ .../gptq/common/gptq_config.py | 9 +- .../gptq/common/gptq_training.py | 39 ++- .../gptq/pytorch/gptq_loss.py | 35 ++- .../gptq/pytorch/gptq_training.py | 42 ++- .../gptq/pytorch/quantization_facade.py | 26 +- .../model_tests/feature_models/gptq_test.py | 41 ++- .../model_tests/test_feature_models_runner.py | 9 + 11 files changed, 434 insertions(+), 122 deletions(-) diff --git a/model_compression_toolkit/core/common/hessian/__init__.py b/model_compression_toolkit/core/common/hessian/__init__.py index 27578918f..d1a43fe7a 100644 --- a/model_compression_toolkit/core/common/hessian/__init__.py +++ b/model_compression_toolkit/core/common/hessian/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, HessianMode, HessianScoresGranularity +from model_compression_toolkit.core.common.hessian.hessian_scores_request import ( + HessianScoresRequest, HessianMode, HessianScoresGranularity, HessianEstimationDistribution +) from model_compression_toolkit.core.common.hessian.hessian_info_service import HessianInfoService import model_compression_toolkit.core.common.hessian.hessian_info_utils as hessian_utils diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index ba5e2c143..b1d683f58 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import hashlib import numpy as np from functools import partial from tqdm import tqdm -from typing import Callable, List, Dict, Any, Tuple +from typing import Callable, List, Dict, Any, Tuple, TYPE_CHECKING from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \ HessianScoresGranularity, HessianMode from model_compression_toolkit.logger import Logger +if TYPE_CHECKING: + from model_compression_toolkit.core.common import BaseNode class HessianInfoService: @@ -228,10 +231,55 @@ def compute(self, return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \ and len(next_iter_remain_samples[0]) > 0 else None + def _compute_trackable_per_sample_hessian(self, + hessian_scores_request: HessianScoresRequest, + representative_dataset_gen) -> Dict[str, Dict['BaseNode', np.ndarray]]: + """ + Compute hessian score per image hash. + + Args: + hessian_scores_request: hessian scores request + representative_dataset_gen: representative dataset generator + + Returns: + A dict of Hessian scores per image hash per layer {image hash: {layer: score}} + """ + topo_sorted_nodes_names = [x.name for x in self.graph.get_topo_sorted_nodes()] + hessian_scores_request.target_nodes.sort(key=lambda x: topo_sorted_nodes_names.index(x.name)) + + hessian_score_by_image_hash = {} + + for inputs_batch in representative_dataset_gen: + if not isinstance(inputs_batch, list): + raise TypeError('Expected representative data generator to yield a list of inputs') + if len(inputs_batch) > 1: + raise NotImplementedError('Per sample hessian computation is not supported for networks with multiple inputs') + # Get the framework-specific calculator Hessian-approximation scores + fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph, + input_images=inputs_batch, + hessian_scores_request=hessian_scores_request, + num_iterations_for_approximation=self.num_iterations_for_approximation) + hessian_scores = fw_hessian_calculator.compute() + for b in range(inputs_batch[0].shape[0]): + img_hash = self.calc_image_hash(inputs_batch[0][b]) + hessian_score_by_image_hash[img_hash] = { + node: score for node, score in zip(hessian_scores_request.target_nodes, hessian_scores) + } + + return hessian_score_by_image_hash + + @staticmethod + def calc_image_hash(image): + if len(image.shape) != 3: + raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}') + image_bytes = image.astype(np.float32).tobytes() + return hashlib.md5(image_bytes).hexdigest() + def fetch_hessian(self, hessian_scores_request: HessianScoresRequest, required_size: int, - batch_size: int = 1) -> List[List[np.ndarray]]: + batch_size: int = 1, + per_sample_hash: bool = False) -> List[List[np.ndarray]]: """ Fetches the computed approximations of the Hessian-based scores for the given request and required size. @@ -240,6 +288,7 @@ def fetch_hessian(self, hessian_scores_request: Configuration for which to fetch the approximation. required_size: Number of approximations required. batch_size: The Hessian computation batch size. + per_sample_hash: Whether to compute hessian per sample hash. Returns: List[List[np.ndarray]]: For each target node, returns a list of computed approximations. @@ -247,7 +296,6 @@ def fetch_hessian(self, The inner list length dependent on the granularity (1 for per-tensor, OC for per-output-channel when the requested node has OC output-channels, etc.) """ - if len(hessian_scores_request.target_nodes) == 0: return [] @@ -263,6 +311,32 @@ def fetch_hessian(self, for node in hessian_scores_request.target_nodes ] + if per_sample_hash: + if required_size is not None: + raise ValueError('required_size cannot be specified with per_sample_hash') + + def gen_single(orig_gen): + # convert original generator into generator that yields sample by sample + for batch in orig_gen: + for i in range(batch[0].shape[0]): + yield [inp[i] for inp in batch] + + def gen_new_batch(): + # convert sample by sample generator into the required batch + samples = [] + for sample in gen_single(self.representative_dataset_gen()): + samples.append(sample) + if len(samples) == batch_size: + yield [np.stack(d, axis=0) for d in zip(*samples)] + samples = [] + if samples: + yield [np.stack(d, axis=0) for d in zip(*samples)] + + return self._compute_trackable_per_sample_hessian(hessian_scores_request, gen_new_batch()) + + if required_size is None: + raise ValueError('required_size must be specified if per_sample_hash is False') + # Ensure the saved info has the required number of approximations self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size) diff --git a/model_compression_toolkit/core/common/hessian/hessian_scores_request.py b/model_compression_toolkit/core/common/hessian/hessian_scores_request.py index cfd9242ad..b982f9b8a 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_scores_request.py +++ b/model_compression_toolkit/core/common/hessian/hessian_scores_request.py @@ -40,6 +40,14 @@ class HessianScoresGranularity(Enum): PER_TENSOR = 2 +class HessianEstimationDistribution(str, Enum): + """ + Distribution for Hutchinson estimator random vector + """ + GAUSSIAN = 'gaussian' + RADEMACHER = 'rademacher' + + class HessianScoresRequest: """ Request configuration for the Hessian-approximation scores. @@ -53,7 +61,8 @@ class HessianScoresRequest: def __init__(self, mode: HessianMode, granularity: HessianScoresGranularity, - target_nodes: List): + target_nodes: List, + distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN): """ Attributes: mode (HessianMode): Mode of Hessian-approximation score (w.r.t weights or activations). @@ -64,6 +73,7 @@ def __init__(self, self.mode = mode # w.r.t activations or weights self.granularity = granularity # per element, per layer, per channel self.target_nodes = target_nodes + self.distribution = distribution def __eq__(self, other): # Checks if the other object is an instance of HessianScoresRequest @@ -71,9 +81,10 @@ def __eq__(self, other): return isinstance(other, HessianScoresRequest) and \ self.mode == other.mode and \ self.granularity == other.granularity and \ - self.target_nodes == other.target_nodes + self.target_nodes == other.target_nodes and \ + self.distribution == other.distribution def __hash__(self): # Computes the hash based on the attributes. # The use of a tuple here ensures that the hash is influenced by all the attributes. - return hash((self.mode, self.granularity, tuple(self.target_nodes))) \ No newline at end of file + return hash((self.mode, self.granularity, tuple(self.target_nodes), self.distribution)) diff --git a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py index 0ff6a5a73..94a866150 100644 --- a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +++ b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py @@ -21,7 +21,8 @@ from model_compression_toolkit.constants import MIN_HESSIAN_ITER, HESSIAN_COMP_TOLERANCE, HESSIAN_NUM_ITERATIONS from model_compression_toolkit.core.common import Graph -from model_compression_toolkit.core.common.hessian import HessianScoresRequest, HessianScoresGranularity +from model_compression_toolkit.core.common.hessian import (HessianScoresRequest, HessianScoresGranularity, + HessianEstimationDistribution) from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder from model_compression_toolkit.core.pytorch.hessian.hessian_scores_calculator_pytorch import \ HessianScoresCalculatorPytorch @@ -55,6 +56,52 @@ def __init__(self, hessian_scores_request=hessian_scores_request, num_iterations_for_approximation=num_iterations_for_approximation) + def forward_pass(self): + model_output_nodes = [ot.node for ot in self.graph.get_outputs()] + + if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0: + Logger.critical("Activation Hessian approximation cannot be computed for model outputs. " + "Exclude output nodes from Hessian request targets.") + + grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes + model, _ = FloatPyTorchModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model() + model.eval() + + # Run model inference + # Set inputs to track gradients during inference + for input_tensor in self.input_images: + input_tensor.requires_grad_() + input_tensor.retain_grad() + + outputs = model(*self.input_images) + + if len(outputs) != len(grad_model_outputs): # pragma: no cover + Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. " + f"Expected {len(grad_model_outputs)} outputs, received {len(outputs)}.") + + # Extracting the intermediate activation tensors and the model real output. + # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap. + num_target_nodes = len(self.hessian_request.target_nodes) + # Extract activation tensors of nodes for which we want to compute Hessian + target_activation_tensors = outputs[:num_target_nodes] + # Extract the model outputs + output_tensors = outputs[num_target_nodes:] + device = output_tensors[0].device + + # Concat outputs + # First, we need to unfold all outputs that are given as list, to extract the actual output tensors + output = self.concat_tensors(output_tensors) + return output, target_activation_tensors + + def _generate_random_vector(self, shape, distribution: HessianEstimationDistribution, device): + if distribution == HessianEstimationDistribution.GAUSSIAN: + return torch.randn(shape, device=device) + + if distribution == HessianEstimationDistribution.RADEMACHER: + return torch.where(torch.randint(0, 2, shape), 1, -1).to(device) + + raise ValueError(f'Unknown distribution {distribution}') + def compute(self) -> List[np.ndarray]: """ Compute the scores that are based on the approximation of the Hessian w.r.t the requested target nodes' activations. @@ -62,91 +109,126 @@ def compute(self) -> List[np.ndarray]: Returns: List[np.ndarray]: Scores based on the approximated Hessian for the requested nodes. """ + output, target_activation_tensors = self.forward_pass() + if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR: + hessian_scores = self._compute_per_tensor(output, target_activation_tensors) + elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL: + hessian_scores = self._compute_per_channel(output, target_activation_tensors) + else: + raise NotImplementedError(f'{HessianScoresGranularity.PER_ELEMENT} is not supported') + + # Convert results to list of numpy arrays + hessian_results = [torch_tensor_to_numpy(h) for h in hessian_scores] + return hessian_results + + def _compute_per_tensor(self, output, target_activation_tensors): + assert self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR + ipts_hessian_approx_scores = [torch.tensor([0.0], requires_grad=True, device=output.device) + for _ in range(len(target_activation_tensors))] + prev_mean_results = None + for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations + # Getting a random vector with normal distribution + v = self._generate_random_vector(output.shape, self.hessian_request.distribution, output.device) + f_v = torch.sum(v * output) + for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor + # Computing the hessian-approximation scores by getting the gradient of (output * v) + hess_v = autograd.grad(outputs=f_v, + inputs=ipt_tensor, + retain_graph=True, + allow_unused=True)[0] + + if hess_v is None: + # In case we have an output node, which is an interest point, but it is not differentiable, + # we consider its Hessian to be the initial value 0. + continue # pragma: no cover + + # Mean over all dims but the batch (CXHXW for conv) + hessian_approx_scores = torch.sum(hess_v ** 2.0, dim=tuple(d for d in range(1, len(hess_v.shape)))) + + # Update node Hessian approximation mean over random iterations + ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1) + + # If the change to the maximal mean Hessian approximation is insignificant we stop the calculation + if j > MIN_HESSIAN_ITER: + if prev_mean_results is not None: + new_mean_res = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1) + relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) / + (torch.abs(new_mean_res) + 1e-6)) + max_delta = torch.max(relative_delta_per_node) + if max_delta < HESSIAN_COMP_TOLERANCE: + break + prev_mean_results = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1) + + # add extra dimension to preserve previous behaviour + ipts_hessian_approx_scores = [torch.unsqueeze(t, -1) for t in ipts_hessian_approx_scores] + return ipts_hessian_approx_scores + + def _compute_per_channel(self, output, target_activation_tensors): + assert self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL + ipts_hessian_approx_scores = [torch.tensor(0.0, requires_grad=True, device=output.device) + for _ in range(len(target_activation_tensors))] + + for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations + # Getting a random vector with normal distribution + v = self._generate_random_vector(output.shape, self.hessian_request.distribution, output.device) + f_v = torch.sum(v * output) + for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor + # Computing the hessian-approximation scores by getting the gradient of (output * v) + hess_v = autograd.grad(outputs=f_v, + inputs=ipt_tensor, + retain_graph=True, + allow_unused=True)[0] + + hessian_approx_scores = hess_v ** 2 + rank = len(hess_v.shape) + if rank > 2: + hessian_approx_scores = torch.mean(hess_v, dim=tuple(range(2, rank))) + + # Update node Hessian approximation mean over random iterations + ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1) + + return ipts_hessian_approx_scores + + +class SampleLayerHessianScoresCalculatorPytorch(ActivationHessianScoresCalculatorPytorch): - model_output_nodes = [ot.node for ot in self.graph.get_outputs()] - - if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0: - Logger.critical("Activation Hessian approximation cannot be computed for model outputs. " - "Exclude output nodes from Hessian request targets.") - - grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes - model, _ = FloatPyTorchModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model() - model.eval() - - # Run model inference - # Set inputs to track gradients during inference - for input_tensor in self.input_images: - input_tensor.requires_grad_() - input_tensor.retain_grad() - - outputs = model(*self.input_images) - - if len(outputs) != len(grad_model_outputs): # pragma: no cover - Logger.critical(f"Mismatch in expected and actual model outputs for activation Hessian approximation. " - f"Expected {len(grad_model_outputs)} outputs, received {len(outputs)}.") - - # Extracting the intermediate activation tensors and the model real output. - # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap. - num_target_nodes = len(self.hessian_request.target_nodes) - # Extract activation tensors of nodes for which we want to compute Hessian - target_activation_tensors = outputs[:num_target_nodes] - # Extract the model outputs - output_tensors = outputs[num_target_nodes:] - device = output_tensors[0].device - - # Concat outputs - # First, we need to unfold all outputs that are given as list, to extract the actual output tensors - output = self.concat_tensors(output_tensors) - - ipts_hessian_approx_scores = [torch.tensor([0.0], - requires_grad=True, - device=device) - for _ in range(len(target_activation_tensors))] - prev_mean_results = None - for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations - # Getting a random vector with normal distribution - v = torch.randn(output.shape, device=device) - f_v = torch.sum(v * output) - for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor - # Computing the hessian-approximation scores by getting the gradient of (output * v) - hess_v = autograd.grad(outputs=f_v, - inputs=ipt_tensor, - retain_graph=True, - allow_unused=True)[0] - - if hess_v is None: - # In case we have an output node, which is an interest point, but it is not differentiable, - # we consider its Hessian to be the initial value 0. - continue # pragma: no cover - - # Mean over all dims but the batch (CXHXW for conv) - hessian_approx_scores = torch.sum(hess_v ** 2.0, dim=tuple(d for d in range(1, len(hess_v.shape)))) - - # Update node Hessian approximation mean over random iterations - ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1) - - # If the change to the maximal mean Hessian approximation is insignificant we stop the calculation - if j > MIN_HESSIAN_ITER: - if prev_mean_results is not None: - new_mean_res = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1) - relative_delta_per_node = (torch.abs(new_mean_res - prev_mean_results) / - (torch.abs(new_mean_res) + 1e-6)) - max_delta = torch.max(relative_delta_per_node) - if max_delta < HESSIAN_COMP_TOLERANCE: - break - prev_mean_results = torch.mean(torch.stack(ipts_hessian_approx_scores), dim=1) - - # Convert results to list of numpy arrays - hessian_results = [torch_tensor_to_numpy(h) for h in ipts_hessian_approx_scores] - # Extend the Hessian tensors shape to align with expected return type - # TODO: currently, only per-tensor Hessian is available for activation. - # Once implementing per-channel or per-element, this alignment needs to be verified and handled separately. - hessian_results = [h[..., np.newaxis] for h in hessian_results] - - return hessian_results - - else: # pragma: no cover - Logger.critical(f"PyTorch activation Hessian's approximation scores does not support " - f"{self.hessian_request.granularity} granularity.") + def compute(self) -> List[np.ndarray]: + """ + Compute the scores that are based on the approximation of the Hessian w.r.t the requested target nodes' activations. + Returns: + List[np.ndarray]: List of approximated Hessian scores of shape (n_samples X n_channels) + for the requested nodes. + """ + output, target_activation_tensors = self.forward_pass() + device = output.device + + # Score is initialize to a scalar. Upon first update it will get the correct shape. + ipts_hessian_approx_scores = [torch.tensor(0., requires_grad=True, device=device) + for _ in range(len(target_activation_tensors))] + + for j in tqdm(range(self.num_iterations_for_approximation), + "Hessian random iterations"): # Approximation iterations + v = torch.randint_like(output, high=2, device=device) + v[v == 0] = -1 + out_v = torch.sum(v * output) + for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor + # Computing the hessian-approximation scores by getting the gradient of (output * v) + hess_v = autograd.grad(outputs=out_v, + inputs=ipt_tensor, + retain_graph=True, + allow_unused=True)[0] + + hessian_approx_scores = hess_v**2 + + rank = len(hess_v.shape) + if rank > 2: + hessian_approx_scores = torch.mean(hess_v, dim=range(2, rank)) + + # Update node Hessian approximation mean over random iterations + ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1) + + # Convert results to list of numpy arrays + hessian_results = [torch_tensor_to_numpy(h) for h in ipts_hessian_approx_scores] + return hessian_results diff --git a/model_compression_toolkit/gptq/common/gptq_config.py b/model_compression_toolkit/gptq/common/gptq_config.py index dcd806a93..972bd742e 100644 --- a/model_compression_toolkit/gptq/common/gptq_config.py +++ b/model_compression_toolkit/gptq/common/gptq_config.py @@ -17,6 +17,7 @@ from typing import Callable, Any, Dict, Optional from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE +from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT @@ -39,17 +40,21 @@ class GPTQHessianScoresConfig: Configuration to use for computing the Hessian-based scores for GPTQ loss metric. Args: - hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores. + hessians_num_samples (int|None): Number of samples to use for computing the Hessian-based scores. + If None, compute Hessian for all images. norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1). log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores. scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores. hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective. + per_sample (bool): Whether to use per sample attention score. """ - hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES + hessians_num_samples: Optional[int] = GPTQ_HESSIAN_NUM_SAMPLES norm_scores: bool = True log_norm: bool = True scale_log_norm: bool = False hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE + per_sample: bool = False + estimator_distribution: HessianEstimationDistribution = HessianEstimationDistribution.GAUSSIAN @dataclass diff --git a/model_compression_toolkit/gptq/common/gptq_training.py b/model_compression_toolkit/gptq/common/gptq_training.py index 01265031f..920e4d70a 100644 --- a/model_compression_toolkit/gptq/common/gptq_training.py +++ b/model_compression_toolkit/gptq/common/gptq_training.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import copy +import hashlib from abc import ABC, abstractmethod import numpy as np from typing import Callable, List, Any, Dict @@ -143,7 +144,12 @@ def compute_hessian_based_weights(self) -> np.ndarray: return np.asarray([1 / num_nodes for _ in range(num_nodes)]) # Fetch hessian approximations for each target node - compare_point_to_hessian_approx_scores = self._fetch_hessian_approximations() + # TODO this smells like a bug. In hessian calculation target nodes are topo sorted and results are returned + # in the same order. Maybe topo sort doesn't do anything and it works? + # TODO also target nodes are replaced for reuse. Does this work correctly? + approximations = self._fetch_hessian_approximations(HessianScoresGranularity.PER_TENSOR) + compare_point_to_hessian_approx_scores = {node: score for node, score in zip(self.compare_points, approximations)} + # Process the fetched hessian approximations to gather them per images hessian_approx_score_by_image = ( self._process_hessian_approximations(compare_point_to_hessian_approx_scores)) @@ -172,29 +178,40 @@ def compute_hessian_based_weights(self) -> np.ndarray: # If log normalization is not enabled, return the mean of the approximations across images return np.mean(hessian_approx_score_by_image, axis=0) - def _fetch_hessian_approximations(self) -> Dict[BaseNode, List[List[float]]]: + def _compute_sample_layer_attention_scores(self) -> Dict[str, Dict[BaseNode, np.ndarray]]: + """ + Compute sample layer attention scores per image hash per layer. + + Returns: + A dictionary {img_hash: {layer: score}} where score is the + + """ + hessian_score_per_image_per_layer = self._fetch_hessian_approximations(HessianScoresGranularity.PER_OUTPUT_CHANNEL) + for layers_score in hessian_score_per_image_per_layer.values(): + for k, t in layers_score.items(): + layers_score[k] = t.max(axis=1) + return hessian_score_per_image_per_layer + + def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) -> Dict[BaseNode, List[List[float]]]: """ Fetches hessian approximations for each target node. Returns: Mapping of target nodes to their hessian approximations. """ - approximations = {} hessian_scores_request = HessianScoresRequest( mode=HessianMode.ACTIVATION, - granularity=HessianScoresGranularity.PER_TENSOR, - target_nodes=self.compare_points + granularity=granularity, + target_nodes=self.compare_points, + distribution=self.gptq_config.hessian_weights_config.estimator_distribution ) node_approximations = self.hessian_service.fetch_hessian( hessian_scores_request=hessian_scores_request, required_size=self.gptq_config.hessian_weights_config.hessians_num_samples, - batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size + batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size, + per_sample_hash=self.gptq_config.hessian_weights_config.per_sample ) - - for i, target_node in enumerate(self.compare_points): - approximations[target_node] = node_approximations[i] - - return approximations + return node_approximations def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[List[float]]]) -> List: """ diff --git a/model_compression_toolkit/gptq/pytorch/gptq_loss.py b/model_compression_toolkit/gptq/pytorch/gptq_loss.py index 2f74fda50..749b079e0 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_loss.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_loss.py @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== from typing import List + import torch + def mse_loss(y: torch.Tensor, x: torch.Tensor, normalized: bool = True) -> torch.Tensor: """ Compute the MSE of two tensors. @@ -25,7 +27,7 @@ def mse_loss(y: torch.Tensor, x: torch.Tensor, normalized: bool = True) -> torch Returns: The MSE of two tensors. """ - loss = torch.nn.MSELoss()(x,y) + loss = torch.nn.MSELoss()(x, y) return loss / torch.mean(torch.square(x)) if normalized else loss @@ -62,3 +64,34 @@ def multiple_tensors_mse_loss(y_list: List[torch.Tensor], else: return torch.mean(torch.stack(loss_values_list)) + +def sample_layer_attention_loss(y_list: List[torch.Tensor], + x_list: List[torch.Tensor], + fxp_w_list, + flp_w_list, + act_bn_mean, + act_bn_std, + loss_weights: torch.Tensor) -> torch.Tensor: + """ + Compute Sample Layer Attention loss between two lists of tensors. + + Args: + y_list: First list of tensors. + x_list: Second list of tensors. + fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface). + loss_weights: A list of weights for each layer. Each weight is a vector of shape (batch,) + + Returns: + Sample Layer Attention loss (a scalar). + """ + loss = 0 + layers_mean_w = [] + + for i, (y, x, w) in enumerate(zip(y_list, x_list, loss_weights)): + norm = (y - x).pow(2).sum(1) + if len(norm.shape) > 1: + norm = norm.flatten(1).mean(1) + loss += torch.mean(w * norm) + layers_mean_w.append(w.mean()) + loss = loss / torch.stack(layers_mean_w).max() + return loss diff --git a/model_compression_toolkit/gptq/pytorch/gptq_training.py b/model_compression_toolkit/gptq/pytorch/gptq_training.py index d5a98ac93..b78b19679 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_training.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_training.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Tuple, Union, Dict import numpy as np from torch.nn import Module @@ -105,8 +105,16 @@ def _get_total_grad_steps(): self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights, trainable_bias, trainable_threshold) - - self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights()) + hessian_cfg = self.gptq_config.hessian_weights_config + self.weights_for_average_loss = None # for fixed layer weights + self.hessian_score_per_image_per_layer = None # for sample-layer attention + if hessian_cfg.per_sample: + assert (hessian_cfg.norm_scores is False and hessian_cfg.log_norm is False and + hessian_cfg.scale_log_norm is False), hessian_cfg + # self.hessian_score_per_image_per_layer = self._fetch_hessian_approximations() + self.hessian_score_per_image_per_layer = self._compute_sample_layer_attention_scores() + else: + self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights()) self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps) @@ -210,13 +218,15 @@ def train(self, representative_data_gen: Callable): def compute_gradients(self, y_float: List[torch.Tensor], - input_tensors: List[torch.Tensor]) -> Tuple[torch.Tensor, List[np.ndarray]]: + input_tensors: List[torch.Tensor], + weights_for_average_loss) -> Tuple[torch.Tensor, List[np.ndarray]]: """ Get outputs from both teacher and student networks. Compute the observed error, and use it to compute the gradients and applying them to the student weights. Args: y_float: A list of reference tensor from the floating point network. input_tensors: A list of Input tensors to pass through the networks. + weights_for_average_loss: Weights for loss. Either per layer, or per layer per sample. Returns: Loss and gradients. """ @@ -231,7 +241,7 @@ def compute_gradients(self, self.flp_weights_list, self.compare_points_mean, self.compare_points_std, - self.weights_for_average_loss) + weights_for_average_loss) reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor) @@ -264,7 +274,9 @@ def micro_training_loop(self, input_data = [d * self.input_scale for d in data] input_tensor = to_torch_tensor(input_data) y_float = self.float_model(input_tensor) # running float model - loss_value, grads = self.compute_gradients(y_float, input_tensor) + # weights are either (layers,) or (batch X layers X channels) + weights = to_torch_tensor(self._get_samples_weights_for_loss(input_tensor)) + loss_value, grads = self.compute_gradients(y_float, input_tensor, weights) # Run one step of gradient descent by updating the value of the variables to minimize the loss. for (optimizer, _) in self.optimizer_with_param: optimizer.step() @@ -276,6 +288,24 @@ def micro_training_loop(self, self.loss_list.append(loss_value.item()) Logger.debug(f'last loss value: {self.loss_list[-1]}') + # TODO move to common after ctor refactor + def _get_samples_weights_for_loss(self, input_tensors: List[torch.Tensor]): + if self.hessian_score_per_image_per_layer is None: + assert self.weights_for_average_loss is not None + return self.weights_for_average_loss + + if len(input_tensors) > 1: + raise NotImplementedError('Sample-Layer attention is not currently supported for networks with multiple inputs') + + scores = [] + batch = input_tensors[0].detach().cpu().numpy() + img_hashes = [self.hessian_service.calc_image_hash(img) for img in batch] + for img_hash in img_hashes: + img_scores_per_layer: Dict[BaseNode, np.ndarray] = self.hessian_score_per_image_per_layer[img_hash] + img_scores = np.stack(list(img_scores_per_layer.values()), axis=0) + scores.append(img_scores) + return np.stack(scores, axis=0) + def update_graph(self) -> Graph: """ Update a graph using GPTQ after minimizing the loss between the float model's output diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index 46e9a25b1..8022a5552 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -18,6 +18,7 @@ from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH from model_compression_toolkit.core import CoreConfig from model_compression_toolkit.core.analyzer import analyzer_model_quantization +from model_compression_toolkit.core.common.hessian import HessianScoresGranularity, HessianEstimationDistribution from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \ MixedPrecisionQuantizationConfig from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \ @@ -43,7 +44,7 @@ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL - from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss + from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss, sample_layer_attention_loss from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.fully_quantized_model_builder import get_exportable_pytorch_model import torch from torch.nn import Module @@ -55,11 +56,12 @@ def get_pytorch_gptq_config(n_epochs: int, optimizer: Optimizer = None, optimizer_rest: Optimizer = None, - loss: Callable = multiple_tensors_mse_loss, + loss: Callable = None, log_function: Callable = None, use_hessian_based_weights: bool = True, regularization_factor: float = REG_DEFAULT, hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE, + use_hessian_sample_attention: bool = False, gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False, ) -> GradientPTQConfig: """ @@ -74,6 +76,7 @@ def get_pytorch_gptq_config(n_epochs: int, use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss. regularization_factor (float): A floating point number that defines the regularization factor. hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ. + use_hessian_sample_attention (bool): whether to use Sample-Layer Attention score for weighted loss. gradual_activation_quantization (bool, GradualActivationQuantizationConfig): If False, GradualActivationQuantization is disabled. If True, GradualActivationQuantization is enabled with the default settings. @@ -105,6 +108,23 @@ def get_pytorch_gptq_config(n_epochs: int, bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM) + if use_hessian_sample_attention: + if not use_hessian_based_weights: + raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.') + hessian_weights_config = GPTQHessianScoresConfig( + hessians_num_samples=None, + norm_scores=False, + log_norm=False, + scale_log_norm=False, + hessian_batch_size=hessian_batch_size, + per_sample=True, + estimator_distribution=HessianEstimationDistribution.RADEMACHER + ) + loss = loss or sample_layer_attention_loss + else: + hessian_weights_config = GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size) + loss = loss or multiple_tensors_mse_loss + if isinstance(gradual_activation_quantization, bool): gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig): @@ -117,7 +137,7 @@ def get_pytorch_gptq_config(n_epochs: int, log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer, use_hessian_based_weights=use_hessian_based_weights, regularization_factor=regularization_factor, - hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size), + hessian_weights_config=hessian_weights_config, gradual_activation_quantization_config=gradual_quant_config) def pytorch_gradient_post_training_quantization(model: Module, diff --git a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py index 082aaac26..552fff89c 100644 --- a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py @@ -56,8 +56,10 @@ def forward(self, inp): class GPTQBaseTest(BasePytorchFeatureNetworkTest): def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationMethod.SYMMETRIC, rounding_type=RoundingType.STE, per_channel=True, - hessian_weights=True, log_norm_weights=True, scaled_log_norm=False, params_learning=True, - num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES, gradual_activation_quantization=False): + hessian_weights=True, norm_scores=True, log_norm_weights=True, scaled_log_norm=False, params_learning=True, + num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES, gradual_activation_quantization=False, + hessian_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, sample_layer_attention=False, + loss=multiple_tensors_mse_loss, hessian_batch_size=1): super().__init__(unit_test, input_shape=(3, 16, 16), num_calibration_iter=num_calibration_iter) self.seed = 0 self.rounding_type = rounding_type @@ -65,12 +67,17 @@ def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationM self.weights_quant_method = weights_quant_method self.per_channel = per_channel self.hessian_weights = hessian_weights + self.norm_scores = norm_scores self.log_norm_weights = log_norm_weights self.scaled_log_norm = scaled_log_norm self.override_params = {QUANT_PARAM_LEARNING_STR: params_learning} if \ rounding_type == RoundingType.SoftQuantizer else {MAX_LSB_STR: DefaultDict(default_value=1)} \ if rounding_type == RoundingType.STE else None self.gradual_activation_quantization = gradual_activation_quantization + self.hessian_num_samples = hessian_num_samples + self.sample_layer_attention = sample_layer_attention + self.loss = loss + self.hessian_batch_size = hessian_batch_size def get_quantization_config(self): return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING, @@ -88,6 +95,21 @@ def get_tpc(self): def gptq_compare(self, ptq_model, gptq_model, input_x=None): pass + def get_representative_data_gen_experimental_fixed_images(self): + # data generator that generates same images in each epoch (in different order) + dataset = [] + for _ in range(self.num_calibration_iter): + dataset.append(self.generate_inputs()) + dataset = [np.concatenate(d) for d in zip(*dataset)] + batch_size = int(np.ceil(dataset[0].shape[0] / self.num_calibration_iter)) + + def gen(): + indices = np.random.permutation(range(dataset[0].shape[0])) + shuffled_dataset = [d[indices] for d in dataset] + for i in range(self.num_calibration_iter): + yield [d[batch_size*i: batch_size*(i+1)] for d in shuffled_dataset] + return gen + def run_test(self): # Create model self.float_model = self.create_networks() @@ -95,8 +117,9 @@ def run_test(self): # Run MCT with PTQ np.random.seed(self.seed) + data_generator = self.get_representative_data_gen_experimental_fixed_images() ptq_model, _ = mct.ptq.pytorch_post_training_quantization(self.float_model, - self.representative_data_gen_experimental, + data_generator, core_config=self.get_core_config(), target_platform_capabilities=self.get_tpc()) @@ -104,7 +127,7 @@ def run_test(self): np.random.seed(self.seed) gptq_model, quantization_info = mct.gptq.pytorch_gradient_post_training_quantization( self.float_model, - self.representative_data_gen_experimental, + data_generator, core_config=self.get_core_config(), target_platform_capabilities=self.get_tpc(), gptq_config=self.get_gptq_config()) @@ -123,11 +146,17 @@ def get_gptq_config(self): gradual_act_cfg = GradualActivationQuantizationConfig() if self.gradual_activation_quantization else None return GradientPTQConfig(5, optimizer=torch.optim.Adam([torch.Tensor([])], lr=1e-4), optimizer_rest=torch.optim.Adam([torch.Tensor([])], lr=1e-4), - loss=multiple_tensors_mse_loss, train_bias=True, rounding_type=self.rounding_type, + loss=self.loss, train_bias=True, rounding_type=self.rounding_type, use_hessian_based_weights=self.hessian_weights, optimizer_bias=torch.optim.Adam([torch.Tensor([])], lr=0.4), hessian_weights_config=GPTQHessianScoresConfig(log_norm=self.log_norm_weights, - scale_log_norm=self.scaled_log_norm), + scale_log_norm=self.scaled_log_norm, + norm_scores=self.norm_scores, + per_sample=self.sample_layer_attention, + hessians_num_samples=self.hessian_num_samples, + hessian_batch_size=self.hessian_batch_size), + + gptq_quantizer_params_override=self.override_params, gradual_activation_quantization_config=gradual_act_cfg) diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 2210fbc21..bd0c2d975 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -24,6 +24,7 @@ from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter from model_compression_toolkit.gptq.common.gptq_config import RoundingType +from model_compression_toolkit.gptq.pytorch.gptq_loss import sample_layer_attention_loss from model_compression_toolkit.target_platform_capabilities import constants as C from model_compression_toolkit.trainable_infrastructure import TrainingMethod from tests.pytorch_tests.model_tests.feature_models.add_net_test import AddNetTest @@ -654,6 +655,14 @@ def test_gptq_with_gradual_activation(self): GPTQLearnRateZeroTest(self, rounding_type=RoundingType.SoftQuantizer, gradual_activation_quantization=True).run_test() + def test_gptq_with_sample_layer_attention(self): + kwargs = dict(sample_layer_attention=True, loss=sample_layer_attention_loss, + hessian_weights=True, rounding_type=RoundingType.SoftQuantizer, + hessian_num_samples=None, norm_scores=False, log_norm_weights=False, scaled_log_norm=False) + GPTQAccuracyTest(self, **kwargs).run_test() + GPTQAccuracyTest(self, hessian_batch_size=16, **kwargs).run_test() + GPTQAccuracyTest(self, hessian_batch_size=5, gradual_activation_quantization=True, **kwargs).run_test() + def test_qat(self): """ This test checks the QAT feature. From 19cad97b20ffd13ecff978dc958fe751550fd303 Mon Sep 17 00:00:00 2001 From: irenab Date: Sun, 29 Sep 2024 15:17:33 +0300 Subject: [PATCH 2/5] fix + calc batch hessian on the fly --- .../common/hessian/hessian_info_service.py | 43 +++++++++-------- ...ation_hessian_scores_calculator_pytorch.py | 48 +------------------ .../gptq/common/gptq_training.py | 33 +++++++++---- .../gptq/pytorch/gptq_loss.py | 6 ++- .../gptq/pytorch/gptq_training.py | 21 ++++---- 5 files changed, 66 insertions(+), 85 deletions(-) diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index b1d683f58..117d33738 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -231,15 +231,16 @@ def compute(self, return next_iter_remain_samples if next_iter_remain_samples is not None and len(next_iter_remain_samples) > 0 \ and len(next_iter_remain_samples[0]) > 0 else None - def _compute_trackable_per_sample_hessian(self, - hessian_scores_request: HessianScoresRequest, - representative_dataset_gen) -> Dict[str, Dict['BaseNode', np.ndarray]]: + def compute_trackable_per_sample_hessian(self, + hessian_scores_request: HessianScoresRequest, + inputs_batch: List[np.ndarray]) -> Dict[str, Dict['BaseNode', np.ndarray]]: """ - Compute hessian score per image hash. + Compute hessian score per image hash. We compute the score directly for images rather than via data generator, + as data generator might yield different images each time, depending on how it was defined, Args: hessian_scores_request: hessian scores request - representative_dataset_gen: representative dataset generator + inputs_batch: a list containing a batch of inputs. Returns: A dict of Hessian scores per image hash per layer {image hash: {layer: score}} @@ -249,22 +250,22 @@ def _compute_trackable_per_sample_hessian(self, hessian_score_by_image_hash = {} - for inputs_batch in representative_dataset_gen: - if not isinstance(inputs_batch, list): - raise TypeError('Expected representative data generator to yield a list of inputs') - if len(inputs_batch) > 1: - raise NotImplementedError('Per sample hessian computation is not supported for networks with multiple inputs') - # Get the framework-specific calculator Hessian-approximation scores - fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph, - input_images=inputs_batch, - hessian_scores_request=hessian_scores_request, - num_iterations_for_approximation=self.num_iterations_for_approximation) - hessian_scores = fw_hessian_calculator.compute() - for b in range(inputs_batch[0].shape[0]): - img_hash = self.calc_image_hash(inputs_batch[0][b]) - hessian_score_by_image_hash[img_hash] = { - node: score for node, score in zip(hessian_scores_request.target_nodes, hessian_scores) - } + if not isinstance(inputs_batch, list): + raise TypeError('Expected a list of inputs') + if len(inputs_batch) > 1: + raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') + + # Get the framework-specific calculator Hessian-approximation scores + fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph, + input_images=inputs_batch, + hessian_scores_request=hessian_scores_request, + num_iterations_for_approximation=self.num_iterations_for_approximation) + hessian_scores = fw_hessian_calculator.compute() + for b in range(inputs_batch[0].shape[0]): + img_hash = self.calc_image_hash(inputs_batch[0][b]) + hessian_score_by_image_hash[img_hash] = { + node: score[b] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores) + } return hessian_score_by_image_hash diff --git a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py index 94a866150..52b00e899 100644 --- a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +++ b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py @@ -98,7 +98,7 @@ def _generate_random_vector(self, shape, distribution: HessianEstimationDistribu return torch.randn(shape, device=device) if distribution == HessianEstimationDistribution.RADEMACHER: - return torch.where(torch.randint(0, 2, shape), 1, -1).to(device) + return torch.where(torch.randint(0, 2, shape, device=device).to(torch.bool), 1, -1).to(device) raise ValueError(f'Unknown distribution {distribution}') @@ -177,8 +177,7 @@ def _compute_per_channel(self, output, target_activation_tensors): # Computing the hessian-approximation scores by getting the gradient of (output * v) hess_v = autograd.grad(outputs=f_v, inputs=ipt_tensor, - retain_graph=True, - allow_unused=True)[0] + retain_graph=True)[0] hessian_approx_scores = hess_v ** 2 rank = len(hess_v.shape) @@ -189,46 +188,3 @@ def _compute_per_channel(self, output, target_activation_tensors): ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1) return ipts_hessian_approx_scores - - -class SampleLayerHessianScoresCalculatorPytorch(ActivationHessianScoresCalculatorPytorch): - - def compute(self) -> List[np.ndarray]: - """ - Compute the scores that are based on the approximation of the Hessian w.r.t the requested target nodes' activations. - - Returns: - List[np.ndarray]: List of approximated Hessian scores of shape (n_samples X n_channels) - for the requested nodes. - """ - output, target_activation_tensors = self.forward_pass() - device = output.device - - # Score is initialize to a scalar. Upon first update it will get the correct shape. - ipts_hessian_approx_scores = [torch.tensor(0., requires_grad=True, device=device) - for _ in range(len(target_activation_tensors))] - - for j in tqdm(range(self.num_iterations_for_approximation), - "Hessian random iterations"): # Approximation iterations - v = torch.randint_like(output, high=2, device=device) - v[v == 0] = -1 - out_v = torch.sum(v * output) - for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor - # Computing the hessian-approximation scores by getting the gradient of (output * v) - hess_v = autograd.grad(outputs=out_v, - inputs=ipt_tensor, - retain_graph=True, - allow_unused=True)[0] - - hessian_approx_scores = hess_v**2 - - rank = len(hess_v.shape) - if rank > 2: - hessian_approx_scores = torch.mean(hess_v, dim=range(2, rank)) - - # Update node Hessian approximation mean over random iterations - ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1) - - # Convert results to list of numpy arrays - hessian_results = [torch_tensor_to_numpy(h) for h in ipts_hessian_approx_scores] - return hessian_results diff --git a/model_compression_toolkit/gptq/common/gptq_training.py b/model_compression_toolkit/gptq/common/gptq_training.py index 920e4d70a..9205531dd 100644 --- a/model_compression_toolkit/gptq/common/gptq_training.py +++ b/model_compression_toolkit/gptq/common/gptq_training.py @@ -178,7 +178,7 @@ def compute_hessian_based_weights(self) -> np.ndarray: # If log normalization is not enabled, return the mean of the approximations across images return np.mean(hessian_approx_score_by_image, axis=0) - def _compute_sample_layer_attention_scores(self) -> Dict[str, Dict[BaseNode, np.ndarray]]: + def _compute_sample_layer_attention_scores(self, inputs_batch) -> Dict[str, Dict[BaseNode, np.ndarray]]: """ Compute sample layer attention scores per image hash per layer. @@ -186,10 +186,21 @@ def _compute_sample_layer_attention_scores(self) -> Dict[str, Dict[BaseNode, np. A dictionary {img_hash: {layer: score}} where score is the """ - hessian_score_per_image_per_layer = self._fetch_hessian_approximations(HessianScoresGranularity.PER_OUTPUT_CHANNEL) + request = self._build_hessian_request(HessianScoresGranularity.PER_OUTPUT_CHANNEL) + hessian_batch_size = self.gptq_config.hessian_weights_config.hessian_batch_size + + hessian_score_per_image_per_layer = {} + # TODO Is it really needed if we compute on the fly per batch? Also if hessian batch is larger its ignored. + # If hessian batch is smaller than inputs batch, split it to hessian batches. + for i in range(0, inputs_batch[0].shape[0], hessian_batch_size): + inputs = [t[i: i+hessian_batch_size] for t in inputs_batch] + hessian_score_per_image_per_layer.update( + self.hessian_service.compute_trackable_per_sample_hessian(request, inputs) + ) + # hessian_score_per_image_per_layer = self._fetch_hessian_approximations(HessianScoresGranularity.PER_OUTPUT_CHANNEL) for layers_score in hessian_score_per_image_per_layer.values(): for k, t in layers_score.items(): - layers_score[k] = t.max(axis=1) + layers_score[k] = t.max(axis=0) # layer score is (channels,) return hessian_score_per_image_per_layer def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) -> Dict[BaseNode, List[List[float]]]: @@ -199,12 +210,8 @@ def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) - Returns: Mapping of target nodes to their hessian approximations. """ - hessian_scores_request = HessianScoresRequest( - mode=HessianMode.ACTIVATION, - granularity=granularity, - target_nodes=self.compare_points, - distribution=self.gptq_config.hessian_weights_config.estimator_distribution - ) + hessian_scores_request = self._build_hessian_request(granularity) + node_approximations = self.hessian_service.fetch_hessian( hessian_scores_request=hessian_scores_request, required_size=self.gptq_config.hessian_weights_config.hessians_num_samples, @@ -213,6 +220,14 @@ def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) - ) return node_approximations + def _build_hessian_request(self, granularity): + return HessianScoresRequest( + mode=HessianMode.ACTIVATION, + granularity=granularity, + target_nodes=self.compare_points, + distribution=self.gptq_config.hessian_weights_config.estimator_distribution + ) + def _process_hessian_approximations(self, approximations: Dict[BaseNode, List[List[float]]]) -> List: """ Processes the fetched hessian approximations by image. diff --git a/model_compression_toolkit/gptq/pytorch/gptq_loss.py b/model_compression_toolkit/gptq/pytorch/gptq_loss.py index 749b079e0..8d0d25494 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_loss.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_loss.py @@ -88,10 +88,14 @@ def sample_layer_attention_loss(y_list: List[torch.Tensor], layers_mean_w = [] for i, (y, x, w) in enumerate(zip(y_list, x_list, loss_weights)): - norm = (y - x).pow(2).sum(1) + # norm = (y - x).pow(2).sum(1) + norm = (y - x).pow(2).mean(1) if len(norm.shape) > 1: norm = norm.flatten(1).mean(1) loss += torch.mean(w * norm) layers_mean_w.append(w.mean()) + + # loss = loss / len(x_list) loss = loss / torch.stack(layers_mean_w).max() return loss + diff --git a/model_compression_toolkit/gptq/pytorch/gptq_training.py b/model_compression_toolkit/gptq/pytorch/gptq_training.py index b78b19679..1d50cf62c 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_training.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_training.py @@ -111,8 +111,10 @@ def _get_total_grad_steps(): if hessian_cfg.per_sample: assert (hessian_cfg.norm_scores is False and hessian_cfg.log_norm is False and hessian_cfg.scale_log_norm is False), hessian_cfg - # self.hessian_score_per_image_per_layer = self._fetch_hessian_approximations() - self.hessian_score_per_image_per_layer = self._compute_sample_layer_attention_scores() + # TODO if a representative dataset is fixed (same images in each epoch) we can precalculate. + # However if images differ between epochs, we have to calculate their hessians each time and pre-calculation + # will be a waste. Currently it is calculated on-demand during the training loop. + self.hessian_score_per_image_per_layer = {} else: self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights()) @@ -271,11 +273,10 @@ def micro_training_loop(self, for _ in epochs_pbar: with tqdm(data_function(), position=1, leave=False) as data_pbar: for data in data_pbar: + weights = to_torch_tensor(self._get_samples_weights_for_loss(data)) input_data = [d * self.input_scale for d in data] input_tensor = to_torch_tensor(input_data) y_float = self.float_model(input_tensor) # running float model - # weights are either (layers,) or (batch X layers X channels) - weights = to_torch_tensor(self._get_samples_weights_for_loss(input_tensor)) loss_value, grads = self.compute_gradients(y_float, input_tensor, weights) # Run one step of gradient descent by updating the value of the variables to minimize the loss. for (optimizer, _) in self.optimizer_with_param: @@ -290,21 +291,25 @@ def micro_training_loop(self, # TODO move to common after ctor refactor def _get_samples_weights_for_loss(self, input_tensors: List[torch.Tensor]): - if self.hessian_score_per_image_per_layer is None: - assert self.weights_for_average_loss is not None + if self.weights_for_average_loss is not None: + assert self.hessian_score_per_image_per_layer is None return self.weights_for_average_loss + # assert self.hessian_score_per_image_per_layer if len(input_tensors) > 1: raise NotImplementedError('Sample-Layer attention is not currently supported for networks with multiple inputs') scores = [] - batch = input_tensors[0].detach().cpu().numpy() + batch = input_tensors[0] img_hashes = [self.hessian_service.calc_image_hash(img) for img in batch] for img_hash in img_hashes: + if img_hash not in self.hessian_score_per_image_per_layer: + score_per_image_layer_per = self._compute_sample_layer_attention_scores(input_tensors) + self.hessian_score_per_image_per_layer.update(score_per_image_layer_per) img_scores_per_layer: Dict[BaseNode, np.ndarray] = self.hessian_score_per_image_per_layer[img_hash] img_scores = np.stack(list(img_scores_per_layer.values()), axis=0) scores.append(img_scores) - return np.stack(scores, axis=0) + return np.stack(scores, axis=1) # layers X images def update_graph(self) -> Graph: """ From 4b84a901a23d953db82ad5ffe350f70d80453a06 Mon Sep 17 00:00:00 2001 From: irenab Date: Tue, 1 Oct 2024 16:22:15 +0300 Subject: [PATCH 3/5] small fixes --- .../common/hessian/hessian_info_service.py | 31 +---------- ...ation_hessian_scores_calculator_pytorch.py | 28 ++++++---- .../gptq/common/gptq_training.py | 21 ++++---- .../gptq/pytorch/gptq_loss.py | 6 +-- .../gptq/pytorch/gptq_training.py | 52 ++++++++++++------- .../quantizer/regularization_factory.py | 2 +- .../soft_rounding/soft_quantizer_reg.py | 28 +++++----- .../model_tests/test_feature_models_runner.py | 10 ++-- 8 files changed, 88 insertions(+), 90 deletions(-) diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index 117d33738..dd6106fac 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -279,8 +279,7 @@ def calc_image_hash(image): def fetch_hessian(self, hessian_scores_request: HessianScoresRequest, required_size: int, - batch_size: int = 1, - per_sample_hash: bool = False) -> List[List[np.ndarray]]: + batch_size: int = 1) -> List[List[np.ndarray]]: """ Fetches the computed approximations of the Hessian-based scores for the given request and required size. @@ -289,7 +288,6 @@ def fetch_hessian(self, hessian_scores_request: Configuration for which to fetch the approximation. required_size: Number of approximations required. batch_size: The Hessian computation batch size. - per_sample_hash: Whether to compute hessian per sample hash. Returns: List[List[np.ndarray]]: For each target node, returns a list of computed approximations. @@ -297,6 +295,7 @@ def fetch_hessian(self, The inner list length dependent on the granularity (1 for per-tensor, OC for per-output-channel when the requested node has OC output-channels, etc.) """ + if len(hessian_scores_request.target_nodes) == 0: return [] @@ -312,32 +311,6 @@ def fetch_hessian(self, for node in hessian_scores_request.target_nodes ] - if per_sample_hash: - if required_size is not None: - raise ValueError('required_size cannot be specified with per_sample_hash') - - def gen_single(orig_gen): - # convert original generator into generator that yields sample by sample - for batch in orig_gen: - for i in range(batch[0].shape[0]): - yield [inp[i] for inp in batch] - - def gen_new_batch(): - # convert sample by sample generator into the required batch - samples = [] - for sample in gen_single(self.representative_dataset_gen()): - samples.append(sample) - if len(samples) == batch_size: - yield [np.stack(d, axis=0) for d in zip(*samples)] - samples = [] - if samples: - yield [np.stack(d, axis=0) for d in zip(*samples)] - - return self._compute_trackable_per_sample_hessian(hessian_scores_request, gen_new_batch()) - - if required_size is None: - raise ValueError('required_size must be specified if per_sample_hash is False') - # Ensure the saved info has the required number of approximations self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size) diff --git a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py index 52b00e899..c79af62ac 100644 --- a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +++ b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py @@ -93,12 +93,25 @@ def forward_pass(self): output = self.concat_tensors(output_tensors) return output, target_activation_tensors - def _generate_random_vector(self, shape, distribution: HessianEstimationDistribution, device): + def _generate_random_vectors_batch(self, shape, distribution: HessianEstimationDistribution, device) -> torch.Tensor: + """ + Generate a batch of random vectors for Hutchinson estimation + + Args: + shape: target shape + distribution: distribution to sample from + device: target device + + Returns: + Random tensor + """ if distribution == HessianEstimationDistribution.GAUSSIAN: return torch.randn(shape, device=device) if distribution == HessianEstimationDistribution.RADEMACHER: - return torch.where(torch.randint(0, 2, shape, device=device).to(torch.bool), 1, -1).to(device) + v = torch.randint(high=2, size=shape, device=device) + v[v == 0] = -1 + return v raise ValueError(f'Unknown distribution {distribution}') @@ -116,7 +129,7 @@ def compute(self) -> List[np.ndarray]: elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL: hessian_scores = self._compute_per_channel(output, target_activation_tensors) else: - raise NotImplementedError(f'{HessianScoresGranularity.PER_ELEMENT} is not supported') + raise NotImplementedError(f'{self.hessian_request.granularity} is not supported') # Convert results to list of numpy arrays hessian_results = [torch_tensor_to_numpy(h) for h in hessian_scores] @@ -129,7 +142,7 @@ def _compute_per_tensor(self, output, target_activation_tensors): prev_mean_results = None for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations # Getting a random vector with normal distribution - v = self._generate_random_vector(output.shape, self.hessian_request.distribution, output.device) + v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device) f_v = torch.sum(v * output) for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor # Computing the hessian-approximation scores by getting the gradient of (output * v) @@ -170,19 +183,16 @@ def _compute_per_channel(self, output, target_activation_tensors): for _ in range(len(target_activation_tensors))] for j in tqdm(range(self.num_iterations_for_approximation), "Hessian random iterations"): # Approximation iterations - # Getting a random vector with normal distribution - v = self._generate_random_vector(output.shape, self.hessian_request.distribution, output.device) + v = self._generate_random_vectors_batch(output.shape, self.hessian_request.distribution, output.device) f_v = torch.sum(v * output) for i, ipt_tensor in enumerate(target_activation_tensors): # Per Interest point activation tensor - # Computing the hessian-approximation scores by getting the gradient of (output * v) hess_v = autograd.grad(outputs=f_v, inputs=ipt_tensor, retain_graph=True)[0] - hessian_approx_scores = hess_v ** 2 rank = len(hess_v.shape) if rank > 2: - hessian_approx_scores = torch.mean(hess_v, dim=tuple(range(2, rank))) + hessian_approx_scores = torch.mean(hessian_approx_scores, dim=tuple(range(2, rank))) # Update node Hessian approximation mean over random iterations ipts_hessian_approx_scores[i] = (j * ipts_hessian_approx_scores[i] + hessian_approx_scores) / (j + 1) diff --git a/model_compression_toolkit/gptq/common/gptq_training.py b/model_compression_toolkit/gptq/common/gptq_training.py index 9205531dd..29e0ada10 100644 --- a/model_compression_toolkit/gptq/common/gptq_training.py +++ b/model_compression_toolkit/gptq/common/gptq_training.py @@ -144,8 +144,7 @@ def compute_hessian_based_weights(self) -> np.ndarray: return np.asarray([1 / num_nodes for _ in range(num_nodes)]) # Fetch hessian approximations for each target node - # TODO this smells like a bug. In hessian calculation target nodes are topo sorted and results are returned - # in the same order. Maybe topo sort doesn't do anything and it works? + # TODO this smells like a potential bug. In hessian calculation target nodes are topo sorted and results are returned # TODO also target nodes are replaced for reuse. Does this work correctly? approximations = self._fetch_hessian_approximations(HessianScoresGranularity.PER_TENSOR) compare_point_to_hessian_approx_scores = {node: score for node, score in zip(self.compare_points, approximations)} @@ -182,25 +181,26 @@ def _compute_sample_layer_attention_scores(self, inputs_batch) -> Dict[str, Dict """ Compute sample layer attention scores per image hash per layer. + Args: + inputs_batch: a list containing a batch of inputs. + Returns: - A dictionary {img_hash: {layer: score}} where score is the + A dictionary with a structure {img_hash: {layer: score}}. """ request = self._build_hessian_request(HessianScoresGranularity.PER_OUTPUT_CHANNEL) hessian_batch_size = self.gptq_config.hessian_weights_config.hessian_batch_size hessian_score_per_image_per_layer = {} - # TODO Is it really needed if we compute on the fly per batch? Also if hessian batch is larger its ignored. - # If hessian batch is smaller than inputs batch, split it to hessian batches. + # If hessian batch is smaller than inputs batch, split it to hessian batches. If hessian batch is larger, + # it's currently ignored (TODO) for i in range(0, inputs_batch[0].shape[0], hessian_batch_size): inputs = [t[i: i+hessian_batch_size] for t in inputs_batch] hessian_score_per_image_per_layer.update( self.hessian_service.compute_trackable_per_sample_hessian(request, inputs) ) - # hessian_score_per_image_per_layer = self._fetch_hessian_approximations(HessianScoresGranularity.PER_OUTPUT_CHANNEL) - for layers_score in hessian_score_per_image_per_layer.values(): - for k, t in layers_score.items(): - layers_score[k] = t.max(axis=0) # layer score is (channels,) + for img_hash, v in hessian_score_per_image_per_layer.items(): + hessian_score_per_image_per_layer[img_hash] = {k: t.max(axis=0) for k, t in v.items()} return hessian_score_per_image_per_layer def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) -> Dict[BaseNode, List[List[float]]]: @@ -215,8 +215,7 @@ def _fetch_hessian_approximations(self, granularity: HessianScoresGranularity) - node_approximations = self.hessian_service.fetch_hessian( hessian_scores_request=hessian_scores_request, required_size=self.gptq_config.hessian_weights_config.hessians_num_samples, - batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size, - per_sample_hash=self.gptq_config.hessian_weights_config.per_sample + batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size ) return node_approximations diff --git a/model_compression_toolkit/gptq/pytorch/gptq_loss.py b/model_compression_toolkit/gptq/pytorch/gptq_loss.py index 8d0d25494..9e97cb3cf 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_loss.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_loss.py @@ -79,7 +79,7 @@ def sample_layer_attention_loss(y_list: List[torch.Tensor], y_list: First list of tensors. x_list: Second list of tensors. fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface). - loss_weights: A list of weights for each layer. Each weight is a vector of shape (batch,) + loss_weights: layer-sample weights tensor of shape (layers, batch) Returns: Sample Layer Attention loss (a scalar). @@ -88,14 +88,12 @@ def sample_layer_attention_loss(y_list: List[torch.Tensor], layers_mean_w = [] for i, (y, x, w) in enumerate(zip(y_list, x_list, loss_weights)): - # norm = (y - x).pow(2).sum(1) - norm = (y - x).pow(2).mean(1) + norm = (y - x).pow(2).sum(1) if len(norm.shape) > 1: norm = norm.flatten(1).mean(1) loss += torch.mean(w * norm) layers_mean_w.append(w.mean()) - # loss = loss / len(x_list) loss = loss / torch.stack(layers_mean_w).max() return loss diff --git a/model_compression_toolkit/gptq/pytorch/gptq_training.py b/model_compression_toolkit/gptq/pytorch/gptq_training.py index 1d50cf62c..37e591674 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_training.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_training.py @@ -106,17 +106,16 @@ def _get_total_grad_steps(): trainable_bias, trainable_threshold) hessian_cfg = self.gptq_config.hessian_weights_config - self.weights_for_average_loss = None # for fixed layer weights + self.use_sample_layer_attention = hessian_cfg.per_sample + self.hessian_score_per_layer = None # for fixed layer weights self.hessian_score_per_image_per_layer = None # for sample-layer attention - if hessian_cfg.per_sample: + if self.use_sample_layer_attention: assert (hessian_cfg.norm_scores is False and hessian_cfg.log_norm is False and hessian_cfg.scale_log_norm is False), hessian_cfg - # TODO if a representative dataset is fixed (same images in each epoch) we can precalculate. - # However if images differ between epochs, we have to calculate their hessians each time and pre-calculation - # will be a waste. Currently it is calculated on-demand during the training loop. + # Per sample hessian scores are calculated on-demand during the training loop self.hessian_score_per_image_per_layer = {} else: - self.weights_for_average_loss = to_torch_tensor(self.compute_hessian_based_weights()) + self.hessian_score_per_layer = to_torch_tensor(self.compute_hessian_based_weights()) self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps) @@ -221,14 +220,16 @@ def train(self, representative_data_gen: Callable): def compute_gradients(self, y_float: List[torch.Tensor], input_tensors: List[torch.Tensor], - weights_for_average_loss) -> Tuple[torch.Tensor, List[np.ndarray]]: + distill_loss_weights: torch.Tensor, + round_reg_weights: torch.Tensor) -> Tuple[torch.Tensor, List[np.ndarray]]: """ Get outputs from both teacher and student networks. Compute the observed error, and use it to compute the gradients and applying them to the student weights. Args: y_float: A list of reference tensor from the floating point network. input_tensors: A list of Input tensors to pass through the networks. - weights_for_average_loss: Weights for loss. Either per layer, or per layer per sample. + distill_loss_weights: Weights for the distillation loss. + round_reg_weights: Weight for the rounding regularization loss. Returns: Loss and gradients. """ @@ -243,9 +244,8 @@ def compute_gradients(self, self.flp_weights_list, self.compare_points_mean, self.compare_points_std, - weights_for_average_loss) - - reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor) + distill_loss_weights) + reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor, round_reg_weights) loss_value += reg_value @@ -273,11 +273,11 @@ def micro_training_loop(self, for _ in epochs_pbar: with tqdm(data_function(), position=1, leave=False) as data_pbar: for data in data_pbar: - weights = to_torch_tensor(self._get_samples_weights_for_loss(data)) + distill_weights, reg_weights = to_torch_tensor(self._get_loss_weights(data)) input_data = [d * self.input_scale for d in data] input_tensor = to_torch_tensor(input_data) y_float = self.float_model(input_tensor) # running float model - loss_value, grads = self.compute_gradients(y_float, input_tensor, weights) + loss_value, grads = self.compute_gradients(y_float, input_tensor, distill_weights, reg_weights) # Run one step of gradient descent by updating the value of the variables to minimize the loss. for (optimizer, _) in self.optimizer_with_param: optimizer.step() @@ -289,13 +289,22 @@ def micro_training_loop(self, self.loss_list.append(loss_value.item()) Logger.debug(f'last loss value: {self.loss_list[-1]}') - # TODO move to common after ctor refactor - def _get_samples_weights_for_loss(self, input_tensors: List[torch.Tensor]): - if self.weights_for_average_loss is not None: - assert self.hessian_score_per_image_per_layer is None - return self.weights_for_average_loss + def _get_loss_weights(self, input_tensors: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fetches weights for distillation and round regularization parts of loss. + + Args: + input_tensors: list containing a batch of inputs. + + Returns: + A tuple of two tensors: + - weights for distillation loss + - weights for rounding regularization loss + + """ + if self.use_sample_layer_attention is False: + return self.hessian_score_per_layer, torch.ones_like(self.hessian_score_per_layer) - # assert self.hessian_score_per_image_per_layer if len(input_tensors) > 1: raise NotImplementedError('Sample-Layer attention is not currently supported for networks with multiple inputs') @@ -309,7 +318,10 @@ def _get_samples_weights_for_loss(self, input_tensors: List[torch.Tensor]): img_scores_per_layer: Dict[BaseNode, np.ndarray] = self.hessian_score_per_image_per_layer[img_hash] img_scores = np.stack(list(img_scores_per_layer.values()), axis=0) scores.append(img_scores) - return np.stack(scores, axis=1) # layers X images + + layer_sample_weights = np.stack(scores, axis=1) # layers X images + layer_weights = layer_sample_weights.mean(axis=1) + return layer_sample_weights, layer_weights def update_graph(self) -> Graph: """ diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py b/model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py index e4aef7932..e79c835c4 100644 --- a/model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +++ b/model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py @@ -41,4 +41,4 @@ def get_regularization(gptq_config: GradientPTQConfig, get_total_grad_steps_fn: scheduler = LinearAnnealingScheduler(t_start=t_start, t_end=total_gradient_steps, initial_val=20, target_val=2) return SoftQuantizerRegularization(scheduler) else: - return lambda m, e_reg: 0 + return lambda *args, **kwargs: 0 diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py b/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py index b08c54faf..c8117e002 100644 --- a/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +++ b/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py @@ -40,32 +40,36 @@ def __init__(self, beta_scheduler: Callable[[int], float]): self.count_iter = 0 - def __call__(self, model: nn.Module, entropy_reg: float): + def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Tensor = None): """ Returns the soft quantizer regularization value for SoftRounding. Args: model: A model to be quantized with SoftRounding. entropy_reg: Entropy value to scale the quantizer regularization. + layer_weights: a vector of layer weights. If None, each layers has a weight of 1. Returns: Regularization value. """ + layers = [m for m in model.modules() if isinstance(m, PytorchQuantizationWrapper)] - soft_reg_aux: List[torch.Tensor] = [] - b = self.beta_scheduler(self.count_iter) - for layer in model.modules(): - if isinstance(layer, PytorchQuantizationWrapper): - kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer), - fw_info=DEFAULT_PYTORCH_INFO) - - st = layer.weights_quantizers[kernel_attribute].get_soft_targets() - soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum()) + if layer_weights is None: + layer_weights = torch.ones((len(layers),)) + if len(layer_weights.shape) != 1 or layer_weights.shape[0] != len(layers): + raise ValueError(f'Expected weights to be a vector of length {len(layers)}, received {layer_weights.shape}.') + max_w = layer_weights.max() + b = self.beta_scheduler(self.count_iter) reg = 0 + for layer, w in zip(layers, layer_weights): + kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer), + fw_info=DEFAULT_PYTORCH_INFO) - for sq in soft_reg_aux: - reg += sq + st = layer.weights_quantizers[kernel_attribute].get_soft_targets() + soft_loss = (1 - torch.pow(torch.abs(st - .5) * 2, b)).sum() + reg += w * soft_loss + reg = reg / max_w self.count_iter += 1 return entropy_reg * reg diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index bd0c2d975..29e228d9b 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -657,11 +657,13 @@ def test_gptq_with_gradual_activation(self): def test_gptq_with_sample_layer_attention(self): kwargs = dict(sample_layer_attention=True, loss=sample_layer_attention_loss, - hessian_weights=True, rounding_type=RoundingType.SoftQuantizer, - hessian_num_samples=None, norm_scores=False, log_norm_weights=False, scaled_log_norm=False) + hessian_weights=True, hessian_num_samples=None, + norm_scores=False, log_norm_weights=False, scaled_log_norm=False) GPTQAccuracyTest(self, **kwargs).run_test() - GPTQAccuracyTest(self, hessian_batch_size=16, **kwargs).run_test() - GPTQAccuracyTest(self, hessian_batch_size=5, gradual_activation_quantization=True, **kwargs).run_test() + GPTQAccuracyTest(self, hessian_batch_size=16, rounding_type=RoundingType.SoftQuantizer, **kwargs).run_test() + GPTQAccuracyTest(self, hessian_batch_size=5, rounding_type=RoundingType.SoftQuantizer, + gradual_activation_quantization=True, **kwargs).run_test() + GPTQAccuracyTest(self, rounding_type=RoundingType.STE, **kwargs) def test_qat(self): """ From ba6433c9d8f823d51e372b0cb7726ea135951d72 Mon Sep 17 00:00:00 2001 From: irenab Date: Sun, 6 Oct 2024 11:31:42 +0300 Subject: [PATCH 4/5] improve coverage --- .../core/common/hessian/hessian_info_service.py | 10 +++++----- .../activation_hessian_scores_calculator_pytorch.py | 4 ++-- .../gptq/pytorch/quantization_facade.py | 11 ++++++----- .../quantizer/soft_rounding/soft_quantizer_reg.py | 2 +- .../model_tests/feature_models/gptq_test.py | 7 +++++-- .../model_tests/test_feature_models_runner.py | 2 ++ 6 files changed, 21 insertions(+), 15 deletions(-) diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index dd6106fac..7905aac7b 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.hessian.hessian_scores_request import HessianScoresRequest, \ HessianScoresGranularity, HessianMode from model_compression_toolkit.logger import Logger -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from model_compression_toolkit.core.common import BaseNode @@ -251,9 +251,9 @@ def compute_trackable_per_sample_hessian(self, hessian_score_by_image_hash = {} if not isinstance(inputs_batch, list): - raise TypeError('Expected a list of inputs') + raise TypeError('Expected a list of inputs') # pragma: no cover if len(inputs_batch) > 1: - raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') + raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover # Get the framework-specific calculator Hessian-approximation scores fw_hessian_calculator = self.fw_impl.get_hessian_scores_calculator(graph=self.graph, @@ -271,7 +271,7 @@ def compute_trackable_per_sample_hessian(self, @staticmethod def calc_image_hash(image): - if len(image.shape) != 3: + if len(image.shape) != 3: # pragma: no cover raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}') image_bytes = image.astype(np.float32).tobytes() return hashlib.md5(image_bytes).hexdigest() @@ -296,7 +296,7 @@ def fetch_hessian(self, OC for per-output-channel when the requested node has OC output-channels, etc.) """ - if len(hessian_scores_request.target_nodes) == 0: + if len(hessian_scores_request.target_nodes) == 0: # pragma: no cover return [] if required_size == 0: diff --git a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py index c79af62ac..f55be9e14 100644 --- a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +++ b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py @@ -113,7 +113,7 @@ def _generate_random_vectors_batch(self, shape, distribution: HessianEstimationD v[v == 0] = -1 return v - raise ValueError(f'Unknown distribution {distribution}') + raise ValueError(f'Unknown distribution {distribution}') # pragma: no cover def compute(self) -> List[np.ndarray]: """ @@ -129,7 +129,7 @@ def compute(self) -> List[np.ndarray]: elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL: hessian_scores = self._compute_per_channel(output, target_activation_tensors) else: - raise NotImplementedError(f'{self.hessian_request.granularity} is not supported') + raise NotImplementedError(f'{self.hessian_request.granularity} is not supported') # pragma: no cover # Convert results to list of numpy arrays hessian_results = [torch_tensor_to_numpy(h) for h in hessian_scores] diff --git a/model_compression_toolkit/gptq/pytorch/quantization_facade.py b/model_compression_toolkit/gptq/pytorch/quantization_facade.py index 8022a5552..90ef9ba67 100644 --- a/model_compression_toolkit/gptq/pytorch/quantization_facade.py +++ b/model_compression_toolkit/gptq/pytorch/quantization_facade.py @@ -109,8 +109,9 @@ def get_pytorch_gptq_config(n_epochs: int, bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM) if use_hessian_sample_attention: - if not use_hessian_based_weights: + if not use_hessian_based_weights: # pragma: no cover raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.') + hessian_weights_config = GPTQHessianScoresConfig( hessians_num_samples=None, norm_scores=False, @@ -129,9 +130,9 @@ def get_pytorch_gptq_config(n_epochs: int, gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig): gradual_quant_config = gradual_activation_quantization - else: + else: # pragma: no cover raise TypeError(f'gradual_activation_quantization argument should be bool or ' - f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}') # pragma: no cover + f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}') return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss, log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer, @@ -205,11 +206,11 @@ def pytorch_gradient_post_training_quantization(model: Module, """ - if core_config.is_mixed_precision_enabled: + if core_config.is_mixed_precision_enabled: # pragma: no cover if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig): Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. " "Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' " - "or provide a valid mixed-precision configuration.") # pragma: no cover + "or provide a valid mixed-precision configuration.") tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO) diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py b/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py index c8117e002..48645f896 100644 --- a/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +++ b/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py @@ -56,7 +56,7 @@ def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Te if layer_weights is None: layer_weights = torch.ones((len(layers),)) if len(layer_weights.shape) != 1 or layer_weights.shape[0] != len(layers): - raise ValueError(f'Expected weights to be a vector of length {len(layers)}, received {layer_weights.shape}.') + raise ValueError(f'Expected weights to be a vector of length {len(layers)}, received {layer_weights.shape}.') # pragma: no cover max_w = layer_weights.max() b = self.beta_scheduler(self.count_iter) diff --git a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py index 552fff89c..b1ec5e2d0 100644 --- a/tests/pytorch_tests/model_tests/feature_models/gptq_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/gptq_test.py @@ -21,6 +21,7 @@ import mct_quantizers from model_compression_toolkit import DefaultDict from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES +from model_compression_toolkit.core.common.hessian import HessianEstimationDistribution from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR, MAX_LSB_STR from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest @@ -59,7 +60,7 @@ def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationM hessian_weights=True, norm_scores=True, log_norm_weights=True, scaled_log_norm=False, params_learning=True, num_calibration_iter=GPTQ_HESSIAN_NUM_SAMPLES, gradual_activation_quantization=False, hessian_num_samples=GPTQ_HESSIAN_NUM_SAMPLES, sample_layer_attention=False, - loss=multiple_tensors_mse_loss, hessian_batch_size=1): + loss=multiple_tensors_mse_loss, hessian_batch_size=1, estimator_distribution=HessianEstimationDistribution.GAUSSIAN): super().__init__(unit_test, input_shape=(3, 16, 16), num_calibration_iter=num_calibration_iter) self.seed = 0 self.rounding_type = rounding_type @@ -78,6 +79,7 @@ def __init__(self, unit_test, weights_bits=8, weights_quant_method=QuantizationM self.sample_layer_attention = sample_layer_attention self.loss = loss self.hessian_batch_size = hessian_batch_size + self.estimator_distribution = estimator_distribution def get_quantization_config(self): return mct.core.QuantizationConfig(mct.core.QuantizationErrorMethod.NOCLIPPING, @@ -154,7 +156,8 @@ def get_gptq_config(self): norm_scores=self.norm_scores, per_sample=self.sample_layer_attention, hessians_num_samples=self.hessian_num_samples, - hessian_batch_size=self.hessian_batch_size), + hessian_batch_size=self.hessian_batch_size, + estimator_distribution=self.estimator_distribution), gptq_quantizer_params_override=self.override_params, diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 29e228d9b..8eface3cd 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -21,6 +21,7 @@ import torch from torch import nn import model_compression_toolkit as mct +from model_compression_toolkit.core.common.hessian import HessianEstimationDistribution from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting from model_compression_toolkit.core.common.network_editors import NodeTypeFilter, NodeNameFilter from model_compression_toolkit.gptq.common.gptq_config import RoundingType @@ -658,6 +659,7 @@ def test_gptq_with_gradual_activation(self): def test_gptq_with_sample_layer_attention(self): kwargs = dict(sample_layer_attention=True, loss=sample_layer_attention_loss, hessian_weights=True, hessian_num_samples=None, + estimator_distribution=HessianEstimationDistribution.RADEMACHER, norm_scores=False, log_norm_weights=False, scaled_log_norm=False) GPTQAccuracyTest(self, **kwargs).run_test() GPTQAccuracyTest(self, hessian_batch_size=16, rounding_type=RoundingType.SoftQuantizer, **kwargs).run_test() From a329151e12767755a71060f5bca99ae159b5ad54 Mon Sep 17 00:00:00 2001 From: irenab Date: Sun, 6 Oct 2024 20:32:42 +0300 Subject: [PATCH 5/5] foxes per code review --- .../common/hessian/hessian_info_service.py | 22 ++++++++++++++----- ...ation_hessian_scores_calculator_pytorch.py | 3 ++- .../gptq/pytorch/gptq_training.py | 17 ++++++++------ .../soft_rounding/soft_quantizer_reg.py | 6 ++--- 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index 7905aac7b..ee0e067d7 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -250,8 +250,8 @@ def compute_trackable_per_sample_hessian(self, hessian_score_by_image_hash = {} - if not isinstance(inputs_batch, list): - raise TypeError('Expected a list of inputs') # pragma: no cover + if not inputs_batch or not isinstance(inputs_batch, list): + raise TypeError('Expected a non-empty list of inputs') # pragma: no cover if len(inputs_batch) > 1: raise NotImplementedError('Per-sample hessian computation is not supported for networks with multiple inputs') # pragma: no cover @@ -261,17 +261,27 @@ def compute_trackable_per_sample_hessian(self, hessian_scores_request=hessian_scores_request, num_iterations_for_approximation=self.num_iterations_for_approximation) hessian_scores = fw_hessian_calculator.compute() - for b in range(inputs_batch[0].shape[0]): - img_hash = self.calc_image_hash(inputs_batch[0][b]) + for i in range(inputs_batch[0].shape[0]): + img_hash = self.calc_image_hash(inputs_batch[0][i]) hessian_score_by_image_hash[img_hash] = { - node: score[b] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores) + node: score[i] for node, score in zip(hessian_scores_request.target_nodes, hessian_scores) } return hessian_score_by_image_hash @staticmethod def calc_image_hash(image): - if len(image.shape) != 3: # pragma: no cover + """ + Calculates hash for an input image. + + Args: + image: input 3d image (without batch). + + Returns: + Image hash. + + """ + if not len(image.shape) == 3: # pragma: no cover raise ValueError(f'Expected 3d image (without batch) for image hash calculation, got {len(image.shape)}') image_bytes = image.astype(np.float32).tobytes() return hashlib.md5(image_bytes).hexdigest() diff --git a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py index f55be9e14..c7d496650 100644 --- a/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +++ b/model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py @@ -93,7 +93,8 @@ def forward_pass(self): output = self.concat_tensors(output_tensors) return output, target_activation_tensors - def _generate_random_vectors_batch(self, shape, distribution: HessianEstimationDistribution, device) -> torch.Tensor: + def _generate_random_vectors_batch(self, shape: tuple, distribution: HessianEstimationDistribution, + device: torch.device) -> torch.Tensor: """ Generate a batch of random vectors for Hutchinson estimation diff --git a/model_compression_toolkit/gptq/pytorch/gptq_training.py b/model_compression_toolkit/gptq/pytorch/gptq_training.py index 37e591674..1e6f30c44 100644 --- a/model_compression_toolkit/gptq/pytorch/gptq_training.py +++ b/model_compression_toolkit/gptq/pytorch/gptq_training.py @@ -110,8 +110,9 @@ def _get_total_grad_steps(): self.hessian_score_per_layer = None # for fixed layer weights self.hessian_score_per_image_per_layer = None # for sample-layer attention if self.use_sample_layer_attention: - assert (hessian_cfg.norm_scores is False and hessian_cfg.log_norm is False and - hessian_cfg.scale_log_norm is False), hessian_cfg + # normalization is currently not supported, make sure the config reflects it. + if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm: + raise NotImplementedError() # Per sample hessian scores are calculated on-demand during the training loop self.hessian_score_per_image_per_layer = {} else: @@ -308,18 +309,20 @@ def _get_loss_weights(self, input_tensors: List[torch.Tensor]) -> Tuple[torch.Te if len(input_tensors) > 1: raise NotImplementedError('Sample-Layer attention is not currently supported for networks with multiple inputs') - scores = [] + image_scores = [] batch = input_tensors[0] img_hashes = [self.hessian_service.calc_image_hash(img) for img in batch] for img_hash in img_hashes: + # If sample-layer attention score for the image is not found, compute and store it for the whole batch. if img_hash not in self.hessian_score_per_image_per_layer: - score_per_image_layer_per = self._compute_sample_layer_attention_scores(input_tensors) - self.hessian_score_per_image_per_layer.update(score_per_image_layer_per) + score_per_image_per_layer = self._compute_sample_layer_attention_scores(input_tensors) + self.hessian_score_per_image_per_layer.update(score_per_image_per_layer) img_scores_per_layer: Dict[BaseNode, np.ndarray] = self.hessian_score_per_image_per_layer[img_hash] + # fetch image scores for all layers and combine them into a single tensor img_scores = np.stack(list(img_scores_per_layer.values()), axis=0) - scores.append(img_scores) + image_scores.append(img_scores) - layer_sample_weights = np.stack(scores, axis=1) # layers X images + layer_sample_weights = np.stack(image_scores, axis=1) # layers X images layer_weights = layer_sample_weights.mean(axis=1) return layer_sample_weights, layer_weights diff --git a/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py b/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py index 48645f896..776a8bbbb 100644 --- a/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +++ b/model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py @@ -40,21 +40,19 @@ def __init__(self, beta_scheduler: Callable[[int], float]): self.count_iter = 0 - def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Tensor = None): + def __call__(self, model: nn.Module, entropy_reg: float, layer_weights: torch.Tensor): """ Returns the soft quantizer regularization value for SoftRounding. Args: model: A model to be quantized with SoftRounding. entropy_reg: Entropy value to scale the quantizer regularization. - layer_weights: a vector of layer weights. If None, each layers has a weight of 1. + layer_weights: a vector of layer weights. Returns: Regularization value. """ layers = [m for m in model.modules() if isinstance(m, PytorchQuantizationWrapper)] - if layer_weights is None: - layer_weights = torch.ones((len(layers),)) if len(layer_weights.shape) != 1 or layer_weights.shape[0] != len(layers): raise ValueError(f'Expected weights to be a vector of length {len(layers)}, received {layer_weights.shape}.') # pragma: no cover max_w = layer_weights.max()