Skip to content

Fix clang format #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions include/pytorch/tokenizers/base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace base64 {
using tokenizers::Error;
using tokenizers::Result;

Result<std::string> decode(const std::string_view &input);
Result<std::string> decode(const std::string_view& input);

namespace detail {

Expand Down Expand Up @@ -68,9 +68,12 @@ inline Error validate(uint32_t v) {
return Error::Ok;
}

inline Error decode(const std::string_view &input, std::string &output) {
TK_CHECK_OR_RETURN_ERROR(input.size() == 4, Base64DecodeFailure,
"input length must be 4, got %zu", input.size());
inline Error decode(const std::string_view& input, std::string& output) {
TK_CHECK_OR_RETURN_ERROR(
input.size() == 4,
Base64DecodeFailure,
"input length must be 4, got %zu",
input.size());

uint32_t val = 0;

Expand Down Expand Up @@ -100,10 +103,14 @@ inline Error decode(const std::string_view &input, std::string &output) {
return Error::Ok;
}

inline Error decode_1_padding(const std::string_view &input,
std::string &output) {
TK_CHECK_OR_RETURN_ERROR(input.size() == 3, Base64DecodeFailure,
"input length must be 3, got %zu", input.size());
inline Error decode_1_padding(
const std::string_view& input,
std::string& output) {
TK_CHECK_OR_RETURN_ERROR(
input.size() == 3,
Base64DecodeFailure,
"input length must be 3, got %zu",
input.size());

uint32_t val = 0;

Expand All @@ -127,10 +134,14 @@ inline Error decode_1_padding(const std::string_view &input,
return Error::Ok;
}

inline Error decode_2_padding(const std::string_view &input,
std::string &output) {
TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure,
"input length must be 2, got %zu", input.size());
inline Error decode_2_padding(
const std::string_view& input,
std::string& output) {
TK_CHECK_OR_RETURN_ERROR(
input.size() == 2,
Base64DecodeFailure,
"input length must be 2, got %zu",
input.size());

uint32_t val = 0;

Expand All @@ -150,12 +161,13 @@ inline Error decode_2_padding(const std::string_view &input,

} // namespace detail

inline tokenizers::Result<std::string> decode(const std::string_view &input) {
inline tokenizers::Result<std::string> decode(const std::string_view& input) {
TK_CHECK_OR_RETURN_ERROR(!input.empty(), Base64DecodeFailure, "empty input");

// Faster than `input.size() % 4`.
TK_CHECK_OR_RETURN_ERROR(
(input.size() & 3) == 0 && input.size() >= 4, Base64DecodeFailure,
(input.size() & 3) == 0 && input.size() >= 4,
Base64DecodeFailure,
"input length must be larger than 4 and is multiple of 4, got %zu",
input.size());

Expand Down
40 changes: 22 additions & 18 deletions include/pytorch/tokenizers/bpe_tokenizer_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,29 @@ using Decoder = std::unordered_map<uint64_t, std::string>;
using Re2UPtr = std::unique_ptr<re2::RE2>;

class BPETokenizerBase : public Tokenizer {
public:
Result<std::vector<uint64_t>> encode(const std::string &input, int8_t bos,
int8_t eos) const override;
public:
Result<std::vector<uint64_t>>
encode(const std::string& input, int8_t bos, int8_t eos) const override;

Result<std::string> decode(uint64_t prev_token,
uint64_t token) const override;
Result<std::string> decode(uint64_t prev_token, uint64_t token)
const override;

protected:
protected:
explicit BPETokenizerBase() {}
virtual ~BPETokenizerBase() override {}
virtual ~BPETokenizerBase() {}

std::pair<std::optional<std::string>, re2::StringPiece>
split_with_allowed_special_token_(re2::StringPiece &input,
const Encoder &allowed_special) const;
split_with_allowed_special_token_(
re2::StringPiece& input,
const Encoder& allowed_special) const;

Result<std::pair<std::vector<uint64_t>, uint64_t>>
encode_with_special_token_(const std::string &text,
const Encoder &allowed_special) const;
Result<std::pair<std::vector<uint64_t>, uint64_t>> encode_with_special_token_(
const std::string& text,
const Encoder& allowed_special) const;

Result<std::vector<uint64_t>> byte_pair_encode_(const std::string &piece,
const Encoder &encoder) const;
Result<std::vector<uint64_t>> byte_pair_encode_(
const std::string& piece,
const Encoder& encoder) const;

// Protected members that can be overloaded by other BPE tokenizers
Re2UPtr special_token_regex_;
Expand All @@ -61,11 +63,13 @@ class BPETokenizerBase : public Tokenizer {
Decoder decoder_;
Decoder special_token_decoder_;

private:
virtual Error _encode(re2::StringPiece &input, std::vector<uint64_t> &ret,
uint64_t &last_piece_token_len) const = 0;
private:
virtual Error _encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const = 0;

virtual void _decode(re2::StringPiece input, std::string &ret) const = 0;
virtual void _decode(re2::StringPiece input, std::string& ret) const = 0;
};

} // namespace detail
Expand Down
66 changes: 33 additions & 33 deletions include/pytorch/tokenizers/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ enum class Error : error_code_t {
* @param[in] message__ Format string for the log error message.
* @param[in] ... Optional additional arguments for the format string.
*/
#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \
{ \
if (!(cond__)) { \
TK_LOG(Error, message__, ##__VA_ARGS__); \
return ::tokenizers::Error::error__; \
} \
#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \
{ \
if (!(cond__)) { \
TK_LOG(Error, message__, ##__VA_ARGS__); \
return ::tokenizers::Error::error__; \
} \
}

/**
Expand All @@ -86,13 +86,13 @@ enum class Error : error_code_t {
* @param[in] ... Optional format string for the log error message and its
* arguments.
*/
#define TK_CHECK_OK_OR_RETURN_ERROR(error__, ...) \
#define TK_CHECK_OK_OR_RETURN_ERROR(error__, ...) \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(error__, ##__VA_ARGS__)

// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, \
4, 3, 2, 1) \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \
__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \
(__VA_ARGS__)

/**
Expand All @@ -119,43 +119,43 @@ enum class Error : error_code_t {
* TK_CHECK_OK_OR_RETURN_ERROR(error_code); // Calls v1
* TK_CHECK_OK_OR_RETURN_ERROR(error_code, "Error message", ...); // Calls v2
*/
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(_1, _2, _3, _4, _5, _6, \
_7, _8, _9, _10, N, ...) \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_##N

// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \
do { \
const auto et_error__ = (error__); \
if (et_error__ != ::tokenizers::Error::Ok) { \
return et_error__; \
} \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \
do { \
const auto et_error__ = (error__); \
if (et_error__ != ::tokenizers::Error::Ok) { \
return et_error__; \
} \
} while (0)

// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \
do { \
const auto et_error__ = (error__); \
if (et_error__ != ::tokenizers::Error::Ok) { \
TK_LOG(Error, message__, ##__VA_ARGS__); \
return et_error__; \
} \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \
do { \
const auto et_error__ = (error__); \
if (et_error__ != ::tokenizers::Error::Ok) { \
TK_LOG(Error, message__, ##__VA_ARGS__); \
return et_error__; \
} \
} while (0)

// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
14 changes: 8 additions & 6 deletions include/pytorch/tokenizers/hf_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

namespace tokenizers {
class HFTokenizer : public detail::BPETokenizerBase {
public:
public:
/*-- Public Interface --*/

/**
Expand All @@ -39,13 +39,15 @@ class HFTokenizer : public detail::BPETokenizerBase {
/**
* Load the model data into the
*/
Error load(const std::string &tokenizer_path) override;
Error load(const std::string& tokenizer_path) override;

private:
Error _encode(re2::StringPiece &input, std::vector<uint64_t> &ret,
uint64_t &last_piece_token_len) const override;
private:
Error _encode(
re2::StringPiece& input,
std::vector<uint64_t>& ret,
uint64_t& last_piece_token_len) const override;

void _decode(re2::StringPiece input, std::string &ret) const override;
void _decode(re2::StringPiece input, std::string& ret) const override;

PreTokenizer::Ptr _pretokenizer;
TokenDecoder::Ptr _decoder;
Expand Down
20 changes: 10 additions & 10 deletions include/pytorch/tokenizers/llama2c_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,33 @@
*/
// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude
#pragma once
#include <memory>
#include <pytorch/tokenizers/tokenizer.h>
#include <memory>

namespace tokenizers {

struct TokenIndex {
const char *str;
const char* str;
int32_t id;
};

// A simple Byte Pair Encoding (BPE) Tokenizer. Note that the current C++ code
// won't work with this class, it needs to go through tokenizer.py first.
class Llama2cTokenizer : public Tokenizer {
public:
public:
explicit Llama2cTokenizer();
~Llama2cTokenizer() override;

Error load(const std::string &tokenizer_path) override;
Error load(const std::string& tokenizer_path) override;

Result<std::vector<uint64_t>> encode(const std::string &input, int8_t bos,
int8_t eos) const override;
Result<std::vector<uint64_t>>
encode(const std::string& input, int8_t bos, int8_t eos) const override;

Result<std::string> decode(uint64_t prev_token,
uint64_t token) const override;
Result<std::string> decode(uint64_t prev_token, uint64_t token)
const override;

private:
std::unique_ptr<char *[]> vocab_ = nullptr;
private:
std::unique_ptr<char*[]> vocab_ = nullptr;
std::unique_ptr<float[]> vocab_scores_ = nullptr;
std::unique_ptr<TokenIndex[]> sorted_vocab_ = nullptr;
unsigned int max_token_length_ = 0;
Expand Down
Loading