Skip to content

Commit

Permalink
CU-8696nbm03: Remove unigram table (#503)
Browse files Browse the repository at this point in the history
* CU-8696nbm03: Remove use of unigram table

* CU-8696nbm03: Fix usage of new unigram table alternative

* CU-8696nbm03: Remove unigram table from loaded vocabs

* CU-8696nbm03: Add tests for unigram table usage/negative sampling frequency

* CU-8696nbm03: Add small comment to tests

* CU-8696nbm03: Calculate frequencies upon load if not present

* CU-8696nbm03: Update comment regarding probability calculatioons

* CU-8696nbm03: Remove commented test case

* CU-8696n7w95: Fix docstring issue

* CU-8696nbm03: Fix serialisation tests

* CU-8696nbm03: Add python 3.9-friendly method for getting the total of a counter
  • Loading branch information
mart-r authored Nov 27, 2024
1 parent b96310b commit 3c44dcb
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 26 deletions.
51 changes: 27 additions & 24 deletions medcat/vocab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import numpy as np
import pickle
from typing import Optional, List, Dict
import logging


logger = logging.getLogger(__name__)


class Vocab(object):
Expand All @@ -22,7 +26,7 @@ def __init__(self) -> None:
self.vocab: Dict = {}
self.index2word: Dict = {}
self.vec_index2word: Dict = {}
self.unigram_table: np.ndarray = np.array([])
self.cum_probs = np.array([])

def inc_or_add(self, word: str, cnt: int = 1, vec: Optional[np.ndarray] = None) -> None:
"""Add a word or increase its count.
Expand Down Expand Up @@ -172,32 +176,29 @@ def add_words(self, path: str, replace: bool = True) -> None:

self.add_word(word, cnt, vec, replace)

def make_unigram_table(self, table_size: int = 100000000) -> None:
def make_unigram_table(self, table_size: int = -1) -> None:
"""Make unigram table for negative sampling, look at the paper if interested
in details.
Args:
table_size (int):
The size of the table (Defaults to 100 000 000)
The size of the table - no longer needed (Defaults to -1)
"""
if table_size != -1:
logger.warning("Unigram table size is no longer necessary since "
"there is now a simpler approach that doesn't require "
"the creation of a massive array. So therefore, there "
"is no need to pass the `table_size` parameter anymore.")
freqs = []
unigram_table = []

words = list(self.vec_index2word.values())
for word in words:
for word in self.vec_index2word.values():
freqs.append(self[word])

freqs = np.array(freqs)
freqs = np.power(freqs, 3/4)
sm = np.sum(freqs)
# Power and normalize frequencies
freqs = np.array(freqs) ** (3/4)
freqs /= freqs.sum()

for ind in self.vec_index2word.keys():
word = self.vec_index2word[ind]
f_ind = words.index(word)
p = freqs[f_ind] / sm
unigram_table.extend([ind] * int(p * table_size))

self.unigram_table = np.array(unigram_table)
# Calculate cumulative probabilities
self.cum_probs = np.cumsum(freqs)

def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) -> List[int]:
"""Get N negative samples.
Expand All @@ -208,17 +209,14 @@ def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) -
ignore_punct_and_num (bool):
Whether to ignore punctuation and numbers. (Default value = False)
Raises:
Exception: If no unigram table is present.
Returns:
List[int]:
Indices for words in this vocabulary.
"""
if len(self.unigram_table) == 0:
raise Exception("No unigram table present, please run the function vocab.make_unigram_table() first.")
inds = np.random.randint(0, len(self.unigram_table), n)
inds = self.unigram_table[inds]
if len(self.cum_probs) == 0:
self.make_unigram_table()
random_vals = np.random.rand(n)
inds = np.searchsorted(self.cum_probs, random_vals).tolist()

if ignore_punct_and_num:
# Do not return anything that does not have letters in it
Expand Down Expand Up @@ -253,4 +251,9 @@ def load(cls, path: str) -> "Vocab":
with open(path, 'rb') as f:
vocab = cls()
vocab.__dict__ = pickle.load(f)
if not hasattr(vocab, 'cum_probs'):
# NOTE: this is not too expensive, only around 0.05s
vocab.make_unigram_table()
if hasattr(vocab, 'unigram_table'):
del vocab.unigram_table
return vocab
44 changes: 44 additions & 0 deletions tests/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import shutil
import unittest
from medcat.vocab import Vocab
from collections import Counter
import numpy as np


class CATTests(unittest.TestCase):
Expand Down Expand Up @@ -36,5 +38,47 @@ def test_save_and_load(self):
self.assertEqual(["house", "dog", "test"], list(vocab.vocab.keys()))


class VocabUnigramTableTests(unittest.TestCase):
EXAMPLE_DATA_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)),
"..", "examples", "vocab_data.txt")
UNIGRAM_TABLE_SIZE = 10_000
# found that this seed had the closest frequency at the sample size we're at
RANDOM_SEED = 4976
NUM_SAMPLES = 20 # NOTE: 3, 9, 18, and 27 at a time are regular due to context vector sizes
NUM_TIMES = 200
# based on the counts on vocab_data.txt and the one set in setUpClass
EXPECTED_FREQUENCIES = [0.62218692, 0.32422858, 0.0535845]
TOLERANCE = 0.001

@classmethod
def setUpClass(cls):
cls.vocab = Vocab()
cls.vocab.add_words(cls.EXAMPLE_DATA_PATH)
cls.vocab.add_word("test", cnt=1310, vec=[1.42, 1.44, 1.55])
cls.vocab.make_unigram_table(table_size=cls.UNIGRAM_TABLE_SIZE)

def setUp(self):
np.random.seed(self.RANDOM_SEED)

@classmethod
def _get_freqs(cls) -> list[float]:
c = Counter()
for _ in range(cls.NUM_TIMES):
got = cls.vocab.get_negative_samples(cls.NUM_SAMPLES)
c += Counter(got)
total = sum(c[i] for i in c)
got_freqs = [c[i]/total for i in range(len(cls.EXPECTED_FREQUENCIES))]
return got_freqs

def assert_accurate_enough(self, got_freqs: list[float]):
self.assertTrue(
np.max(np.abs(np.array(got_freqs) - self.EXPECTED_FREQUENCIES)) < self.TOLERANCE
)

def test_negative_sampling(self):
got_freqs = self._get_freqs()
self.assert_accurate_enough(got_freqs)


if __name__ == '__main__':
unittest.main()
3 changes: 1 addition & 2 deletions tests/utils/saving/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def test_round_trip(self):
self.assertEqual(cat.vocab.index2word, self.undertest.vocab.index2word)
self.assertEqual(cat.vocab.vec_index2word,
self.undertest.vocab.vec_index2word)
self.assertEqual(cat.vocab.unigram_table,
self.undertest.vocab.unigram_table)
self.assertTrue((cat.vocab.cum_probs == self.undertest.vocab.cum_probs).all())
for name in SPECIALITY_NAMES:
if name in ONE2MANY:
# ignore cui2many and name2many
Expand Down

0 comments on commit 3c44dcb

Please sign in to comment.