Skip to content

Commit

Permalink
Add Timeout to Regex used in the tokenizers (#7284)
Browse files Browse the repository at this point in the history
* Add Timeout to Regex used in the tokenizers

* Address the feedback
  • Loading branch information
tarekgh authored Nov 4, 2024
1 parent 7cce753 commit 5c50319
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ private SentencePieceTokenizer(ModelProto modelProto, IReadOnlyDictionary<string
_specialTokensReverse.Add(item.Value, item.Key);
}

// We create this Regex object without a timeout, as we expect the match operation to complete in \(O(N)\) time complexity. Note that `specialTokens` are treated as constants after the tokenizer is created.
_specialTokensRegex = new Regex(string.Join("|", specialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/Microsoft.ML.Tokenizers/Model/TiktokenTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1175,23 +1175,23 @@ private static (Dictionary<string, int> SpecialTokens, Regex Regex, string Vocab
internal const string R50kBaseTypeName = "Microsoft.ML.Tokenizers.R50kBaseTokenizerData, Microsoft.ML.Tokenizers.Data.R50kBase, Version=1.0.0.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51";

#if NET7_0_OR_GREATER
[GeneratedRegex(Cl100kBaseRegexPattern)]
[GeneratedRegex(Cl100kBaseRegexPattern, RegexOptions.None, PreTokenizer.DefaultTimeOutInMilliseconds)]
private static partial Regex Cl100kBaseRegex();

[GeneratedRegex(P50kBaseRegexPattern)]
[GeneratedRegex(P50kBaseRegexPattern, RegexOptions.None, PreTokenizer.DefaultTimeOutInMilliseconds)]
internal static partial Regex P50kBaseRegex();

[GeneratedRegex(O200kBaseRegexPattern)]
[GeneratedRegex(O200kBaseRegexPattern, RegexOptions.None, PreTokenizer.DefaultTimeOutInMilliseconds)]
internal static partial Regex O200kBaseRegex();
#else
private static Regex? _cl100kBaseRegex;
private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled);
private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled, TimeSpan.FromMilliseconds(PreTokenizer.DefaultTimeOutInMilliseconds));

private static Regex? _p50kBaseRegex;
internal static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled);
internal static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled, TimeSpan.FromMilliseconds(PreTokenizer.DefaultTimeOutInMilliseconds));

private static Regex? _o200kBaseRegex;
internal static Regex O200kBaseRegex() => _o200kBaseRegex ??= new Regex(O200kBaseRegexPattern, RegexOptions.Compiled);
internal static Regex O200kBaseRegex() => _o200kBaseRegex ??= new Regex(O200kBaseRegexPattern, RegexOptions.Compiled, TimeSpan.FromMilliseconds(PreTokenizer.DefaultTimeOutInMilliseconds));
#endif

private static readonly ConcurrentDictionary<string, (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<StringSpanOrdinalKey, (int Id, string Token)> vocab, Dictionary<int, ReadOnlyMemory<byte>> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
Expand Down
15 changes: 9 additions & 6 deletions src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ public abstract partial class PreTokenizer
}
}

// 30 seconds is a reasonable time to process any text and find the match.
internal const int DefaultTimeOutInMilliseconds = 30_000;

private const string WhiteSpaceOrPunctuationPattern = @"\w+|[\p{P}]";
private static PreTokenizer? _whiteSpaceOrPunctuationPreTokenizer;
#if NET7_0_OR_GREATER
[GeneratedRegex(WhiteSpaceOrPunctuationPattern)]
[GeneratedRegex(WhiteSpaceOrPunctuationPattern, RegexOptions.None, DefaultTimeOutInMilliseconds)]
private static partial Regex WhiteSpaceOrPunctuationRegex();
#else
private static Regex WhiteSpaceOrPunctuationRegex() => new Regex(WhiteSpaceOrPunctuationPattern, RegexOptions.Compiled);
private static Regex WhiteSpaceOrPunctuationRegex() => new Regex(WhiteSpaceOrPunctuationPattern, RegexOptions.Compiled, TimeSpan.FromMilliseconds(DefaultTimeOutInMilliseconds));
#endif

/// <summary>
Expand All @@ -69,10 +72,10 @@ public static PreTokenizer CreateWhiteSpaceOrPunctuationPreTokenizer(IReadOnlyDi
private static PreTokenizer? _wordOrNonWordPreTokenizer;

#if NET7_0_OR_GREATER
[GeneratedRegex(WordOrNonWordPattern)]
[GeneratedRegex(WordOrNonWordPattern, RegexOptions.None, DefaultTimeOutInMilliseconds)]
private static partial Regex WordOrNonWordRegex();
#else
private static Regex WordOrNonWordRegex() => new Regex(WordOrNonWordPattern, RegexOptions.Compiled);
private static Regex WordOrNonWordRegex() => new Regex(WordOrNonWordPattern, RegexOptions.Compiled, TimeSpan.FromMilliseconds(DefaultTimeOutInMilliseconds));
#endif

/// <summary>
Expand All @@ -96,10 +99,10 @@ public static PreTokenizer CreateWordOrNonWordPreTokenizer(IReadOnlyDictionary<s
private static PreTokenizer? _whiteSpacePreTokenizer;

#if NET7_0_OR_GREATER
[GeneratedRegex(WhiteSpacePattern)]
[GeneratedRegex(WhiteSpacePattern, RegexOptions.None, DefaultTimeOutInMilliseconds)]
private static partial Regex WhiteSpaceRegex();
#else
private static Regex WhiteSpaceRegex() => new Regex(WhiteSpacePattern, RegexOptions.Compiled);
private static Regex WhiteSpaceRegex() => new Regex(WhiteSpacePattern, RegexOptions.Compiled, TimeSpan.FromMilliseconds(DefaultTimeOutInMilliseconds));
#endif

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public RegexPreTokenizer(Regex regex, IReadOnlyDictionary<string, int>? specialT

if (specialTokensEncoder is { Count: > 0 })
{
// We create this Regex object without a timeout, as we expect the match operation to complete in \(O(N)\) time complexity. Note that `specialTokensEncoder` is treated as constants after the pre-tokenizer is created.
_specialTokensRegex = new Regex(string.Join("|", specialTokensEncoder.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
}
}
Expand Down

0 comments on commit 5c50319

Please sign in to comment.