Skip to content

Commit

Permalink
enable onnx export for INC PTQ model (#373)
Browse files Browse the repository at this point in the history
* enable onnx export for PTQ

* fix output loading quantized model check

* tests refactorization

* fix style
  • Loading branch information
echarlaix authored and changwangss committed Sep 25, 2023
1 parent 0ea6dee commit e0c1fc4
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 149 deletions.
19 changes: 12 additions & 7 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,19 @@ def quantize(
remove_unused_columns=remove_unused_columns,
data_collator=data_collator,
)
op_type_dict = getattr(quantization_config, "op_type_dict", None)
if op_type_dict is None or "Embedding" not in op_type_dict:
logger.warning("ONNX export is no supported for model with quantized embeddings")
save_onnx_model = False

# Disable ONNX export for post-training quantized model as deprecated in neural-compressor>=2.2.0
if save_onnx_model:
logger.warning(
"ONNX export for post-training quantized model is no longer supported by neural-compressor>=2.2.0. "
"To apply quantization on an ONNX model, check out optimum.onnxruntime.ORTQuantizer"
)
save_onnx_model = False
else:
# Disable ONNX export for dynamically quantized model as deprecated in neural-compressor>=2.2.0
if save_onnx_model:
logger.warning(
"ONNX export for dynamic quantized model is no longer supported by neural-compressor>=2.2.0. "
"To apply dynamic quantization on an ONNX model, you can use optimum.onnxruntime.ORTQuantizer"
)
save_onnx_model = False

if (
quantization_config.backend == "ipex"
Expand Down
80 changes: 80 additions & 0 deletions tests/neural_compressor/test_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ruff: noqa


import os
import tempfile

from neural_compressor.config import PostTrainingQuantConfig
from parameterized import parameterized
from transformers import AutoTokenizer, set_seed
from utils_tests import SEED, INCTestMixin, _generate_dataset

from optimum.intel import (
INCConfig,
INCModelForCausalLM,
INCModelForSeq2SeqLM,
INCModelForQuestionAnswering,
INCModelForSequenceClassification,
INCModelForMaskedLM,
INCModelForTokenClassification,
INCQuantizer,
)
from optimum.onnxruntime import ORTModelForCausalLM, ORTModelForSequenceClassification
from optimum.pipelines import ORT_SUPPORTED_TASKS

os.environ["CUDA_VISIBLE_DEVICES"] = ""
set_seed(SEED)


class OptimizationTest(INCTestMixin):
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
("text-classification", "hf-internal-testing/tiny-random-bert", 34),
)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
def test_static_quantization(self, task, model_name, expected_quantized_matmuls):
num_samples = 10
model = ORT_SUPPORTED_TASKS[task]["class"][0].auto_model_class.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
quantizer = INCQuantizer.from_pretrained(model, task=task)
calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=num_samples)
save_onnx_model = True
op_type_dict = (
{"Embedding": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}}}
if save_onnx_model
else None
)
quantization_config = PostTrainingQuantConfig(approach="static", op_type_dict=op_type_dict)
with tempfile.TemporaryDirectory() as tmp_dir:
quantizer.quantize(
quantization_config=quantization_config,
calibration_dataset=calibration_dataset,
save_directory=tmp_dir,
save_onnx_model=save_onnx_model,
)
self.check_model_outputs(
q_model=quantizer._quantized_model,
task=task,
tokenizer=tokenizer,
save_directory=tmp_dir,
expected_quantized_matmuls=expected_quantized_matmuls,
is_static=True,
num_samples=num_samples,
load_onnx_model=save_onnx_model,
)
160 changes: 18 additions & 142 deletions tests/neural_compressor/test_optimization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,10 +14,9 @@

# ruff: noqa


import os
import tempfile
import unittest
from functools import partial

import evaluate
import numpy as np
Expand All @@ -32,22 +31,18 @@
TuningCriterion,
WeightPruningConfig,
)
from onnx import load as onnx_load
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoTokenizer,
EvalPrediction,
TrainingArguments,
Seq2SeqTrainingArguments,
default_data_collator,
pipeline,
BertTokenizer,
EncoderDecoderModel,
Seq2SeqTrainingArguments,
pipeline,
set_seed,
)
from utils_tests import SEED, INCTestMixin, _generate_dataset

from optimum.intel import (
INCConfig,
Expand All @@ -58,63 +53,27 @@
INCModelForMaskedLM,
INCModelForTokenClassification,
INCQuantizer,
INCStableDiffusionPipeline,
INCTrainer,
INCSeq2SeqTrainer,
INCStableDiffusionPipeline,
)
from optimum.intel.neural_compressor.utils import _HEAD_TO_AUTOMODELS
from optimum.intel.utils.constant import DIFFUSION_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
from optimum.intel.utils.constant import DIFFUSION_WEIGHTS_NAME
from optimum.onnxruntime import ORTModelForCausalLM, ORTModelForSequenceClassification
from optimum.pipelines import ORT_SUPPORTED_TASKS


os.environ["CUDA_VISIBLE_DEVICES"] = ""
set_seed(1009)

_TASK_TO_DATASET = {
"text-classification": ("glue", "sst2", "sentence"),
"text-generation": ("wikitext", "wikitext-2-raw-v1", "text"),
"text2text-generation": ("cnn_dailymail", "3.0.0", ("article", "highlights")),
}


def num_quantized_matmul_onnx_model(onnx_model):
num_quantized_matmul = 0
for initializer in onnx_model.graph.initializer:
if "MatMul" in initializer.name and "quantized" in initializer.name:
num_quantized_matmul += 1
return num_quantized_matmul
set_seed(SEED)


def _preprocess_function(examples, tokenizer, column_name):
return tokenizer(examples[column_name], padding="max_length", max_length=128, truncation=True)


def _compute_metrics(outputs, metric):
return metric.compute(predictions=np.argmax(outputs.predictions, axis=1), references=outputs.label_ids)


def _generate_dataset(quantizer, tokenizer, num_samples=10):
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[quantizer.task]
dataset = quantizer.get_calibration_dataset(
dataset_name,
dataset_config_name=dataset_config_name,
preprocess_function=partial(_preprocess_function, tokenizer=tokenizer, column_name=column_name),
num_samples=num_samples,
dataset_split="train",
)
return dataset


class OptimizationTest(unittest.TestCase):
class OptimizationTest(INCTestMixin):
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (
("text-classification", "hf-internal-testing/tiny-random-bert", 30),
("text-classification", "hf-internal-testing/tiny-random-bert", 34),
# ("text-generation", "hf-internal-testing/tiny-random-BloomForCausalLM", 1), # TODO : enable causal lm task once INC ONNX export fixed
)

SUPPORTED_ARCHITECTURES_DYNAMIC = SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS + (
("fill-mask", "hf-internal-testing/tiny-random-DistilBertForMaskedLM", 30),
("token-classification", "hf-internal-testing/tiny-random-AlbertForTokenClassification", 30),
("fill-mask", "hf-internal-testing/tiny-random-DistilBertForMaskedLM", 34),
("token-classification", "hf-internal-testing/tiny-random-AlbertForTokenClassification", 34),
)

TEXT_GENERATION_SUPPORTED_ARCHITECTURES = (
Expand Down Expand Up @@ -148,16 +107,19 @@ def test_dynamic_quantization(self, task, model_name, expected_quantized_matmuls
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
def test_static_quantization(self, task, model_name, expected_quantized_matmuls):
num_samples = 10
quantization_config = PostTrainingQuantConfig(approach="static")
model = ORT_SUPPORTED_TASKS[task]["class"][0].auto_model_class.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

quantizer = INCQuantizer.from_pretrained(model, task=task)
calibration_dataset = _generate_dataset(quantizer, tokenizer, num_samples=num_samples)
save_onnx_model = False

op_type_dict = (
{"Embedding": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}}}
if save_onnx_model
else None
)
quantization_config = PostTrainingQuantConfig(approach="static", op_type_dict=op_type_dict)
with tempfile.TemporaryDirectory() as tmp_dir:
quantizer.quantize(
quantization_config=quantization_config,
Expand Down Expand Up @@ -530,89 +492,3 @@ def _compute_metrics(pred):
self.assertIsInstance(loaded_model_outputs.logits, torch.Tensor)
# Compare tensor outputs
self.assertTrue(torch.allclose(loaded_model_outputs.logits, model_outputs.logits, atol=1e-4))

def check_model_outputs(
self,
q_model,
task,
tokenizer,
save_directory,
expected_quantized_matmuls,
is_static=True,
load_onnx_model=True,
num_samples=None,
file_name=ONNX_WEIGHTS_NAME,
):
tokens = tokenizer("This is a sample input", return_tensors="pt")
inc_model = eval(_HEAD_TO_AUTOMODELS[task]).from_pretrained(save_directory)
model_kwargs = (
{"decoder_file_name": file_name, "use_cache": False}
if task == "text-generation"
else {"file_name": file_name}
)
inc_config = INCConfig.from_pretrained(save_directory)
self.assertEqual(inc_config.save_onnx_model, load_onnx_model)

if num_samples is not None:
self.assertEqual(inc_config.quantization["dataset_num_samples"], num_samples)

if load_onnx_model:
onnx_model = onnx_load(os.path.join(save_directory, file_name))
num_quantized_matmul = num_quantized_matmul_onnx_model(onnx_model)

if num_quantized_matmul > 0:
self.assertEqual(inc_config.quantization["is_static"], is_static)

self.assertEqual(expected_quantized_matmuls, num_quantized_matmul)
ort_model = ORT_SUPPORTED_TASKS[task]["class"][0].from_pretrained(save_directory, **model_kwargs)
ort_outputs = ort_model(**tokens)
self.assertTrue("logits" in ort_outputs)

with torch.no_grad():
model_outputs = q_model(**tokens)
inc_model_outputs = inc_model(**tokens)
self.assertTrue(torch.equal(model_outputs["logits"], inc_model_outputs["logits"]))
# self.assertTrue(torch.allclose(ort_outputs.logits, inc_model_outputs.logits, atol=1e-4))

@staticmethod
def get_trainer(
model_name,
task,
save_directory,
q_config=None,
p_config=None,
d_config=None,
save_onnx_model=True,
num_train_samples=8,
num_eval_samples=8,
):
model = ORT_SUPPORTED_TASKS[task]["class"][0].auto_model_class.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

metric = evaluate.load("accuracy")
dataset_name, dataset_config_name, column_name = _TASK_TO_DATASET[task]
dataset = load_dataset(dataset_name, dataset_config_name)
dataset = dataset.map(
partial(_preprocess_function, tokenizer=tokenizer, column_name=column_name), batched=True
)

trainer = INCTrainer(
model=model,
quantization_config=q_config,
pruning_config=p_config,
distillation_config=d_config,
task=task,
args=TrainingArguments(save_directory, num_train_epochs=2.0, do_train=True, do_eval=True),
train_dataset=dataset["train"].select(range(num_train_samples)),
eval_dataset=dataset["validation"].select(range(num_eval_samples)),
compute_metrics=partial(_compute_metrics, metric=metric),
tokenizer=tokenizer,
data_collator=default_data_collator,
)
trainer.train()
trainer.evaluate()
trainer.save_model(save_onnx_model=save_onnx_model)
trainer.model.eval()
return trainer
Loading

0 comments on commit e0c1fc4

Please sign in to comment.