Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert from pybind11 to nanobind #230

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ set(XGRAMMAR_INCLUDE_PATH
)

file(GLOB_RECURSE XGRAMMAR_SOURCES_PATH "${PROJECT_SOURCE_DIR}/cpp/*.cc")
list(FILTER XGRAMMAR_SOURCES_PATH EXCLUDE REGEX "${PROJECT_SOURCE_DIR}/cpp/pybind/.*\\.cc")
list(FILTER XGRAMMAR_SOURCES_PATH EXCLUDE REGEX "${PROJECT_SOURCE_DIR}/cpp/nanobind/.*\\.cc")

add_library(xgrammar STATIC ${XGRAMMAR_SOURCES_PATH})
target_include_directories(xgrammar PUBLIC ${XGRAMMAR_INCLUDE_PATH})
Expand All @@ -76,7 +76,7 @@ else()
endif()

if(XGRAMMAR_BUILD_PYTHON_BINDINGS)
add_subdirectory(${PROJECT_SOURCE_DIR}/cpp/pybind)
add_subdirectory(${PROJECT_SOURCE_DIR}/cpp/nanobind)
install(TARGETS xgrammar_bindings DESTINATION .)
endif()

Expand Down
35 changes: 35 additions & 0 deletions cpp/nanobind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
find_package(
Python
COMPONENTS Development.Module
REQUIRED
)
find_package(nanobind CONFIG REQUIRED)

# Compile this source file seperately. Nanobind suggests to optimize bindings code for size, but
# this source file contains mostly core logic. See notes about size optimizations here:
# https://nanobind.readthedocs.io/en/latest/api_cmake.html#command:nanobind_add_module
add_library(python_methods STATIC)
target_sources(python_methods PRIVATE python_methods.cc)
target_link_libraries(python_methods PUBLIC xgrammar)

# Any code that uses nanobind directly lives here
nanobind_add_module(xgrammar_bindings LTO nanobind.cc)
target_link_libraries(xgrammar_bindings PRIVATE python_methods)

if(DEFINED SKBUILD_PROJECT_NAME)
# Building wheel through scikit-build-core
set(LIB_OUTPUT_DIRECTORY xgrammar)
else()
set(LIB_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/python/xgrammar)
endif()

set_target_properties(xgrammar_bindings PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${LIB_OUTPUT_DIRECTORY})
set_target_properties(
xgrammar_bindings PROPERTIES LIBRARY_OUTPUT_DIRECTORY_DEBUG ${LIB_OUTPUT_DIRECTORY}
)
set_target_properties(
xgrammar_bindings PROPERTIES LIBRARY_OUTPUT_DIRECTORY_RELEASE ${LIB_OUTPUT_DIRECTORY}
)
set_target_properties(
xgrammar_bindings PROPERTIES LIBRARY_OUTPUT_DIRECTORY_REL_WITH_DEB_INFO ${LIB_OUTPUT_DIRECTORY}
)
211 changes: 211 additions & 0 deletions cpp/nanobind/nanobind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/nanobind/nanobind.cc
*/

#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include <xgrammar/xgrammar.h>

#include "../json_schema_converter.h"
#include "../regex_converter.h"
#include "python_methods.h"

namespace nb = nanobind;
using namespace xgrammar;

std::vector<std::string> CommonEncodedVocabType(
const nb::typed<nb::list, std::variant<std::string, nb::bytes>> encoded_vocab
) {
std::vector<std::string> encoded_vocab_strs;
encoded_vocab_strs.reserve(encoded_vocab.size());
for (const auto& token : encoded_vocab) {
if (nb::bytes result; nb::try_cast(token, result)) {
encoded_vocab_strs.emplace_back(result.c_str());
} else if (nb::str result; nb::try_cast(token, result)) {
encoded_vocab_strs.emplace_back(result.c_str());
} else {
throw nb::type_error("Expected str or bytes for encoded_vocab");
}
}
return encoded_vocab_strs;
}

std::vector<nanobind::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer) {
const auto& decoded_vocab = tokenizer.GetDecodedVocab();
std::vector<nanobind::bytes> py_result;
py_result.reserve(decoded_vocab.size());
for (const auto& item : decoded_vocab) {
py_result.emplace_back(nanobind::bytes(item.c_str()));
}
return py_result;
}

NB_MODULE(xgrammar_bindings, m) {
auto pyTokenizerInfo = nb::class_<TokenizerInfo>(m, "TokenizerInfo");
pyTokenizerInfo
.def(
"__init__",
[](TokenizerInfo* out,
const nb::typed<nb::list, std::variant<std::string, nb::bytes>> encoded_vocab,
std::string vocab_type,
std::optional<int> vocab_size,
std::optional<std::vector<int32_t>> stop_token_ids,
bool add_prefix_space) {
new (out) TokenizerInfo{TokenizerInfo_Init(
CommonEncodedVocabType(encoded_vocab),
std::move(vocab_type),
vocab_size,
std::move(stop_token_ids),
add_prefix_space
)};
},
nb::arg("encoded_vocab"),
nb::arg("vocab_type"),
nb::arg("vocab_size").none(),
nb::arg("stop_token_ids").none(),
nb::arg("add_prefix_space")
)
.def_prop_ro("vocab_type", &TokenizerInfo_GetVocabType)
.def_prop_ro("vocab_size", &TokenizerInfo::GetVocabSize)
.def_prop_ro("add_prefix_space", &TokenizerInfo::GetAddPrefixSpace)
.def_prop_ro("decoded_vocab", &TokenizerInfo_GetDecodedVocab)
.def_prop_ro("stop_token_ids", &TokenizerInfo::GetStopTokenIds)
.def_prop_ro("special_token_ids", &TokenizerInfo::GetSpecialTokenIds)
.def("dump_metadata", &TokenizerInfo::DumpMetadata)
.def_static("from_huggingface", &TokenizerInfo::FromHuggingFace)
.def_static(
"from_vocab_and_metadata",
[](const nb::typed<nb::list, std::variant<std::string, nb::bytes>> encoded_vocab,
const std::string& metadata) {
return TokenizerInfo::FromVocabAndMetadata(
CommonEncodedVocabType(encoded_vocab), metadata
);
}
);

auto pyGrammar = nb::class_<Grammar>(m, "Grammar");
pyGrammar.def("to_string", &Grammar::ToString)
.def_static("from_ebnf", &Grammar::FromEBNF)
.def_static(
"from_json_schema",
&Grammar::FromJSONSchema,
nb::arg("schema"),
nb::arg("any_whitespace"),
nb::arg("indent").none(),
nb::arg("separators").none(),
nb::arg("strict_mode")
)
.def_static("from_regex", &Grammar::FromRegex)
.def_static("from_structural_tag", &Grammar_FromStructuralTag)
.def_static("builtin_json_grammar", &Grammar::BuiltinJSONGrammar)
.def_static("union", &Grammar::Union)
.def_static("concat", &Grammar::Concat);

auto pyCompiledGrammar = nb::class_<CompiledGrammar>(m, "CompiledGrammar");
pyCompiledGrammar.def_prop_ro("grammar", &CompiledGrammar::GetGrammar)
.def_prop_ro("tokenizer_info", &CompiledGrammar::GetTokenizerInfo);

auto pyGrammarCompiler = nb::class_<GrammarCompiler>(m, "GrammarCompiler");
pyGrammarCompiler.def(nb::init<const TokenizerInfo&, int, bool>())
.def(
"compile_json_schema",
&GrammarCompiler::CompileJSONSchema,
nb::call_guard<nb::gil_scoped_release>(),
nb::arg("schema"),
nb::arg("any_whitespace"),
nb::arg("indent").none(),
nb::arg("separators").none(),
nb::arg("strict_mode")
)
.def(
"compile_builtin_json_grammar",
&GrammarCompiler::CompileBuiltinJSONGrammar,
nb::call_guard<nb::gil_scoped_release>()
)
.def(
"compile_structural_tag",
&GrammarCompiler_CompileStructuralTag,
nb::call_guard<nb::gil_scoped_release>()
)
.def(
"compile_regex", &GrammarCompiler::CompileRegex, nb::call_guard<nb::gil_scoped_release>()
)
.def(
"compile_grammar",
&GrammarCompiler::CompileGrammar,
nb::call_guard<nb::gil_scoped_release>()
)
.def("clear_cache", &GrammarCompiler::ClearCache);

auto pyGrammarMatcher = nb::class_<GrammarMatcher>(m, "GrammarMatcher");
pyGrammarMatcher
.def(
nb::init<const CompiledGrammar&, std::optional<std::vector<int>>, bool, int>(),
nb::arg("compiled_grammar"),
nb::arg("override_stop_tokens").none(),
nb::arg("terminate_without_stop_token"),
nb::arg("max_rollback_tokens")
)
.def("accept_token", &GrammarMatcher::AcceptToken)
.def("fill_next_token_bitmask", &GrammarMatcher_FillNextTokenBitmask)
.def("find_jump_forward_string", &GrammarMatcher::FindJumpForwardString)
.def("rollback", &GrammarMatcher::Rollback)
.def("is_terminated", &GrammarMatcher::IsTerminated)
.def("reset", &GrammarMatcher::Reset)
.def_prop_ro("max_rollback_tokens", &GrammarMatcher::GetMaxRollbackTokens)
.def_prop_ro("stop_token_ids", &GrammarMatcher::GetStopTokenIds)
.def("_debug_accept_string", &GrammarMatcher::_DebugAcceptString)
.def(
"_debug_accept_string",
[](GrammarMatcher& self, const nb::bytes& input_str, bool debug_print) {
return self._DebugAcceptString(input_str.c_str(), debug_print);
}
);

auto pyTestingModule = m.def_submodule("testing");
pyTestingModule
.def(
"_json_schema_to_ebnf",
nb::overload_cast<
const std::string&,
bool,
std::optional<int>,
std::optional<std::pair<std::string, std::string>>,
bool>(&JSONSchemaToEBNF),
nb::arg("schema"),
nb::arg("any_whitespace"),
nb::arg("indent").none(),
nb::arg("separators").none(),
nb::arg("strict_mode")
)
.def("_regex_to_ebnf", &RegexToEBNF)
.def("_get_masked_tokens_from_bitmask", &Matcher_DebugGetMaskedTokensFromBitmask)
.def("_get_allow_empty_rule_ids", &GetAllowEmptyRuleIds)
.def(
"_generate_range_regex",
[](std::optional<int> start, std::optional<int> end) {
std::string result = GenerateRangeRegex(start, end);
result.erase(std::remove(result.begin(), result.end(), '\0'), result.end());
return result;
},
nb::arg("start").none(),
nb::arg("end").none()
);

auto pyKernelsModule = m.def_submodule("kernels");
pyKernelsModule.def(
"apply_token_bitmask_inplace_cpu",
&Kernels_ApplyTokenBitmaskInplaceCPU,
nb::arg("logits_ptr"),
nb::arg("logits_shape"),
nb::arg("bitmask_ptr"),
nb::arg("bitmask_shape"),
nb::arg("indices").none()
);
}
12 changes: 1 addition & 11 deletions cpp/pybind/python_methods.cc → cpp/nanobind/python_methods.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/pybind/python_methods.cc
* \file xgrammar/nanobind/python_methods.cc
*/

#include "python_methods.h"
Expand Down Expand Up @@ -43,16 +43,6 @@ std::string TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer) {
return VOCAB_TYPE_NAMES[static_cast<int>(tokenizer.GetVocabType())];
}

std::vector<pybind11::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer) {
const auto& decoded_vocab = tokenizer.GetDecodedVocab();
std::vector<pybind11::bytes> py_result;
py_result.reserve(decoded_vocab.size());
for (const auto& item : decoded_vocab) {
py_result.emplace_back(pybind11::bytes(item));
}
return py_result;
}

bool GrammarMatcher_FillNextTokenBitmask(
GrammarMatcher& matcher,
intptr_t token_bitmask_ptr,
Expand Down
11 changes: 4 additions & 7 deletions cpp/pybind/python_methods.h → cpp/nanobind/python_methods.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
/*!
* Copyright (c) 2024 by Contributors
* \file xgrammar/pybind/python_methods.h
* \file xgrammar/nanobind/python_methods.h
* \brief The header for the support of grammar-guided generation.
*/

#ifndef XGRAMMAR_PYBIND_PYTHON_METHODS_H_
#define XGRAMMAR_PYBIND_PYTHON_METHODS_H_
#ifndef XGRAMMAR_NANOBIND_PYTHON_METHODS_H_
#define XGRAMMAR_NANOBIND_PYTHON_METHODS_H_

#include <pybind11/pybind11.h>
#include <xgrammar/xgrammar.h>

#include <optional>
Expand All @@ -28,8 +27,6 @@ TokenizerInfo TokenizerInfo_Init(

std::string TokenizerInfo_GetVocabType(const TokenizerInfo& tokenizer);

std::vector<pybind11::bytes> TokenizerInfo_GetDecodedVocab(const TokenizerInfo& tokenizer);

bool GrammarMatcher_FillNextTokenBitmask(
GrammarMatcher& matcher,
intptr_t token_bitmask_ptr,
Expand Down Expand Up @@ -65,4 +62,4 @@ CompiledGrammar GrammarCompiler_CompileStructuralTag(

} // namespace xgrammar

#endif // XGRAMMAR_PYBIND_PYTHON_METHODS_H_
#endif // XGRAMMAR_NANOBIND_PYTHON_METHODS_H_
31 changes: 0 additions & 31 deletions cpp/pybind/CMakeLists.txt

This file was deleted.

Loading