From 3c44dcb347727bcad8639c4623b52629bb165be1 Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Wed, 27 Nov 2024 11:00:16 +0000 Subject: [PATCH] CU-8696nbm03: Remove unigram table (#503) * CU-8696nbm03: Remove use of unigram table * CU-8696nbm03: Fix usage of new unigram table alternative * CU-8696nbm03: Remove unigram table from loaded vocabs * CU-8696nbm03: Add tests for unigram table usage/negative sampling frequency * CU-8696nbm03: Add small comment to tests * CU-8696nbm03: Calculate frequencies upon load if not present * CU-8696nbm03: Update comment regarding probability calculatioons * CU-8696nbm03: Remove commented test case * CU-8696n7w95: Fix docstring issue * CU-8696nbm03: Fix serialisation tests * CU-8696nbm03: Add python 3.9-friendly method for getting the total of a counter --- medcat/vocab.py | 51 +++++++++++++----------- tests/test_vocab.py | 44 ++++++++++++++++++++ tests/utils/saving/test_serialization.py | 3 +- 3 files changed, 72 insertions(+), 26 deletions(-) diff --git a/medcat/vocab.py b/medcat/vocab.py index 56bd1e0d9..88350c945 100644 --- a/medcat/vocab.py +++ b/medcat/vocab.py @@ -1,6 +1,10 @@ import numpy as np import pickle from typing import Optional, List, Dict +import logging + + +logger = logging.getLogger(__name__) class Vocab(object): @@ -22,7 +26,7 @@ def __init__(self) -> None: self.vocab: Dict = {} self.index2word: Dict = {} self.vec_index2word: Dict = {} - self.unigram_table: np.ndarray = np.array([]) + self.cum_probs = np.array([]) def inc_or_add(self, word: str, cnt: int = 1, vec: Optional[np.ndarray] = None) -> None: """Add a word or increase its count. @@ -172,32 +176,29 @@ def add_words(self, path: str, replace: bool = True) -> None: self.add_word(word, cnt, vec, replace) - def make_unigram_table(self, table_size: int = 100000000) -> None: + def make_unigram_table(self, table_size: int = -1) -> None: """Make unigram table for negative sampling, look at the paper if interested in details. Args: table_size (int): - The size of the table (Defaults to 100 000 000) + The size of the table - no longer needed (Defaults to -1) """ + if table_size != -1: + logger.warning("Unigram table size is no longer necessary since " + "there is now a simpler approach that doesn't require " + "the creation of a massive array. So therefore, there " + "is no need to pass the `table_size` parameter anymore.") freqs = [] - unigram_table = [] - - words = list(self.vec_index2word.values()) - for word in words: + for word in self.vec_index2word.values(): freqs.append(self[word]) - freqs = np.array(freqs) - freqs = np.power(freqs, 3/4) - sm = np.sum(freqs) + # Power and normalize frequencies + freqs = np.array(freqs) ** (3/4) + freqs /= freqs.sum() - for ind in self.vec_index2word.keys(): - word = self.vec_index2word[ind] - f_ind = words.index(word) - p = freqs[f_ind] / sm - unigram_table.extend([ind] * int(p * table_size)) - - self.unigram_table = np.array(unigram_table) + # Calculate cumulative probabilities + self.cum_probs = np.cumsum(freqs) def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) -> List[int]: """Get N negative samples. @@ -208,17 +209,14 @@ def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) - ignore_punct_and_num (bool): Whether to ignore punctuation and numbers. (Default value = False) - Raises: - Exception: If no unigram table is present. - Returns: List[int]: Indices for words in this vocabulary. """ - if len(self.unigram_table) == 0: - raise Exception("No unigram table present, please run the function vocab.make_unigram_table() first.") - inds = np.random.randint(0, len(self.unigram_table), n) - inds = self.unigram_table[inds] + if len(self.cum_probs) == 0: + self.make_unigram_table() + random_vals = np.random.rand(n) + inds = np.searchsorted(self.cum_probs, random_vals).tolist() if ignore_punct_and_num: # Do not return anything that does not have letters in it @@ -253,4 +251,9 @@ def load(cls, path: str) -> "Vocab": with open(path, 'rb') as f: vocab = cls() vocab.__dict__ = pickle.load(f) + if not hasattr(vocab, 'cum_probs'): + # NOTE: this is not too expensive, only around 0.05s + vocab.make_unigram_table() + if hasattr(vocab, 'unigram_table'): + del vocab.unigram_table return vocab diff --git a/tests/test_vocab.py b/tests/test_vocab.py index f51cef140..5e4f8e25e 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -2,6 +2,8 @@ import shutil import unittest from medcat.vocab import Vocab +from collections import Counter +import numpy as np class CATTests(unittest.TestCase): @@ -36,5 +38,47 @@ def test_save_and_load(self): self.assertEqual(["house", "dog", "test"], list(vocab.vocab.keys())) +class VocabUnigramTableTests(unittest.TestCase): + EXAMPLE_DATA_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), + "..", "examples", "vocab_data.txt") + UNIGRAM_TABLE_SIZE = 10_000 + # found that this seed had the closest frequency at the sample size we're at + RANDOM_SEED = 4976 + NUM_SAMPLES = 20 # NOTE: 3, 9, 18, and 27 at a time are regular due to context vector sizes + NUM_TIMES = 200 + # based on the counts on vocab_data.txt and the one set in setUpClass + EXPECTED_FREQUENCIES = [0.62218692, 0.32422858, 0.0535845] + TOLERANCE = 0.001 + + @classmethod + def setUpClass(cls): + cls.vocab = Vocab() + cls.vocab.add_words(cls.EXAMPLE_DATA_PATH) + cls.vocab.add_word("test", cnt=1310, vec=[1.42, 1.44, 1.55]) + cls.vocab.make_unigram_table(table_size=cls.UNIGRAM_TABLE_SIZE) + + def setUp(self): + np.random.seed(self.RANDOM_SEED) + + @classmethod + def _get_freqs(cls) -> list[float]: + c = Counter() + for _ in range(cls.NUM_TIMES): + got = cls.vocab.get_negative_samples(cls.NUM_SAMPLES) + c += Counter(got) + total = sum(c[i] for i in c) + got_freqs = [c[i]/total for i in range(len(cls.EXPECTED_FREQUENCIES))] + return got_freqs + + def assert_accurate_enough(self, got_freqs: list[float]): + self.assertTrue( + np.max(np.abs(np.array(got_freqs) - self.EXPECTED_FREQUENCIES)) < self.TOLERANCE + ) + + def test_negative_sampling(self): + got_freqs = self._get_freqs() + self.assert_accurate_enough(got_freqs) + + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/saving/test_serialization.py b/tests/utils/saving/test_serialization.py index cb26312f0..d3eec090d 100644 --- a/tests/utils/saving/test_serialization.py +++ b/tests/utils/saving/test_serialization.py @@ -138,8 +138,7 @@ def test_round_trip(self): self.assertEqual(cat.vocab.index2word, self.undertest.vocab.index2word) self.assertEqual(cat.vocab.vec_index2word, self.undertest.vocab.vec_index2word) - self.assertEqual(cat.vocab.unigram_table, - self.undertest.vocab.unigram_table) + self.assertTrue((cat.vocab.cum_probs == self.undertest.vocab.cum_probs).all()) for name in SPECIALITY_NAMES: if name in ONE2MANY: # ignore cui2many and name2many