From 784373a0254be2f6f7801b9f6e89c6223a16e4bc Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 14 Mar 2025 13:08:29 -0700 Subject: [PATCH] Apply clang-format linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- include/pytorch/tokenizers/base64.h | 40 +++++--- .../pytorch/tokenizers/bpe_tokenizer_base.h | 40 ++++---- include/pytorch/tokenizers/error.h | 66 ++++++------- include/pytorch/tokenizers/hf_tokenizer.h | 14 +-- .../pytorch/tokenizers/llama2c_tokenizer.h | 20 ++-- include/pytorch/tokenizers/log.h | 62 +++++++++---- include/pytorch/tokenizers/pre_tokenizer.h | 50 +++++----- include/pytorch/tokenizers/result.h | 93 ++++++++++--------- include/pytorch/tokenizers/sentencepiece.h | 20 ++-- include/pytorch/tokenizers/tiktoken.h | 52 +++++++---- include/pytorch/tokenizers/token_decoder.h | 8 +- include/pytorch/tokenizers/tokenizer.h | 24 +++-- src/bpe_tokenizer_base.cpp | 69 +++++++------- src/hf_tokenizer.cpp | 63 +++++++------ src/llama2c_tokenizer.cpp | 69 ++++++++------ src/pre_tokenizer.cpp | 45 ++++----- src/sentencepiece.cpp | 24 ++--- src/tiktoken.cpp | 76 ++++++++------- src/token_decoder.cpp | 6 +- 19 files changed, 473 insertions(+), 368 deletions(-) diff --git a/include/pytorch/tokenizers/base64.h b/include/pytorch/tokenizers/base64.h index 3dfebc7..9034d7c 100644 --- a/include/pytorch/tokenizers/base64.h +++ b/include/pytorch/tokenizers/base64.h @@ -36,7 +36,7 @@ namespace base64 { using tokenizers::Error; using tokenizers::Result; -Result decode(const std::string_view &input); +Result decode(const std::string_view& input); namespace detail { @@ -68,9 +68,12 @@ inline Error validate(uint32_t v) { return Error::Ok; } -inline Error decode(const std::string_view &input, std::string &output) { - TK_CHECK_OR_RETURN_ERROR(input.size() == 4, Base64DecodeFailure, - "input length must be 4, got %zu", input.size()); +inline Error decode(const std::string_view& input, std::string& output) { + TK_CHECK_OR_RETURN_ERROR( + input.size() == 4, + Base64DecodeFailure, + "input length must be 4, got %zu", + input.size()); uint32_t val = 0; @@ -100,10 +103,14 @@ inline Error decode(const std::string_view &input, std::string &output) { return Error::Ok; } -inline Error decode_1_padding(const std::string_view &input, - std::string &output) { - TK_CHECK_OR_RETURN_ERROR(input.size() == 3, Base64DecodeFailure, - "input length must be 3, got %zu", input.size()); +inline Error decode_1_padding( + const std::string_view& input, + std::string& output) { + TK_CHECK_OR_RETURN_ERROR( + input.size() == 3, + Base64DecodeFailure, + "input length must be 3, got %zu", + input.size()); uint32_t val = 0; @@ -127,10 +134,14 @@ inline Error decode_1_padding(const std::string_view &input, return Error::Ok; } -inline Error decode_2_padding(const std::string_view &input, - std::string &output) { - TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure, - "input length must be 2, got %zu", input.size()); +inline Error decode_2_padding( + const std::string_view& input, + std::string& output) { + TK_CHECK_OR_RETURN_ERROR( + input.size() == 2, + Base64DecodeFailure, + "input length must be 2, got %zu", + input.size()); uint32_t val = 0; @@ -150,12 +161,13 @@ inline Error decode_2_padding(const std::string_view &input, } // namespace detail -inline tokenizers::Result decode(const std::string_view &input) { +inline tokenizers::Result decode(const std::string_view& input) { TK_CHECK_OR_RETURN_ERROR(!input.empty(), Base64DecodeFailure, "empty input"); // Faster than `input.size() % 4`. TK_CHECK_OR_RETURN_ERROR( - (input.size() & 3) == 0 && input.size() >= 4, Base64DecodeFailure, + (input.size() & 3) == 0 && input.size() >= 4, + Base64DecodeFailure, "input length must be larger than 4 and is multiple of 4, got %zu", input.size()); diff --git a/include/pytorch/tokenizers/bpe_tokenizer_base.h b/include/pytorch/tokenizers/bpe_tokenizer_base.h index b3ff663..587e663 100644 --- a/include/pytorch/tokenizers/bpe_tokenizer_base.h +++ b/include/pytorch/tokenizers/bpe_tokenizer_base.h @@ -32,27 +32,29 @@ using Decoder = std::unordered_map; using Re2UPtr = std::unique_ptr; class BPETokenizerBase : public Tokenizer { -public: - Result> encode(const std::string &input, int8_t bos, - int8_t eos) const override; + public: + Result> + encode(const std::string& input, int8_t bos, int8_t eos) const override; - Result decode(uint64_t prev_token, - uint64_t token) const override; + Result decode(uint64_t prev_token, uint64_t token) + const override; -protected: + protected: explicit BPETokenizerBase() {} - virtual ~BPETokenizerBase() override {} + virtual ~BPETokenizerBase() {} std::pair, re2::StringPiece> - split_with_allowed_special_token_(re2::StringPiece &input, - const Encoder &allowed_special) const; + split_with_allowed_special_token_( + re2::StringPiece& input, + const Encoder& allowed_special) const; - Result, uint64_t>> - encode_with_special_token_(const std::string &text, - const Encoder &allowed_special) const; + Result, uint64_t>> encode_with_special_token_( + const std::string& text, + const Encoder& allowed_special) const; - Result> byte_pair_encode_(const std::string &piece, - const Encoder &encoder) const; + Result> byte_pair_encode_( + const std::string& piece, + const Encoder& encoder) const; // Protected members that can be overloaded by other BPE tokenizers Re2UPtr special_token_regex_; @@ -61,11 +63,13 @@ class BPETokenizerBase : public Tokenizer { Decoder decoder_; Decoder special_token_decoder_; -private: - virtual Error _encode(re2::StringPiece &input, std::vector &ret, - uint64_t &last_piece_token_len) const = 0; + private: + virtual Error _encode( + re2::StringPiece& input, + std::vector& ret, + uint64_t& last_piece_token_len) const = 0; - virtual void _decode(re2::StringPiece input, std::string &ret) const = 0; + virtual void _decode(re2::StringPiece input, std::string& ret) const = 0; }; } // namespace detail diff --git a/include/pytorch/tokenizers/error.h b/include/pytorch/tokenizers/error.h index 330ef24..7823f16 100644 --- a/include/pytorch/tokenizers/error.h +++ b/include/pytorch/tokenizers/error.h @@ -70,12 +70,12 @@ enum class Error : error_code_t { * @param[in] message__ Format string for the log error message. * @param[in] ... Optional additional arguments for the format string. */ -#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \ - { \ - if (!(cond__)) { \ - TK_LOG(Error, message__, ##__VA_ARGS__); \ - return ::tokenizers::Error::error__; \ - } \ +#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \ + { \ + if (!(cond__)) { \ + TK_LOG(Error, message__, ##__VA_ARGS__); \ + return ::tokenizers::Error::error__; \ + } \ } /** @@ -86,13 +86,13 @@ enum class Error : error_code_t { * @param[in] ... Optional format string for the log error message and its * arguments. */ -#define TK_CHECK_OK_OR_RETURN_ERROR(error__, ...) \ +#define TK_CHECK_OK_OR_RETURN_ERROR(error__, ...) \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(error__, ##__VA_ARGS__) // Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \ - TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, \ - 4, 3, 2, 1) \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \ + __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \ (__VA_ARGS__) /** @@ -119,43 +119,43 @@ enum class Error : error_code_t { * TK_CHECK_OK_OR_RETURN_ERROR(error_code); // Calls v1 * TK_CHECK_OK_OR_RETURN_ERROR(error_code, "Error message", ...); // Calls v2 */ -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(_1, _2, _3, _4, _5, _6, \ - _7, _8, _9, _10, N, ...) \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_##N // Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \ - do { \ - const auto et_error__ = (error__); \ - if (et_error__ != ::tokenizers::Error::Ok) { \ - return et_error__; \ - } \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \ + do { \ + const auto et_error__ = (error__); \ + if (et_error__ != ::tokenizers::Error::Ok) { \ + return et_error__; \ + } \ } while (0) // Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \ - do { \ - const auto et_error__ = (error__); \ - if (et_error__ != ::tokenizers::Error::Ok) { \ - TK_LOG(Error, message__, ##__VA_ARGS__); \ - return et_error__; \ - } \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \ + do { \ + const auto et_error__ = (error__); \ + if (et_error__ != ::tokenizers::Error::Ok) { \ + TK_LOG(Error, message__, ##__VA_ARGS__); \ + return et_error__; \ + } \ } while (0) // Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 -#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \ TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 diff --git a/include/pytorch/tokenizers/hf_tokenizer.h b/include/pytorch/tokenizers/hf_tokenizer.h index 0f12255..4f8301a 100644 --- a/include/pytorch/tokenizers/hf_tokenizer.h +++ b/include/pytorch/tokenizers/hf_tokenizer.h @@ -27,7 +27,7 @@ namespace tokenizers { class HFTokenizer : public detail::BPETokenizerBase { -public: + public: /*-- Public Interface --*/ /** @@ -39,13 +39,15 @@ class HFTokenizer : public detail::BPETokenizerBase { /** * Load the model data into the */ - Error load(const std::string &tokenizer_path) override; + Error load(const std::string& tokenizer_path) override; -private: - Error _encode(re2::StringPiece &input, std::vector &ret, - uint64_t &last_piece_token_len) const override; + private: + Error _encode( + re2::StringPiece& input, + std::vector& ret, + uint64_t& last_piece_token_len) const override; - void _decode(re2::StringPiece input, std::string &ret) const override; + void _decode(re2::StringPiece input, std::string& ret) const override; PreTokenizer::Ptr _pretokenizer; TokenDecoder::Ptr _decoder; diff --git a/include/pytorch/tokenizers/llama2c_tokenizer.h b/include/pytorch/tokenizers/llama2c_tokenizer.h index 011a1d3..6163b55 100644 --- a/include/pytorch/tokenizers/llama2c_tokenizer.h +++ b/include/pytorch/tokenizers/llama2c_tokenizer.h @@ -7,33 +7,33 @@ */ // @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude #pragma once -#include #include +#include namespace tokenizers { struct TokenIndex { - const char *str; + const char* str; int32_t id; }; // A simple Byte Pair Encoding (BPE) Tokenizer. Note that the current C++ code // won't work with this class, it needs to go through tokenizer.py first. class Llama2cTokenizer : public Tokenizer { -public: + public: explicit Llama2cTokenizer(); ~Llama2cTokenizer() override; - Error load(const std::string &tokenizer_path) override; + Error load(const std::string& tokenizer_path) override; - Result> encode(const std::string &input, int8_t bos, - int8_t eos) const override; + Result> + encode(const std::string& input, int8_t bos, int8_t eos) const override; - Result decode(uint64_t prev_token, - uint64_t token) const override; + Result decode(uint64_t prev_token, uint64_t token) + const override; -private: - std::unique_ptr vocab_ = nullptr; + private: + std::unique_ptr vocab_ = nullptr; std::unique_ptr vocab_scores_ = nullptr; std::unique_ptr sorted_vocab_ = nullptr; unsigned int max_token_length_ = 0; diff --git a/include/pytorch/tokenizers/log.h b/include/pytorch/tokenizers/log.h index 505caa4..0282a2c 100644 --- a/include/pytorch/tokenizers/log.h +++ b/include/pytorch/tokenizers/log.h @@ -40,7 +40,7 @@ #include #define TK_PRINTFLIKE(_string_index, _va_index) _Printf_format_string_ #else -#define TK_PRINTFLIKE(_string_index, _va_index) \ +#define TK_PRINTFLIKE(_string_index, _va_index) \ __attribute__((format(printf, _string_index, _va_index))) #endif @@ -122,8 +122,12 @@ typedef enum { * @param[in] length Message string length. */ inline void TK_INTERNAL_PLATFORM_WEAKNESS tk_pal_emit_log_message( - tk_pal_log_level_t level, const char *filename, const char *function, - size_t line, const char *message, size_t length) { + tk_pal_log_level_t level, + const char* filename, + const char* function, + size_t line, + const char* message, + size_t length) { // Use a format similar to glog and folly::logging, except: // - Print time since et_pal_init since we don't have wall time // - Don't include the thread ID, to avoid adding a threading dependency @@ -131,8 +135,13 @@ inline void TK_INTERNAL_PLATFORM_WEAKNESS tk_pal_emit_log_message( // // Clients who want to change the format or add other fields can override this // weak implementation of et_pal_emit_log_message. - fprintf(TK_LOG_OUTPUT_FILE, "%c tokenizers:%s:%zu] %s\n", level, filename, - line, message); + fprintf( + TK_LOG_OUTPUT_FILE, + "%c tokenizers:%s:%zu] %s\n", + level, + filename, + line, + message); fflush(TK_LOG_OUTPUT_FILE); } @@ -203,8 +212,13 @@ static constexpr tk_pal_log_level_t kLevelToPal[size_t(LogLevel::NumLevels)] = { * @param[in] args Variable argument list. */ TK_PRINTFLIKE(5, 0) -inline void vlogf(LogLevel level, const char *filename, const char *function, - size_t line, const char *format, va_list args) { +inline void vlogf( + LogLevel level, + const char* filename, + const char* function, + size_t line, + const char* format, + va_list args) { // Maximum length of a log message. static constexpr size_t kMaxLogMessageLength = 256; char buf[kMaxLogMessageLength]; @@ -217,8 +231,8 @@ inline void vlogf(LogLevel level, const char *filename, const char *function, tk_pal_log_level_t pal_level = (int(level) >= 0 && level < LogLevel::NumLevels) - ? kLevelToPal[size_t(level)] - : tk_pal_log_level_t::kUnknown; + ? kLevelToPal[size_t(level)] + : tk_pal_log_level_t::kUnknown; tk_pal_emit_log_message(pal_level, filename, function, line, buf, len); } @@ -235,8 +249,13 @@ inline void vlogf(LogLevel level, const char *filename, const char *function, * @param[in] format Format string. */ TK_PRINTFLIKE(5, 6) -inline void logf(LogLevel level, const char *filename, const char *function, - size_t line, const char *format, ...) { +inline void logf( + LogLevel level, + const char* filename, + const char* function, + size_t line, + const char* format, + ...) { #if TK_LOG_ENABLED va_list args; va_start(args, format); @@ -257,14 +276,19 @@ inline void logf(LogLevel level, const char *filename, const char *function, * @param[in] _level Log severity level. * @param[in] _format Log message format string. */ -#define TK_LOG(_level, _format, ...) \ - do { \ - const auto _log_level = ::tokenizers::LogLevel::_level; \ - if (static_cast(_log_level) >= \ - static_cast(::tokenizers::LogLevel::TK_MIN_LOG_LEVEL)) { \ - ::tokenizers::internal::logf(_log_level, TK_SHORT_FILENAME, TK_FUNCTION, \ - TK_LINE, _format, ##__VA_ARGS__); \ - } \ +#define TK_LOG(_level, _format, ...) \ + do { \ + const auto _log_level = ::tokenizers::LogLevel::_level; \ + if (static_cast(_log_level) >= \ + static_cast(::tokenizers::LogLevel::TK_MIN_LOG_LEVEL)) { \ + ::tokenizers::internal::logf( \ + _log_level, \ + TK_SHORT_FILENAME, \ + TK_FUNCTION, \ + TK_LINE, \ + _format, \ + ##__VA_ARGS__); \ + } \ } while (0) #else // TK_LOG_ENABLED diff --git a/include/pytorch/tokenizers/pre_tokenizer.h b/include/pytorch/tokenizers/pre_tokenizer.h index ae9ad83..56218c7 100644 --- a/include/pytorch/tokenizers/pre_tokenizer.h +++ b/include/pytorch/tokenizers/pre_tokenizer.h @@ -28,7 +28,7 @@ namespace tokenizers { * input string piece */ class PreTokenizer { -public: + public: /** Shared pointer type */ typedef std::shared_ptr Ptr; @@ -41,8 +41,8 @@ class PreTokenizer { * NOTE: Pass by value per best practice * https://abseil.io/docs/cpp/guides/strings#string_view */ - virtual std::vector - pre_tokenize(re2::StringPiece input) const = 0; + virtual std::vector pre_tokenize( + re2::StringPiece input) const = 0; virtual ~PreTokenizer() = default; }; // end class PreTokenizer @@ -50,11 +50,11 @@ class PreTokenizer { // -- Factory ------------------------------------------------------------------ // Helper macro to standardize addition of config member fields -#define CONFIG_MEMBER(type, name) \ - std::optional name; \ - PreTokenizerConfig &set_##name(type arg) { \ - this->name = std::move(arg); \ - return *this; \ +#define CONFIG_MEMBER(type, name) \ + std::optional name; \ + PreTokenizerConfig& set_##name(type arg) { \ + this->name = std::move(arg); \ + return *this; \ } /** @@ -75,7 +75,7 @@ class PreTokenizer { * const auto pre_tokenized = pre_tokenizer->pre_tokenize("Hello World!"); */ class PreTokenizerConfig { -public: + public: /*------------------------*/ /* Public mutable members */ /*------------------------*/ @@ -123,7 +123,7 @@ class PreTokenizerConfig { /** * Populate from a json config file */ - PreTokenizerConfig &parse_json(const nlohmann::json &json_config); + PreTokenizerConfig& parse_json(const nlohmann::json& json_config); }; // end class PreTokenizerConfig @@ -137,17 +137,17 @@ class PreTokenizerConfig { // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/tokenizer/pattern.rs#L128 class RegexPreTokenizer : public PreTokenizer { -public: + public: typedef std::unique_ptr Re2UPtr; - explicit RegexPreTokenizer(const std::string &pattern) + explicit RegexPreTokenizer(const std::string& pattern) : regex_(RegexPreTokenizer::create_regex_(pattern)) {} /** Pre-tokenize with the stored regex */ std::vector pre_tokenize(re2::StringPiece input) const; -protected: - static Re2UPtr create_regex_(const std::string &pattern); + protected: + static Re2UPtr create_regex_(const std::string& pattern); Re2UPtr regex_; @@ -159,10 +159,11 @@ class RegexPreTokenizer : public PreTokenizer { // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/digits.rs class DigitsPreTokenizer : public RegexPreTokenizer { -public: + public: explicit DigitsPreTokenizer(bool individual_digits = false) - : RegexPreTokenizer(individual_digits ? R"([^\p{N}]+|\p{N})" - : R"([^\p{N}]+|[\p{N}]+)") {} + : RegexPreTokenizer( + individual_digits ? R"([^\p{N}]+|\p{N})" + : R"([^\p{N}]+|[\p{N}]+)") {} }; // end class DigitsPreTokenizer // -- ByteLevel ---------------------------------------------------------------- @@ -171,21 +172,22 @@ class DigitsPreTokenizer : public RegexPreTokenizer { // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs class ByteLevelPreTokenizer : public PreTokenizer { -public: + public: /** * @param add_prefix_space: Whether to add a leading space to the first word * @param pattern: A user-supplied regex to use for token splitting. If not * provided, it use the standard GPT2 pattern. */ - ByteLevelPreTokenizer(bool add_prefix_space = true, - const std::string &pattern = ""); - explicit ByteLevelPreTokenizer(const std::string &pattern) + ByteLevelPreTokenizer( + bool add_prefix_space = true, + const std::string& pattern = ""); + explicit ByteLevelPreTokenizer(const std::string& pattern) : ByteLevelPreTokenizer(true, pattern) {} /** Perform pre-tokenization */ std::vector pre_tokenize(re2::StringPiece input) const override; -private: + private: const std::string pattern_; const bool add_prefix_space_; @@ -197,7 +199,7 @@ class ByteLevelPreTokenizer : public PreTokenizer { // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/sequence.rs class SequencePreTokenizer : public PreTokenizer { -public: + public: /** * @param pre_tokenizers: The sequence of owned pre-tokenizer objects to use */ @@ -206,7 +208,7 @@ class SequencePreTokenizer : public PreTokenizer { /** Perform pre-tokenization */ std::vector pre_tokenize(re2::StringPiece input) const override; -private: + private: const std::vector pre_tokenizers_; }; // end class ByteLevelPreTokenizer diff --git a/include/pytorch/tokenizers/result.h b/include/pytorch/tokenizers/result.h index e40beb5..868c38c 100644 --- a/include/pytorch/tokenizers/result.h +++ b/include/pytorch/tokenizers/result.h @@ -13,9 +13,9 @@ #pragma once +#include #include #include -#include #include namespace tokenizers { @@ -30,8 +30,9 @@ namespace tokenizers { * void generate() * @endcode */ -template class Result final { -public: +template +class Result final { + public: /// `value_type` member for generic programming. typedef T value_type; @@ -43,17 +44,17 @@ template class Result final { * a non-Ok value. */ /* implicit */ Result(Error error) - : error_(error == Error::Ok ? Error::Internal : error), hasValue_(false) { - } + : error_(error == Error::Ok ? Error::Internal : error), + hasValue_(false) {} /// Value copy constructor. - /* implicit */ Result(const T &val) : value_(val), hasValue_(true) {} + /* implicit */ Result(const T& val) : value_(val), hasValue_(true) {} /// Value move constructor. - /* implicit */ Result(T &&val) : value_(std::move(val)), hasValue_(true) {} + /* implicit */ Result(T&& val) : value_(std::move(val)), hasValue_(true) {} /// Result move constructor. - /* implicit */ Result(Result &&rhs) noexcept : hasValue_(rhs.hasValue_) { + /* implicit */ Result(Result&& rhs) noexcept : hasValue_(rhs.hasValue_) { if (hasValue_) { // Use the value type's move constructor. new (&value_) T(std::move(rhs.value_)); @@ -76,7 +77,9 @@ template class Result final { * If true, it is guaranteed that `error()` will return `Error::Ok`. * If false, it is guaranteed that `error()` will not return `Error::Ok`. */ - bool ok() const { return hasValue_; } + bool ok() const { + return hasValue_; + } /** * Returns the error code of this Result. @@ -98,7 +101,7 @@ template class Result final { * * Only legal to call if `ok()` returns true. */ - T &get() { + T& get() { CheckOk(); return value_; } @@ -108,7 +111,7 @@ template class Result final { * * Only legal to call if `ok()` returns true. */ - const T &get() const { + const T& get() const { CheckOk(); return value_; } @@ -118,29 +121,29 @@ template class Result final { * * Only legal to call if `ok()` returns true. */ - const T &operator*() const &; - T &operator*() &; + const T& operator*() const&; + T& operator*() &; /* * Returns a pointer to the Result's value. * * Only legal to call if `ok()` returns true. */ - const T *operator->() const; - T *operator->(); + const T* operator->() const; + T* operator->(); -private: + private: /** * Delete default constructor since all Results should contain a value or * error. */ Result() = delete; /// Delete copy constructor since T may not be copyable. - Result(const Result &) = delete; + Result(const Result&) = delete; /// Delete copy assignment since T may not be copyable. - Result &operator=(const Result &) = delete; + Result& operator=(const Result&) = delete; /// Delete move assignment since it's not a supported pattern to reuse Result. - Result &operator=(Result &&rhs) = delete; + Result& operator=(Result&& rhs) = delete; // Panics if ok() would return false; void CheckOk() const { @@ -148,7 +151,7 @@ template class Result final { } union { - T value_; // Used if hasValue_ is true. + T value_; // Used if hasValue_ is true. Error error_; // Used if hasValue_ is false. }; @@ -156,22 +159,26 @@ template class Result final { const bool hasValue_; }; -template const T &Result::operator*() const & { +template +const T& Result::operator*() const& { CheckOk(); return value_; } -template T &Result::operator*() & { +template +T& Result::operator*() & { CheckOk(); return value_; } -template const T *Result::operator->() const { +template +const T* Result::operator->() const { CheckOk(); return &value_; } -template T *Result::operator->() { +template +T* Result::operator->() { CheckOk(); return &value_; } @@ -191,34 +198,34 @@ template T *Result::operator->() { #define TK_UNWRAP(result__, ...) TK_INTERNAL_UNWRAP(result__, ##__VA_ARGS__) // Internal only: Use TK_UNWRAP() instead. -#define TK_INTERNAL_UNWRAP(...) \ - TK_INTERNAL_UNWRAP_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \ +#define TK_INTERNAL_UNWRAP(...) \ + TK_INTERNAL_UNWRAP_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \ (__VA_ARGS__) // Internal only: Use TK_UNWRAP() instead. -#define TK_INTERNAL_UNWRAP_SELECT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, \ - ...) \ +#define TK_INTERNAL_UNWRAP_SELECT( \ + _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) \ TK_INTERNAL_UNWRAP_##N // Internal only: Use TK_UNWRAP() instead. -#define TK_INTERNAL_UNWRAP_1(result__) \ - ({ \ - auto et_result__ = (result__); \ - if (!et_result__.ok()) { \ - return et_result__.error(); \ - } \ - std::move(*et_result__); \ +#define TK_INTERNAL_UNWRAP_1(result__) \ + ({ \ + auto et_result__ = (result__); \ + if (!et_result__.ok()) { \ + return et_result__.error(); \ + } \ + std::move(*et_result__); \ }) // Internal only: Use TK_UNWRAP() instead. -#define TK_INTERNAL_UNWRAP_2(result__, message__, ...) \ - ({ \ - auto et_result__ = (result__); \ - if (!et_result__.ok()) { \ - TK_LOG(Error, message__, ##__VA_ARGS__); \ - return et_result__.error(); \ - } \ - std::move(*et_result__); \ +#define TK_INTERNAL_UNWRAP_2(result__, message__, ...) \ + ({ \ + auto et_result__ = (result__); \ + if (!et_result__.ok()) { \ + TK_LOG(Error, message__, ##__VA_ARGS__); \ + return et_result__.error(); \ + } \ + std::move(*et_result__); \ }) // Internal only: Use TK_UNWRAP() instead. diff --git a/include/pytorch/tokenizers/sentencepiece.h b/include/pytorch/tokenizers/sentencepiece.h index 517b903..be7fff6 100644 --- a/include/pytorch/tokenizers/sentencepiece.h +++ b/include/pytorch/tokenizers/sentencepiece.h @@ -10,31 +10,31 @@ // A tokenizer that works with sentencepiece. Used by Llama2. #pragma once -#include "sentencepiece_processor.h" -#include #include +#include #include +#include "sentencepiece_processor.h" namespace tokenizers { struct TokenIndex { - const char *str; + const char* str; int32_t id; }; class SPTokenizer : public Tokenizer { -public: + public: explicit SPTokenizer(); ~SPTokenizer() override; - Error load(const std::string &tokenizer_path) override; + Error load(const std::string& tokenizer_path) override; - Result> encode(const std::string &input, int8_t bos, - int8_t eos) const override; + Result> + encode(const std::string& input, int8_t bos, int8_t eos) const override; - Result decode(uint64_t prev_token, - uint64_t token) const override; + Result decode(uint64_t prev_token, uint64_t token) + const override; -private: + private: std::unique_ptr _processor; }; diff --git a/include/pytorch/tokenizers/tiktoken.h b/include/pytorch/tokenizers/tiktoken.h index 9ee3f95..4706a8e 100644 --- a/include/pytorch/tokenizers/tiktoken.h +++ b/include/pytorch/tokenizers/tiktoken.h @@ -28,11 +28,14 @@ static constexpr size_t kBOSTokenIndex = 0; static constexpr size_t kEOSTokenIndex = 1; class Tiktoken : public detail::BPETokenizerBase { -public: - explicit Tiktoken(std::unique_ptr> special_tokens, - size_t bos_token_index, size_t eos_token_index) + public: + explicit Tiktoken( + std::unique_ptr> special_tokens, + size_t bos_token_index, + size_t eos_token_index) : _special_tokens(std::move(special_tokens)), - _bos_token_index(bos_token_index), _eos_token_index(eos_token_index) { + _bos_token_index(bos_token_index), + _eos_token_index(eos_token_index) { if (_bos_token_index >= _special_tokens->size() || _eos_token_index >= _special_tokens->size()) { abort(); @@ -41,19 +44,27 @@ class Tiktoken : public detail::BPETokenizerBase { explicit Tiktoken() : _special_tokens(_get_default_special_tokens()), - _bos_token_index(kBOSTokenIndex), _eos_token_index(kEOSTokenIndex){}; + _bos_token_index(kBOSTokenIndex), + _eos_token_index(kEOSTokenIndex){}; - Error load(const std::string &tokenizer_path) override; + Error load(const std::string& tokenizer_path) override; -private: + private: static inline std::unique_ptr> _get_default_special_tokens() { auto special_tokens = std::make_unique>(std::vector{ - "<|begin_of_text|>", "<|end_of_text|>", - "<|reserved_special_token_0|>", "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", "<|step_id|>", "<|start_header_id|>", - "<|end_header_id|>", "<|eom_id|>", "<|eot_id|>", "<|python_tag|>"}); + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", + "<|eot_id|>", + "<|python_tag|>"}); // pad the rest of the special tokens with reserved tokens ssize_t reserved_special_token_num = 2; while (special_tokens->size() < kSpecialTokensSize) { @@ -66,18 +77,21 @@ class Tiktoken : public detail::BPETokenizerBase { template std::pair, re2::StringPiece> - _split_with_allowed_special_token(re2::StringPiece &input, - const T &allowed_special) const; + _split_with_allowed_special_token( + re2::StringPiece& input, + const T& allowed_special) const; - Error _encode(re2::StringPiece &input, std::vector &ret, - uint64_t &last_piece_token_len) const override; + Error _encode( + re2::StringPiece& input, + std::vector& ret, + uint64_t& last_piece_token_len) const override; - void _decode(re2::StringPiece input, std::string &ret) const override; + void _decode(re2::StringPiece input, std::string& ret) const override; template - Result, uint64_t>> - _encode_with_special_token(const std::string &text, - const T &allowed_special) const; + Result, uint64_t>> _encode_with_special_token( + const std::string& text, + const T& allowed_special) const; detail::Encoder _build_special_token_encoder(ssize_t num_base_tokens) const; diff --git a/include/pytorch/tokenizers/token_decoder.h b/include/pytorch/tokenizers/token_decoder.h index 30759d6..825e95a 100644 --- a/include/pytorch/tokenizers/token_decoder.h +++ b/include/pytorch/tokenizers/token_decoder.h @@ -28,7 +28,7 @@ namespace tokenizers { * Base class for all token decoders */ class TokenDecoder { -public: + public: /* -- Types -- */ /** Shared pointer type */ @@ -58,7 +58,7 @@ class TokenDecoder { * Factory and config class for creating a new TokenDecoder */ class TokenDecoderConfig { -public: + public: /** * The Type name string matching from decoders * https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/mod.rs#L55 @@ -82,7 +82,7 @@ class TokenDecoderConfig { /** * Populate from a json config file */ - TokenDecoderConfig &parse_json(const nlohmann::json &json_config); + TokenDecoderConfig& parse_json(const nlohmann::json& json_config); }; // end class TokenDecoderConfig // -- ByteLevel ---------------------------------------------------------------- @@ -91,7 +91,7 @@ class TokenDecoderConfig { // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs class ByteLevelTokenDecoder : public TokenDecoder { -public: + public: std::string decode(re2::StringPiece token) const override; }; // end class ByteLevelTokenDecoder diff --git a/include/pytorch/tokenizers/tokenizer.h b/include/pytorch/tokenizers/tokenizer.h index ca0d9e2..23bde19 100644 --- a/include/pytorch/tokenizers/tokenizer.h +++ b/include/pytorch/tokenizers/tokenizer.h @@ -21,14 +21,14 @@ namespace tokenizers { class Tokenizer { -public: + public: explicit Tokenizer() {} virtual ~Tokenizer() {} - virtual Error load(const std::string &tokenizer_path) = 0; + virtual Error load(const std::string& tokenizer_path) = 0; virtual Result> - encode(const std::string &input, int8_t bos, int8_t eos) const = 0; + encode(const std::string& input, int8_t bos, int8_t eos) const = 0; Error decode_verify(uint64_t token) const { if (!initialized_) { @@ -40,17 +40,23 @@ class Tokenizer { return Error::Ok; } - virtual Result decode(uint64_t prev_token, - uint64_t token) const = 0; + virtual Result decode(uint64_t prev_token, uint64_t token) + const = 0; // getters - int32_t vocab_size() const { return vocab_size_; } + int32_t vocab_size() const { + return vocab_size_; + } - uint64_t bos_tok() const { return bos_tok_; } + uint64_t bos_tok() const { + return bos_tok_; + } - uint64_t eos_tok() const { return eos_tok_; } + uint64_t eos_tok() const { + return eos_tok_; + } -protected: + protected: bool initialized_ = false; int32_t vocab_size_ = 0; uint64_t bos_tok_, eos_tok_ = 0; diff --git a/src/bpe_tokenizer_base.cpp b/src/bpe_tokenizer_base.cpp index 1a530b8..6a50b91 100644 --- a/src/bpe_tokenizer_base.cpp +++ b/src/bpe_tokenizer_base.cpp @@ -18,12 +18,14 @@ namespace detail { // ---- Helper utils start ----------------------------------------------------- namespace { -static uint64_t _max_size() { return std::numeric_limits::max(); } +static uint64_t _max_size() { + return std::numeric_limits::max(); +} -static std::vector -_byte_pair_merge(const std::string &piece, - const std::unordered_map &ranks, - std::function func) { +static std::vector _byte_pair_merge( + const std::string& piece, + const std::unordered_map& ranks, + std::function func) { // This is a vector of (start, rank). // The rank is of the byte pair starting at position start. // The rank of the last item in the vector is not a valid value. @@ -33,10 +35,10 @@ _byte_pair_merge(const std::string &piece, parts.emplace_back(idx, _max_size()); } - auto get_rank = - [&piece, &ranks](const std::vector> &parts, - uint64_t start_idx, - uint64_t skip) -> std::optional { + auto get_rank = [&piece, &ranks]( + const std::vector>& parts, + uint64_t start_idx, + uint64_t skip) -> std::optional { if (start_idx + skip + 2 < parts.size()) { auto s = parts[start_idx].first; auto e = parts[start_idx + skip + 2].first; @@ -132,7 +134,8 @@ _byte_pair_merge(const std::string &piece, std::pair, re2::StringPiece> BPETokenizerBase::split_with_allowed_special_token_( - re2::StringPiece &input, const Encoder &allowed_special) const { + re2::StringPiece& input, + const Encoder& allowed_special) const { if (!special_token_regex_) { return std::make_pair(std::nullopt, input); } @@ -158,7 +161,8 @@ BPETokenizerBase::split_with_allowed_special_token_( Result, uint64_t>> BPETokenizerBase::encode_with_special_token_( - const std::string &text, const Encoder &allowed_special) const { + const std::string& text, + const Encoder& allowed_special) const { std::vector tokens; uint64_t last_piece_token_len = 0; re2::StringPiece input(text); @@ -172,7 +176,7 @@ BPETokenizerBase::encode_with_special_token_( uint64_t token = 0; try { token = special_token_encoder_.at(*special); - } catch (const std::out_of_range &) { + } catch (const std::out_of_range&) { // Should never go here, since special pattern includes all special // chars. TK_LOG(Error, "unknown special token: %s\n", special->c_str()); @@ -192,9 +196,9 @@ BPETokenizerBase::encode_with_special_token_( return std::make_pair(tokens, last_piece_token_len); } -Result> -BPETokenizerBase::byte_pair_encode_(const std::string &piece, - const Encoder &encoder) const { +Result> BPETokenizerBase::byte_pair_encode_( + const std::string& piece, + const Encoder& encoder) const { if (piece.size() == 1) { auto iter = encoder.find(piece); if (iter != encoder.end()) { @@ -205,26 +209,27 @@ BPETokenizerBase::byte_pair_encode_(const std::string &piece, } } - return _byte_pair_merge(piece, encoder, - [&piece, &encoder](uint64_t start, uint64_t stop) { - std::string key = piece.substr(start, stop - start); - auto iter = encoder.find(key); - if (iter != encoder.end()) { - return iter->second; - } else { - // TODO: what if key does not exist? Should we - // return `unknown`? assert(false); // ?? - return uint64_t(0); - } - }); + return _byte_pair_merge( + piece, encoder, [&piece, &encoder](uint64_t start, uint64_t stop) { + std::string key = piece.substr(start, stop - start); + auto iter = encoder.find(key); + if (iter != encoder.end()) { + return iter->second; + } else { + // TODO: what if key does not exist? Should we + // return `unknown`? assert(false); // ?? + return uint64_t(0); + } + }); } // ---- protected end ---------------------------------------------------------- // ---- public start ----------------------------------------------------------- -Result> BPETokenizerBase::encode(const std::string &text, - int8_t bos, - int8_t eos) const { +Result> BPETokenizerBase::encode( + const std::string& text, + int8_t bos, + int8_t eos) const { if (!initialized_) { return Error::Uninitialized; } @@ -239,8 +244,8 @@ Result> BPETokenizerBase::encode(const std::string &text, return Result>(std::move(res)); } -Result BPETokenizerBase::decode(uint64_t prev, - uint64_t cur) const { +Result BPETokenizerBase::decode(uint64_t prev, uint64_t cur) + const { (void)prev; if (!initialized_) { return Error::Uninitialized; diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index f2ef4a0..0eefbcc 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -26,7 +26,7 @@ namespace tokenizers { // -------------------------private method end------------------------------- // -------------------------public method start------------------------------- -Error HFTokenizer::load(const std::string &path) { +Error HFTokenizer::load(const std::string& path) { // If this is a directory, look for tokenizer.json and tokenizer_config.json std::string model_json = path; std::string model_config_json = ""; @@ -49,19 +49,19 @@ Error HFTokenizer::load(const std::string &path) { fprintf(stderr, "failed to open encoder file: %s\n", path.c_str()); return Error::LoadFailure; } - std::string contents((std::istreambuf_iterator(file)), - std::istreambuf_iterator()); + std::string contents( + (std::istreambuf_iterator(file)), std::istreambuf_iterator()); json parsed_json; try { parsed_json = json::parse(contents); - } catch (const json::exception &e) { + } catch (const json::exception& e) { std::cout << "Error parsing json file: " << e.what() << std::endl; return Error::LoadFailure; } // Parse the special tokens try { - const auto &special_tokens = parsed_json.at("added_tokens"); + const auto& special_tokens = parsed_json.at("added_tokens"); for (auto it = special_tokens.begin(); it != special_tokens.end(); ++it) { const std::string token = it->at("content"); const uint64_t token_id = it->at("id"); @@ -74,15 +74,15 @@ Error HFTokenizer::load(const std::string &path) { return Error::LoadFailure; } } - } catch (const json::out_of_range &e) { + } catch (const json::out_of_range& e) { fprintf(stderr, "Could not parse special tokens: %s\n", e.what()); return Error::LoadFailure; } // Parse the standard tokens try { - const auto &vocab = parsed_json.at("/model/vocab"_json_pointer); - for (const auto &entry : vocab.items()) { + const auto& vocab = parsed_json.at("/model/vocab"_json_pointer); + for (const auto& entry : vocab.items()) { const std::string token = entry.key(); const uint64_t token_id = entry.value(); // Skip adding special tokens to the standard encoder/decoder @@ -98,7 +98,7 @@ Error HFTokenizer::load(const std::string &path) { } } } - } catch (const json::out_of_range &e) { + } catch (const json::out_of_range& e) { fprintf(stderr, "Could not parse tokens: %s\n", e.what()); return Error::LoadFailure; } @@ -111,7 +111,7 @@ Error HFTokenizer::load(const std::string &path) { _pretokenizer = PreTokenizerConfig() .parse_json(parsed_json.at("pre_tokenizer")) .create(); - } catch (const json::out_of_range &e) { + } catch (const json::out_of_range& e) { fprintf(stderr, "Could not parse pre_tokenizer: %s\n", e.what()); return Error::LoadFailure; } @@ -120,7 +120,7 @@ Error HFTokenizer::load(const std::string &path) { try { _decoder = TokenDecoderConfig().parse_json(parsed_json.at("decoder")).create(); - } catch (const json::out_of_range &e) { + } catch (const json::out_of_range& e) { // No decoder specified } @@ -134,12 +134,13 @@ Error HFTokenizer::load(const std::string &path) { fprintf(stderr, "failed to open encoder file: %s\n", path.c_str()); return Error::LoadFailure; } - std::string config_contents((std::istreambuf_iterator(config_file)), - std::istreambuf_iterator()); + std::string config_contents( + (std::istreambuf_iterator(config_file)), + std::istreambuf_iterator()); json parsed_config_json; try { parsed_config_json = json::parse(config_contents); - } catch (const json::exception &e) { + } catch (const json::exception& e) { std::cout << "Error parsing model config json json file: " << e.what() << std::endl; return Error::LoadFailure; @@ -149,23 +150,23 @@ Error HFTokenizer::load(const std::string &path) { try { const std::string bos_token = parsed_config_json.at("bos_token"); const std::string eos_token = parsed_config_json.at("eos_token"); - const auto &bos_it = special_token_encoder_.find(bos_token); - const auto &eos_it = special_token_encoder_.find(eos_token); + const auto& bos_it = special_token_encoder_.find(bos_token); + const auto& eos_it = special_token_encoder_.find(eos_token); if (bos_it == special_token_encoder_.end()) { - fprintf(stderr, "BOS token %s not in special tokens\n", - bos_token.c_str()); + fprintf( + stderr, "BOS token %s not in special tokens\n", bos_token.c_str()); return Error::LoadFailure; } if (eos_it == special_token_encoder_.end()) { - fprintf(stderr, "EOS token %s not in special tokens\n", - eos_token.c_str()); + fprintf( + stderr, "EOS token %s not in special tokens\n", eos_token.c_str()); return Error::LoadFailure; } bos_tok_ = bos_it->second; eos_tok_ = eos_it->second; - } catch (const json::out_of_range &e) { - fprintf(stderr, "Could not eos/bos from tokenizer config: %s\n", - e.what()); + } catch (const json::out_of_range& e) { + fprintf( + stderr, "Could not eos/bos from tokenizer config: %s\n", e.what()); return Error::LoadFailure; } } @@ -177,7 +178,7 @@ Error HFTokenizer::load(const std::string &path) { else { std::vector bos_candidates; std::vector eos_candidates; - for (const auto &token : special_token_encoder_) { + for (const auto& token : special_token_encoder_) { if (token.first.find("bos") != std::string::npos || token.first.find("begin") != std::string::npos) { bos_candidates.push_back(token.first); @@ -190,7 +191,7 @@ Error HFTokenizer::load(const std::string &path) { if (bos_candidates.size() > 1) { const auto orig_candidates = bos_candidates; bos_candidates.clear(); - for (const auto &cand : orig_candidates) { + for (const auto& cand : orig_candidates) { if (cand.find("text") != std::string::npos) { bos_candidates.push_back(cand); } @@ -199,7 +200,7 @@ Error HFTokenizer::load(const std::string &path) { if (eos_candidates.size() > 1) { const auto orig_candidates = eos_candidates; eos_candidates.clear(); - for (const auto &cand : orig_candidates) { + for (const auto& cand : orig_candidates) { if (cand.find("text") != std::string::npos) { eos_candidates.push_back(cand); } @@ -234,9 +235,11 @@ Error HFTokenizer::load(const std::string &path) { // -------------------------public method end----------------------------------- // -------------------------private method start-------------------------------- -Error HFTokenizer::_encode(re2::StringPiece &input, std::vector &ret, - uint64_t &last_piece_token_len) const { - for (const auto &piece : _pretokenizer->pre_tokenize(input)) { +Error HFTokenizer::_encode( + re2::StringPiece& input, + std::vector& ret, + uint64_t& last_piece_token_len) const { + for (const auto& piece : _pretokenizer->pre_tokenize(input)) { auto iter = encoder_.find(piece); if (iter != encoder_.end()) { last_piece_token_len = 1; @@ -251,7 +254,7 @@ Error HFTokenizer::_encode(re2::StringPiece &input, std::vector &ret, return Error::Ok; } -void HFTokenizer::_decode(re2::StringPiece input, std::string &ret) const { +void HFTokenizer::_decode(re2::StringPiece input, std::string& ret) const { if (_decoder) { ret += _decoder->decode(input); } else { diff --git a/src/llama2c_tokenizer.cpp b/src/llama2c_tokenizer.cpp index b5e8691..951ee3d 100644 --- a/src/llama2c_tokenizer.cpp +++ b/src/llama2c_tokenizer.cpp @@ -6,19 +6,19 @@ * LICENSE file in the root directory of this source tree. */ // @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude -#include #include +#include namespace tokenizers { -static int compare_tokens(const void *a, const void *b) { - if (((TokenIndex *)a)->str == nullptr) { +static int compare_tokens(const void* a, const void* b) { + if (((TokenIndex*)a)->str == nullptr) { return -1; } - if (((TokenIndex *)b)->str == nullptr) { + if (((TokenIndex*)b)->str == nullptr) { return 1; } - return strcmp(((TokenIndex *)a)->str, ((TokenIndex *)b)->str); + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); } Llama2cTokenizer::Llama2cTokenizer() : Tokenizer() { @@ -38,13 +38,13 @@ Llama2cTokenizer::Llama2cTokenizer() : Tokenizer() { * @param tokenizer_path The path to the tokenizer file. * @return Error */ -Error Llama2cTokenizer::load(const std::string &tokenizer_path) { +Error Llama2cTokenizer::load(const std::string& tokenizer_path) { if (initialized_) { TK_LOG(Info, "Tokenizer already initialized"); return Error::Ok; } // read in the file - FILE *file = fopen(tokenizer_path.c_str(), "rb"); + FILE* file = fopen(tokenizer_path.c_str(), "rb"); if (!file) { TK_LOG(Error, "couldn't load %s", tokenizer_path.c_str()); return Error::LoadFailure; @@ -52,10 +52,10 @@ Error Llama2cTokenizer::load(const std::string &tokenizer_path) { int32_t metadata[4]; for (int i = 0; i < 4; i++) { if (fread(metadata + i, sizeof(int32_t), 1, file) != 1) { - TK_LOG(Error, - "Failed to read the metadata at position %d, the tokenizer file " - "is not valid!", - i); + TK_LOG( + Error, + "Failed to read the metadata at position %d, the tokenizer file is not valid!", + i); return Error::ParseFailure; } } @@ -69,7 +69,7 @@ Error Llama2cTokenizer::load(const std::string &tokenizer_path) { max_token_length_ = metadata[3]; // allocate space for the vocabulary - vocab_ = std::make_unique(vocab_size_); + vocab_ = std::make_unique(vocab_size_); vocab_scores_ = std::make_unique(vocab_size_); sorted_vocab_ = std::make_unique(vocab_size_); @@ -90,8 +90,11 @@ Error Llama2cTokenizer::load(const std::string &tokenizer_path) { } vocab_[i] = new char[len + 1]; if (fread(vocab_[i], len, 1, file) != 1) { - TK_LOG(Error, "Failed to read the word, total length %d, index %d\n", len, - i); + TK_LOG( + Error, + "Failed to read the word, total length %d, index %d\n", + len, + i); return Error::ParseFailure; } vocab_[i][len] = '\0'; // add the string terminating token @@ -122,10 +125,11 @@ Llama2cTokenizer::~Llama2cTokenizer() { * @return Result A pointer to the string representation of the * token. */ -Result Llama2cTokenizer::decode(uint64_t prev_token, - uint64_t token) const { +Result Llama2cTokenizer::decode( + uint64_t prev_token, + uint64_t token) const { TK_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token)); - const char *piece = vocab_[token]; + const char* piece = vocab_[token]; // following BOS token, sentencepiece decoder strips any leading // whitespace if (prev_token == bos_tok_ && piece[0] == ' ') { @@ -135,19 +139,19 @@ Result Llama2cTokenizer::decode(uint64_t prev_token, // parse this and convert and return the actual byte unsigned char byte_val; if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { - piece = (char *)byte_pieces_ + byte_val * 2; + piece = (char*)byte_pieces_ + byte_val * 2; } std::string res(piece); return res; } -static int32_t str_lookup(const char *str, TokenIndex *sorted_vocab, - int32_t vocab_size) { +static int32_t +str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) { // efficiently find the perfect match for str in vocab, return its index or -1 // if not found TokenIndex tok = {.str = str}; // acts as the key to search for - TokenIndex *res = (TokenIndex *)bsearch(&tok, sorted_vocab, vocab_size, - sizeof(TokenIndex), compare_tokens); + TokenIndex* res = (TokenIndex*)bsearch( + &tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); return res != nullptr ? res->id : -1; } @@ -161,9 +165,10 @@ static int32_t str_lookup(const char *str, TokenIndex *sorted_vocab, * @param n_tokens The number of tokens. * @return Result> */ -Result> Llama2cTokenizer::encode(const std::string &text, - int8_t bos, - int8_t eos) const { +Result> Llama2cTokenizer::encode( + const std::string& text, + int8_t bos, + int8_t eos) const { if (!initialized_) { TK_LOG(Error, "Tokenizer not initialized"); return Error::Uninitialized; @@ -179,7 +184,7 @@ Result> Llama2cTokenizer::encode(const std::string &text, // create a temporary buffer that will store merge candidates of always two // consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in // case max_token_length is 1) - char *str_buffer = new char[max_token_length_ * 2 + 1 + 2]; + char* str_buffer = new char[max_token_length_ * 2 + 1 + 2]; size_t str_len = 0; // start at 0 tokens @@ -200,7 +205,7 @@ Result> Llama2cTokenizer::encode(const std::string &text, // TODO: pretty sure this isn't correct in the general case but I don't have // the energy to read more of the sentencepiece code to figure out what it's // doing - const char *space = " "; + const char* space = " "; if (text[0] != '\0') { int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_); tokens.push_back(dummy_prefix); @@ -215,7 +220,7 @@ Result> Llama2cTokenizer::encode(const std::string &text, // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx // process the raw (UTF-8) byte sequence of the input string - for (const char *c = text.c_str(); *c != '\0'; c++) { + for (const char* c = text.c_str(); *c != '\0'; c++) { // reset buffer if the current byte is ASCII or a leading byte // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the // rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in @@ -265,8 +270,12 @@ Result> Llama2cTokenizer::encode(const std::string &text, for (int i = 0; i < tokens.size() - 1; i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) - snprintf(str_buffer, max_token_length_ * 2 + 3, "%s%s", vocab_[tokens[i]], - vocab_[tokens[i + 1]]); + snprintf( + str_buffer, + max_token_length_ * 2 + 3, + "%s%s", + vocab_[tokens[i]], + vocab_[tokens[i + 1]]); int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); if (id != -1 && vocab_scores_[id] > best_score) { // this merge pair exists in vocab! record its score and position diff --git a/src/pre_tokenizer.cpp b/src/pre_tokenizer.cpp index 9b47e20..5e6e662 100644 --- a/src/pre_tokenizer.cpp +++ b/src/pre_tokenizer.cpp @@ -63,35 +63,37 @@ PreTokenizer::Ptr PreTokenizerConfig::create() const { "Missing pretokenizers for PreTokenizer of type Sequence"); } std::vector pretoks; - std::transform(pretokenizers->begin(), pretokenizers->end(), - std::back_inserter(pretoks), - [](const PreTokenizerConfig &cfg) { return cfg.create(); }); + std::transform( + pretokenizers->begin(), + pretokenizers->end(), + std::back_inserter(pretoks), + [](const PreTokenizerConfig& cfg) { return cfg.create(); }); return PreTokenizer::Ptr(new SequencePreTokenizer(pretoks)); } throw std::runtime_error("Unsupported PreTokenizer type: " + type); } -PreTokenizerConfig &PreTokenizerConfig::parse_json(const json &json_config) { +PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) { type = json_config.at("type"); if (type == "Split") { try { pattern = json_config.at("pattern"); - } catch (json::out_of_range &) { + } catch (json::out_of_range&) { } } else if (type == "Digits") { try { individual_digits = json_config.at("individual_digits"); - } catch (json::out_of_range &) { + } catch (json::out_of_range&) { } } else if (type == "ByteLevel") { try { add_prefix_space = json_config.at("add_prefix_space"); - } catch (json::out_of_range &) { + } catch (json::out_of_range&) { } // TODO: trim_offsets, use_regex } else if (type == "Sequence") { pretokenizers = std::vector(); - for (const auto &entry : json_config.at("pretokenizers")) { + for (const auto& entry : json_config.at("pretokenizers")) { pretokenizers->push_back(PreTokenizerConfig().parse_json(entry)); } } else { @@ -102,14 +104,14 @@ PreTokenizerConfig &PreTokenizerConfig::parse_json(const json &json_config) { // RegexPreTokenizer /////////////////////////////////////////////////////////// -RegexPreTokenizer::Re2UPtr -RegexPreTokenizer::create_regex_(const std::string &pattern) { +RegexPreTokenizer::Re2UPtr RegexPreTokenizer::create_regex_( + const std::string& pattern) { assert(!pattern.empty()); return std::make_unique("(" + pattern + ")"); } -std::vector -RegexPreTokenizer::pre_tokenize(re2::StringPiece input) const { +std::vector RegexPreTokenizer::pre_tokenize( + re2::StringPiece input) const { std::vector result; std::string piece; while (RE2::FindAndConsume(&input, *regex_, &piece)) { @@ -136,13 +138,14 @@ constexpr char GPT2_EXPR[] = // Construction // ////////////////// -ByteLevelPreTokenizer::ByteLevelPreTokenizer(bool add_prefix_space, - const std::string &pattern) +ByteLevelPreTokenizer::ByteLevelPreTokenizer( + bool add_prefix_space, + const std::string& pattern) : pattern_(pattern.empty() ? GPT2_EXPR : pattern), add_prefix_space_(add_prefix_space) {} -std::vector -ByteLevelPreTokenizer::pre_tokenize(re2::StringPiece input) const { +std::vector ByteLevelPreTokenizer::pre_tokenize( + re2::StringPiece input) const { // Add the prefix space if configured to do so std::string input_str(input); if (add_prefix_space_ && !input_str.empty() && input_str[0] != ' ') { @@ -158,13 +161,13 @@ SequencePreTokenizer::SequencePreTokenizer( std::vector pre_tokenizers) : pre_tokenizers_(std::move(pre_tokenizers)) {} -std::vector -SequencePreTokenizer::pre_tokenize(re2::StringPiece input) const { +std::vector SequencePreTokenizer::pre_tokenize( + re2::StringPiece input) const { std::vector pieces{std::string(input)}; - for (const auto &pre_tokenizer : pre_tokenizers_) { + for (const auto& pre_tokenizer : pre_tokenizers_) { std::vector new_pieces; - for (const auto &piece : pieces) { - for (const auto &subpiece : pre_tokenizer->pre_tokenize(piece)) { + for (const auto& piece : pieces) { + for (const auto& subpiece : pre_tokenizer->pre_tokenize(piece)) { new_pieces.push_back(subpiece); } } diff --git a/src/sentencepiece.cpp b/src/sentencepiece.cpp index 8c06e10..7401dd9 100644 --- a/src/sentencepiece.cpp +++ b/src/sentencepiece.cpp @@ -8,10 +8,10 @@ // A tokenizer that works with sentencepiece. -#include "third_party/absl/strings/str_replace.h" -#include #include +#include #include +#include "third_party/absl/strings/str_replace.h" namespace tokenizers { const char kSpaceSymbol[] = "\xe2\x96\x81"; @@ -29,7 +29,7 @@ SPTokenizer::SPTokenizer() * @param tokenizer_path The path to the tokenizer file. * @return Error */ -Error SPTokenizer::load(const std::string &tokenizer_path) { +Error SPTokenizer::load(const std::string& tokenizer_path) { if (initialized_) { fprintf(stderr, "Tokenizer already initialized.\n"); return Error::Ok; @@ -37,11 +37,13 @@ Error SPTokenizer::load(const std::string &tokenizer_path) { // read in the file const auto status = _processor->Load(tokenizer_path); if (!status.ok()) { - fprintf(stderr, - "couldn't load %s. \nError message: \n%s\n" - "It is likely that the tokenizer artifact is " - "broken or of a different format.", - tokenizer_path.c_str(), status.error_message()); + fprintf( + stderr, + "couldn't load %s. \nError message: \n%s\n" + "It is likely that the tokenizer artifact is " + "broken or of a different format.", + tokenizer_path.c_str(), + status.error_message()); return Error::LoadFailure; } // load vocab_size, bos_tok, eos_tok @@ -62,8 +64,8 @@ SPTokenizer::~SPTokenizer() {} * @return Result The string representation of the * token. */ -Result SPTokenizer::decode(uint64_t prev_token, - uint64_t token) const { +Result SPTokenizer::decode(uint64_t prev_token, uint64_t token) + const { if (!initialized_) { fprintf(stderr, "Tokenizer not initialized\n"); return Error::Uninitialized; @@ -99,7 +101,7 @@ Result SPTokenizer::decode(uint64_t prev_token, * @return Result> */ Result> -SPTokenizer::encode(const std::string &text, int8_t bos, int8_t eos) const { +SPTokenizer::encode(const std::string& text, int8_t bos, int8_t eos) const { if (!initialized_) { fprintf(stderr, "Tokenizer not initialized\n"); return Error::Uninitialized; diff --git a/src/tiktoken.cpp b/src/tiktoken.cpp index 4195980..cdc31f7 100644 --- a/src/tiktoken.cpp +++ b/src/tiktoken.cpp @@ -25,12 +25,12 @@ limitations under the License. *************************************************************************/ -#include "re2/re2.h" +#include +#include #include #include #include -#include -#include +#include "re2/re2.h" namespace tokenizers { @@ -39,15 +39,15 @@ using namespace detail; // ------------------------------Util start------------------------------------ namespace { -static Re2UPtr _create_regex(const std::string &pattern) { +static Re2UPtr _create_regex(const std::string& pattern) { assert(!pattern.empty()); return std::make_unique("(" + pattern + ")"); } -static Re2UPtr _build_special_token_regex(const Encoder &special_encoder) { +static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) { std::string special_pattern; - for (const auto &ele : special_encoder) { + for (const auto& ele : special_encoder) { if (!special_pattern.empty()) { special_pattern += "|"; } @@ -61,52 +61,60 @@ static Re2UPtr _build_special_token_regex(const Encoder &special_encoder) { return _create_regex(special_pattern); } -static Result> -_parse(const std::string &line) { +static Result> _parse( + const std::string& line) { // Tiktoken format // https://github.com/openai/tiktoken/blob/main/tiktoken/load.py#L140 auto pos = line.find(" "); - TK_CHECK_OR_RETURN_ERROR(pos != std::string::npos, ParseFailure, - "invalid tiktoken line: %s", line.c_str()); + TK_CHECK_OR_RETURN_ERROR( + pos != std::string::npos, + ParseFailure, + "invalid tiktoken line: %s", + line.c_str()); auto token = TK_UNWRAP(base64::decode({line.data(), pos})); uint64_t rank = 0; try { rank = std::stoul(line.substr(pos + 1)); - } catch (const std::exception &) { - TK_CHECK_OR_RETURN_ERROR(false, EncodeFailure, "invalid encoder rank: %s", - line.c_str()); + } catch (const std::exception&) { + TK_CHECK_OR_RETURN_ERROR( + false, EncodeFailure, "invalid encoder rank: %s", line.c_str()); } return std::pair{std::move(token), rank}; } -static Result _load_encoder(const std::string &path) { +static Result _load_encoder(const std::string& path) { std::ifstream file(path); - TK_CHECK_OR_RETURN_ERROR(file, LoadFailure, "failed to open encoder file: %s", - path.c_str()); + TK_CHECK_OR_RETURN_ERROR( + file, LoadFailure, "failed to open encoder file: %s", path.c_str()); Encoder encoder; std::string line; while (std::getline(file, line)) { auto [token, rank] = TK_UNWRAP(_parse(line)); - TK_CHECK_OR_RETURN_ERROR(encoder.emplace(std::move(token), rank).second, - ParseFailure, "duplicate item: %s", line.c_str()); + TK_CHECK_OR_RETURN_ERROR( + encoder.emplace(std::move(token), rank).second, + ParseFailure, + "duplicate item: %s", + line.c_str()); } return encoder; } -static Result _build_decoder(const Encoder &encoder) { +static Result _build_decoder(const Encoder& encoder) { Decoder decoder; - for (const auto &[k, v] : encoder) { + for (const auto& [k, v] : encoder) { decoder.emplace(v, k); } - TK_CHECK_OR_RETURN_ERROR(encoder.size() == decoder.size(), LoadFailure, - "duplicate items in encoder"); + TK_CHECK_OR_RETURN_ERROR( + encoder.size() == decoder.size(), + LoadFailure, + "duplicate items in encoder"); return decoder; } @@ -118,8 +126,9 @@ static Result _build_decoder(const Encoder &encoder) { template std::pair, re2::StringPiece> -Tiktoken::_split_with_allowed_special_token(re2::StringPiece &input, - const T &allowed_special) const { +Tiktoken::_split_with_allowed_special_token( + re2::StringPiece& input, + const T& allowed_special) const { if (!special_token_regex_) { return std::make_pair(std::nullopt, input); } @@ -127,7 +136,7 @@ Tiktoken::_split_with_allowed_special_token(re2::StringPiece &input, #if __cplusplus >= 202002L auto start = input.begin(); #else - const char *start = input.data(); + const char* start = input.data(); #endif std::string special; while (true) { @@ -153,8 +162,10 @@ Tiktoken::_split_with_allowed_special_token(re2::StringPiece &input, return std::make_pair(std::nullopt, input); } -Error Tiktoken::_encode(re2::StringPiece &input, std::vector &ret, - uint64_t &last_piece_token_len) const { +Error Tiktoken::_encode( + re2::StringPiece& input, + std::vector& ret, + uint64_t& last_piece_token_len) const { std::string piece; assert(_regex); while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) { @@ -171,7 +182,7 @@ Error Tiktoken::_encode(re2::StringPiece &input, std::vector &ret, return Error::Ok; } -void Tiktoken::_decode(re2::StringPiece input, std::string &ret) const { +void Tiktoken::_decode(re2::StringPiece input, std::string& ret) const { #ifdef _USE_INTERNAL_STRING_VIEW ret += input.as_string(); #else @@ -181,8 +192,9 @@ void Tiktoken::_decode(re2::StringPiece input, std::string &ret) const { template Result, uint64_t>> -Tiktoken::_encode_with_special_token(const std::string &text, - const T &allowed_special) const { +Tiktoken::_encode_with_special_token( + const std::string& text, + const T& allowed_special) const { std::vector tokens; uint64_t last_piece_token_len = 0; re2::StringPiece input(text); @@ -197,7 +209,7 @@ Tiktoken::_encode_with_special_token(const std::string &text, uint64_t token = 0; try { token = special_token_encoder_.at(*special); - } catch (const std::out_of_range &) { + } catch (const std::out_of_range&) { // Should never go here, since special pattern includes all special // chars. TK_LOG(Error, "unknown special token: %s", special->c_str()); @@ -228,7 +240,7 @@ Encoder Tiktoken::_build_special_token_encoder(ssize_t num_base_tokens) const { // -------------------------private method end------------------------------- // -------------------------public method start------------------------------- -Error Tiktoken::load(const std::string &path) { +Error Tiktoken::load(const std::string& path) { encoder_ = TK_UNWRAP(_load_encoder(path)); special_token_encoder_ = _build_special_token_encoder(encoder_.size()); diff --git a/src/token_decoder.cpp b/src/token_decoder.cpp index 1d04473..669f6dd 100644 --- a/src/token_decoder.cpp +++ b/src/token_decoder.cpp @@ -37,7 +37,7 @@ TokenDecoder::Ptr TokenDecoderConfig::create() const { throw std::runtime_error("Unsupported TokenDecoder type: " + type); } -TokenDecoderConfig &TokenDecoderConfig::parse_json(const json &json_config) { +TokenDecoderConfig& TokenDecoderConfig::parse_json(const json& json_config) { type = json_config.at("type"); if (type == "ByteLevel") { // No parameters to parse @@ -54,7 +54,7 @@ namespace { // Copied from llama.cpp // CITE: // https://github.com/ggerganov/llama.cpp/blob/master/src/llama-vocab.cpp#L20 -static std::string format(const char *fmt, ...) { +static std::string format(const char* fmt, ...) { va_list ap; va_list ap2; va_start(ap, fmt); @@ -84,7 +84,7 @@ std::string ByteLevelTokenDecoder::decode(re2::StringPiece token) const { const auto utf8 = unicode_cpt_to_utf8(cpt); try { decoded_text += unicode_utf8_to_byte(utf8); - } catch (const std::out_of_range & /*e*/) { + } catch (const std::out_of_range& /*e*/) { decoded_text += "[UNK_BYTE_0x"; for (const auto c : utf8) { decoded_text += format("%02x", (uint8_t)c);