-
2.1. Install Dependency
-
3.1. Script
3.2. Parameters of Data Augmentation
3.3. Supported Augmenter
3.4. Text Generation Augmenter
3.5. Augmenter Arguments
Data Augmentation is a tool to help with augmenting NLP datasets for machine learning projects. This tool integrates nlpaug and other methods from Intel Lab.
pip install nlpaug
pip install transformers
git clone https://github.com/intel/intel-extension-for-transformers.git itrex
cd itrex
pip install -v .
Please refer to example.
from intel_extension_for_transformers.utils.data_augmentation import DataAugmentation
aug = DataAugmentation(augmenter_type="TextGenerationAug")
aug.input_dataset = "dev.csv"
aug.output_path = os.path.join(self.result_path, "test1.cvs")
aug.augmenter_arguments = {'model_name_or_path': 'gpt2-medium'}
aug.data_augment()
raw_datasets = load_dataset("csv", data_files=aug.output_path, delimiter="\t", split="train")
self.assertTrue(len(raw_datasets) == 10)
Parameter | Type | Description | Default value |
---|---|---|---|
augmenter_type | String | Augmentation type | NA |
input_dataset | String | Dataset name or a csv or a json file | None |
output_path | String | Saved path and name of augmented data file | "save_path/augmented_dataset.csv" |
data_config_or_task_name | String | Task name of glue dataset or data configure name | None |
augmenter_arguments | Dict | Parameters for augmenters. Different augmenter has different parameters | None |
column_names | String | The column needed to conduct augmentation, which is used for python package datasets | "sentence" |
split | String | Dataset needed to conduct augmentation, like:'validation', 'training' | "validation" |
num_samples | Integer | The number of the generated augmentation samples | 1 |
device | String | Deployment devices, "cuda" or "cpu" | 1 |
augmenter_type | augmenter_arguments | default value |
---|---|---|
"TextGenerationAug" | Refer to "Text Generation Augmenter" field in this document | NA |
"KeyboardAug" | Refer to "KeyboardAug" | NA |
"OcrAug" | Refer to "OcrAug" | NA |
"SpellingAug" | Refer to "SpellingAug" | NA |
"ContextualWordEmbsForSentenceAug" | Refer to "ContextualWordEmbsForSentenceAug" |
The text generation augment contains recipe to run data augmentation algorithm based on the conditional text generation using auto-regressive transformer models (like GPT, GPT-2, Transformer-XL, XLNet, CTRL) in order to automatically generate labeled data. Our approach follows algorithms described by Not Enough Data? Deep Learning to the Rescue! and Natural Language Generation for Effective Knowledge Distillation.
-
First, we fine-tune an auto-regressive model on the training set. Each sample contains both a label and a sentence.
-
Prepare datasets:
from datasets import load_dataset from intel_extension_for_transformers.utils.utils import EOS for split in {'train', 'validation'}: dataset = load_dataset('glue', 'sst2', split=split) with open('SST-2/' + split + '.txt', 'w') as fw: for d in dataset: fw.write(str(d['label']) + '\t' + d['sentence'] + EOS + '\n')
-
Fine-tune Causal Language Model
You can use the script run_clm.py from transformers examples for fine-tuning GPT2 (gpt2-medium) on SST-2 task. The loss is that of causal language modeling.
DATASET=SST-2 TRAIN_FILE=$DATASET/train.txt VALIDATION_FILE=$DATASET/validation.txt MODEL=gpt2-medium MODEL_DIR=model/$MODEL-$DATASET python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --model_name_or_path $MODEL \ --train_file $TRAIN_FILE \ --validation_file $VALIDATION_FILE \ --do_train \ --do_eval \ --output_dir $MODEL_DIR \ --overwrite_output_dir
-
-
Secondly, we generate labeled data. Given class labels sampled from the training set, we use the fine-tuned language model to predict sentences with below script:
from intel_extension_for_transformers.utils.data_augmentation import DataAugmentation aug = DataAugmentation(augmenter_type="TextGenerationAug") aug.input_dataset = "/your/original/training_set.csv" aug.output_path = os.path.join(self.result_path, "/your/augmented/dataset.cvs") aug.augmenter_arguments = {'model_name_or_path': '/your/fine-tuned/model'} aug.data_augment()
This data augmentation algorithm can be used in several scenarios, like model distillation.
Parameter | Type | Description | Default value |
---|---|---|---|
"model_name_or_path" | String | Language modeling model to generate data, refer to line | NA |
"stop_token" | String | Stop token used in input data file | EOS |
"num_return_sentences" | Integer | Total samples to generate, -1 means the number of the input samples | -1 |
"temperature" | float | parameter for CLM model | 1.0 |
"k" | float | top K | 0.0 |
"p" | float | top p | 0.9 |
"repetition_penalty" | float | repetition_penalty | 1.0 |