Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support newer merges format in tokenizer.json files #392

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions rten-text/src/tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::split::SliceExt;
mod bpe;
mod json;
mod wordpiece;
pub use bpe::{patterns, Bpe, BpeError};
pub use bpe::{merge_pairs_from_lines, patterns, Bpe, BpeError};
pub use wordpiece::{WordPiece, WordPieceOptions};

/// Input sequences for [`Tokenizer::encode`].
Expand Down Expand Up @@ -317,14 +317,21 @@ impl Tokenizer {
.collect()
})
.unwrap_or_default();
let merges: Vec<_> = model.merges.iter().map(|s| s.as_str()).collect();
let merges: Vec<(&str, &str)> = match &model.merges {
json::MergeList::Legacy(lines) => bpe::merge_pairs_from_lines(lines),
json::MergeList::Tuple(pairs) => pairs
.iter()
.map(|(a, b)| (a.as_str(), b.as_str()))
.collect(),
};
let encoder = Bpe::new(
&merges,
bpe::patterns::GPT2,
Some(model.vocab),
added_tokens,
)
.map_err(FromJsonError::BpeError)?;

let tokenizer = Tokenizer::new(
encoder,
TokenizerOptions {
Expand Down
51 changes: 37 additions & 14 deletions rten-text/src/tokenizers/bpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,17 @@ impl BpeBuilder {
/// `merges` contains entries of the BPE merge table. Each entry is a
/// space-separated pair of tokens. Each token is a sequence of byte values
/// encoded using the scheme described in [`char_to_byte`].
fn add_merges(&mut self, merges: &[EncodedByteSlice]) -> Result<(), BpeError> {
fn add_merges(
&mut self,
merges: &[(EncodedByteSlice, EncodedByteSlice)],
) -> Result<(), BpeError> {
// The first 256 ranks are assigned to individual byte values.
let mut rank = 256 + self.ranks.len() as u32;
self.ranks.reserve(merges.len());
self.token_ranks.reserve(merges.len());

for entry in merges.iter() {
if entry.starts_with("#version") || entry.trim().is_empty() {
continue;
}

let invalid_entry = || BpeError::InvalidMergeEntry(entry.to_string());
let (a, b) = entry.split_once(' ').ok_or_else(invalid_entry)?;
for (a, b) in merges.iter().copied() {
let invalid_entry = || BpeError::InvalidMergeEntry(format!("{} {}", a, b));
let a_rank = self.get_token_rank(a).ok_or_else(invalid_entry)?;
let b_rank = self.get_token_rank(b).ok_or_else(invalid_entry)?;
self.ranks.insert((a_rank, b_rank), rank);
Expand All @@ -215,6 +213,25 @@ pub mod patterns {
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
}

/// Parse a list of space-separated BPE merge entries into pairs of tokens.
///
/// Lines that are empty or contain only a `#version` marker are ignored.
pub fn merge_pairs_from_lines(
lines: &[impl AsRef<str>],
) -> Vec<(EncodedByteSlice, EncodedByteSlice)> {
lines
.iter()
.filter_map(|line| {
let line = line.as_ref();
if line.starts_with("#version") || line.trim().is_empty() {
None
} else {
line.split_once(' ')
}
})
.collect()
}

/// Byte Pair Encoding tokenizer used by GPT-2 [^1] and subsequently used by
/// many other models.
///
Expand Down Expand Up @@ -267,7 +284,9 @@ impl Bpe {
/// Create a new Byte Pair Encoding tokenizer.
///
/// `merges` are the ordered entries of the merge list. Each entry is a
/// space-separated pair of strings representing byte sequences.
/// pair of strings representing byte sequences. See also
/// [`merge_pairs_from_lines`] which can be used to extract pairs from
/// the space-separated format used in eg. `merges.txt` files.
///
/// `pattern` is a regex used to split input text into pieces before BPE
/// encoding is applied. The supported syntax is that supported by the
Expand All @@ -284,7 +303,7 @@ impl Bpe {
/// do have a mapping in `vocab`. These are used for special purposes such
/// as representing the end of output.
pub fn new(
merges: &[EncodedByteSlice],
merges: &[(EncodedByteSlice, EncodedByteSlice)],
pattern: &str,
vocab: Option<HashMap<EncodedBytes, TokenId>>,
added_tokens: HashMap<TokenId, String>,
Expand Down Expand Up @@ -473,7 +492,7 @@ mod tests {
use std::collections::HashMap;

use super::patterns::GPT2 as GPT2_SPLIT_PATTERN;
use super::{Bpe, EncodedBytes};
use super::{merge_pairs_from_lines, Bpe, EncodedBytes};
use crate::tokenizers::{TokenId, Tokenizer};

// The first ~25 lines of the merge list from GPT 2.
Expand Down Expand Up @@ -573,7 +592,8 @@ in g";
} in cases
{
let merges: Vec<&str> = merges.lines().collect();
let encoder = Bpe::new(&merges, GPT2_SPLIT_PATTERN, None, HashMap::new()).unwrap();
let merge_pairs = merge_pairs_from_lines(&merges);
let encoder = Bpe::new(&merge_pairs, GPT2_SPLIT_PATTERN, None, HashMap::new()).unwrap();
let tokenizer = Tokenizer::new(encoder, Default::default());
let encoded = tokenizer.encode(text.into(), Default::default()).unwrap();
assert_eq!(
Expand Down Expand Up @@ -610,7 +630,8 @@ in g";
];

let merges: Vec<&str> = MINI_GPT2.lines().collect();
let encoder = Bpe::new(&merges, GPT2_SPLIT_PATTERN, None, added_tokens()).unwrap();
let merge_pairs = merge_pairs_from_lines(&merges);
let encoder = Bpe::new(&merge_pairs, GPT2_SPLIT_PATTERN, None, added_tokens()).unwrap();
let tokenizer = Tokenizer::new(encoder, Default::default());

for Case { input, encoded_str } in cases {
Expand Down Expand Up @@ -666,7 +687,9 @@ in g";
} in cases
{
let merges: Vec<&str> = MINI_GPT2.lines().collect();
let encoder = Bpe::new(&merges, GPT2_SPLIT_PATTERN, vocab, added_tokens()).unwrap();
let merge_pairs = merge_pairs_from_lines(&merges);
let encoder =
Bpe::new(&merge_pairs, GPT2_SPLIT_PATTERN, vocab, added_tokens()).unwrap();
let tokenizer = Tokenizer::new(encoder, Default::default());

let encoded = tokenizer.encode(text.into(), Default::default()).unwrap();
Expand Down
13 changes: 11 additions & 2 deletions rten-text/src/tokenizers/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,22 @@ pub(crate) struct WordPieceModel {
pub vocab: HashMap<String, TokenId>,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub(crate) enum MergeList {
/// Pairs represented as a JSON array.
Tuple(Vec<(String, String)>),
/// Pairs represented as `<token_a> [SPACE] <token_b>`.
Legacy(Vec<String>),
}

#[derive(Deserialize)]
pub(crate) struct BpeModel {
/// Mapping from token text to token ID.
pub vocab: HashMap<String, TokenId>,

/// List of `<token_a> [SPACE] <token_b>` containing tokens to merge.
pub merges: Vec<String>,
/// List of pairs of tokens to merge.
pub merges: MergeList,
}

#[derive(Deserialize)]
Expand Down
Loading
Loading