Skip to content

Commit

Permalink
[refactor] replace c++17 to c++11.
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhaode committed Jul 4, 2024
1 parent 9f0a3ce commit 6e03844
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 19 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ if (BUILD_JNI)
endif()

if (MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /std:c++17")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /std:c++11")
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/source-charset:utf-8>")
# compile static lib, surrpot Winwows
add_library(llm STATIC ${SRCS})
link_directories(${CMAKE_BINARY_DIR}/MNN/Release)
target_link_libraries(llm MNN.lib)
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
# compile dynamic so, support Linux/Mac
add_library(llm SHARED ${SRCS})
set_target_properties(llm PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
Expand Down
56 changes: 53 additions & 3 deletions include/tokenizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,54 @@
#include <string>
#include <unordered_map>
#include <iostream>
#include <string_view>
// #include <string_view>
#include <cstring>

// std::string_view impl in c++11 start
class string_view_ {
public:
string_view_() : data_(nullptr), size_(0) {}
string_view_(const char* data) : data_(data), size_(std::strlen(data)) {}
string_view_(const char* data, std::size_t size) : data_(data), size_(size) {}
string_view_(const std::string& str) : data_(str.data()), size_(str.size()) {}
constexpr string_view_(const string_view_&) noexcept = default;
string_view_& operator=(const string_view_&) noexcept = default;
const char& operator[](size_t pos) const { return data_[pos]; }
constexpr const char* data() const noexcept { return data_; }
constexpr std::size_t size() const noexcept { return size_; }
constexpr bool empty() const { return size_ == 0; }
std::string to_string() const { return std::string(data_, size_); }
bool operator==(const string_view_& other) const noexcept {
return size_ == other.size_ && strncmp(data_, other.data_, size_) == 0;
}
void remove_prefix(size_t n) {
if (n < size_) {
data_ += n;
size_ -= n;
} else {
data_ = "";
size_ = 0;
}
}
private:
const char* data_;
std::size_t size_ = 0;
};

namespace std {
template<>
class hash<string_view_> {
public:
size_t operator()(const string_view_& sv) const {
size_t result = 0;
for (size_t i = 0; i < sv.size(); ++i) {
result = (result * 31) + static_cast<size_t>(sv[i]);
}
return result;
}
};
}
// std::string_view impl in c++11 end

class Tokenizer {
public:
Expand All @@ -28,6 +75,7 @@ class Tokenizer {
virtual ~Tokenizer() = default;
static Tokenizer* createTokenizer(const std::string& filename);
bool is_stop(int token);
bool is_special(int token);
std::vector<int> encode(const std::string& str);
virtual std::string decode(int id) = 0;
protected:
Expand Down Expand Up @@ -65,8 +113,10 @@ class Sentencepiece : public Tokenizer {
std::string piece;
float score;
PieceType type = PieceType::NORMAL;
SentencePiece() {}
SentencePiece(const std::string& p, float s, PieceType t) : piece(p), score(s), type(t) {}
};
using EncodeResult = std::vector<std::pair<std::string_view, int>>;
using EncodeResult = std::vector<std::pair<string_view_, int>>;
private:
// model train type
ModelType type_ = BPE;
Expand All @@ -86,7 +136,7 @@ class Sentencepiece : public Tokenizer {
bool is_control(int id) const;
int piece_to_id(const std::string& w) const;
std::string byte_to_piece(unsigned char c) const;
EncodeResult bpe_encode(std::string_view str, float alpha = 0.f);
EncodeResult bpe_encode(string_view_ str, float alpha = 0.f);
};

class Tiktoken : public Tokenizer {
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def make_relative_rpath(path):
sources=['./mnnllm.cpp'],
include_dirs=['../include', '../MNN/include'],
library_dirs=['../build'],
extra_compile_args=['-std=c++17'],
extra_compile_args=['-std=c++11'],
extra_link_args=['-lllm'] + make_relative_rpath('lib'))

setup(name='mnnllm',
Expand Down
30 changes: 17 additions & 13 deletions src/tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ bool Tokenizer::is_stop(int token) {
return std::find(stop_tokens_.begin(), stop_tokens_.end(), token) != stop_tokens_.end();
}

bool Tokenizer::is_special(int token) {
return std::find(special_tokens_.begin(), special_tokens_.end(), token) != special_tokens_.end();
}

void Tokenizer::load_special(std::ifstream& tok_file) {
std::string line;
std::getline(tok_file, line);
Expand Down Expand Up @@ -201,7 +205,7 @@ bool Sentencepiece::load_vocab(std::ifstream& tok_file) {
line_str >> token >> score >> type;
token = base64_decode(token);
auto piece_type = static_cast<PieceType>(type);
SentencePiece piece {token, score, piece_type};
SentencePiece piece = {token, score, piece_type};
sentence_pieces_[index] = std::move(piece);
if (piece_type == PieceType::NORMAL) {
pieces_.insert({token, index});
Expand Down Expand Up @@ -236,7 +240,7 @@ std::string Sentencepiece::byte_to_piece(unsigned char c) const {
}

// ref: https://github.com/google/sentencepiece/blob/master/src/bpe_model.cc
Sentencepiece::EncodeResult Sentencepiece::bpe_encode(std::string_view normalized, float alpha) {
Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized, float alpha) {
// util class begin
struct SymbolPair {
int left; // left index of this pair
Expand All @@ -256,7 +260,7 @@ Sentencepiece::EncodeResult Sentencepiece::bpe_encode(std::string_view normalize
int prev; // prev index of this symbol. -1 for BOS.
int next; // next index of tihs symbol. -1 for EOS.
bool freeze = false; // this symbol is never be merged.
std::string_view piece;
string_view_ piece;
};
// util class end

Expand All @@ -265,16 +269,16 @@ Sentencepiece::EncodeResult Sentencepiece::bpe_encode(std::string_view normalize
std::vector<Symbol> symbols;
symbols.reserve(normalized.size());
// Reverse merge rules. key: merged symbol, value: pair of original symbols.
std::unordered_map<std::string_view, std::pair<std::string_view, std::string_view>> rev_merge;
std::unordered_map<string_view_, std::pair<string_view_, string_view_>> rev_merge;
// SymbolPair holder.
std::vector<std::unique_ptr<SymbolPair>> symbol_pair_holder;
// Lookup new symbol pair at [left, right] and inserts it to agenda.
auto MaybeAddNewSymbolPair = [this, &symbol_pair_holder, &symbols, &agenda, &rev_merge](int left, int right) {
if (left == -1 || right == -1 || symbols[left].freeze || symbols[right].freeze) {
return;
}
const std::string_view piece(symbols[left].piece.data(), symbols[left].piece.size() + symbols[right].piece.size());
std::string piece_str(piece);
const string_view_ piece(symbols[left].piece.data(), symbols[left].piece.size() + symbols[right].piece.size());
std::string piece_str(piece.to_string());
const auto it = pieces_.find(piece_str);
if (it == pieces_.end()) {
return;
Expand All @@ -298,7 +302,7 @@ Sentencepiece::EncodeResult Sentencepiece::bpe_encode(std::string_view normalize
Symbol s;
// const int mblen = matcher_->PrefixMatch(normalized, &s.freeze);
int mblen = std::min<int>(normalized.size(), one_char_len(normalized.data()));
s.piece = std::string_view(normalized.data(), mblen);
s.piece = string_view_(normalized.data(), mblen);
s.prev = index == 0 ? -1 : index - 1;
normalized.remove_prefix(mblen);
s.next = normalized.empty() ? -1 : index + 1;
Expand Down Expand Up @@ -338,7 +342,7 @@ Sentencepiece::EncodeResult Sentencepiece::bpe_encode(std::string_view normalize

if (skip_merge()) continue;
// Replaces symbols with `top` rule.
symbols[top->left].piece = std::string_view(
symbols[top->left].piece = string_view_(
symbols[top->left].piece.data(),
symbols[top->left].piece.size() + symbols[top->right].piece.size());

Expand All @@ -347,16 +351,16 @@ Sentencepiece::EncodeResult Sentencepiece::bpe_encode(std::string_view normalize
if (symbols[top->right].next >= 0) {
symbols[symbols[top->right].next].prev = top->left;
}
symbols[top->right].piece = std::string_view("");
symbols[top->right].piece = string_view_("");

// Adds new symbol pairs which are newly added after symbol replacement.
MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
}

std::function<void(std::string_view, EncodeResult*)> resegment;
resegment = [this, &resegment, &rev_merge](std::string_view w, EncodeResult *output) -> void {
std::string w_str(w);
std::function<void(string_view_, EncodeResult*)> resegment;
resegment = [this, &resegment, &rev_merge](string_view_ w, EncodeResult *output) -> void {
std::string w_str(w.to_string());
const int id = piece_to_id(w_str);
// std::cout << "piece: " << w << ", id = " << id << std::endl;
if (id == -1 || !is_unused(id)) {
Expand Down Expand Up @@ -385,7 +389,7 @@ void Sentencepiece::encode(const std::string& str, std::vector<int>& ids) {
auto result = bpe_encode(str);
size_t consumed = 0;
for (const auto &p : result) {
const std::string_view w = p.first; // piece
const string_view_ w = p.first; // piece
const int id = p.second; // id
const bool is_unk = (id == unk_id_);
if (is_unk && byte_fall_back_) {
Expand Down

0 comments on commit 6e03844

Please sign in to comment.