-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnewtoken.py
217 lines (175 loc) · 7 KB
/
newtoken.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import re
import numpy as np
import matplotlib.pyplot as plt
import regex
# Import from collections module
from collections import defaultdict
class RegexTokenizer:
def __init__(self, split_pattern):
"""
Initialize tokenizer with a specific regex splitting pattern
Args:
split_pattern (str): Regex pattern for tokenization
"""
self.vocab = {}
self.inverse_vocab = {}
self.split_pattern = split_pattern
def _tokenize(self, text):
return regex.findall(self.split_pattern, text)
def get_stats(self, ids):
"""Compute frequency of adjacent token pairs."""
stats = defaultdict(int)
for i in range(len(ids) - 1):
pair = (ids[i], ids[i+1])
stats[pair] += 1
return stats
def merge_vocab(self, ids, pair):
"""Merge most frequent pair of tokens."""
new_ids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and (ids[i], ids[i+1]) == pair:
new_ids.append(len(self.vocab))
i += 2
else:
new_ids.append(ids[i])
i += 1
return new_ids
def train(self, text, vocab_size, verbose=False):
"""
Train the tokenizer using Byte Pair Encoding
Args:
text (str): Training text
vocab_size (int): Desired vocabulary size
verbose (bool, optional): Print detailed info. Defaults to False.
Returns:
dict: Trained vocabulary
"""
# Tokenize text using regex pattern
tokens = self._tokenize(text)
# Convert tokens to ids (start with character-level)
ids = []
for token in tokens:
token_ids = [ord(c) for c in token]
ids.extend(token_ids + [-1]) # -1 as separator between tokens
# Initialize vocab with unique characters and tokens
unique_chars_tokens = set(ids) - {-1}
self.vocab = {chr(char) if char >= 0 else 'SEP': char for char in unique_chars_tokens}
self.inverse_vocab = {char: chr(char) for char in unique_chars_tokens}
# BPE Training loop
while len(self.vocab) < vocab_size:
# Get pair frequencies
stats = self.get_stats(ids)
if not stats:
break
# Find most frequent pair
pair = max(stats, key=stats.get)
# Safely get token representations
token1 = self.inverse_vocab.get(pair[0], chr(pair[0]) if pair[0] >= 0 else 'SEP')
token2 = self.inverse_vocab.get(pair[1], chr(pair[1]) if pair[1] >= 0 else 'SEP')
# Create new token
new_token = token1 + token2
new_token_id = len(self.vocab)
self.vocab[new_token] = new_token_id
self.inverse_vocab[new_token_id] = new_token
# Merge tokens
ids = self.merge_vocab(ids, pair)
# Optional verbose output
if verbose:
print(f"Merged {pair} into {new_token}, vocab size now: {len(self.vocab)}")
# Visualize merged tokens
if verbose:
plt.figure(figsize=(10, 5))
plt.bar(range(len(self.vocab)), [1]*len(self.vocab))
plt.title('Vocabulary Tokens')
plt.xlabel('Token Index')
plt.ylabel('Token Presence')
plt.tight_layout()
plt.show()
return self.vocab
def encode(self, text):
"""
Encode text into token ids
Args:
text (str): Text to encode
Returns:
list: Encoded token ids
"""
# Tokenize using regex pattern
tokens = self._tokenize(text)
# Encode each token
encoded_tokens = []
for token in tokens:
# Convert to ids
token_ids = [ord(c) for c in token]
# Apply learned merges
while True:
stats = self.get_stats(token_ids)
if not stats:
break
# Find the pair with the lowest index token
try:
pair = min(stats, key=lambda p: self.vocab.get(
self.inverse_vocab.get(p[0], chr(p[0])) +
self.inverse_vocab.get(p[1], chr(p[1])),
float('inf')
))
except ValueError:
break
# If no more valid merges, break
if pair not in stats:
break
# Merge
token_ids = self.merge_vocab(token_ids, pair)
encoded_tokens.extend(token_ids)
return encoded_tokens
def decode(self, ids):
decoded_tokens = []
for id in ids:
if id < 0:
# Skip separator tokens
continue
elif id < 256:
# Standard ASCII characters
decoded_tokens.append(chr(id))
elif id in self.inverse_vocab:
# Custom learned tokens from BPE
decoded_tokens.append(self.inverse_vocab[id])
else:
# Fallback for unrecognized tokens
decoded_tokens.append('�') # Unicode replacement character
return ''.join(decoded_tokens)
# Example usage
if __name__ == "__main__":
# GPT-4 recommended regex splitting pattern
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
# Read test text
try:
with open('tests/taylorswift.txt', 'r', encoding='utf-8') as f:
text = f.read()
except FileNotFoundError:
# Fallback text if file not found
text = """
Taylor Swift is an American singer-songwriter.
Her narrative songwriting, often centered on her personal life,
has received widespread critical praise and media coverage.
"""
# Initialize and train tokenizer
tokenizer = RegexTokenizer(GPT4_SPLIT_PATTERN)
vocab = tokenizer.train(text, vocab_size=100, verbose=True)
sample_text = "Taylor's amazing song, 'Blank Space', sold 1,000,000 copies!"
# Tokenize with regex
regex_tokens = tokenizer._tokenize(sample_text)
print("\nRegex Tokens:")
print(regex_tokens)
# Encode and decode
encoded = tokenizer.encode(sample_text)
decoded = tokenizer.decode(encoded)
print("\nEncoding Test:")
print(f"Original: {sample_text}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")
# Print vocabulary
print("\nVocabulary:")
for token, token_id in list(vocab.items())[:20]: # Print first 20 tokens
print(f"{token}: {token_id}")