diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..36e2df362 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +models/ diff --git a/README.md b/README.md new file mode 100644 index 000000000..fe5a33e09 --- /dev/null +++ b/README.md @@ -0,0 +1,29 @@ +# gpt-2 + +Code and samples from the paper "Language Models are Unsupervised Multitask Learners" + +## Installation + +Download the model data: +``` +gsutil rsync -r gs://gpt-2/models/ models/ +``` + +Install python packages: +``` +pip install -r requirements.txt +``` + +## Sample generation + +| WARNING: Samples are unfiltered and may contain offensive content. | +| --- | + +To generate unconditional samples from the small model: +``` +python3 src/main.py | tee samples +``` +There are various flags for controlling the samples: +``` +python3 src/main.py --top_k 40 --temperature 0.7 | tee samples +``` diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..43934f2fc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +fire>=0.1.3 +tensorflow>=1.12 +regex==2017.4.5 diff --git a/src/encoder.py b/src/encoder.py new file mode 100644 index 000000000..285d643ff --- /dev/null +++ b/src/encoder.py @@ -0,0 +1,120 @@ +"""Byte pair encoding utilities""" + +import os +import json +import regex as re +from functools import lru_cache + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + +class Encoder: + def __init__(self, encoder, bpe_merges, errors='replace'): + self.encoder = encoder + self.decoder = {v:k for k,v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode_text(self, text): + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def encode(self, texts): + return [self.encode_text(text) for text in texts] + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + return text + +def get_encoder(model_name): + with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f: + encoder = json.load(f) + with open(os.path.join('models', model_name, 'vocab.bpe'), 'r') as f: + bpe_data = f.read() + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] + return Encoder( + encoder=encoder, + bpe_merges=bpe_merges, + ) diff --git a/src/main.py b/src/main.py new file mode 100755 index 000000000..98efe8165 --- /dev/null +++ b/src/main.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +import fire +import json +import os +import numpy as np +import tensorflow as tf + +from src import model, sample, encoder + +def sample_model( + model_name='117M', + seed=None, + nsamples=0, + batch_size=1, + length=None, + temperature=1, + top_k=0, +): + np.random.seed(seed) + tf.set_random_seed(seed) + + enc = encoder.get_encoder(model_name) + hparams = model.default_hparams() + with open(os.path.join('models', model_name, 'hparams.json')) as f: + hparams.override_from_dict(json.load(f)) + + if length is None: + length = hparams.n_ctx + elif length > hparams.n_ctx: + raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}") + + with tf.Session(graph=tf.Graph()) as sess: + output = sample.sample_sequence( + hparams=hparams, length=length, + start_token=enc.encoder['<|endoftext|>'], + batch_size=batch_size, + temperature=temperature, top_k=top_k + )[:, 1:] + + saver = tf.train.Saver() + ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) + saver.restore(sess, ckpt) + + generated = 0 + while nsamples == 0 or generated < nsamples: + out = sess.run(output) + for i in range(batch_size): + generated += batch_size + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(f"{text}") + +if __name__ == '__main__': + fire.Fire(sample_model) + diff --git a/src/model.py b/src/model.py new file mode 100644 index 000000000..230b83cc2 --- /dev/null +++ b/src/model.py @@ -0,0 +1,174 @@ +import numpy as np +import tensorflow as tf +from tensorflow.contrib.training import HParams + +def default_hparams(): + return HParams( + n_vocab=0, + n_ctx=1024, + n_embd=768, + n_head=12, + n_layer=12, + ) + +def shape_list(x): + """Deal with dynamic shape in tensorflow cleanly.""" + static = x.shape.as_list() + dynamic = tf.shape(x) + return [dynamic[i] if s is None else s for i, s in enumerate(static)] + +def softmax(x, axis=-1): + x = x - tf.reduce_max(x, axis=axis, keepdims=True) + ex = tf.exp(x) + return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) + +def gelu(x): + return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) + +def norm(x, scope, *, axis=-1, epsilon=1e-5): + """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" + with tf.variable_scope(scope): + n_state = x.shape[-1].value + g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) + b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) + u = tf.reduce_mean(x, axis=axis, keepdims=True) + s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) + x = (x - u) * tf.rsqrt(s + epsilon) + x = x*g + b + return x + +def split_states(x, n): + """Reshape the last dimension of x into [n, x.shape[-1]/n].""" + *start, m = shape_list(x) + return tf.reshape(x, start + [n, m//n]) + +def merge_states(x): + """Smash the last two dimensions of x into a single dimension.""" + *start, a, b = shape_list(x) + return tf.reshape(x, start + [a*b]) + +def conv1d(x, scope, nf, *, w_init_stdev=0.02): + with tf.variable_scope(scope): + *start, nx = shape_list(x) + w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) + b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) + c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf]) + return c + +def attention_mask(nd, ns, *, dtype): + """1's in the lower triangle, counting from the lower right corner. + + Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. + """ + i = tf.range(nd)[:,None] + j = tf.range(ns) + m = i >= j - ns + nd + return tf.cast(m, dtype) + + +def attn(x, scope, n_state, *, past, hparams): + assert x.shape.ndims == 3 # Should be [batch, sequence, features] + assert n_state % hparams.n_head == 0 + if past is not None: + assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] + + def split_heads(x): + # From [batch, sequence, features] to [batch, heads, sequence, features] + return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) + + def merge_heads(x): + # Reverse of split_heads + return merge_states(tf.transpose(x, [0, 2, 1, 3])) + + def mask_attn_weights(w): + # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. + _, _, nd, ns = shape_list(w) + b = attention_mask(nd, ns, dtype=w.dtype) + b = tf.reshape(b, [1, 1, nd, ns]) + w = w*b - tf.cast(1e10, w.dtype)*(1-b) + return w + + def multihead_attn(q, k, v): + # q, k, v have shape [batch, heads, sequence, features] + w = tf.matmul(q, k, transpose_b=True) + w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype)) + + w = mask_attn_weights(w) + w = softmax(w) + a = tf.matmul(w, v) + return a + + with tf.variable_scope(scope): + c = conv1d(x, 'c_attn', n_state*3) + q, k, v = map(split_heads, tf.split(c, 3, axis=2)) + present = tf.stack([k, v], axis=1) + if past is not None: + pk, pv = tf.unstack(past, axis=1) + k = tf.concat([pk, k], axis=-2) + v = tf.concat([pv, v], axis=-2) + a = multihead_attn(q, k, v) + a = merge_heads(a) + a = conv1d(a, 'c_proj', n_state) + return a, present + + +def mlp(x, scope, n_state, *, hparams): + with tf.variable_scope(scope): + nx = x.shape[-1].value + h = gelu(conv1d(x, 'c_fc', n_state)) + h2 = conv1d(h, 'c_proj', nx) + return h2 + + +def block(x, scope, *, past, hparams): + with tf.variable_scope(scope): + nx = x.shape[-1].value + a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams) + x = x + a + m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams) + x = x + m + return x, present + +def past_shape(*, hparams, batch_size=None, sequence=None): + return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head] + +def expand_tile(value, size): + """Add a new axis of given size.""" + value = tf.convert_to_tensor(value, name='value') + ndims = value.shape.ndims + return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims) + +def positions_for(tokens, past_length): + batch_size = tf.shape(tokens)[0] + nsteps = tf.shape(tokens)[1] + return expand_tile(past_length + tf.range(nsteps), batch_size) + + +def model(hparams, X, past=None, scope='model', reuse=False): + with tf.variable_scope(scope, reuse=reuse): + results = {} + batch, sequence = shape_list(X) + + wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd], + initializer=tf.random_normal_initializer(stddev=0.01)) + wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd], + initializer=tf.random_normal_initializer(stddev=0.02)) + past_length = 0 if past is None else tf.shape(past)[-2] + h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) + + # Transformer + presents = [] + pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer + assert len(pasts) == hparams.n_layer + for layer, past in enumerate(pasts): + h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) + presents.append(present) + results['present'] = tf.stack(presents, axis=1) + h = norm(h, 'ln_f') + + # Language model loss. Do tokens