diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index f6e504e7b05..b9cf5d9e465 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -15,9 +15,10 @@ import logging from typing import Any, Dict, List, Optional, Union +from sparseml.core import Modifier from sparseml.core.factory import ModifierFactory +from sparseml.core.model.base import ModifiableModel from sparseml.core.state import State -from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier __all__ = ["SparseGPTModifier"] @@ -25,7 +26,7 @@ _LOGGER = logging.getLogger(__name__) -class SparseGPTModifier(WandaPruningModifier): +class SparseGPTModifier(Modifier): """ Modifier for applying the one-shot OBCQ algorithm to a model @@ -41,19 +42,35 @@ class SparseGPTModifier(WandaPruningModifier): - on_finalize - LayerCompressor.revert_layer_wrappers() + :param sparsity: Sparsity to compress model to + :param mask_structure: String to define the structure of the mask to apply. + Must be of the form N:M where N, M are integers that define a custom block + shape. Defaults to 0:0 which represents an unstructured mask. + :param sequential_update: Whether or not to update weights sequentially by layer, + True saves on GPU memory + :param targets: list of layer names to compress during OBCQ, or '__ALL__' + to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass :param quantize: Whether or not to quantize weights during SparseGPT. Set to True to quantize using an existing quantization modifier, or pass in the configuration for a quantization modifier if one does not already exist in the recipe - :param sparsity: Sparsity to compress model to :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm """ + sparsity: Union[float, List[float]] = 0.0 + sparsity_profile: Optional[str] = None + owl_m: Optional[int] = None + owl_lmbda: Optional[float] = None + mask_structure: str = "0:0" + sequential_update: Optional[bool] = False + targets: Union[str, List[str], None] = None + compressible_layers_: Optional[List] = None + prunen_: Optional[int] = None + prunem_: Optional[int] = None block_size: int = 128 quantize: Union[bool, Dict] = False - sparsity: Union[float, List[float]] = 0.0 dampening_frac: Optional[float] = 0.01 quantization_modifier_: Any = None @@ -112,6 +129,39 @@ def on_initialize_structure(self, state: State, **kwargs): if self.quantization_modifier_: self.quantization_modifier_.on_initialize_structure(state, **kwargs) + def compressible_layers(self) -> Dict: + """ + Retrieves the modules corresponding to a list of + compressible layer names + + :precondition: self.model is set and is a `ModifiableModel` + :precondition: The `ModifiableModel` implements a `get_layers` + method + :return: dictionary of modules to compress + """ + if not isinstance(self.model, ModifiableModel): + raise ValueError( + "`self.model` must be a ModifiableModel to use " + f"the {self.__class__.__qualname__} modifier but got " + f"{type(self.model)} instead" + ) + + return self.model.get_layers(self.targets) + + def _validate_layerwise_sparsity(self): + if isinstance(self.sparsity, float): + # single sparsity will be applied to all layers + return + + target_layers = list(self.compressible_layers_.keys()) + + if len(target_layers) != len(self.sparsity): + raise ValueError( + "Number of layer targets must match the number of " + f"sparsities. Got {len(target_layers)} layers and " + f"{len(self.sparsity)} sparsities" + ) + def _build_quant_modifier_from_dict(self, quant_config, framework): modifier_type = list(quant_config.keys())[0] modifier_args = quant_config[modifier_type] @@ -122,3 +172,14 @@ def _build_quant_modifier_from_dict(self, quant_config, framework): allow_experimental=True, **modifier_args, ) + + def on_finalize(self, state: State, **kwargs): + """ + Nothing to do on finalize, on this level. + Quantization Modifier if any will be finalized in the subclass + + :param state: session state storing input model and calibration data + :param kwargs: additional arguments + :return: True + """ + return True diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index de1eef74189..646292a1b20 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -13,13 +13,19 @@ # limitations under the License. import logging -from typing import List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import numpy as np +import torch +from tqdm import tqdm from sparseml.core.model import ModifiableModel from sparseml.core.state import State from sparseml.modifiers.obcq.base import SparseGPTModifier from sparseml.modifiers.obcq.utils.sgpt_wrapper import SparseGptWrapper -from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch +from sparseml.modifiers.utils.layer_compressor import LayerCompressor +from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward +from sparseml.utils.pytorch.module import get_prunable_layers __all__ = ["SparseGPTModifierPyTorch"] @@ -27,7 +33,7 @@ _LOGGER = logging.getLogger(__name__) -class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier): +class SparseGPTModifierPyTorch(SparseGPTModifier): """ Pytorch implementation of SparseGPT @@ -40,14 +46,25 @@ class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier): - run_calibration_forward() - LayerCompressor.compress() - LayerCompressor.post_compress() - - on_finalize - - LayerCompressor.revert_layer_wrappers() + - LayerCompressor.revert_layer_wrappers() + + | Sample yaml: + | test_stage: + | obcq_modifiers: + | SparseGPTModifier: + | sparsity: 0.5 + | mask_structure: "2:4" + | sequential_update: True + | dampening_frac: 0.001 + | targets: __ALL__ + | block_size: 128 + | quantize: False :param model: Pytorch model to perform OBCQ on, in-place """ model: Optional[ModifiableModel] = None - layer_compressors: List = None + layer_compressors_: Optional[List[Any]] = None def on_initialize(self, state: "State", **kwargs) -> bool: """ @@ -65,7 +82,99 @@ def on_initialize(self, state: "State", **kwargs) -> bool: "quantization must be enabled." ) - return super(SparseGPTModifierPyTorch, self).on_initialize(state, **kwargs) + modifiable_model = state.model + calibration_dataloader = state.data.calib + + if self.targets is None: + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) + self.targets = modifiable_model.get_no_split_params() + + self.initialize_compression(modifiable_model, calibration_dataloader) + self.apply_compression(calibration_dataloader) + + return True + + def initialize_compression( + self, + model: ModifiableModel, + dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None, + ): + """ + Setup for WANDA, initializes the model, device, + and other parameters, also initilializes the + compressible layers of model, and sets the device + + :param model: model to initialize for compression + """ + self.model = model + self.compressible_layers_ = self.compressible_layers() + self.model = self.model.model + self.layer_compressors_ = [] + self._infer_mask_block_size() + + if self.sparsity_profile is not None and self.sparsity_profile.lower() == "owl": + _LOGGER.info( + "Inferring layer-wise sparsities from " + f"{len(dataloader)} calibration samples..." + ) + self.sparsity = self._infer_layer_sparsity(dataloader) + self._validate_layerwise_sparsity() + + for idx, (name, layer) in enumerate(self.compressible_layers_.items()): + _LOGGER.info(f"Preparing {name} for compression") + if isinstance(self.sparsity, Dict): + layer_sparsity = self.sparsity[name] + elif isinstance(self.sparsity, List): + layer_sparsity = self.sparsity[idx] + else: # float + layer_sparsity = self.sparsity + args = self._pruning_arguments(layer_sparsity) + comp_cls = self._compression_class() + compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) + if not self.sequential_update: + # add all batch processing hooks before the forward pass + compressor.pre_compress() + self.layer_compressors_.append(compressor) + + @torch.no_grad() + def apply_compression( + self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None + ) -> Dict: + """ + Run Wanda on the loaded model, using dataloader as calibration data + + :param dataloader: calibration data for WANDA + """ + class_name = self.__class__.__name__.replace("PyTorch", "") + _LOGGER.info( + f"Running {class_name} calibration with " f"{len(dataloader)} samples..." + ) + if not self.sequential_update: + # in non-sequential mode we run one forward batch for all modules + run_calibration_forward(self.model, dataloader, mask_padding=True) + + num_layers = len(self.compressible_layers_) + for idx, layer_compressor in enumerate(self.layer_compressors_): + layer_sparsity = layer_compressor.args["sparsity"] + _LOGGER.info( + f"\n===== Compressing layer {idx+1}/{num_layers} " + f"to sparsity {layer_sparsity} =====" + ) + + # Prune/quantize using SparseGPT + if self.sequential_update: + # in sequential mode we run one forward pass for each module we + # want to compress, this will be really slow but allows compression in + # earlier layers to affect later layers + layer_compressor.pre_compress() + _LOGGER.info(f"Calibrating {layer_compressor.name}...") + run_calibration_forward(self.model, dataloader, mask_padding=True) + layer_compressor.compress() + layer_compressor.post_compress() + layer_compressor.revert_layer_wrappers() + torch.cuda.empty_cache() def on_finalize(self, state: "State", **kwargs) -> bool: """ @@ -98,3 +207,96 @@ def _compression_class(self): :return: wrapper class used for root modules of this compression class """ return SparseGptWrapper + + def _infer_mask_block_size(self): + """ + Infer the mask block size from the mask structure. + Parses mask_structure of the form N:M where N, M are integers that + define a custom block shape; and sets prunen_ and prunem_ accordingly. + + :post-condition: prunen_ and prunem_ are set + """ + if self.mask_structure is None: + raise ValueError("mask_structure must be defined") + + self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":"))) + + def _infer_layer_sparsity(self, calibration_dataloader): + acts = _get_activations(self.model, calibration_dataloader) + sparsegpt_groups = {} + for name, layer in self.compressible_layers_.items(): + prunable_layers = get_prunable_layers(layer) + z = [ + m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0) + for n, m in prunable_layers.items() + ] + sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z]) + + acts = None + del acts + torch.cuda.empty_cache() + + outlier_ratios = {} + for group in sparsegpt_groups: + threshold = torch.mean(sparsegpt_groups[group]) * self.owl_m + outlier_ratios[group] = ( + 100 + * (sparsegpt_groups[group] > threshold).sum().item() + / sparsegpt_groups[group].numel() + ) + outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) + for k in outlier_ratios: + outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * ( + 1 + / (outlier_ratios_arr.max() - outlier_ratios_arr.min()) + * self.owl_lmbda + * 2 + ) + outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios]) + sparsities = { + k: 1 + - ( + outlier_ratios[k] + - np.mean(outlier_ratios_arr) + + (1 - float(self.sparsity)) + ) + for k in outlier_ratios + } + _LOGGER.info(f"OWL sparsities for sp={self.sparsity} are:") + for k in sparsities: + _LOGGER.info(f"Sparsity for {k}: {sparsities[k]}") + return sparsities + + +@torch.no_grad() +def _get_activations(model, data_loader, nsamples=128): + import functools + + model.eval() + acts = {} + + def save_acts(module, input, name): + if isinstance(input, tuple): + input = input[0] + if name not in acts: + acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + else: + acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt() + + hooks = [] + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear) and "lm_head" not in name: + hooks.append( + mod.register_forward_pre_hook(functools.partial(save_acts, name=name)) + ) + device = next(model.parameters()).device + for batch in tqdm(data_loader): + batch = {k: v.to(device) for k, v in batch.items()} + model(**batch) + batch = None + torch.cuda.empty_cache() + + for h in hooks: + h.remove() + + return acts diff --git a/src/sparseml/modifiers/pruning/wanda/base.py b/src/sparseml/modifiers/pruning/wanda/base.py index 26cb6db5bf7..d59621cc09d 100644 --- a/src/sparseml/modifiers/pruning/wanda/base.py +++ b/src/sparseml/modifiers/pruning/wanda/base.py @@ -37,8 +37,8 @@ class WandaPruningModifier(Modifier): - run_calibration_forward() - LayerCompressor.compress() - LayerCompressor.post_compress() + - LayerCompressor.revert_layer_wrappers() - on_finalize - - LayerCompressor.revert_layer_wrappers() :param sparsity: Sparsity to compress model to :param mask_structure: String to define the structure of the mask to apply. diff --git a/src/sparseml/modifiers/pruning/wanda/pytorch.py b/src/sparseml/modifiers/pruning/wanda/pytorch.py index 8d7e8ff3b76..6203e73f600 100644 --- a/src/sparseml/modifiers/pruning/wanda/pytorch.py +++ b/src/sparseml/modifiers/pruning/wanda/pytorch.py @@ -44,8 +44,18 @@ class WandaPruningModifierPyTorch(WandaPruningModifier): - run_calibration_forward() - LayerCompressor.compress() - LayerCompressor.post_compress() + - LayerCompressor.revert_layer_wrappers() - on_finalize - - LayerCompressor.revert_layer_wrappers() + + | Sample yaml: + | test_stage: + | wanda_modifiers: + | WandaPruningModifier: + | sparsity: 0.05 + | mask_structure: "2:4" + | sequential_update: True + | targets: __ALL__ + :param model: `ModifiableModel` to perform WANDA on, in-place """ @@ -141,7 +151,7 @@ def apply_compression( f"to sparsity {layer_sparsity} =====" ) - # Prune/quantize using SparseGPT + # Prune/quantize using the layer compressor if self.sequential_update: # in sequential mode we run one forward pass for each module we # want to compress, this will be really slow but allows compression in