forked from mlc-ai/xgrammar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompiler.cc
468 lines (409 loc) · 17.1 KB
/
compiler.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/compiler.cc
*/
#include <xgrammar/compiler.h>
#include "compiled_grammar_data_structure.h"
#include "grammar_data_structure.h"
#include "matcher_base.h"
#include "support/thread_pool.h"
#include "support/thread_safe_cache.h"
namespace xgrammar {
/******************* AdaptiveTokenMask and CompiledGrammar *******************/
AdaptiveTokenMask::AdaptiveTokenMask(
size_t vocab_size,
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
const std::vector<int32_t>& accepted_indices,
const std::vector<int32_t>& rejected_indices,
const std::vector<int32_t>& uncertain_indices
) {
auto size_acc = accepted_indices.size();
auto size_rej = rejected_indices.size();
store_type = size_acc >= USE_BITSET_THRESHOLD && size_rej >= USE_BITSET_THRESHOLD
? StoreType::kAcceptedBitset
: size_acc < size_rej ? StoreType::kAccepted
: StoreType::kRejected;
if (store_type == StoreType::kAcceptedBitset) {
accepted_bitset = DynamicBitset(vocab_size);
for (auto idx : accepted_indices) {
accepted_bitset.Set(sorted_decoded_vocab[idx].first, true);
}
} else if (store_type == StoreType::kAccepted) {
this->accepted_indices = accepted_indices;
} else {
this->rejected_indices = rejected_indices;
}
this->uncertain_indices = uncertain_indices;
}
Grammar CompiledGrammar::GetGrammar() const { return pimpl_->GetGrammar(); }
TokenizerInfo CompiledGrammar::GetTokenizerInfo() const { return pimpl_->GetTokenizerInfo(); }
/******************* Use GrammarMatcher to generate the AdaptiveTokenMaskCache *******************/
/*! \brief The concrete implementation of GrammarMatcherNode. */
class GrammarMatcherForCompiler : public GrammarMatcherBase {
public:
// Do not expand the initial rule position: we want to find the accepted/rejected tokens
// that exactly start from the initial rule position.
GrammarMatcherForCompiler(const Grammar& grammar, RulePosition init_rule_position)
: GrammarMatcherBase(grammar, init_rule_position, false),
init_rule_id(init_rule_position.rule_id) {}
/*!
* \brief Get the adaptive token mask for the given RulePosition.
* \param consider_parent_rule Whether to consider the parent rule. If false, there will be
* no uncertain tokens. Useful for the root rule.
*/
AdaptiveTokenMask GetAdaptiveTokenMask(
size_t vocab_size,
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
bool consider_parent_rule
);
private:
/*! \brief Check if a token can pass the lookahead assertion. */
bool IsTokenPassLookaheadAssertion(
const std::string& token, const std::vector<bool>& can_reach_end_stack
);
// The id of the initial rule.
int32_t init_rule_id;
// Temporary data for GetAdaptiveTokenMask.
std::vector<int32_t> tmp_accepted_indices_;
std::vector<int32_t> tmp_rejected_indices_;
std::vector<int32_t> tmp_uncertain_indices_;
std::vector<bool> tmp_can_reach_end_stack_;
std::vector<bool> tmp_can_reach_end_prefix_or_stack_;
};
bool GrammarMatcherForCompiler::IsTokenPassLookaheadAssertion(
const std::string& token, const std::vector<bool>& can_reach_end_stack
) {
auto lookahead_assertion_id = grammar_->GetRule(init_rule_id).lookahead_assertion_id;
if (lookahead_assertion_id == -1) {
return true;
}
auto lookahead_rule_position = RulePosition(-1, lookahead_assertion_id, 0);
PushInitialState(lookahead_rule_position, true);
int token_len = token.size();
// Find all positions that can come to and end. Then check if the suffix from that position
// can be accepted by the lookahead assertion.
for (int i = static_cast<int>(can_reach_end_stack.size()); i >= 0; --i) {
if (!can_reach_end_stack[i]) {
continue;
}
int last_accept_pos = i - 1;
for (int pos = i; pos < token_len; ++pos) {
if (!AcceptChar(token[pos])) {
break;
}
last_accept_pos = pos;
// Case 1. The whole rule is finished.
if (CanReachEnd()) {
// accepted chars: pos - i + 1
// we need to rollback the pushed initial state as well
RollbackChars(pos - i + 2);
return true;
}
}
// Case 2. The whole token is accepted
if (last_accept_pos == token_len - 1) {
RollbackChars(last_accept_pos - i + 2);
return true;
}
// Case 3. The token is not accepted. Check the next position.
RollbackChars(last_accept_pos - i + 1);
}
RollbackChars(1);
return false;
}
AdaptiveTokenMask GrammarMatcherForCompiler::GetAdaptiveTokenMask(
size_t vocab_size,
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
bool consider_parent_rule
) {
tmp_accepted_indices_.clear();
tmp_rejected_indices_.clear();
tmp_uncertain_indices_.clear();
// For every character in the current token, stores whether it is possible to reach the end of
// the rule when matching until this character. Store it in a stack for later rollback.
tmp_can_reach_end_stack_.assign({CanReachEnd()});
tmp_can_reach_end_prefix_or_stack_.assign({tmp_can_reach_end_stack_.back()});
int prev_matched_size = 0;
for (int i = 0; i < static_cast<int>(sorted_decoded_vocab.size()); ++i) {
const auto& token = sorted_decoded_vocab[i].second;
bool accepted = true;
// Many tokens may contain the same prefix, so we will avoid unnecessary matching
// by finding the longest common prefix with the previous token.
if (i > 0) {
const auto& prev_token = sorted_decoded_vocab[i - 1].second;
int lcp_len =
std::mismatch(token.begin(), token.end(), prev_token.begin(), prev_token.end()).first -
token.begin();
if (lcp_len > prev_matched_size) {
// Case 1. The common prefix is rejected by the matcher in the last token. Reject directly.
accepted = false;
} else if (lcp_len < prev_matched_size) {
// Case 2. The common prefix is shorter than the previous matched size. Rollback
// the non-common part.
RollbackChars(prev_matched_size - lcp_len);
tmp_can_reach_end_stack_.erase(
tmp_can_reach_end_stack_.end() - (prev_matched_size - lcp_len),
tmp_can_reach_end_stack_.end()
);
tmp_can_reach_end_prefix_or_stack_.erase(
tmp_can_reach_end_prefix_or_stack_.end() - (prev_matched_size - lcp_len),
tmp_can_reach_end_prefix_or_stack_.end()
);
}
prev_matched_size = std::min(prev_matched_size, lcp_len);
}
if (accepted) {
// Accept the rest chars one by one
for (int j = prev_matched_size; j < static_cast<int>(token.size()); ++j) {
if (!AcceptChar(token[j], false)) {
accepted = false;
break;
}
tmp_can_reach_end_stack_.push_back(CanReachEnd());
tmp_can_reach_end_prefix_or_stack_.push_back(
tmp_can_reach_end_stack_.back() || tmp_can_reach_end_prefix_or_stack_.back()
);
prev_matched_size = j + 1;
}
}
bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back();
if (accepted) {
tmp_accepted_indices_.push_back(i);
} else if (can_reach_end && consider_parent_rule &&
IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_)) {
// 1. If the current rule is the root rule (consider_parent_rule=false), there are no
// uncertain tokens. Not accepted tokens are just rejected.
// 2. If a token cannot pass the lookahead assertion, it is rejected.
tmp_uncertain_indices_.push_back(i);
} else {
tmp_rejected_indices_.push_back(i);
}
}
// Rollback the last matched part
RollbackChars(prev_matched_size);
return AdaptiveTokenMask(
vocab_size,
sorted_decoded_vocab,
tmp_accepted_indices_,
tmp_rejected_indices_,
tmp_uncertain_indices_
);
}
CompiledGrammar MultiThreadCompileGrammar(
const Grammar& grammar, const TokenizerInfo& tokenizer_info, int max_threads
) {
using RuleExprType = Grammar::Impl::RuleExprType;
auto compiled_grammar_impl = std::make_shared<CompiledGrammar::Impl>();
compiled_grammar_impl->grammar = grammar;
compiled_grammar_impl->tokenizer_info = tokenizer_info;
if (tokenizer_info.GetVocabSize() == 0) {
return CompiledGrammar(compiled_grammar_impl);
}
// Find the corresponding adaptive token mask for:
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
// 2. All byte strings (with element_in_string=0, 1, 2, ...)
// TODO(Charlie): Figure out how to support ThreadPool and std::mutex in WebAssembly.
// Only declare ThreadPool and mutex if max_threads > 1, so when max_threads = 1, we do
// not need ThreadPool or std::mutex, which throws error in runtime in WebAssembly.
std::optional<ThreadPool> thread_pool;
std::optional<std::mutex> adaptive_token_mask_cache_mutex;
if (max_threads > 1) {
thread_pool.emplace(max_threads);
adaptive_token_mask_cache_mutex.emplace();
}
auto root_rule_id = grammar->GetRootRuleId();
for (int32_t rule_id = 0; rule_id < static_cast<int>(grammar->NumRules()); ++rule_id) {
auto rule = grammar->GetRule(rule_id);
auto rule_body = grammar->GetRuleExpr(rule.body_expr_id);
XGRAMMAR_DCHECK(rule_body.type == RuleExprType::kChoices);
for (auto sequence_id : rule_body) {
auto sequence = grammar->GetRuleExpr(sequence_id);
if (sequence.type == RuleExprType::kEmptyStr) {
continue;
}
XGRAMMAR_DCHECK(sequence.type == RuleExprType::kSequence);
for (int element_id = 0; element_id < sequence.size(); ++element_id) {
auto element = grammar->GetRuleExpr(sequence[element_id]);
if (element.type == RuleExprType::kRuleRef) {
continue;
}
// Define the per-element processing logic for code reuse between
// using thread_pool and not using thread_pool
auto process_element = [&, rule_id, sequence_id, element_id, element]() {
auto add_adaptive_token_mask = [&](const RulePosition& rule_position) {
auto grammar_matcher = GrammarMatcherForCompiler(grammar, rule_position);
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
tokenizer_info.GetVocabSize(),
tokenizer_info.GetSortedDecodedVocab(),
rule_id != root_rule_id
);
if (max_threads > 1) {
std::lock_guard<std::mutex> lock(adaptive_token_mask_cache_mutex.value());
compiled_grammar_impl->adaptive_token_mask_cache[rule_position] =
cur_adaptive_token_mask_cache;
} else {
compiled_grammar_impl->adaptive_token_mask_cache[rule_position] =
cur_adaptive_token_mask_cache;
}
};
auto cur_rule_position = RulePosition(rule_id, sequence_id, element_id);
if (element.type == RuleExprType::kByteString) {
for (int idx = 0; idx < element.size(); ++idx) {
cur_rule_position.element_in_string = idx;
add_adaptive_token_mask(cur_rule_position);
}
} else {
XGRAMMAR_DCHECK(
element.type == RuleExprType::kCharacterClassStar ||
element.type == RuleExprType::kCharacterClass
);
for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) {
cur_rule_position.left_utf8_bytes = left_utf8_bytes;
add_adaptive_token_mask(cur_rule_position);
}
}
};
// Execute depending on whether we use thread_pool
if (max_threads > 1) {
thread_pool->Execute([process_element]() { process_element(); });
} else {
process_element();
}
}
}
}
if (max_threads > 1) {
thread_pool->Join();
}
return CompiledGrammar(compiled_grammar_impl);
}
/******************* GrammarCompiler::Impl *******************/
class GrammarCompiler::Impl {
public:
Impl(const TokenizerInfo& tokenizer_info, int max_threads, bool cache_enabled)
: tokenizer_info_(tokenizer_info),
max_threads_(max_threads),
cache_enabled_(cache_enabled),
compile_json_schema_cache_(GetCompileJSONSchemaCacheFunc(cache_enabled_)),
compile_builtin_json_grammar_cache_(GetCompileBuiltinJSONGrammarCacheFunc(cache_enabled_)),
compile_grammar_cache_(GetCompileGrammarCacheFunc(cache_enabled_)) {}
CompiledGrammar CompileBuiltinJSONGrammar();
CompiledGrammar CompileJSONSchema(
const std::string& schema,
bool any_whitespace,
std::optional<int> indent,
std::optional<std::pair<std::string, std::string>> separators,
bool strict_mode = true
);
CompiledGrammar CompileGrammar(const Grammar& grammar);
void ClearCache();
private:
/*! \brief The cache for the compiled grammar of a JSON schema. */
using SchemaKey =
std::tuple<std::string, bool, std::optional<int>, std::pair<std::string, std::string>, bool>;
std::function<CompiledGrammar(const SchemaKey&)> GetCompileJSONSchemaCacheFunc(bool cache_enabled
) {
if (!cache_enabled) {
return nullptr;
}
return [&](const SchemaKey& key) {
auto [schema, any_whitespace, indent, separators, strict_mode] = key;
auto grammar =
Grammar::FromJSONSchema(schema, any_whitespace, indent, separators, strict_mode);
return MultiThreadCompileGrammar(grammar, tokenizer_info_, max_threads_);
};
}
std::function<CompiledGrammar()> GetCompileBuiltinJSONGrammarCacheFunc(bool cache_enabled) {
if (!cache_enabled) {
return nullptr;
}
return [&]() {
return MultiThreadCompileGrammar(
Grammar::BuiltinJSONGrammar(), tokenizer_info_, max_threads_
);
};
}
using GrammarKey = std::pair<std::string, std::string>;
std::function<CompiledGrammar(const GrammarKey&)> GetCompileGrammarCacheFunc(bool cache_enabled) {
if (!cache_enabled) {
return nullptr;
}
return [&](const GrammarKey& key) {
auto [grammar_str, root_rule_name] = key;
return MultiThreadCompileGrammar(
Grammar::FromEBNF(grammar_str, root_rule_name), tokenizer_info_, max_threads_
);
};
}
/*! \brief The vocabulary associated with this storage class. */
const TokenizerInfo tokenizer_info_;
/*! \brief The maximum number of threads to use. */
const int max_threads_;
/*! \brief Whether the cache is enabled. */
const bool cache_enabled_;
/*! \brief The cache for the compiled grammar of a JSON schema. */
ThreadSafeCache<SchemaKey, CompiledGrammar> compile_json_schema_cache_;
/*! \brief The cache for the compiled grammar for JSON. */
ThreadSafeCache<CompiledGrammar> compile_builtin_json_grammar_cache_;
/*! \brief The cache for the compiled grammar for bnf grammar. */
ThreadSafeCache<GrammarKey, CompiledGrammar> compile_grammar_cache_;
};
CompiledGrammar GrammarCompiler::Impl::CompileBuiltinJSONGrammar() {
if (!cache_enabled_) {
return MultiThreadCompileGrammar(Grammar::BuiltinJSONGrammar(), tokenizer_info_, max_threads_);
}
return compile_builtin_json_grammar_cache_.Get();
}
CompiledGrammar GrammarCompiler::Impl::CompileJSONSchema(
const std::string& schema,
bool any_whitespace,
std::optional<int> indent,
std::optional<std::pair<std::string, std::string>> separators,
bool strict_mode
) {
if (!cache_enabled_) {
return MultiThreadCompileGrammar(
Grammar::FromJSONSchema(schema, any_whitespace, indent, separators, strict_mode),
tokenizer_info_,
max_threads_
);
}
auto separators_value = separators.value_or(
(indent == std::nullopt) ? std::make_pair(", ", ": ") : std::make_pair(",", ": ")
);
auto key = std::make_tuple(schema, any_whitespace, indent, separators_value, strict_mode);
return compile_json_schema_cache_.Get(key);
}
CompiledGrammar GrammarCompiler::Impl::CompileGrammar(const Grammar& grammar) {
if (!cache_enabled_) {
return MultiThreadCompileGrammar(grammar, tokenizer_info_, max_threads_);
}
auto key = std::make_pair(grammar.ToString(), grammar->GetRootRule().name);
return compile_grammar_cache_.Get(key);
}
void GrammarCompiler::Impl::ClearCache() {
compile_builtin_json_grammar_cache_.Clear();
compile_json_schema_cache_.Clear();
}
/******************* GrammarCompiler *******************/
GrammarCompiler::GrammarCompiler(
const TokenizerInfo& tokenizer_info, int max_threads, bool cache_enabled
)
: pimpl_(std::make_shared<Impl>(tokenizer_info, max_threads, cache_enabled)) {}
CompiledGrammar GrammarCompiler::CompileJSONSchema(
const std::string& schema,
bool any_whitespace,
std::optional<int> indent,
std::optional<std::pair<std::string, std::string>> separators,
bool strict_mode
) {
return pimpl_->CompileJSONSchema(schema, any_whitespace, indent, separators, strict_mode);
}
CompiledGrammar GrammarCompiler::CompileBuiltinJSONGrammar() {
return pimpl_->CompileBuiltinJSONGrammar();
}
CompiledGrammar GrammarCompiler::CompileGrammar(const Grammar& grammar) {
return pimpl_->CompileGrammar(grammar);
}
void GrammarCompiler::ClearCache() { pimpl_->ClearCache(); }
} // namespace xgrammar