From 4581968193699de14b56527296262dd76ab43557 Mon Sep 17 00:00:00 2001 From: Guanheng George Zhang <6156351+zhangguanheng66@users.noreply.github.com> Date: Fri, 9 Aug 2019 09:17:33 -0700 Subject: [PATCH] Apply Transformer model for the word language problem in pytorch/examples (#555) * Use append to accelerate data loading process. * First transformer model working for word language model. * Work for GPU (all the model and data have to be sent to cuda) * Transformer model GPU activated nhead=1 nlayers=1 d_ff=64 test loss 6.55 * Use lr=5.0 test loss 4.8 Encoder/decoder embeddings normalized by sqrt(d_model). test loss 3.84 lr=5.0 Encoder/decoder embeddings normalized by sqrt(d_model). test loss 4.68 lr=20.0 Remove print out. Revise main.py file. Load the best training model through epochs. Update README.md file to include the transformer model. Update the README.md file. Use PositionalEncoding in transformer. test loss 0.30 lr=5.0 * Update main.py to have mask on source sequences. Update generate.py to generate text with transformer.pt model. Add CUDA function to generate.py when running transformer model. Add generate_subsequent_mask() in Transformer Generate transformer model in main.py. Revise generate.py working for both RNN and Transformer models. Remove decoder_data Add some changes because of transformer.py. * No need to provide Trnasform args for generating text. Change d_ff to dim_feedforward. Remove Embeddings and PositionalEncoder out of transformer.py. * Replace tabs with spaces. * Update transformer model in model.py. * Recycle RNN arguments for Transformer model. * Update README.md file. * Remove model.generator in main.py. * Update the warnings in transformer model. * Fix a small bug in model.py. * Remove keyword arguments for consistence. * Create a new function generate_square_subsequent_mask inside the TransformerSeq2Seq * Remove unnecessary attributes. * A minor change. * Move src_mask and tgt_mask as the members of the module. * Move transformer check to model.py * Move masks inside forward function. * User TransformerEncoder for word language model. * Remove Generator module in Transformer. * Merge RNN and Transformer model in model.py * Minor changes. * Minor changes to address reviewer's comments. * Remove reset_parameter function. * Split RNN and Transformer model to keep code readable. * Minor changes. --- word_language_model/README.md | 54 +++++++++++-------- word_language_model/data.py | 9 ++-- word_language_model/generate.py | 19 +++++-- word_language_model/main.py | 36 +++++++++---- word_language_model/model.py | 93 +++++++++++++++++++++++++++++++++ 5 files changed, 170 insertions(+), 41 deletions(-) diff --git a/word_language_model/README.md b/word_language_model/README.md index 9f37c8a73f..be13a20517 100644 --- a/word_language_model/README.md +++ b/word_language_model/README.md @@ -4,11 +4,15 @@ This example trains a multi-layer RNN (Elman, GRU, or LSTM) on a language modeli By default, the training script uses the Wikitext-2 dataset, provided. The trained model can then be used by the generate script to generate new text. -```bash -python main.py --cuda --epochs 6 # Train a LSTM on Wikitext-2 with CUDA -python main.py --cuda --epochs 6 --tied # Train a tied LSTM on Wikitext-2 with CUDA -python main.py --cuda --tied # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs -python generate.py # Generate samples from the trained LSTM model. +```bash +python main.py --cuda --epochs 6 # Train a LSTM on Wikitext-2 with CUDA +python main.py --cuda --epochs 6 --tied # Train a tied LSTM on Wikitext-2 with CUDA +python main.py --cuda --epochs 6 --model Transformer --lr 5 + # Train a Transformer model on Wikitext-2 with CUDA +python main.py --cuda --tied # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs +python generate.py # Generate samples from the trained LSTM model. +python generate.py --cuda --model Transformer + # Generate samples from the trained Transformer model. ``` The model uses the `nn.RNN` module (and its sister modules `nn.GRU` and `nn.LSTM`) @@ -21,24 +25,28 @@ The `main.py` script accepts the following arguments: ```bash optional arguments: - -h, --help show this help message and exit - --data DATA location of the data corpus - --model MODEL type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU) - --emsize EMSIZE size of word embeddings - --nhid NHID number of hidden units per layer - --nlayers NLAYERS number of layers - --lr LR initial learning rate - --clip CLIP gradient clipping - --epochs EPOCHS upper epoch limit - --batch_size N batch size - --bptt BPTT sequence length - --dropout DROPOUT dropout applied to layers (0 = no dropout) - --decay DECAY learning rate decay per epoch - --tied tie the word embedding and softmax weights - --seed SEED random seed - --cuda use CUDA - --log-interval N report interval - --save SAVE path to save the final model + -h, --help show this help message and exit + --data DATA location of the data corpus + --model MODEL type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU) + --emsize EMSIZE size of word embeddings + --nhid NHID number of hidden units per layer + --nlayers NLAYERS number of layers + --lr LR initial learning rate + --clip CLIP gradient clipping + --epochs EPOCHS upper epoch limit + --batch_size N batch size + --bptt BPTT sequence length + --dropout DROPOUT dropout applied to layers (0 = no dropout) + --decay DECAY learning rate decay per epoch + --tied tie the word embedding and softmax weights + --seed SEED random seed + --cuda use CUDA + --log-interval N report interval + --save SAVE path to save the final model + --transformer_head N the number of heads in the encoder/decoder of the transformer model + --transformer_encoder_layers N the number of layers in the encoder of the transformer model + --transformer_decoder_layers N the number of layers in the decoder of the transformer model + --transformer_d_ff N the number of nodes on the hidden layer in feed forward nn ``` With these arguments, a variety of models can be tested. diff --git a/word_language_model/data.py b/word_language_model/data.py index 6f917c2c46..cda6e90dff 100644 --- a/word_language_model/data.py +++ b/word_language_model/data.py @@ -38,12 +38,13 @@ def tokenize(self, path): # Tokenize file content with open(path, 'r', encoding="utf8") as f: - ids = torch.LongTensor(tokens) - token = 0 + idss = [] for line in f: words = line.split() + [''] + ids = [] for word in words: - ids[token] = self.dictionary.word2idx[word] - token += 1 + ids.append(self.dictionary.word2idx[word]) + idss.append(torch.tensor(ids).type(torch.int64)) + ids = torch.cat(idss) return ids diff --git a/word_language_model/generate.py b/word_language_model/generate.py index 1f9a38dc5a..fb8543eab6 100644 --- a/word_language_model/generate.py +++ b/word_language_model/generate.py @@ -49,16 +49,25 @@ corpus = data.Corpus(args.data) ntokens = len(corpus.dictionary) -hidden = model.init_hidden(1) +if model.model_type != 'Transformer': + hidden = model.init_hidden(1) input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device) with open(args.outf, 'w') as outf: with torch.no_grad(): # no tracking history for i in range(args.words): - output, hidden = model(input, hidden) - word_weights = output.squeeze().div(args.temperature).exp().cpu() - word_idx = torch.multinomial(word_weights, 1)[0] - input.fill_(word_idx) + if model.model_type == 'Transformer': + output = model(input, False) + word_weights = output[-1].squeeze().div(args.temperature).exp().cpu() + word_idx = torch.multinomial(word_weights, 1)[0] + word_tensor = torch.Tensor([[word_idx]]).long().to(device) + input = torch.cat([input, word_tensor], 0) + else: + output, hidden = model(input, hidden) + word_weights = output.squeeze().div(args.temperature).exp().cpu() + word_idx = torch.multinomial(word_weights, 1)[0] + input.fill_(word_idx) + word = corpus.dictionary.idx2word[word_idx] outf.write(word + ('\n' if i % 20 == 19 else ' ')) diff --git a/word_language_model/main.py b/word_language_model/main.py index f2a447837c..03bfaeb87e 100644 --- a/word_language_model/main.py +++ b/word_language_model/main.py @@ -14,7 +14,7 @@ parser.add_argument('--data', type=str, default='./data/wikitext-2', help='location of the data corpus') parser.add_argument('--model', type=str, default='LSTM', - help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)') + help='type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU, Transformer)') parser.add_argument('--emsize', type=int, default=200, help='size of word embeddings') parser.add_argument('--nhid', type=int, default=200, @@ -45,6 +45,10 @@ help='path to save the final model') parser.add_argument('--onnx-export', type=str, default='', help='path to export the final model in onnx format') + +parser.add_argument('--nhead', type=int, default=2, + help='the number of heads in the encoder/decoder of the transformer model') + args = parser.parse_args() # Set the random seed manually for reproducibility. @@ -92,7 +96,10 @@ def batchify(data, bsz): ############################################################################### ntokens = len(corpus.dictionary) -model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device) +if args.model == 'Transformer': + model = model.TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device) +else: + model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device) criterion = nn.CrossEntropyLoss() @@ -102,6 +109,7 @@ def batchify(data, bsz): def repackage_hidden(h): """Wraps hidden states in new Tensors, to detach them from their history.""" + if isinstance(h, torch.Tensor): return h.detach() else: @@ -130,14 +138,18 @@ def evaluate(data_source): model.eval() total_loss = 0. ntokens = len(corpus.dictionary) - hidden = model.init_hidden(eval_batch_size) + if args.model != 'Transformer': + hidden = model.init_hidden(eval_batch_size) with torch.no_grad(): for i in range(0, data_source.size(0) - 1, args.bptt): data, targets = get_batch(data_source, i) - output, hidden = model(data, hidden) + if args.model == 'Transformer': + output = model(data) + else: + output, hidden = model(data, hidden) + hidden = repackage_hidden(hidden) output_flat = output.view(-1, ntokens) total_loss += len(data) * criterion(output_flat, targets).item() - hidden = repackage_hidden(hidden) return total_loss / (len(data_source) - 1) @@ -147,14 +159,18 @@ def train(): total_loss = 0. start_time = time.time() ntokens = len(corpus.dictionary) - hidden = model.init_hidden(args.batch_size) + if args.model != 'Transformer': + hidden = model.init_hidden(args.batch_size) for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): data, targets = get_batch(train_data, i) # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. - hidden = repackage_hidden(hidden) model.zero_grad() - output, hidden = model(data, hidden) + if args.model == 'Transformer': + output = model(data) + else: + hidden = repackage_hidden(hidden) + output, hidden = model(data, hidden) loss = criterion(output.view(-1, ntokens), targets) loss.backward() @@ -217,7 +233,9 @@ def export_onnx(path, batch_size, seq_len): model = torch.load(f) # after load the rnn params are not a continuous chunk of memory # this makes them a continuous chunk, and will speed up forward pass - model.rnn.flatten_parameters() + # Currently, only rnn model supports flatten_parameters function. + if args.model in ['RNN_TANH', 'RNN_RELU', 'LSTM', 'GRU']: + model.rnn.flatten_parameters() # Run on test data. test_loss = evaluate(test_data) diff --git a/word_language_model/model.py b/word_language_model/model.py index ca51ae5017..a77cdbf2fe 100644 --- a/word_language_model/model.py +++ b/word_language_model/model.py @@ -1,4 +1,7 @@ +import math +import torch import torch.nn as nn +import torch.nn.functional as F class RNNModel(nn.Module): """Container module with an encoder, a recurrent module, and a decoder.""" @@ -55,3 +58,93 @@ def init_hidden(self, bsz): weight.new_zeros(self.nlayers, bsz, self.nhid)) else: return weight.new_zeros(self.nlayers, bsz, self.nhid) + +# Temporarily leave PositionalEncoding module here. Will be moved somewhere else. +class PositionalEncoding(nn.Module): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoding(d_model) + """ + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + Shape: + x: [sequence length, batch size, embed dim] + output: [sequence length, batch size, embed dim] + Examples: + >>> output = pos_encoder(x) + """ + + x = x + self.pe[:x.size(0), :] + return self.dropout(x) + +class TransformerModel(nn.Module): + """Container module with an encoder, a recurrent or transformer module, and a decoder.""" + + def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): + super(TransformerModel, self).__init__() + try: + from torch.nn import TransformerEncoder, TransformerEncoderLayer + except: + raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.') + self.model_type = 'Transformer' + self.src_mask = None + self.pos_encoder = PositionalEncoding(ninp, dropout) + encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + self.encoder = nn.Embedding(ntoken, ninp) + self.ninp = ninp + self.decoder = nn.Linear(ninp, ntoken) + + self.init_weights() + + def _generate_square_subsequent_mask(self, sz): + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + return mask + + def init_weights(self): + initrange = 0.1 + self.encoder.weight.data.uniform_(-initrange, initrange) + self.decoder.bias.data.zero_() + self.decoder.weight.data.uniform_(-initrange, initrange) + + def forward(self, src, has_mask=True): + if has_mask: + device = src.device + if self.src_mask is None or self.src_mask.size(0) != len(src): + mask = self._generate_square_subsequent_mask(len(src)).to(device) + self.src_mask = mask + else: + self.src_mask = None + + src = self.encoder(src) * math.sqrt(self.ninp) + src = self.pos_encoder(src) + output = self.transformer_encoder(src, self.src_mask) + output = self.decoder(output) + return F.log_softmax(output, dim=-1)