Skip to content

Commit

Permalink
Enable TorchScript export, loading and inference for causal lm models (
Browse files Browse the repository at this point in the history
…#283)

* add original example

* add jit model for generation

* enable generation

* add requirements and README

* enable caching

* move modeling to ipex subpackage

* fix import

* rename model class

* Update code and move modeling.py to utils folder. (#290)

* Update code and move modeling.py to utils folder.

* move modeling.py to generation folder

* Fixed typo

* Update code for support instance of fx model to TracedModelForCausalLM

* Update code

* Fixed typo

* Update readme

* Add generation for bloom models (#295)

* add bloom generation

* Fix generation

* refactorization

* set back input ids to prepare_inputs

* Add tests

* fix dates

* remove unecessary list comprehension

* remove unecessary list comprehension

* rename class

* fix dependencies

* fix dependency

---------

Co-authored-by: Cheng, Penghui <[email protected]>
  • Loading branch information
echarlaix and PenghuiCheng authored Apr 20, 2023
1 parent 08f6330 commit 993c1a5
Show file tree
Hide file tree
Showing 8 changed files with 899 additions and 1 deletion.
37 changes: 37 additions & 0 deletions .github/workflows/test_generation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Intel Generation Utils - Test

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
os: [ubuntu-latest]

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install optimum[exporters]
pip install .[tests]
- name: Test with Pytest
run: |
pytest tests/generation/
30 changes: 30 additions & 0 deletions examples/neural_compressor/text-generation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!---
Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

## Language generation

Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py).

The original generation task only supported the PyTorch eager model. By calling the `TorchScriptModelForCausalLM` class, we can now support a TorchScript model for generation tasks.

Example usage:

```bash
python run_generation.py \
--model_type=gpt2 \
--model_name_or_path=gpt2 \
--jit
```
3 changes: 3 additions & 0 deletions examples/neural_compressor/text-generation/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sentencepiece != 0.1.92
protobuf
torch >= 2.0.0
310 changes: 310 additions & 0 deletions examples/neural_compressor/text-generation/run_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
"""


import argparse
import logging

import numpy as np
import torch
from transformers import (
CTRLLMHeadModel,
CTRLTokenizer,
GPT2LMHeadModel,
GPT2Tokenizer,
OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
TransfoXLLMHeadModel,
TransfoXLTokenizer,
XLMTokenizer,
XLMWithLMHeadModel,
XLNetLMHeadModel,
XLNetTokenizer,
)

from optimum.intel.generation.modeling import TorchScriptModelForCausalLM


logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop

MODEL_CLASSES = {
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
"ctrl": (CTRLLMHeadModel, CTRLTokenizer),
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
}

# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""


def set_seed(args):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)


#
# Functions to prepare models' input
#


def prepare_ctrl_input(args, _, tokenizer, prompt_text):
if args.temperature > 0.7:
logger.info("CTRL typically works better with lower temperatures (and lower top_k).")

encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
return prompt_text


def prepare_xlm_input(args, model, tokenizer, prompt_text):
# kwargs = {"language": None, "mask_token_id": None}

# Set the language
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
if hasattr(model.config, "lang2id") and use_lang_emb:
available_languages = model.config.lang2id.keys()
if args.xlm_language in available_languages:
language = args.xlm_language
else:
language = None
while language not in available_languages:
language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")

model.config.lang_id = model.config.lang2id[language]
# kwargs["language"] = tokenizer.lang2id[language]

# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
# XLM masked-language modeling (MLM) models need masked token
# is_xlm_mlm = "mlm" in args.model_name_or_path
# if is_xlm_mlm:
# kwargs["mask_token_id"] = tokenizer.mask_token_id

return prompt_text


def prepare_xlnet_input(args, _, tokenizer, prompt_text):
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
prompt_text = prefix + prompt_text
return prompt_text


def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
prompt_text = prefix + prompt_text
return prompt_text


PREPROCESSING_FUNCTIONS = {
"ctrl": prepare_ctrl_input,
"xlm": prepare_xlm_input,
"xlnet": prepare_xlnet_input,
"transfo-xl": prepare_transfoxl_input,
}


def adjust_length_to_model(length, max_sequence_length):
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)

parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")

parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
)
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--p", type=float, default=0.9)

parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")

parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")

parser.add_argument(
"--output_dir",
default=None,
type=str,
help="Output directory where to save the resulting model",
)
args = parser.parse_args()

args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()

logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")

set_seed(args)

# Initialize the model and tokenizer
try:
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
except KeyError:
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")

tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)

if args.jit:
model = TorchScriptModelForCausalLM.from_pretrained(args.model_name_or_path, export=True)
else:
model = model_class.from_pretrained(args.model_name_or_path)

if args.output_dir is not None and args.jit:
model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)

model.to(args.device)

args.length = adjust_length_to_model(
args.length,
max_sequence_length=model.config.max_position_embeddings
if hasattr(model.config, "max_position_embeddings")
else 0,
)
logger.info(args)

prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")

# Different models need different input formatting and/or extra arguments
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)

if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else:
tokenizer_kwargs = {}

encoded_prompt = tokenizer.encode(
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
)
else:
prefix = args.prefix if args.prefix else args.padding_text
encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device)

if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt

output_sequences = model.generate(
input_ids=input_ids,
max_length=args.length + len(encoded_prompt[0]),
temperature=args.temperature,
top_k=args.k,
top_p=args.p,
repetition_penalty=args.repetition_penalty,
do_sample=True,
num_return_sequences=args.num_return_sequences,
)

# Remove the batch dimension when returning multiple sequences
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()

generated_sequences = []

for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequence = generated_sequence.tolist()

# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

# Remove all text after the stop token
text = text[: text.find(args.stop_token) if args.stop_token else None]

# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)

generated_sequences.append(total_sequence)
print(total_sequence)

return generated_sequences


if __name__ == "__main__":
main()
Loading

0 comments on commit 993c1a5

Please sign in to comment.