Skip to content

Commit

Permalink
remove uneccessary files
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed May 6, 2024
1 parent b05ad00 commit 31d86d5
Show file tree
Hide file tree
Showing 42 changed files with 1,443 additions and 986 deletions.
14 changes: 1 addition & 13 deletions src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,10 @@ def fasterprune(
fake_quantize,
)

while scale.ndim < 2:
scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

while q.ndim < 2:
q = q.unsqueeze(1)
q = fake_quantize(
q,
scale[:, i],
zero_point[:, i],
self.layer.quantization_scheme.weights,
q, scale, zero_point, self.layer.quantization_scheme.weights
)

while q.ndim != 1:
q.squeeze()

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2

Expand Down
6 changes: 0 additions & 6 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization.utils.helpers import (
configure_module_bn_wrappers,
freeze_bn_stats,
Expand Down Expand Up @@ -74,16 +73,11 @@ def __init__(self, **kwargs):

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
# before the structure is modified to support quantization,
# we need to potentially modify the model architecture
module = modify_model(module)
self._enable_module_qat(module)
state.model.model.apply(torch.quantization.disable_observer)

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
module = state.model.model
module = modify_model(module)
if self.end and self.end != -1:
raise ValueError(
"end_epoch is disabled for QuantizationModifier and can only be set to"
Expand Down
2 changes: 2 additions & 0 deletions src/sparseml/modifiers/quantization_vllm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ class vLLMQuantizationModifier(Modifier):
not be updated. Leave None to not disable observers during QAT. Default is None
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
:param post_oneshot_calibration: Whether to rerun calibration on finalization
"""

config_groups: Dict[str, QuantizationScheme]
ignore: List[str] = Field(default_factory=list)
disable_quantization_observer_epoch: Optional[float] = None
num_calibration_steps: Optional[int] = None
post_oneshot_calibration: Optional[bool] = False

def create_init_config(self) -> QuantizationConfig:
return QuantizationConfig(
Expand Down
8 changes: 5 additions & 3 deletions src/sparseml/modifiers/quantization_vllm/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
set_module_for_calibration,
)
from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward

Expand Down Expand Up @@ -51,7 +50,6 @@ class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier):

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
module = modify_model(module)
self._apply_modifier_to_model(module)
module.apply(freeze_module_quantization)

Expand All @@ -64,7 +62,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:

self.calibration_dataloader_ = state.data.calib
module = state.model.model
module = modify_model(module)

# intialize quantization in appropriate modules
self._apply_modifier_to_model(module)
Expand All @@ -77,6 +74,11 @@ def on_initialize(self, state: State, **kwargs) -> bool:
return True

def on_finalize(self, state: State, **kwargs) -> bool:
module = state.model.model
if self.post_oneshot_calibration:
module.apply(set_module_for_calibration)
self._calibrate_if_possible(module)
module.apply(freeze_module_quantization)
return True

def on_start(self, state: State, event: Event, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion src/sparseml/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from dataclasses import dataclass
from typing import Dict, Generic, List, Optional, Tuple, TypeVar

from pydantic import Field

from sparseml.core import Modifier
from sparseml.core.model import ModifiableModel
from sparseml.core.model.base import LT
Expand Down Expand Up @@ -96,7 +98,7 @@ class SmoothQuantModifier(Modifier):
use the whole dataset
"""

smoothing_strength: float = 0.5
smoothing_strength: float = Field(validation_alias="alpha", default=0.5)
mappings: List[Tuple]
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None
Expand Down
1 change: 0 additions & 1 deletion src/sparseml/transformers/sparsification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

# flake8: noqa

from .modification import *
from .question_answering import *
from .sparse_config import *
from .sparse_model import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,42 +83,28 @@ def save_pretrained_wrapper(
# state_dict gets passed in as a kwarg for FSDP models
state_dict = kwargs.get("state_dict", None)

# check if we are in the old quantization framework
if qat_active(model) and not is_model_quantized(model):
if qat_active(model) or is_model_quantized(model):
_LOGGER.info(
"Compression for models quantized with QuantizationModifer is not "
"supported. Save will be run without compression and no sparsity "
"statistics will be calculated. To save a quantized model in a "
"compressed state please use vLLMQuantizationModifier instead."
"Compression for quantized models is not yet supported. Save will "
"be run without compression and no sparsity statistics will be "
"calculated."
)

original_save_pretrained.__get__(model, model_class)(
save_directory, **kwargs
)

return

elif qat_active(model): # quantized in new framework
_LOGGER.info(
"Sparsity compression for quantized models is not yet supported. "
"No sparsity statistics will be calculated and no sparsity config "
"will be saved."
)

original_save_pretrained.__get__(model, model_class)(
save_directory, **kwargs
)
if is_model_quantized(model):
quant_config = QuantizationConfig.from_pretrained(model)
quant_config_data = quant_config.dict()
config_file_path = os.path.join(save_directory, CONFIG_NAME)

quant_config = QuantizationConfig.from_pretrained(model)
quant_config_data = quant_config.model_dump(exclude_unset=True)
config_file_path = os.path.join(save_directory, CONFIG_NAME)

# add the sparsity config to the model's config file
with open(config_file_path, "r") as config_file:
config_data = json.load(config_file)
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)
# add the sparsity config to the model's config file
with open(config_file_path, "r") as config_file:
config_data = json.load(config_file)
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)

return

Expand All @@ -140,7 +126,7 @@ def save_pretrained_wrapper(
"calculation of compression statistics set "
"skip_compression_stats=True"
)
sparsity_config = SparsityConfigMetadata.from_pretrained(
sparsity_config = SparsityConfigMetadata.infer_config_from_model(
model, state_dict=state_dict, compress=save_compressed
)

Expand Down
21 changes: 7 additions & 14 deletions src/sparseml/transformers/sparsification/modification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
# isort:skip_file

# the modification module that adds modifications
# for transformers models to enable quantization

# import all the modification functions for the different models
from .modifying_bert import modify
from .modifying_llama import modify
from .modifying_mistral import modify
from .modifying_distilbert import modify
from .modifying_mobilebert import modify
from .modifying_opt import modify
from .modifying_qwen2_moe import modify
from .modify_model import modify_model
from .modifying_bert import *
from .modifying_distilbert import *
from .modifying_llama import *
from .modifying_mistral import *
from .modifying_mobilebert import *
from .modifying_opt import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Set of helper objects that are used to modify
the HuggingFace transformer models
"""

import torch


__all__ = [
"QuantizableIdentity",
"QuantizableMatMul",
"QuantizableBatchMatmul",
"QATMatMul",
"QATLinear",
]


class QuantizableIdentity(torch.nn.Module):
"""
Identity model that is introduced to be used
together with QuantizableMatMul to allow for
SparseML quantization scheme
"""

def forward(self, x):
return x


class QuantizableMatMul(torch.nn.Module):
"""
Wrapper around torch.matmul with distinct inputs/output class
instances that could be quantized through SparseML recipe
:param left_input_cls: class instance that is used to quantize the left input
:param right_input_cls: class instance that is used to quantize the right input
:param output_cls: class instance that is used to quantize the output (optional)
:return: the output of the matrix multiplication
"""

def __init__(self, left_input_cls, right_input_cls, output_cls=None):
super().__init__()
self.left_input = left_input_cls()
self.right_input = right_input_cls()
self.output = output_cls() if output_cls is not None else None

def forward(self, a: torch.Tensor, b: torch.Tensor):
out = torch.matmul(self.left_input(a), self.right_input(b))
if self.output is not None:
return self.output(out)
return out


class QuantizableBatchMatmul(QuantizableMatMul):
"""
Wrapper around torch.bmm with distinct inputs/output class
instances that could be quantized through SparseML recipe
:param left_input_cls: class instance that is used to quantize the left input
:param right_input_cls: class instance that is used to quantize the right input
:param output_cls: class instance that is used to quantize the output (optional)
:return: the output of the batch matrix multiplication
"""

def forward(self, a: torch.Tensor, b: torch.Tensor):
out = torch.bmm(self.left_input(a), self.right_input(b))
if self.output is not None:
return self.output(out)
return out


class QATMatMul(torch.nn.Module):
"""
Behaves like normal torch.matmul unless a SparseML QuantizationModifier
is initialized (Quantization-Aware-Training is invoked)
"""

def __init__(self):
super().__init__()

self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)


class QATLinear(torch.nn.Module):
"""
Behaves like normal torch.nn.Linear unless a SparseML QuantizationModifier
is initialized (Quantization-Aware-Training is invoked)
When initialized does not quantize inputs. Only weights are quantized
(inputs may come quantized)
"""

def __init__(self, in_features, out_features):
super().__init__()

self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 0,
"num_outputs": 1,
}

self.linear = torch.nn.Linear(in_features, out_features)

def forward(self, x: torch.Tensor):
return self.linear(x)
Loading

0 comments on commit 31d86d5

Please sign in to comment.