-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize word splitter with state machine, replacing regex #58
base: wbrown.fix-streaming-add-md
Are you sure you want to change the base?
Optimize word splitter with state machine, replacing regex #58
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a very ambitious and cool feature! A few requests to make sure this code is clean and idiomatic.
tests: Cache load encoders when not benchmarking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job! Some comments that are mostly nitpicks and good just to document, because it's already doing well when profiled. Once these are resolved I'll approve this.
The only thing I'll mention is that traverseRegexTree
is doing an enormous amount of heavy lifting, and would benefit significantly from modularizing in to small functions. I understand if you're concerned about function call overhead being a potential bottleneck in this case (although if they're small enough your compiler should still inline it), so at the very least I would very heavily comment this for each group-able code chunk.
word := wordsBuffer[idx-1] | ||
return &word |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this change?
func isNewLine(r rune) bool { | ||
// While \n is often considered a whitespace, we treat it as a symbol | ||
// to ensure it is always a separate token. | ||
return r == '\n' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
\n
is technically considered whitespace as you say. It might be a good idea to change the name of isWhitespace
to account for the weird implication that isWhitespace('\n')
is falsy.
// Process replacements and normalization | ||
for replaced, replacement := range encoder.replacements { | ||
line = strings.ReplaceAll(line, replaced, replacement) | ||
// AppendBatch appends a batch of words to the wordBatch and flushes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// AppendBatch appends a batch of words to the wordBatch and flushes | |
// appendBatch appends a batch of words to the wordBatch and flushes |
line = strings.ReplaceAll(line, replaced, replacement) | ||
// AppendBatch appends a batch of words to the wordBatch and flushes | ||
// the batch if it is full. | ||
// ForceFlush forces the batch to be flushed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// ForceFlush forces the batch to be flushed. | |
// forceFlush forces the batch to be flushed. |
Although technically this should be described for the definition of forceFlush
itself rather than here.
if len(v) == 0 { | ||
continue | ||
} | ||
if runes[i] == []rune(k)[0] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting idea to only compare strings if the first runes for each match. Do you know if this is faster than some direct string comparison function from the standard library or Go's builtins?
return root | ||
} | ||
|
||
func (runeTree *RegexNode) createTree(AST *syntax.Regexp, ASTPath []string) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really do not recommend calling your receiver runeTree
. This note only adds confusion with the receiver for RuneNode
being also called runeTree
, and usually struct receivers are abbreviated anyway, so it would be rt
or something. These are different structs and as such really shouldn't be taking the same name for the receiver.
runeArray: sub.Rune, | ||
parent: runeTree, | ||
children: make([]*RegexNode, 0), | ||
terminal: sub.Op == syntax.OpCharClass, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assumption here is that termination always is from a syntax.OpCharClass
op. Why is this?
func (runeTree *RegexNode) PrintTree() { | ||
// Print the tree | ||
sb := strings.Builder{} | ||
runeTree.string(0, &sb) | ||
fmt.Printf("%s\n", sb.String()) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
func (runeTree *RegexNode) PrintTree() { | |
// Print the tree | |
sb := strings.Builder{} | |
runeTree.string(0, &sb) | |
fmt.Printf("%s\n", sb.String()) | |
} | |
func (runeTree *RegexNode) String() string { | |
// Print the tree | |
sb := strings.Builder{} | |
runeTree.string(0, &sb) | |
return sb.String() | |
} |
currentPath = append(currentPath, parentIndex) | ||
|
||
// If not already in the map, add the current path | ||
pathCopy := make([]int, len(currentPath)) | ||
copy(pathCopy, currentPath) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another slight nitpick here. Take with a grain of salt as profiling looks good.
Instead of copy
ing, you could've also created pathCopy
with something like:
pathCopy := append([]int{}, currentPath...)
pathCopy = append(pathCopy, parentIndex)
Which might be slightly more idiomatic, but I suspect there would be a negligible performance difference. Just documenting this.
} | ||
level += 1 | ||
thisNodeMap := matchVars.pathMap[matchVars.currentNodeIdx] | ||
lastNodeMap := make([]int, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tiny nitpick again, but wouldn't this technically be a wasted allocation for currentNodeIdx = 0
?
Summary
Profiling the
dataset_tokenizer
, we find that it is mainly bottlenecked on the lines -> words portion handled by the golang regex package in thewordSplitter
function. In this PR, we propose that this process can be optimized by replacing this regex line with a simple state machine that can handle the splitting of lines into words.Changes
This state machine replaces the regex line in
gpt_bpe.go
'smakeWordSplitter
function. The state machine operates mostly in rune-space instead of string-space as well, which helps with computation times. The state machine is created by decomposing the provided or default regex pattern with thesyntax.regex
golang package and implements a subset of all regex features which should support all tokenizer wordsplitters.A modified runetree, dubbed a
Contraction Tree
was created inrunetree.go
to represent the tree of choices for contractions, as part of the word splitter change.Results
(Rough performance)
Before changes, when tested on a linux VDI, the benchmark yielded roughly
1.7m words/second
and the dataset test yielded around4 million tokens/second
.After changes, the benchmark yields roughly
2.25-2.5 million words/second
using gpt2 as the baseline, and8.5-9.5 million tokens/second
on the dataset test.The dataset test was run with 2 reader threads, 16 tokenizer threads, streaming encode, on the gutenberg dataset via building the
dataset_tokenizer
and running it in the command line.Wes tested this setup with 64 tokenizer threads on a CPU node and reached as high as 67 million tokens/second, over 3x previous maximums, we speculate that this starts to bump up on OS file operation limit rates.