Skip to content

Commit

Permalink
Fix compilation on metal (#337)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Jul 30, 2024
1 parent 0f7a8a9 commit 606eb1b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 33 deletions.
24 changes: 15 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ if (CHATGLM_ENABLE_PYBIND)
endif ()

# third-party libraries

# ggml
if (GGML_CUDA)
add_compile_definitions(GGML_USE_CUDA)
enable_language(CUDA)
Expand All @@ -42,33 +44,37 @@ if (GGML_CUDA)
set(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH_LIST} CACHE STRING "")
endif ()

if (GGML_METAL)
add_compile_definitions(GGML_USE_METAL)
set(GGML_METAL_EMBED_LIBRARY ON CACHE BOOL "" FORCE)
endif ()

if (GGML_PERF)
add_compile_definitions(GGML_PERF)
endif ()

include_directories(third_party/ggml/include/ggml third_party/ggml/src)
add_subdirectory(third_party/ggml)

# sentencepiece
set(SPM_ENABLE_SHARED OFF CACHE BOOL "chatglm: disable sentencepiece shared libraries by default")
set(SPM_ENABLE_TCMALLOC OFF CACHE BOOL "chatglm: disable tcmalloc by default")
include_directories(third_party/sentencepiece/src)
add_subdirectory(third_party/sentencepiece)

include_directories(third_party/sentencepiece/third_party/protobuf-lite)

# absl
set(ABSL_ENABLE_INSTALL ON CACHE BOOL "" FORCE)
set(ABSL_PROPAGATE_CXX_STD ON CACHE BOOL "" FORCE)
add_subdirectory(third_party/abseil-cpp)

# re2
add_subdirectory(third_party/re2)

# stb
include_directories(third_party/stb)

if (GGML_METAL)
add_compile_definitions(GGML_USE_METAL)
configure_file(third_party/ggml/src/ggml-metal.metal ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
endif ()

if (GGML_PERF)
add_compile_definitions(GGML_PERF)
endif ()

include_directories(${CMAKE_CURRENT_SOURCE_DIR})

file(GLOB CPP_SOURCES
Expand Down
25 changes: 22 additions & 3 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,27 @@ ggml_tensor *GLMBlock::forward(ModelContext *mctx, ggml_tensor *hidden_states, g
return output;
}

static void alloc_weight_context(ModelContext *mctx, const ggml_backend_buffer_t sd_buf) {
void *sd_buf_base = ggml_backend_buffer_get_base(sd_buf);
const size_t sd_buf_size = ggml_backend_buffer_get_size(sd_buf);
if (ggml_backend_is_cpu(mctx->backend.get())) {
mctx->buf_w = unique_ggml_backend_buffer_t(ggml_backend_cpu_buffer_from_ptr(sd_buf_base, sd_buf_size));
}
#ifdef GGML_USE_METAL
else if (ggml_backend_is_metal(mctx->backend.get())) {
const size_t max_size = ggml_get_max_tensor_size(mctx->ctx_w.get());
mctx->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_metal_buffer_from_ptr(sd_buf_base, sd_buf_size, max_size));
}
#endif
else {
mctx->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_alloc_ctx_tensors(mctx->ctx_w.get(), mctx->backend.get()));
}
}

void ChatGLMForCausalLM::load_state_dict(const StateDict &sd) {
alloc_weight_context(sd.buf.get());
alloc_weight_context(mctx_.get(), sd.buf.get());

StateDict self_sd = state_dict();
for (auto &item : self_sd.kv) {
Expand Down Expand Up @@ -1259,7 +1278,7 @@ bool ChatGLM2Tokenizer::is_special_id(int id) const {
}

void ChatGLM2ForCausalLM::load_state_dict(const StateDict &sd) {
alloc_weight_context(sd.buf.get());
alloc_weight_context(mctx_.get(), sd.buf.get());

if (config.num_virtual_tokens > 0) {
ggml_tensor *past_key_values = sd.kv.at("past_key_values");
Expand Down Expand Up @@ -1959,7 +1978,7 @@ int ChatGLM4VForCausalLM::count_tokens(const std::vector<int> &input_ids, const
}

void ChatGLM4VForCausalLM::load_state_dict(const StateDict &sd) {
alloc_weight_context(sd.buf.get());
alloc_weight_context(mctx_.get(), sd.buf.get());

auto self_sd = state_dict();
ChatGLM2ForCausalLM::load_state_dict(mctx_.get(), self_sd, sd);
Expand Down
20 changes: 0 additions & 20 deletions chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -999,26 +999,6 @@ class BasicModelForCausalLM : public BaseModelForCausalLM {

void load_prefix_cache(ggml_tensor *past_key_values) { transformer.load_prefix_cache(config, past_key_values); }

protected:
void alloc_weight_context(const ggml_backend_buffer_t sd_buf) const {
void *sd_buf_base = ggml_backend_buffer_get_base(sd_buf);
const size_t sd_buf_size = ggml_backend_buffer_get_size(sd_buf);
if (ggml_backend_is_cpu(mctx_->backend.get())) {
mctx_->buf_w = unique_ggml_backend_buffer_t(ggml_backend_cpu_buffer_from_ptr(sd_buf_base, sd_buf_size));
}
#ifdef GGML_USE_METAL
else if (ggml_backend_is_metal(mctx_->backend.get())) {
const size_t max_size = ggml_get_max_tensor_size(mctx_->ctx_w.get());
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_metal_buffer_from_ptr(sd_buf_base, sd_buf_size, max_size));
}
#endif
else {
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_alloc_ctx_tensors(mctx_->ctx_w.get(), mctx_->backend.get()));
}
}

public:
Model transformer;
Linear lm_head;
Expand Down
2 changes: 1 addition & 1 deletion chatglm_cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chatglm_cpp._C as _C
from chatglm_cpp._C import ChatMessage, Image

__version__ = "0.4.1"
__version__ = "0.4.2"


@dataclass
Expand Down

0 comments on commit 606eb1b

Please sign in to comment.