From 980f2268924233fcdc570f1779515f8dbcebed04 Mon Sep 17 00:00:00 2001 From: Joseph Catanzarite Date: Mon, 23 Mar 2020 10:57:54 -0700 Subject: [PATCH] Add files via upload --- bleu_metric.ipynb | 541 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 503 insertions(+), 38 deletions(-) diff --git a/bleu_metric.ipynb b/bleu_metric.ipynb index f6c6920..2f5ee20 100644 --- a/bleu_metric.ipynb +++ b/bleu_metric.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 149, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## What is the BLEU metric?" + "# What is the BLEU metric?" ] }, { @@ -57,7 +57,7 @@ "```\n", "the the the the the the the the\n", "```\n", - "for our first example, it would score a precision of 1. for 1-grams (because `the` is in the target sentence, so all the words are considered correct). To avoid that, we put a maximum for a given words to the number of times it can be considered correct, which is the number of times that word is in the target sentence. So in our example, only 2 of the 7 the are considered correct and this clamped precision gives us 2/7 for 1-grams.\n", + "for our first example, it would score a precision of 1. for 1-grams (because `the` is in the target sentence, so all the words are considered correct). To avoid that, we put a maximum for a given words to the number of times it can be considered correct, which is the number of times that word is in the target sentence. So in our example, only 2 of the 7 `the` are considered correct and this clamped precision gives us 2/7 for 1-grams.\n", "\n", "One thing to note is that when we deal with a whole corpus, we take the sum of all the corrects over all the sentences then divide by the sum of all the predicted over all the sentences (we don't avarage precisions over each sentence).\n", "\n", @@ -67,7 +67,9 @@ "```\n", "which is the geometric average of `p1`, `p2`, `p3` and `p4` (our n-gram precision scores) multiplied by a penalty given for the length: we penalize longer predictions with the precision scores, but short ones get it easier, especially if they only contain correct words. So we apply the following penalty:\n", "```\n", - "length_penalty = 1 if len(pred) >= len(targ) else (1 - exp(len(targ)/len(pred)))\n", + "length_penalty = 1 if len(pred) >= len(targ) else (1 - exp(-len(targ)/len(pred)))\n", + "\n", + "NOte there should be a negative sign before the exponent.\n", "```\n", "\n", "And if we are taking the BLEU score of a whole corpus, we use the sum of the lengths of predicted sentences and the sum of the lengths of predicted targets." @@ -77,7 +79,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Let's code this" + "# Let's code this" ] }, { @@ -89,9 +91,23 @@ "Specifically we are going to use the Counter class, which is going to count the number of instances of each n-gram in the predicted sentence and the target one:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Bleu score for unigrams " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Suppose we have two lists of integers that represent word sequences" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 150, "metadata": {}, "outputs": [], "source": [ @@ -99,25 +115,53 @@ "pred = [1,2,3,7,5,1,1]" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Make a Counter object for each list, which is essentially a dictionary of items and number of occurrences" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 151, "metadata": {}, "outputs": [], "source": [ "cnt_pred,cnt_targ = Counter(pred),Counter(targ)" ] }, + { + "cell_type": "code", + "execution_count": 152, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(Counter({1: 3, 2: 1, 3: 1, 7: 1, 5: 1}),\n", + " Counter({1: 2, 2: 2, 3: 1, 4: 1, 5: 1}))" + ] + }, + "execution_count": 152, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cnt_pred,cnt_targ" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now the number of corrects is the number of words in `pred` that are in `targ`, with a cap that is the number of times they appear in `targ`." + "#### The Bleu score is the number of corrects (the number of words in `pred` that are in `targ`) with a cap that is the number of times they appear in `targ`." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 153, "metadata": {}, "outputs": [ { @@ -126,7 +170,7 @@ "5" ] }, - "execution_count": null, + "execution_count": 153, "metadata": {}, "output_type": "execute_result" } @@ -136,43 +180,127 @@ "corrects" ] }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_items([(1, 3), (2, 1), (3, 1), (7, 1), (5, 1)])" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cnt_pred.items()" + ] + }, + { + "cell_type": "code", + "execution_count": 155, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_values([3, 1, 1, 1, 1])" + ] + }, + "execution_count": 155, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cnt_pred.values()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This works for unigrams, which are represented as ints. \n", + "But it won't work for an ngram of more than one word, which is represented as a `list of ints`, since a `Counter` requires the objects inside to be `hashable`. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. BLEU score for ngrams\n", + "Here we develop machinery needed to calculate a BLEU score for ngrams." + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now this works very well for ints, but it won't work for a list of ints, since a `Counter` requires the objects inside to be hashable. That's why we define a custom class:" + "### 2.1 Define a class that `hashes` an ngram, i.e. maps the ngram to a (hopefully) unique integer" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 156, "metadata": {}, "outputs": [], "source": [ + "# NGram class contains methods __eq__ and __hash__\n", + "# Takes inputs ngram and max_n\n", + "# ngram is the vector embedding of a sequence of words\n", "class NGram():\n", - " def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n\n", + " def __init__(self, ngram, max_n=5000): \n", + " self.ngram,self.max_n = ngram,max_n\n", + " \n", + " # Test ngram equivalence\n", " def __eq__(self, other):\n", - " if len(self.ngram) != len(other.ngram): return False\n", - " return np.all(np.array(self.ngram) == np.array(other.ngram))\n", - " def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))" + " # Input: two ngrams\n", + " # Output: boolean indicator. True if the ngrams are the same, False if they are not\n", + " if len(self.ngram) != len(other.ngram): \n", + " return False\n", + " equivalence_indicator = np.all(np.array(self.ngram) == np.array(other.ngram))\n", + " return equivalence_indicator\n", + " \n", + " # Generate an integer hash for each ngram\n", + " def __hash__(self):\n", + " # suppose the ngram is \"I see\"\n", + " # then enumerate(self.ngram) returns [ (0,stoi[\"I\"]), (1,stoi[\"see\"]) ]\n", + " # then output is [ stoi[\"I\"]*5000**0, stoi[\"see\"]*5000**1 ]\n", + " hash_value = int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))\n", + " return hash_value" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "`max_n` should be set to the vocab size, so we're sure we don't have two different ngrams with the same hash (it needs to be an int).\n", - "\n", - "Then we can define `get_grams` to gather all the possible ngrams:" + "#### A `collision` is said to occur if the hash function maps two *different* ngrams to the *same* integer. To avoid collisions, `max_n` should be a large number, such as the vocab size, which makes collisions highly unlikely. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2 Define a function to gather all the possible ngrams of a sequence of words" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 157, "metadata": {}, "outputs": [], "source": [ + "# generate all the ngrams in a sequence of words\n", "def get_grams(x, n, max_n=5000):\n", + " # input:\n", + " # x is a sequence of words or their integer representations\n", + " # n is the length of the ngram: if n = 3 then we are processing a 3-gram\n", + " # output:\n", + " # a list of all the ngrams that can be constructed from x\n", " return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]" ] }, @@ -180,16 +308,280 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "From here, it's easy to compute correctly predicted ngrams:" + "### 2.3 Example" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 158, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "7\n" + ] + }, + { + "data": { + "text/plain": [ + "[<__main__.NGram at 0x14a8cde33c8>,\n", + " <__main__.NGram at 0x14a8cde3240>,\n", + " <__main__.NGram at 0x14a8cde3438>,\n", + " <__main__.NGram at 0x14a8cde39b0>,\n", + " <__main__.NGram at 0x14a8cde3a58>,\n", + " <__main__.NGram at 0x14a8cde3ac8>,\n", + " <__main__.NGram at 0x14a8cde3b38>]" + ] + }, + "execution_count": 158, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ngram_list = get_grams([\"I\",\"see\",\"that\",\"you\",\"like\",\"to\",\"run\",\"fast\"],2,max_n=5000)\n", + "print(len(ngram_list))\n", + "ngram_list" + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n" + ] + } + ], + "source": [ + "# ngram_list is a list object whose elements are ngrams\n", + "print(type(ngram_list))\n", + "\n", + "# an element of ngram_list is an NGram object\n", + "print(type(ngram_list[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['__class__',\n", + " '__delattr__',\n", + " '__dict__',\n", + " '__dir__',\n", + " '__doc__',\n", + " '__eq__',\n", + " '__format__',\n", + " '__ge__',\n", + " '__getattribute__',\n", + " '__gt__',\n", + " '__hash__',\n", + " '__init__',\n", + " '__init_subclass__',\n", + " '__le__',\n", + " '__lt__',\n", + " '__module__',\n", + " '__ne__',\n", + " '__new__',\n", + " '__reduce__',\n", + " '__reduce_ex__',\n", + " '__repr__',\n", + " '__setattr__',\n", + " '__sizeof__',\n", + " '__str__',\n", + " '__subclasshook__',\n", + " '__weakref__',\n", + " 'max_n',\n", + " 'ngram']" + ] + }, + "execution_count": 160, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# An object of the NGram class has static properties to retrieve max_n and ngram\n", + "# There are also `dunder` methods, including \n", + "# __eq__ to test equivalence between two ngrams\n", + "# __hash__ to compute the hash of the ngram\n", + "dir(ngram_list[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Let's see all the ngrams in ngram_list" + ] + }, + { + "cell_type": "code", + "execution_count": 161, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['I', 'see'],\n", + " ['see', 'that'],\n", + " ['that', 'you'],\n", + " ['you', 'like'],\n", + " ['like', 'to'],\n", + " ['to', 'run'],\n", + " ['run', 'fast']]" + ] + }, + "execution_count": 161, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[ngram_list[i].ngram for i in range(len(ngram_list))]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Test the `__eq__` method" + ] + }, + { + "cell_type": "code", + "execution_count": 162, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "False\n", + "True\n" + ] + } + ], + "source": [ + "print(ngram_list[0].ngram.__eq__(ngram_list[1].ngram))\n", + "print(ngram_list[0].ngram.__eq__(ngram_list[0].ngram))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Test the __hash__ method\n", + "In order to `hash` each ngram to an integer, the input `x` to `get_grams()` must be a sequence of integer representations of a sequence of words, not the words themselves" + ] + }, + { + "cell_type": "code", + "execution_count": 163, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[<__main__.NGram at 0x14a8cdeb5c0>,\n", + " <__main__.NGram at 0x14a8cdeb630>,\n", + " <__main__.NGram at 0x14a8cdeb710>,\n", + " <__main__.NGram at 0x14a8cdeb518>,\n", + " <__main__.NGram at 0x14a8cdeb0f0>,\n", + " <__main__.NGram at 0x14a8cdeb358>]" + ] + }, + "execution_count": 163, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# make all possible 3-grams from an input sequence\n", + "ngram_list = get_grams([1,5,6,10,2,1,10,6],n=3,max_n=5000)\n", + "ngram_list" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Here are all the 3-grams" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[1, 5, 6], [5, 6, 10], [6, 10, 2], [10, 2, 1], [2, 1, 10], [1, 10, 6]]" + ] + }, + "execution_count": 164, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[ngram_list[i].ngram for i in range(len(ngram_list))]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Here are the hashes of the 3-grams" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[150025001, 250030005, 50050006, 25010010, 250005002, 150050001]" + ] + }, + "execution_count": 165, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[ngram_list[i].__hash__() for i in range(len(ngram_list))]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.4 Compute the number of correctly predicted ngrams:" + ] + }, + { + "cell_type": "code", + "execution_count": 166, "metadata": {}, "outputs": [], "source": [ "def get_correct_ngrams(pred, targ, n, max_n=5000):\n", + " # inputs pred, targ are predicted and target word sequences\n", " pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n)\n", " pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)\n", " return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)" @@ -199,59 +591,125 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The BLEU metric over two sentences can be written as:" + "#### example" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 167, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 6)" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_correct_ngrams([1,5,4,10,1,1,10,6],[1,5,6,10,2,1,10,6],3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. The BLEU metric over two sentences is implemented as follows:\n", + "\n", + "Assume we include sequences of up to 4 words, i.e., 1-grams, 2-grams, 3-grams, and 4-grams.\n", + "\n", + "Note that the BLEU metric imposes a penalty if the predicted sentence is `shorter` than the target sentence" + ] + }, + { + "cell_type": "code", + "execution_count": 168, "metadata": {}, "outputs": [], "source": [ "def sentence_bleu(pred, targ, max_n=5000):\n", - " corrects = [get_correct_ngrams(pred,targ,n,max_n=max_n) for n in range(1,5)]\n", - " n_precs = [c/t for c,t in corrects]\n", - " len_penalty = exp(1 - len(targ)/len(pred)) if len(pred) < len(targ) else 1\n", - " return len_penalty * ((n_precs[0]*n_precs[1]*n_precs[2]*n_precs[3]) ** 0.25)" + " \n", + " # We count up to 4-grams\n", + " correct_ngrams = [get_correct_ngrams(pred, targ, ngram_size, max_n=max_n) for ngram_size in range(1,4)]\n", + " \n", + " # list with the fraction of ngrams recovered, for sequence lengths of 1, 2, 3, and 4\n", + " frac_recovered = [n_recovered/n_total for n_recovered,n_total in correct_ngrams]\n", + " \n", + " # compute the penalty if predicted sentence is *shorter* than target sentence\n", + " len_penalty = exp(1 - len(targ)/len(pred)) if len(pred) < len(targ) else 1 # between 0 and 1\n", + " \n", + " # raw score is the product of fraction recovered for each sequence length, to the power of \n", + " raw_score = (frac_recovered[0]*frac_recovered[1]*frac_recovered[2]*frac_recovered[3])\n", + " \n", + " # bleu score is product of len_penalty and raw_score ** 0.25 \n", + " bleu = len_penalty * raw_score**0.25 \n", + " return bleu" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "And the BLEU metric over a corpus is:" + "## 4. The BLEU metric over a text corpus" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 171, "metadata": {}, "outputs": [], "source": [ "def corpus_bleu(preds, targs, max_n=5000):\n", + " \n", + " # Analysis includes 1-grams, 2-grams, 3-grams, 4-grams\n", + " \n", + " # Initialize\n", " pred_len,targ_len,n_precs,counts = 0,0,[0]*4,[0]*4\n", + " \n", + " # Loop over the sentences in the corpus\n", " for pred,targ in zip(preds,targs):\n", + " \n", + " # increment length of corpus, in sentences\n", " pred_len += len(pred)\n", " targ_len += len(targ)\n", + " \n", + " # extract all ngrams of each allowed length from the current sentence\n", " for i in range(4):\n", - " c,t = ngram_corrects(pred, targ, i+1, max_n=max_n)\n", - " n_precs[i] += c\n", - " counts[i] += t\n", - " n_precs = [c/t for c,t in zip(n_precs,counts)]\n", + " \n", + " # where is the function ngram_corrects() defined?\n", + " count,total = ngram_corrects(pred, targ, i+1, max_n=max_n)\n", + " n_recovered[i] += count\n", + " n_total[i] += total\n", + " \n", + " # list with fraction of ngrams recovered for each sentence in the corpus\n", + " frac_recovered = [count/total for count,total in zip(n_recovered,n_total)]\n", + " \n", + " # compute a single length penalty for the whole corpus\n", " len_penalty = exp(1 - targ_len/pred_len) if pred_len < targ_len else 1\n", - " return len_penalty * ((n_precs[0]*n_precs[1]*n_precs[2]*n_precs[3]) ** 0.25)" + " \n", + " # compute raw score\n", + " raw_score = frac_recovered [0]*frac_recovered [1]*frac_recovered [2]*frac_recovered [3]\n", + " \n", + " # compute and return bleu metric\n", + " bleu = len_penalty * raw_score ** 0.25\n", + " return bleu" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This takes 11s to run on our validation set (instead of 5 hours), so we can even use it as a metric during training, we have to define it as a `Callback`" + "#### This takes 11s to run on our (Sylvain and Jeremy) validation set (instead of 5 hours), so we can even use it as a metric during training. \n", + "#### We define a `Callback`to do this:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 170, "metadata": {}, "outputs": [], "source": [ @@ -279,6 +737,13 @@ " bleu = len_penalty * ((n_precs[0]*n_precs[1]*n_precs[2]*n_precs[3]) ** 0.25)\n", " return add_metrics(last_metrics, bleu)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -297,7 +762,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.1" + "version": "3.7.3" } }, "nbformat": 4,