diff --git a/.github/workflows/test_generation.yml b/.github/workflows/test_generation.yml new file mode 100644 index 0000000000..0753061282 --- /dev/null +++ b/.github/workflows/test_generation.yml @@ -0,0 +1,37 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions +name: Intel Generation Utils - Test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: [3.8, 3.9] + os: [ubuntu-latest] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install optimum[exporters] + pip install .[tests] + - name: Test with Pytest + run: | + pytest tests/generation/ diff --git a/examples/neural_compressor/text-generation/README.md b/examples/neural_compressor/text-generation/README.md new file mode 100644 index 0000000000..2a151ecf7b --- /dev/null +++ b/examples/neural_compressor/text-generation/README.md @@ -0,0 +1,30 @@ + + +## Language generation + +Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py). + +The original generation task only supported the PyTorch eager model. By calling the `TorchScriptModelForCausalLM` class, we can now support a TorchScript model for generation tasks. + +Example usage: + +```bash +python run_generation.py \ + --model_type=gpt2 \ + --model_name_or_path=gpt2 \ + --jit +``` diff --git a/examples/neural_compressor/text-generation/requirements.txt b/examples/neural_compressor/text-generation/requirements.txt new file mode 100644 index 0000000000..c495ec0c41 --- /dev/null +++ b/examples/neural_compressor/text-generation/requirements.txt @@ -0,0 +1,3 @@ +sentencepiece != 0.1.92 +protobuf +torch >= 2.0.0 diff --git a/examples/neural_compressor/text-generation/run_generation.py b/examples/neural_compressor/text-generation/run_generation.py new file mode 100755 index 0000000000..11ad363ad3 --- /dev/null +++ b/examples/neural_compressor/text-generation/run_generation.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) +""" + + +import argparse +import logging + +import numpy as np +import torch +from transformers import ( + CTRLLMHeadModel, + CTRLTokenizer, + GPT2LMHeadModel, + GPT2Tokenizer, + OpenAIGPTLMHeadModel, + OpenAIGPTTokenizer, + TransfoXLLMHeadModel, + TransfoXLTokenizer, + XLMTokenizer, + XLMWithLMHeadModel, + XLNetLMHeadModel, + XLNetTokenizer, +) + +from optimum.intel.generation.modeling import TorchScriptModelForCausalLM + + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + +MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop + +MODEL_CLASSES = { + "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), + "ctrl": (CTRLLMHeadModel, CTRLTokenizer), + "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), + "xlnet": (XLNetLMHeadModel, XLNetTokenizer), + "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), + "xlm": (XLMWithLMHeadModel, XLMTokenizer), +} + +# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia +# in https://github.com/rusiaaman/XLNet-gen#methodology +# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e +PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family +(except for Alexei and Maria) are discovered. +The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the +remainder of the story. 1883 Western Siberia, +a young Grigori Rasputin is asked by his father and a group of men to perform magic. +Rasputin has a vision and denounces one of the men as a horse thief. Although his +father initially slaps him for making such an accusation, Rasputin watches as the +man is chased outside and beaten. Twenty years later, Rasputin sees a vision of +the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, +with people, even a bishop, begging for his blessing. """ + + +def set_seed(args): + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +# +# Functions to prepare models' input +# + + +def prepare_ctrl_input(args, _, tokenizer, prompt_text): + if args.temperature > 0.7: + logger.info("CTRL typically works better with lower temperatures (and lower top_k).") + + encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) + if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): + logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") + return prompt_text + + +def prepare_xlm_input(args, model, tokenizer, prompt_text): + # kwargs = {"language": None, "mask_token_id": None} + + # Set the language + use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb + if hasattr(model.config, "lang2id") and use_lang_emb: + available_languages = model.config.lang2id.keys() + if args.xlm_language in available_languages: + language = args.xlm_language + else: + language = None + while language not in available_languages: + language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") + + model.config.lang_id = model.config.lang2id[language] + # kwargs["language"] = tokenizer.lang2id[language] + + # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers + # XLM masked-language modeling (MLM) models need masked token + # is_xlm_mlm = "mlm" in args.model_name_or_path + # if is_xlm_mlm: + # kwargs["mask_token_id"] = tokenizer.mask_token_id + + return prompt_text + + +def prepare_xlnet_input(args, _, tokenizer, prompt_text): + prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX + prompt_text = prefix + prompt_text + return prompt_text + + +def prepare_transfoxl_input(args, _, tokenizer, prompt_text): + prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX + prompt_text = prefix + prompt_text + return prompt_text + + +PREPROCESSING_FUNCTIONS = { + "ctrl": prepare_ctrl_input, + "xlm": prepare_xlm_input, + "xlnet": prepare_xlnet_input, + "transfo-xl": prepare_transfoxl_input, +} + + +def adjust_length_to_model(length, max_sequence_length): + if length < 0 and max_sequence_length > 0: + length = max_sequence_length + elif 0 < max_sequence_length < length: + length = max_sequence_length # No generation bigger than model size + elif length < 0: + length = MAX_LENGTH # avoid infinite loop + return length + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + ) + + parser.add_argument("--prompt", type=str, default="") + parser.add_argument("--length", type=int, default=20) + parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") + + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="temperature of 1.0 has no effect, lower tend toward greedy sampling", + ) + parser.add_argument( + "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" + ) + parser.add_argument("--k", type=int, default=0) + parser.add_argument("--p", type=float, default=0.9) + + parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") + parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") + parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") + + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") + parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference") + + parser.add_argument( + "--output_dir", + default=None, + type=str, + help="Output directory where to save the resulting model", + ) + args = parser.parse_args() + + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() + + logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}") + + set_seed(args) + + # Initialize the model and tokenizer + try: + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + except KeyError: + raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") + + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + + if args.jit: + model = TorchScriptModelForCausalLM.from_pretrained(args.model_name_or_path, export=True) + else: + model = model_class.from_pretrained(args.model_name_or_path) + + if args.output_dir is not None and args.jit: + model.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + model.to(args.device) + + args.length = adjust_length_to_model( + args.length, + max_sequence_length=model.config.max_position_embeddings + if hasattr(model.config, "max_position_embeddings") + else 0, + ) + logger.info(args) + + prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") + + # Different models need different input formatting and/or extra arguments + requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() + if requires_preprocessing: + prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) + preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) + + if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: + tokenizer_kwargs = {"add_space_before_punct_symbol": True} + else: + tokenizer_kwargs = {} + + encoded_prompt = tokenizer.encode( + preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs + ) + else: + prefix = args.prefix if args.prefix else args.padding_text + encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") + encoded_prompt = encoded_prompt.to(args.device) + + if encoded_prompt.size()[-1] == 0: + input_ids = None + else: + input_ids = encoded_prompt + + output_sequences = model.generate( + input_ids=input_ids, + max_length=args.length + len(encoded_prompt[0]), + temperature=args.temperature, + top_k=args.k, + top_p=args.p, + repetition_penalty=args.repetition_penalty, + do_sample=True, + num_return_sequences=args.num_return_sequences, + ) + + # Remove the batch dimension when returning multiple sequences + if len(output_sequences.shape) > 2: + output_sequences.squeeze_() + + generated_sequences = [] + + for generated_sequence_idx, generated_sequence in enumerate(output_sequences): + print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") + generated_sequence = generated_sequence.tolist() + + # Decode text + text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) + + # Remove all text after the stop token + text = text[: text.find(args.stop_token) if args.stop_token else None] + + # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing + total_sequence = ( + prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] + ) + + generated_sequences.append(total_sequence) + print(total_sequence) + + return generated_sequences + + +if __name__ == "__main__": + main() diff --git a/optimum/intel/generation/__init__.py b/optimum/intel/generation/__init__.py new file mode 100644 index 0000000000..362bbb31e9 --- /dev/null +++ b/optimum/intel/generation/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .modeling import TorchScriptModelForCausalLM diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py new file mode 100644 index 0000000000..5bdec03927 --- /dev/null +++ b/optimum/intel/generation/modeling.py @@ -0,0 +1,355 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple, Union + +import torch +from huggingface_hub import hf_hub_download +from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import WEIGHTS_NAME + +from optimum.exporters import TasksManager +from optimum.modeling_base import OptimizedModel +from optimum.utils import NormalizedConfigManager + +from ..utils.import_utils import is_torch_version, is_transformers_version + + +if is_transformers_version("<", "4.25.0"): + from transformers.generation_utils import GenerationMixin +else: + from transformers.generation import GenerationMixin + + +logger = logging.getLogger(__name__) + + +class TorchScriptModelForCausalLM(OptimizedModel, GenerationMixin): + auto_model_class = AutoModelForCausalLM + export_feature = "text-generation" + main_input_name = "input_ids" + base_model_prefix = "torch_script_model" + + def __init__( + self, + model, + config: PretrainedConfig = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + use_cache: bool = True, + **kwargs, + ): + self.model = model + self.config = config + self.model_save_dir = model_save_dir + self.preprocessors = kwargs.get("preprocessors", []) + self.use_cache = use_cache + self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model.to(self._device) + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) + + if is_transformers_version("<=", "4.25.1"): + self.generation_config = None + else: + from transformers import GenerationConfig + + self.generation_config = GenerationConfig.from_model_config(config) + + # Avoid warnings when creating a transformers pipeline + AutoConfig.register(self.base_model_prefix, AutoConfig) + self.auto_model_class.register(AutoConfig, self.__class__) + + @staticmethod + def load_model(file_name: Union[str, Path]): + model = torch.jit.load(file_name) + torch.jit.freeze(model.eval()) + return model + + def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): + torch.jit.save(self.model, os.path.join(save_directory, WEIGHTS_NAME)) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + **kwargs, + ) -> CausalLMOutputWithPast: + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + if self.use_cache: + if past_key_values is None: + nb_pkv = 2 + num_layers = self.normalized_config.num_layers + num_attention_heads = self.normalized_config.num_attention_heads + hidden_size = self.normalized_config.hidden_size + d_k = hidden_size // num_attention_heads + + if self.config.model_type != "bloom": + new_shape = [input_ids.shape[0], num_attention_heads, 0, d_k] + empty_tensor = torch.empty(size=new_shape) + past_key_values = tuple(tuple(empty_tensor for _ in range(nb_pkv)) for _ in range(num_layers)) + pkv = tuple(empty_tensor for _ in range(nb_pkv)) + else: + pkv = () + for nb_pkv in range(nb_pkv): + if nb_pkv % 2 == 0: + new_shape = [input_ids.shape[0] * num_attention_heads, d_k, 0] + else: + new_shape = [input_ids.shape[0] * num_attention_heads, 0, d_k] + pkv = pkv + (torch.empty(size=new_shape),) + past_key_values = tuple(tuple(pkv) for _ in range(num_layers)) + + inputs["past_key_values"] = past_key_values + outputs = self.model(**inputs) + + return CausalLMOutputWithPast(logits=outputs[0], past_key_values=outputs[1] if self.use_cache else None) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: PretrainedConfig, + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = WEIGHTS_NAME, + local_files_only: bool = False, + use_cache: bool = True, + **kwargs, + ): + if not getattr(config, "torchscript", False): + raise ValueError("`torchscript` should be set to True to load TorchScript model") + + # Load the model from local directory + if os.path.isdir(model_id): + file_name = os.path.join(model_id, file_name) + model = cls.load_model(file_name) + model_save_dir = model_id + # Download the model from the hub + else: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=file_name, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + model_save_dir = Path(model_cache_path).parent + model = cls.load_model(model_cache_path) + + # IPEX jit model need 2 iterations to convert model to int8 model + onnx_config_class = TasksManager.get_exporter_config_constructor( + model_type=config.model_type.replace("_", "-"), + exporter="onnx", + task=cls.export_feature, + ) + onnx_config = onnx_config_class(config, use_past=use_cache) + model_inputs = onnx_config.generate_dummy_inputs(framework="pt") + for i in range(2): + model(**model_inputs) + + return cls( + model, + config=config, + model_save_dir=model_save_dir, + use_cache=use_cache, + **kwargs, + ) + + @classmethod + def _from_transformers( + cls, + model_id: str, + config: PretrainedConfig, + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, + use_cache: bool = True, + **kwargs, + ): + if is_torch_version("<", "2.0.0"): + raise ImportError("`torch>=2.0.0` is needed to trace your model") + + task = cls.export_feature + model_kwargs = { + "revision": revision, + "use_auth_token": use_auth_token, + "cache_dir": cache_dir, + "subfolder": subfolder, + "local_files_only": local_files_only, + "force_download": force_download, + } + + model = TasksManager.get_model_from_task(task, model_id, **model_kwargs) + model.config.return_dict = False + signature = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call) + onnx_config_class = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) + onnx_config = onnx_config_class(model.config, use_past=use_cache) + dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") + model_inputs = { + key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None + } + + if use_cache: + traced_model = torch.jit.trace(model, example_inputs=tuple(model_inputs.values())) + else: + traced_model = torch.jit.trace(model, example_kwarg_inputs=model_inputs) + traced_model = torch.jit.freeze(traced_model.eval()) + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME) + config.torchscript = True + + return cls._from_pretrained( + model_id=save_dir_path, + config=config, + use_cache=use_cache, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + local_files_only=local_files_only, + **kwargs, + ) + + def can_generate(self) -> bool: + return True + + @property + def device(self) -> torch.device: + return self._device + + def to(self, device: Union[torch.device, str]): + self._device = device if isinstance(device, torch.device) else torch.device(device) + self.model.to(self._device) + return self + + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + past_key_values = past_key_values or kwargs.get("past", None) + + if self.use_cache and past_key_values is not None: + input_ids = input_ids[:, -1:] + + # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed + if past_key_values is not None and self.config.model_type == "bloom": + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": self.use_cache, + "position_ids": None, + "attention_mask": kwargs.get("attention_mask", None), + "token_type_ids": None, + } + + def _reorder_cache( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + if self.config.model_type == "bloom": + return self._reorder_cache_bloom(past_key_values, beam_idx) + + # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache + def _reorder_cache_bloom( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called for bloom architecture. + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) + + # Get a copy of `beam_idx` on all the devices where we need those indices. + device_to_beam_idx = { + past_state.device: beam_idx.to(past_state.device) + for layer_past in past_key_values + for past_state in layer_past + } + reordered_past = tuple( + ( + layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), + layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), + ) + for layer_past in standardized_past + ) + return self._convert_to_bloom_cache(reordered_past) + + # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache + @staticmethod + def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]: + """ + Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + # Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache + def _convert_to_standard_cache( + self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) + """ + if self.config.model_type != "bloom": + return past_key_value + + batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape + num_heads = batch_size_times_num_heads // batch_size + # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] + # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size, num_heads, head_dim, seq_length), + layer_past[1].view(batch_size, num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) diff --git a/tests/generation/test_modeling.py b/tests/generation/test_modeling.py new file mode 100644 index 0000000000..ee6c57bc3d --- /dev/null +++ b/tests/generation/test_modeling.py @@ -0,0 +1,147 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import time +import unittest + +import torch +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, pipeline, set_seed + +from optimum.intel.generation.modeling import TorchScriptModelForCausalLM + + +MODEL_NAMES = { + "bloom": "hf-internal-testing/tiny-random-BloomModel", + "gptj": "hf-internal-testing/tiny-random-gptj", + "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", +} + +SEED = 42 + + +class Timer(object): + def __enter__(self): + self.elapsed = time.perf_counter() + return self + + def __exit__(self, type, value, traceback): + self.elapsed = (time.perf_counter() - self.elapsed) * 1e3 + + +class ModelingIntegrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = ( + "bloom", + "gpt2", + "gptj", + "gpt_neo", + ) + GENERATION_LENGTH = 100 + SPEEDUP_CACHE = 1.2 + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + model = TorchScriptModelForCausalLM.from_pretrained(model_id, export=True) + self.assertIsInstance(model.config, PretrainedConfig) + trfs_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample", return_tensors="pt") + outputs = model(**tokens) + self.assertIsInstance(outputs.logits, torch.Tensor) + with torch.no_grad(): + trfs_outputs = trfs_model(**tokens) + # Compare outputs with original transformers model + atol = 1e-1 if model_arch == "bloom" else 1e-4 + self.assertTrue(torch.allclose(outputs.logits, trfs_outputs.logits, atol=atol)) + # Compare outputs with loaded model + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + loaded_model = TorchScriptModelForCausalLM.from_pretrained(tmpdirname) + loaded_model_outputs = loaded_model(**tokens) + self.assertTrue(torch.equal(outputs.logits, loaded_model_outputs.logits)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers_generate(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + model = TorchScriptModelForCausalLM.from_pretrained(model_id, export=True) + self.assertIsInstance(model.config, PretrainedConfig) + trfs_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample", return_tensors="pt") + outputs = model.generate(**tokens, do_sample=False, num_beams=1, temperature=0.9, min_length=20, max_length=20) + self.assertIsInstance(outputs, torch.Tensor) + with torch.no_grad(): + trfs_outputs = trfs_model.generate( + **tokens, do_sample=False, num_beams=1, temperature=0.9, min_length=20, max_length=20 + ) + # Compare outputs with original transformers model + self.assertTrue(torch.equal(outputs, trfs_outputs)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline(self, model_arch): + model_id = MODEL_NAMES[model_arch] + model = TorchScriptModelForCausalLM.from_pretrained(model_id, export=True) + model.to("cpu") + tokenizer = AutoTokenizer.from_pretrained(model_id) + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device="cpu") + outputs = pipe("This is a sample", max_length=10) + self.assertEqual(pipe.device, model.device) + self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_multiple_inputs(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + model = TorchScriptModelForCausalLM.from_pretrained(model_id, export=True) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + texts = ["this is a simple input", "this is a second simple input", "this is a third simple input"] + tokens = tokenizer(texts, padding=True, return_tensors="pt") + outputs = model.generate(**tokens, max_new_tokens=20, num_beams=2) + self.assertIsInstance(outputs, torch.Tensor) + self.assertEqual(outputs.shape[0], 3) + + def test_compare_with_and_without_past_key_values(self): + model_id = MODEL_NAMES["gpt2"] + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + + model_with_pkv = TorchScriptModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True) + # Warmup + _ = model_with_pkv.generate(**tokens) + with Timer() as with_pkv_timer: + outputs_model_with_pkv = model_with_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + + model_without_pkv = TorchScriptModelForCausalLM.from_pretrained(model_id, export=True, use_cache=False) + # Warmup + _ = model_without_pkv.generate(**tokens) + with Timer() as without_pkv_timer: + outputs_model_without_pkv = model_without_pkv.generate( + **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 + ) + self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) + self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) + self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) + self.assertTrue( + without_pkv_timer.elapsed / with_pkv_timer.elapsed > self.SPEEDUP_CACHE, + f"With pkv latency: {with_pkv_timer.elapsed:.3f} ms, without pkv latency: {without_pkv_timer.elapsed:.3f} ms," + f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", + ) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 2bfdedbec3..caac969722 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -451,7 +451,7 @@ def test_pipeline(self, model_arch): pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) outputs = pipe("This is a sample", max_length=10) self.assertEqual(pipe.device, model.device) - self.assertTrue(all(["This is a sample" in item["generated_text"] for item in outputs])) + self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES)