Skip to content

Commit 784373a

Browse files
committed
Apply clang-format linter
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent ca4b79d commit 784373a

19 files changed

+473
-368
lines changed

include/pytorch/tokenizers/base64.h

+26-14
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace base64 {
3636
using tokenizers::Error;
3737
using tokenizers::Result;
3838

39-
Result<std::string> decode(const std::string_view &input);
39+
Result<std::string> decode(const std::string_view& input);
4040

4141
namespace detail {
4242

@@ -68,9 +68,12 @@ inline Error validate(uint32_t v) {
6868
return Error::Ok;
6969
}
7070

71-
inline Error decode(const std::string_view &input, std::string &output) {
72-
TK_CHECK_OR_RETURN_ERROR(input.size() == 4, Base64DecodeFailure,
73-
"input length must be 4, got %zu", input.size());
71+
inline Error decode(const std::string_view& input, std::string& output) {
72+
TK_CHECK_OR_RETURN_ERROR(
73+
input.size() == 4,
74+
Base64DecodeFailure,
75+
"input length must be 4, got %zu",
76+
input.size());
7477

7578
uint32_t val = 0;
7679

@@ -100,10 +103,14 @@ inline Error decode(const std::string_view &input, std::string &output) {
100103
return Error::Ok;
101104
}
102105

103-
inline Error decode_1_padding(const std::string_view &input,
104-
std::string &output) {
105-
TK_CHECK_OR_RETURN_ERROR(input.size() == 3, Base64DecodeFailure,
106-
"input length must be 3, got %zu", input.size());
106+
inline Error decode_1_padding(
107+
const std::string_view& input,
108+
std::string& output) {
109+
TK_CHECK_OR_RETURN_ERROR(
110+
input.size() == 3,
111+
Base64DecodeFailure,
112+
"input length must be 3, got %zu",
113+
input.size());
107114

108115
uint32_t val = 0;
109116

@@ -127,10 +134,14 @@ inline Error decode_1_padding(const std::string_view &input,
127134
return Error::Ok;
128135
}
129136

130-
inline Error decode_2_padding(const std::string_view &input,
131-
std::string &output) {
132-
TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure,
133-
"input length must be 2, got %zu", input.size());
137+
inline Error decode_2_padding(
138+
const std::string_view& input,
139+
std::string& output) {
140+
TK_CHECK_OR_RETURN_ERROR(
141+
input.size() == 2,
142+
Base64DecodeFailure,
143+
"input length must be 2, got %zu",
144+
input.size());
134145

135146
uint32_t val = 0;
136147

@@ -150,12 +161,13 @@ inline Error decode_2_padding(const std::string_view &input,
150161

151162
} // namespace detail
152163

153-
inline tokenizers::Result<std::string> decode(const std::string_view &input) {
164+
inline tokenizers::Result<std::string> decode(const std::string_view& input) {
154165
TK_CHECK_OR_RETURN_ERROR(!input.empty(), Base64DecodeFailure, "empty input");
155166

156167
// Faster than `input.size() % 4`.
157168
TK_CHECK_OR_RETURN_ERROR(
158-
(input.size() & 3) == 0 && input.size() >= 4, Base64DecodeFailure,
169+
(input.size() & 3) == 0 && input.size() >= 4,
170+
Base64DecodeFailure,
159171
"input length must be larger than 4 and is multiple of 4, got %zu",
160172
input.size());
161173

include/pytorch/tokenizers/bpe_tokenizer_base.h

+22-18
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,29 @@ using Decoder = std::unordered_map<uint64_t, std::string>;
3232
using Re2UPtr = std::unique_ptr<re2::RE2>;
3333

3434
class BPETokenizerBase : public Tokenizer {
35-
public:
36-
Result<std::vector<uint64_t>> encode(const std::string &input, int8_t bos,
37-
int8_t eos) const override;
35+
public:
36+
Result<std::vector<uint64_t>>
37+
encode(const std::string& input, int8_t bos, int8_t eos) const override;
3838

39-
Result<std::string> decode(uint64_t prev_token,
40-
uint64_t token) const override;
39+
Result<std::string> decode(uint64_t prev_token, uint64_t token)
40+
const override;
4141

42-
protected:
42+
protected:
4343
explicit BPETokenizerBase() {}
44-
virtual ~BPETokenizerBase() override {}
44+
virtual ~BPETokenizerBase() {}
4545

4646
std::pair<std::optional<std::string>, re2::StringPiece>
47-
split_with_allowed_special_token_(re2::StringPiece &input,
48-
const Encoder &allowed_special) const;
47+
split_with_allowed_special_token_(
48+
re2::StringPiece& input,
49+
const Encoder& allowed_special) const;
4950

50-
Result<std::pair<std::vector<uint64_t>, uint64_t>>
51-
encode_with_special_token_(const std::string &text,
52-
const Encoder &allowed_special) const;
51+
Result<std::pair<std::vector<uint64_t>, uint64_t>> encode_with_special_token_(
52+
const std::string& text,
53+
const Encoder& allowed_special) const;
5354

54-
Result<std::vector<uint64_t>> byte_pair_encode_(const std::string &piece,
55-
const Encoder &encoder) const;
55+
Result<std::vector<uint64_t>> byte_pair_encode_(
56+
const std::string& piece,
57+
const Encoder& encoder) const;
5658

5759
// Protected members that can be overloaded by other BPE tokenizers
5860
Re2UPtr special_token_regex_;
@@ -61,11 +63,13 @@ class BPETokenizerBase : public Tokenizer {
6163
Decoder decoder_;
6264
Decoder special_token_decoder_;
6365

64-
private:
65-
virtual Error _encode(re2::StringPiece &input, std::vector<uint64_t> &ret,
66-
uint64_t &last_piece_token_len) const = 0;
66+
private:
67+
virtual Error _encode(
68+
re2::StringPiece& input,
69+
std::vector<uint64_t>& ret,
70+
uint64_t& last_piece_token_len) const = 0;
6771

68-
virtual void _decode(re2::StringPiece input, std::string &ret) const = 0;
72+
virtual void _decode(re2::StringPiece input, std::string& ret) const = 0;
6973
};
7074

7175
} // namespace detail

include/pytorch/tokenizers/error.h

+33-33
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ enum class Error : error_code_t {
7070
* @param[in] message__ Format string for the log error message.
7171
* @param[in] ... Optional additional arguments for the format string.
7272
*/
73-
#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \
74-
{ \
75-
if (!(cond__)) { \
76-
TK_LOG(Error, message__, ##__VA_ARGS__); \
77-
return ::tokenizers::Error::error__; \
78-
} \
73+
#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \
74+
{ \
75+
if (!(cond__)) { \
76+
TK_LOG(Error, message__, ##__VA_ARGS__); \
77+
return ::tokenizers::Error::error__; \
78+
} \
7979
}
8080

8181
/**
@@ -86,13 +86,13 @@ enum class Error : error_code_t {
8686
* @param[in] ... Optional format string for the log error message and its
8787
* arguments.
8888
*/
89-
#define TK_CHECK_OK_OR_RETURN_ERROR(error__, ...) \
89+
#define TK_CHECK_OK_OR_RETURN_ERROR(error__, ...) \
9090
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(error__, ##__VA_ARGS__)
9191

9292
// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
93-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \
94-
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, \
95-
4, 3, 2, 1) \
93+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \
94+
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \
95+
__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) \
9696
(__VA_ARGS__)
9797

9898
/**
@@ -119,43 +119,43 @@ enum class Error : error_code_t {
119119
* TK_CHECK_OK_OR_RETURN_ERROR(error_code); // Calls v1
120120
* TK_CHECK_OK_OR_RETURN_ERROR(error_code, "Error message", ...); // Calls v2
121121
*/
122-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(_1, _2, _3, _4, _5, _6, \
123-
_7, _8, _9, _10, N, ...) \
122+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT( \
123+
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) \
124124
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_##N
125125

126126
// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
127-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \
128-
do { \
129-
const auto et_error__ = (error__); \
130-
if (et_error__ != ::tokenizers::Error::Ok) { \
131-
return et_error__; \
132-
} \
127+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \
128+
do { \
129+
const auto et_error__ = (error__); \
130+
if (et_error__ != ::tokenizers::Error::Ok) { \
131+
return et_error__; \
132+
} \
133133
} while (0)
134134

135135
// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
136-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \
137-
do { \
138-
const auto et_error__ = (error__); \
139-
if (et_error__ != ::tokenizers::Error::Ok) { \
140-
TK_LOG(Error, message__, ##__VA_ARGS__); \
141-
return et_error__; \
142-
} \
136+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \
137+
do { \
138+
const auto et_error__ = (error__); \
139+
if (et_error__ != ::tokenizers::Error::Ok) { \
140+
TK_LOG(Error, message__, ##__VA_ARGS__); \
141+
return et_error__; \
142+
} \
143143
} while (0)
144144

145145
// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead.
146-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \
146+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \
147147
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
148-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \
148+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \
149149
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
150-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \
150+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \
151151
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
152-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \
152+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \
153153
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
154-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \
154+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \
155155
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
156-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \
156+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \
157157
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
158-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \
158+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \
159159
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2
160-
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \
160+
#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \
161161
TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2

include/pytorch/tokenizers/hf_tokenizer.h

+8-6
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
namespace tokenizers {
2929
class HFTokenizer : public detail::BPETokenizerBase {
30-
public:
30+
public:
3131
/*-- Public Interface --*/
3232

3333
/**
@@ -39,13 +39,15 @@ class HFTokenizer : public detail::BPETokenizerBase {
3939
/**
4040
* Load the model data into the
4141
*/
42-
Error load(const std::string &tokenizer_path) override;
42+
Error load(const std::string& tokenizer_path) override;
4343

44-
private:
45-
Error _encode(re2::StringPiece &input, std::vector<uint64_t> &ret,
46-
uint64_t &last_piece_token_len) const override;
44+
private:
45+
Error _encode(
46+
re2::StringPiece& input,
47+
std::vector<uint64_t>& ret,
48+
uint64_t& last_piece_token_len) const override;
4749

48-
void _decode(re2::StringPiece input, std::string &ret) const override;
50+
void _decode(re2::StringPiece input, std::string& ret) const override;
4951

5052
PreTokenizer::Ptr _pretokenizer;
5153
TokenDecoder::Ptr _decoder;

include/pytorch/tokenizers/llama2c_tokenizer.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,33 @@
77
*/
88
// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude
99
#pragma once
10-
#include <memory>
1110
#include <pytorch/tokenizers/tokenizer.h>
11+
#include <memory>
1212

1313
namespace tokenizers {
1414

1515
struct TokenIndex {
16-
const char *str;
16+
const char* str;
1717
int32_t id;
1818
};
1919

2020
// A simple Byte Pair Encoding (BPE) Tokenizer. Note that the current C++ code
2121
// won't work with this class, it needs to go through tokenizer.py first.
2222
class Llama2cTokenizer : public Tokenizer {
23-
public:
23+
public:
2424
explicit Llama2cTokenizer();
2525
~Llama2cTokenizer() override;
2626

27-
Error load(const std::string &tokenizer_path) override;
27+
Error load(const std::string& tokenizer_path) override;
2828

29-
Result<std::vector<uint64_t>> encode(const std::string &input, int8_t bos,
30-
int8_t eos) const override;
29+
Result<std::vector<uint64_t>>
30+
encode(const std::string& input, int8_t bos, int8_t eos) const override;
3131

32-
Result<std::string> decode(uint64_t prev_token,
33-
uint64_t token) const override;
32+
Result<std::string> decode(uint64_t prev_token, uint64_t token)
33+
const override;
3434

35-
private:
36-
std::unique_ptr<char *[]> vocab_ = nullptr;
35+
private:
36+
std::unique_ptr<char*[]> vocab_ = nullptr;
3737
std::unique_ptr<float[]> vocab_scores_ = nullptr;
3838
std::unique_ptr<TokenIndex[]> sorted_vocab_ = nullptr;
3939
unsigned int max_token_length_ = 0;

0 commit comments

Comments
 (0)