From 9fe90d461f756b63f77eaa2c9d37498515a0e5b1 Mon Sep 17 00:00:00 2001 From: Hk669 Date: Fri, 7 Jun 2024 11:43:46 +0530 Subject: [PATCH] fix: special_tokens in the encode method to support the special tokens in the vocab --- .../pretrained/wi17k_base/wi17k_base.json | 8 +++++-- bpetokenizer/tokenizer.py | 23 ++++++++++++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/bpetokenizer/pretrained/wi17k_base/wi17k_base.json b/bpetokenizer/pretrained/wi17k_base/wi17k_base.json index aca5ff9..39f8142 100644 --- a/bpetokenizer/pretrained/wi17k_base/wi17k_base.json +++ b/bpetokenizer/pretrained/wi17k_base/wi17k_base.json @@ -17065,7 +17065,9 @@ "(17306, 195)": 17307, "(17307, 163)": 17308, "(1012, 7365)": 17309, - "(9137, 336)": 17310 + "(9137, 336)": 17310, + "(32, 32)": 17320, + "(17320, 32)": 17321 }, "vocab": { "0": "\\u0000", @@ -34380,6 +34382,8 @@ "17309": " differs", "17311": " def", "17312": "_stats", - "17313": " get" + "17313": " get", + "17320": " ", + "17321": " " } } \ No newline at end of file diff --git a/bpetokenizer/tokenizer.py b/bpetokenizer/tokenizer.py index 7d7be60..89f5b1f 100644 --- a/bpetokenizer/tokenizer.py +++ b/bpetokenizer/tokenizer.py @@ -33,11 +33,13 @@ def __init__(self, pattern=None, special_tokens=None): self.special_tokens = {} if special_tokens is None else special_tokens self.inverse_special_tokens = {} if special_tokens is None else {v: k for k, v in special_tokens.items()} self.vocab_size = len(self.vocab) if self.vocab else 0 + self.inverse_merges = {int(v): k for k, v in self.merges.items()} if self.merges else {} @classmethod def from_pretrained(cls, tokenizer_name: str, verbose=False): + """Allows you to load the pretrained tokenizers""" tokenizer = cls() pretrained_dir = 'bpetokenizer/pretrained' tokenizer_file = os.path.join(pretrained_dir, tokenizer_name, f'{tokenizer_name}.json') @@ -60,6 +62,10 @@ def train(self, texts, vocab_size, verbose=False, min_frequency=1) -> None: vocab_size: int (the size of the vocab, gpt4 vocab size is around 100k) verbose: bool (to get extra visibilty and the overview of internal processes) min_frequency: int (the minimum frequency of the pair to be merged and added into the vocab as a new token) + + internal_args: + text_chunks: list[str] + pair: tuple(int, int) """ assert vocab_size >= 256 num_merges = vocab_size - 256 @@ -150,13 +156,16 @@ def encode(self, text, special_tokens="none") -> list: else: raise ValueError(f"invalid special tokens argument: {special_tokens}") - - text_chunks = re.findall(self.compiled_pattern, text) + if not special: + # shortcut: if no special tokens, just use the ordinary encoding + return self.encode_ord(text) + special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")" + text_chunks = re.split(special_pattern, text) ids = [] for chunk in text_chunks: if chunk in self.inverse_vocab: ids.append(self.inverse_vocab[chunk]) - elif chunk in self.special_tokens: + elif special and chunk in self.special_tokens: ids.append(self.special_tokens[chunk]) else: chunk_ids = self._encode(chunk.encode("utf-8")) @@ -164,19 +173,21 @@ def encode(self, text, special_tokens="none") -> list: return ids - def decode(self, ids) -> str: + def decode(self, ids, verbose=False) -> str: part_bytes = [] for idx in ids: if idx in self.vocab: #str conversion because vocab keys are strings when loaded from json part_bytes.append(self.vocab[idx]) elif idx in self.inverse_special_tokens: part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) # special tokens are not encoded in vocab - elif idx in self.merges: - pair = self.merges[idx] + elif idx in self.inverse_merges: + pair = self.inverse_merges[idx] part_bytes.append(self.vocab[pair[0]] + self.vocab[pair[1]]) else: raise ValueError(f"invalid token id: {idx}") text_bytes = b"".join(part_bytes) + if verbose: + print("---\nText bytes: ", text_bytes) text = text_bytes.decode("utf-8", errors="replace") return text