Skip to content

llamax : add a possible implementation of a simple API for llama.cpp … #12835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)

# llamax
option(LLAMAX "llama: enable high-level C++ API." ON)

# Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake)
Expand Down Expand Up @@ -187,6 +190,14 @@ if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES)
add_subdirectory(pocs)
endif()

#
# llamax
#

if (LLAMAX)
add_subdirectory(llamax)
endif()

#
# install
#
Expand Down
48 changes: 48 additions & 0 deletions llamax/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#
# Define version
#

set(LLAMAX_MAJOR_VERSION 2)
set(LLAMAX_MINOR_VERSION 1)
set(LLAMAX_PATCH_VERSION 0)
set(LLAMAX_VERSION ${LLAMAX_MAJOR_VERSION}.${LLAMAX_MINOR_VERSION}.${LLAMAX_PATCH_VERSION})

#
# Build llamax
#
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)

set(LLAMAX_SRCS src/llamax.cpp)
add_library(llamax SHARED ${LLAMAX_SRCS})
target_link_libraries(llamax PRIVATE llama)

set(LLAMAX_PUBLIC_HEADERS
${CMAKE_CURRENT_SOURCE_DIR}/include/llamax.h)

set_target_properties(llamax
PROPERTIES
PUBLIC_HEADER "${LLAMAX_PUBLIC_HEADERS}")

add_subdirectory(examples)

#
# install
#

install(TARGETS llamax LIBRARY PUBLIC_HEADER)
#
# Config files
#

include(CMakePackageConfigHelpers)
configure_package_config_file(llamaxConfig.cmake.in "${PROJECT_BINARY_DIR}/llamaxConfig.cmake"
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/llama
PATH_VARS CMAKE_INSTALL_INCLUDEDIR
)
write_basic_package_version_file("${PROJECT_BINARY_DIR}/llamaxConfigVersion.cmake" VERSION ${LLAMAX_VERSION} COMPATIBILITY SameMajorVersion)

# Install the llamaxConfig.cmake and llamaxConfigVersion.cmake
install(FILES
"${PROJECT_BINARY_DIR}/llamaxConfig.cmake"
"${PROJECT_BINARY_DIR}/llamaxConfigVersion.cmake"
DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/llama" COMPONENT dev)
13 changes: 13 additions & 0 deletions llamax/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
llamax
======

`llamax` is an experimental high-level API for [llama](https://github.com/ggerganov/llama.cpp).

Development occurs in the `dev/1` branch.

The roadmap includes:

* ~~support for text based LLM models~~
* support for multi-mode models
* support for embeddings
* ~~support for grammars~~
8 changes: 8 additions & 0 deletions llamax/examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_executable(llamax_simple simple.cpp)
target_link_libraries(llamax_simple llamax)

add_executable(llamax_chat chat.cpp)
target_link_libraries(llamax_chat llamax)

add_executable(llamax_grammar grammar.cpp)
target_link_libraries(llamax_grammar llamax)
6 changes: 6 additions & 0 deletions llamax/examples/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@


```bash
wget https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v0.2-GGUF/resolve/main/ggml-model-q4_0.gguf
./llamax_simple ggml-model-q4_0.gguf "What is up doctor?"
```
49 changes: 49 additions & 0 deletions llamax/examples/chat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include <llamax.h>

#include <iostream>

int main(int _argc, const char ** _argv) {
if (_argc != 2) {
std::cerr << "llamax_chat [model]" << std::endl;
return -1;
}

llamax::model model = llamax::model::load_from_file(_argv[1], llamax::model_params::default_params());
llamax::context ctx =
model.create_context(llamax::context_params::default_params(),
llamax::sampler_builder().min_p(0.05, 1).temp(0.8f).dist(llamax::default_seed()));
llamax::chat_template ct = model.create_chat_template();
std::vector<llamax::chat_message> messages;

messages.push_back({ llamax::chat_message_role::system, "You are an assistant." });

int offset = 0;

while (true) {
printf("\033[32m> \033[0m");
std::string user;
std::getline(std::cin, user);

if (user.empty()) {
break;
}

messages.push_back({ llamax::chat_message_role::user, user });

std::string prompt = ct.generate(messages);

std::string answer;
llamax::iterator it = ctx.prompt(prompt.substr(offset));
offset = prompt.size();

while (std::optional<std::string> s = it.next()) {
answer += s.value();
std::cout << s.value();
}
std::cout << std::endl;

messages.push_back({ llamax::chat_message_role::assistant, answer });
}

return 0;
}
51 changes: 51 additions & 0 deletions llamax/examples/grammar.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include <llamax.h>

#include <iostream>

const char * json_grammar = R"(
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws

array ::=
"[" ws (
value
("," ws value)*
)? "]" ws

string ::=
"\"" (
[^"\\\x7F\x00-\x1F] |
"\\" (["\\bfnrt] | "u" [0-9a-fA-F]{4}) # escapes
)* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]{0,15})) ("." [0-9]+)? ([eE] [-+]? [0-9] [1-9]{0,15})? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= | " " | "\n" [ \t]{0,20}
)";

int main(int _argc, const char ** _argv) {
if (_argc != 3) {
std::cerr << "llamax_simple [model] \"What is up doctor?\"" << std::endl;
return -1;
}

llamax::model model = llamax::model::load_from_file(_argv[1], llamax::model_params::default_params());
llamax::context ctx = model.create_context(
llamax::context_params::default_params().set_context_size(4096).set_batch_size(2048),
llamax::sampler_builder().grammar(json_grammar, "root").min_p(0.05, 1).temp(0.8f).dist(llamax::default_seed()));
llamax::iterator it = ctx.prompt(_argv[2]);

while (std::optional<std::string> s = it.next()) {
std::cout << s.value();
}
std::cout << std::endl;

return 0;
}
22 changes: 22 additions & 0 deletions llamax/examples/simple.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include <llamax.h>

#include <iostream>

int main(int _argc, const char ** _argv) {
if (_argc != 3) {
std::cerr << "llamax_simple [model] \"What is up doctor?\"" << std::endl;
return -1;
}

llamax::model model = llamax::model::load_from_file(_argv[1], llamax::model_params::default_params());
llamax::context ctx =
model.create_context(llamax::context_params::default_params(), llamax::sampler_builder().greedy());
llamax::iterator it = ctx.prompt(_argv[2]);

while (std::optional<std::string> s = it.next()) {
std::cout << s.value();
}
std::cout << std::endl;

return 0;
}
144 changes: 144 additions & 0 deletions llamax/include/llamax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#include <memory>
#include <optional>
#include <string>
#include <vector>

namespace llamax {
uint32_t default_seed();

class context;
class chat_template;
class model;
class sampler_builder;

/// Exception class for errors in llamax
class exception : public std::exception {
friend class chat_template;
friend class context;
friend class iterator;
friend class model;

exception(const std::string & what) : m_what(what) {}
public:
const char * what() const noexcept override { return m_what.c_str(); }
private:
std::string m_what;
};

/// Parameters for a llama models
class model_params {
friend class model;
public:
static model_params default_params();
/// Set the number of layers offset to a GPU
model_params & set_n_gpu_layers(unsigned _n_gpu_layers);
private:
struct data;
std::shared_ptr<data> d;
};

/// Parameters for the context
class context_params {
friend class model;
public:
static context_params default_params();
/// Set the context size
context_params & set_context_size(unsigned _context_size);
// batch_size is the maximum number of tokens that can be processed in a single call to
// llama_decode
context_params & set_batch_size(unsigned _batch_size);
private:
struct data;
std::shared_ptr<data> d;
};

class sampler_builder {
friend class model;
public:
sampler_builder();
~sampler_builder();
sampler_builder & top_k(int32_t _k);
sampler_builder & top_p(float p, size_t min_keep);
sampler_builder & min_p(float p, size_t min_keep);
sampler_builder & grammar(const std::string & _grammar, const std::string & _root);
sampler_builder & temp(float t);
sampler_builder & greedy();
sampler_builder & dist(uint32_t seed);
private:
struct data;
std::unique_ptr<data> d;
};
class context;

class model {
friend class iterator;
friend class context;
friend class chat_template;
public:
/**
* Attempt to load a model from a file.
*
* This function can trigger an exception.
*/
static model load_from_file(const std::string & _name, const model_params & _params);
/**
* Create a context that can be used to generate text based on a prompt.
*/
context create_context(const context_params & _context_params, const sampler_builder & _sampler_builder) const;
/**
* Create a chat template that can be used to generate the prompt for a chat bot.
*/
chat_template create_chat_template(bool _add_assistant = true) const;
private:
struct data;
std::shared_ptr<data> d;
};

class iterator {
friend class context;
iterator();
public:
iterator(iterator && _rhs);
~iterator();
/**
* Return the next token, or null, if no more tokens.
*
* This function can trigger an exception.
*/
std::optional<std::string> next();
private:
struct data;
std::unique_ptr<data> d;
};

class context {
friend class model;
friend class iterator;
public:
/**
* Prompt the llm.
*/
iterator prompt(const std::string & _prompt);
private:
struct data;
std::shared_ptr<data> d;
};
enum class chat_message_role { system, user, assistant };

struct chat_message {
chat_message_role role;
std::string content;
};

class chat_template {
friend class model;
public:
/**
* Generate a prompt based on a template and a set of messages.
*/
std::string generate(const std::vector<chat_message> & _messages);
private:
struct data;
std::shared_ptr<data> d;
};
} // namespace llamax
18 changes: 18 additions & 0 deletions llamax/llamaxConfig.cmake.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@PACKAGE_INIT@

set_and_check(LLAMAX_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@")
set_and_check(LLAMAX_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@")

find_library(llamax_LIBRARY llamax
REQUIRED
HINTS ${LLAMAX_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH
)

add_library(llamax UNKNOWN IMPORTED)
set_target_properties(llamax
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${LLAMAX_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${llama_LIBRARY}"
POSITION_INDEPENDENT_CODE ON)
Loading