diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index c1534618302..74920d0d697 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -54,6 +54,9 @@ class SparseGPTModifier(Modifier): :param block_size: Used to determine number of columns to compress in one pass :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm + :param preserve_sparsity_mask: Whether or not to preserve the sparsity mask + during when applying sparsegpt, this becomes useful when starting from a + previously pruned model, defaults to False. """ sparsity: Union[float, List[float]] = 0.0 @@ -68,6 +71,7 @@ class SparseGPTModifier(Modifier): prunem_: Optional[int] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 + preserve_sparsity_mask: bool = False def on_initialize_structure(self, state: State, **kwargs): """ diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index 4825eed1a92..ec9dfd90d23 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -203,6 +203,7 @@ def _compression_arguments(self, sparsity): "prunem": self.prunem_, "blocksize": self.block_size, "percdamp": self.dampening_frac, + "preserve_sparsity_mask": self.preserve_sparsity_mask, } def _compression_class(self): diff --git a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py index d8a95f18853..0079071bd0e 100644 --- a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py +++ b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py @@ -84,6 +84,7 @@ def fasterprune( prunem: int = 0, blocksize: int = 128, percdamp: float = 0.01, + preserve_sparsity_mask: bool = False, ): """ Run pruning and quantization(if applicable) on the layer up to the target @@ -95,6 +96,7 @@ def fasterprune( :param blocksize: Number of columns to compress in one pass :param percdamp: Amount of dampening to apply to H, as a fraction of the diagonal norm + :param preserve_sparsity_mask: Extend or ignore the base sparsity mask """ final_shape = self.layer.weight.shape final_dtype = self.layer.weight.dtype @@ -123,6 +125,13 @@ def fasterprune( Hinv = self.H mask = None + if preserve_sparsity_mask: + # compute existing sparsity mask + mask = torch.where( + W == 0, + torch.tensor(1, dtype=torch.bool), + torch.tensor(0, dtype=torch.bool), + ) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): @@ -138,12 +147,32 @@ def fasterprune( if prunen == 0: if mask is not None: mask1 = mask[:, i1:i2] + if int(W1.numel() * sparsity) > mask1.sum(): + # target sparsity is higher than base sparsity, extend mask1 + tmp = ( + (~mask[:, i1:i2]) + * W1**2 + / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + ) + thresh = torch.sort(tmp.flatten())[0][ + int(tmp.numel() * sparsity) + ] + mask1 = tmp <= thresh + else: + raise ValueError( + "The target sparsity is lower than the sparsity " + "of the base model. Please retry " + "after turning preserve_sparsity_mask=False" + ) else: tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] mask1 = tmp <= thresh else: - mask1 = torch.zeros_like(W1) == 1 + if mask is not None: + mask1 = mask[:, i1:i2] + else: + mask1 = torch.zeros_like(W1) == 1 for i in range(count): w = W1[:, i] @@ -154,6 +183,10 @@ def fasterprune( W1[:, i : (i + prunem)] ** 2 / (torch.diag(Hinv1)[i : (i + prunem)].reshape((1, -1))) ** 2 ) + + if mask is not None: + tmp = tmp * (~mask[:, i : (i + prunem)]) + mask1.scatter_( 1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True ) @@ -174,7 +207,12 @@ def fasterprune( W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_sparsity_mask: + # respect the sparsity of other groups + # really not needed, but kept for explicitness + W[:, i2:] -= (~mask[:, i2:]) * Err1.matmul(Hinv[i1:i2, i2:]) + else: + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 215560b230b..3dce40cecc5 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -14,6 +14,7 @@ import time +from sparseml.modifiers.utils import SPARSITY_THRESHOLD from sparseml.modifiers.utils.compression_wrapper import ModuleCompressionWrapper @@ -92,6 +93,7 @@ def fasterprune( final_shape = self.layer.weight.shape final_dtype = self.layer.weight.dtype W = self.layer.weight.data.clone() + from sparseml.pytorch.utils.helpers import tensor_sparsity if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) @@ -115,6 +117,17 @@ def fasterprune( self.H = torch.linalg.cholesky(self.H, upper=True) Hinv = self.H + sparsity = tensor_sparsity(W) + mask = ( + torch.where( + W == 0, + torch.tensor(1, dtype=torch.bool), + torch.tensor(0, dtype=torch.bool), + ) + if sparsity >= SPARSITY_THRESHOLD + else None + ) + # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -126,11 +139,22 @@ def fasterprune( Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] + if sparsity >= SPARSITY_THRESHOLD: + tmp = ( + (~mask[:, i1:i2]) + * W1**2 + / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + ) + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] + mask1 = tmp <= thresh + for i in range(count): w = W1[:, i] d = Hinv1[i, i] q = w.clone() + if sparsity >= SPARSITY_THRESHOLD: + q[mask1[:, i]] = 0 if hasattr(self.layer, "weight_fake_quant"): scale = self.layer.weight_fake_quant.scale diff --git a/src/sparseml/modifiers/utils/__init__.py b/src/sparseml/modifiers/utils/__init__.py index 0c44f887a47..39d1132f697 100644 --- a/src/sparseml/modifiers/utils/__init__.py +++ b/src/sparseml/modifiers/utils/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# flake8: noqa + +from .constants import * diff --git a/src/sparseml/modifiers/utils/constants.py b/src/sparseml/modifiers/utils/constants.py new file mode 100644 index 00000000000..3801c2e9ea9 --- /dev/null +++ b/src/sparseml/modifiers/utils/constants.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +__all__ = ["SPARSITY_THRESHOLD"] + +SPARSITY_THRESHOLD: float = 0.05 diff --git a/tests/sparseml/transformers/obcq/test_consecutive_runs.py b/tests/sparseml/transformers/obcq/test_consecutive_runs.py index 04b78ec82b8..7bcfc8b7efe 100644 --- a/tests/sparseml/transformers/obcq/test_consecutive_runs.py +++ b/tests/sparseml/transformers/obcq/test_consecutive_runs.py @@ -114,7 +114,7 @@ def setUp(self): self.output_second = Path(self.output) / "test_2" def test_consecutive_runs_small(self): - self._test_consecutive_runs(tolerance=1e-1) + self._test_consecutive_runs(tolerance=1e-3) @requires_gpu