Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into text2text
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng authored Dec 16, 2024
2 parents 24b988c + 35cf1d2 commit 5c4f9a1
Show file tree
Hide file tree
Showing 13 changed files with 401 additions and 94 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ To load your IPEX model, you can just replace your `AutoModelForXxx` class with

model_id = "gpt2"
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, export=True)
+ model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
results = pipe("He's a dreadful magician and")
Expand Down
4 changes: 2 additions & 2 deletions docs/source/openvino/export.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ Optional arguments:
The group size to use for quantization. Recommended value is 128 and -1 uses per-column
quantization.
--backup-precision {none,int8_sym,int8_asym}
Defines a backup precision for mixed-precision weight compression. Only valid for int4 weight
format. If not provided, backup precision is int8_asym. 'none' stands for original floating-
Defines a backup precision for mixed-precision weight compression. Only valid for 4-bit weight
formats. If not provided, backup precision is int8_asym. 'none' stands for original floating-
point precision of the model weights, in this case weights are retained in their original
precision without any quantization. 'int8_sym' stands for 8-bit integer symmetric quantization
without zero point. 'int8_asym' stands for 8-bit integer asymmetric quantization with zero
Expand Down
19 changes: 19 additions & 0 deletions docs/source/openvino/optimization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ calibration_dataset = quantizer.get_calibration_dataset(
The `quantize()` method applies post-training static quantization and export the resulting quantized model to the OpenVINO Intermediate Representation (IR). The resulting graph is represented with two files: an XML file describing the network topology and a binary file describing the weights. The resulting model can be run on any target Intel device.


#### Speech-to-text Models Quantization

The speech-to-text Whisper model can be quantized without the need for preparing a custom calibration dataset. Please see example below.

```python
model_id = "openai/whisper-tiny"
ov_model = OVModelForSpeechSeq2Seq.from_pretrained(
model_id,
quantization_config=OVQuantizationConfig(
num_samples=10,
dataset="librispeech",
processor=model_id,
matmul_sq_alpha=0.95,
)
)
```

With this, encoder, decoder and decoder-with-past models of the Whisper pipeline will be fully quantized, including activations.

### Hybrid quantization

Traditional optimization methods like post-training 8-bit quantization do not work well for Stable Diffusion (SD) models and can lead to poor generation results. On the other hand, weight compression does not improve performance significantly when applied to Stable Diffusion models, as the size of activations is comparable to weights.
Expand Down
6 changes: 5 additions & 1 deletion optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def parse_args_openvino(parser: "ArgumentParser"):
choices=["none", "int8_sym", "int8_asym"],
default=None,
help=(
"Defines a backup precision for mixed-precision weight compression. Only valid for int4 weight format. "
"Defines a backup precision for mixed-precision weight compression. Only valid for 4-bit weight formats. "
"If not provided, backup precision is int8_asym. 'none' stands for original floating-point precision of "
"the model weights, in this case weights are retained in their original precision without any "
"quantization. 'int8_sym' stands for 8-bit integer symmetric quantization without zero point. 'int8_asym' "
Expand Down Expand Up @@ -354,6 +354,10 @@ def run(self):
from optimum.intel import OVStableDiffusion3Pipeline

model_cls = OVStableDiffusion3Pipeline
elif class_name == "FluxPipeline":
from optimum.intel import OVFluxPipeline

model_cls = OVFluxPipeline
else:
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")

Expand Down
6 changes: 3 additions & 3 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,9 @@ def __init__(self, module, config) -> None:
if use_bias:
concat_bias = torch.concat(bias_list, 0).contiguous()
self.concat_linear.bias = nn.Parameter(concat_bias)
self.q_slice = self.q_proj.out_features
self.k_slice = self.q_slice + self.k_proj.out_features
self.v_slice = self.k_slice + self.v_proj.out_features
self.q_slice = self.q_proj.weight.shape[0]
self.k_slice = self.q_slice + self.k_proj.weight.shape[0]
self.v_slice = self.k_slice + self.v_proj.weight.shape[0]
if self.module_device.type == "cpu":
if module.o_proj.__class__.__name__ not in ["LinearAllreduce"]:
self.mha_linear_add = LinearAdd(module.o_proj)
Expand Down
88 changes: 80 additions & 8 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from optimum.configuration_utils import BaseConfig

from ..utils.import_utils import is_nncf_available
from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_VISUAL_LM_DATASETS
from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_SPEECH_TO_TEXT_DATASETS, PREDEFINED_VISUAL_LM_DATASETS


if is_nncf_available():
Expand Down Expand Up @@ -123,11 +123,18 @@ class OVQuantizationMethod(str, Enum):
"mistralai/Mistral-7B-v0.1": {"bits": 4, "sym": True, "group_size": 128, "ratio": 0.9},
"baichuan-inc/Baichuan2-7B-Chat": {
"bits": 4,
"sym": True,
"sym": False,
"group_size": 128,
"ratio": 0.8,
},
"baichuan-inc/Baichuan2-13B-Chat": {
"bits": 4,
"sym": False,
"group_size": 128,
"ratio": 1.0,
"dataset": "wikitext2",
"quant_method": OVQuantizationMethod.AWQ,
"scale_estimation": True,
},
"lmsys/longchat-7b-16k": {
"bits": 4,
Expand Down Expand Up @@ -255,6 +262,10 @@ def __init__(
sym: bool = False,
ignored_scope: Optional[dict] = None,
num_samples: Optional[int] = None,
dataset: Optional[Optional[Union[str, List[str]]]] = None,
tokenizer: Optional[str] = None,
processor: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs,
):
"""
Expand All @@ -272,6 +283,10 @@ def __init__(
self.bits = bits
self.sym = sym
self.num_samples = num_samples
self.dataset = dataset
self.tokenizer = tokenizer
self.processor = processor
self.trust_remote_code = trust_remote_code

if isinstance(ignored_scope, nncf.IgnoredScope):
ignored_scope = ignored_scope.__dict__
Expand Down Expand Up @@ -313,6 +328,10 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
trust_remote_code (`bool`, defaults to `False`):
Allows to use custom code for the modeling hosted in the model repository. This option should only be set
for repositories you trust and in which you have read the code, as it will execute on your local machine
arbitrary code present in the model repository.
dataset (`str or List[str]`, *optional*):
The dataset used for data-aware compression with NNCF.
- For language models you can provide your own dataset in a list of strings or just use one from the list
Expand Down Expand Up @@ -395,10 +414,16 @@ def __init__(
backup_precision: Optional[str] = None,
**kwargs,
):
super().__init__(bits=bits, sym=sym, ignored_scope=ignored_scope, num_samples=num_samples)
self.tokenizer = tokenizer
self.trust_remote_code = trust_remote_code
self.dataset = dataset
super().__init__(
bits=bits,
sym=sym,
ignored_scope=ignored_scope,
num_samples=num_samples,
dataset=dataset,
tokenizer=tokenizer,
processor=processor,
trust_remote_code=trust_remote_code,
)
self.group_size = group_size or (-1 if bits == 8 else 128)
self.ratio = ratio
self.all_layers = all_layers
Expand All @@ -407,7 +432,6 @@ def __init__(
self.scale_estimation = scale_estimation
self.weight_format = weight_format
self.gptq = gptq
self.processor = processor
self.lora_correction = lora_correction
self.backup_precision = backup_precision
self.post_init()
Expand Down Expand Up @@ -535,6 +559,11 @@ def __init__(
model_type: str = "transformer",
fast_bias_correction: bool = True,
overflow_fix: str = "disable",
dataset: Optional[str] = None,
tokenizer: Optional[str] = None,
processor: Optional[str] = None,
trust_remote_code: bool = False,
smooth_quant_alpha: Optional[float] = None,
**kwargs,
):
"""
Expand All @@ -557,11 +586,42 @@ def __init__(
Whether to apply fast or full bias correction algorithm.
overflow_fix (`str`, default to "disable"):
Parameter for controlling overflow fix setting.
dataset (`str`, *optional*):
The dataset used for quantization. For text-to-speech model quantization the allowed value is 'librispeech'.
tokenizer (`str`, *optional*):
The tokenizer used to process the dataset. You can pass either:
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
processor (`str`, *optional*):
A transformers processor used to process inputs for multi-modal models. You can pass either:
- A string, the *model id* of a predefined processor hosted inside a model repo on huggingface.co.
- A path to a *directory* containing files required by the processor, for instance saved
using the [`~AutoProcessor.save_pretrained`] method, e.g., `./my_model_directory/`.
trust_remote_code (`bool`, defaults to `False`):
Allows to use custom code for the modeling hosted in the model repository. This option should only be set
for repositories you trust and in which you have read the code, as it will execute on your local machine
arbitrary code present in the model repository.
smooth_quant_alpha (`float`, *optional*):
SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and
reduces quantization error.
"""
super().__init__(bits=bits, sym=sym, ignored_scope=ignored_scope, num_samples=num_samples)
super().__init__(
bits=bits,
sym=sym,
ignored_scope=ignored_scope,
num_samples=num_samples,
dataset=dataset,
tokenizer=tokenizer,
processor=processor,
trust_remote_code=trust_remote_code,
)
self.model_type = model_type
self.fast_bias_correction = fast_bias_correction
self.overflow_fix = overflow_fix
self.smooth_quant_alpha = smooth_quant_alpha
self.post_init()

def post_init(self):
Expand All @@ -573,6 +633,18 @@ def post_init(self):
if self.bits != 8:
raise ValueError(f"Only support 8-bit for static quantization but found {self.bits}")

if self.dataset is not None:
if self.dataset not in PREDEFINED_SPEECH_TO_TEXT_DATASETS:
raise ValueError(
f"You have entered the following string value for dataset: {self.dataset}. But it is not supported."
f" Currently you can only choose {list(PREDEFINED_SPEECH_TO_TEXT_DATASETS.keys())}."
)

if self.smooth_quant_alpha is not None and not (0 <= self.smooth_quant_alpha <= 1):
raise ValueError(
f"SmoothQuant alpha parameter must be in range [0, 1], but found {self.smooth_quant_alpha}"
)


class OVConfig(BaseConfig):
CONFIG_NAME = "openvino_config.json"
Expand Down
22 changes: 20 additions & 2 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
from pathlib import Path
Expand All @@ -35,7 +35,9 @@
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput

from .. import OVConfig, OVQuantizer
from ..utils import is_transformers_version
from .configuration import OVQuantizationConfig, OVQuantizationConfigBase
from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM
from .utils import OV_TO_PT_TYPE, _print_compiled_model_properties

Expand Down Expand Up @@ -973,9 +975,25 @@ def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
load_in_8bit: bool = False,
quantization_config: Union[dict, OVQuantizationConfigBase] = None,
**kwargs,
):
return super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs)
compile_only = kwargs.get("compile_only", False)

if not compile_only and isinstance(quantization_config, OVQuantizationConfig):
model = super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(
model_id, config, load_in_8bit=False, **kwargs
)
quantization_config_copy = copy.deepcopy(quantization_config)
quantization_config_copy.processor = quantization_config.processor or model_id
OVQuantizer(model).quantize(ov_config=OVConfig(quantization_config=quantization_config_copy))
else:
model = super(OVModelForSpeechSeq2Seq, cls)._from_pretrained(
model_id, config, load_in_8bit=load_in_8bit, quantization_config=quantization_config, **kwargs
)

return model

class DummyWhisperModel:
def __init__(self):
Expand Down
Loading

0 comments on commit 5c4f9a1

Please sign in to comment.