diff --git a/README.md b/README.md index 80e2f28..f4837a0 100644 --- a/README.md +++ b/README.md @@ -196,10 +196,8 @@ python3 chatglm_cpp/convert.py -i THUDM/glm-4-9b-chat -t q4_0 -o models/chatglm4 You may use `-vt ` to set quantization type for the vision encoder. It is recommended to run GLM4V on GPU since vision encoding runs too slow on CPU even with 4-bit quantization. ```sh python3 chatglm_cpp/convert.py -i THUDM/glm-4v-9b -t q4_0 -vt q4_0 -o models/chatglm4v-ggml.bin -./build/bin/main -m models/chatglm4v-ggml.bin --image examples/03-Confusing-Pictures.jpg -p "这张图片有什么不寻常之处" --temp 0 -# 这张图片中不寻常的是,一个男人站在一辆黄色SUV的后备箱上,正在使用一个铁板熨烫衣物。 -# 通常情况下,熨衣是在室内进行的,使用的是家用电熨斗,而不是在户外使用汽车后备箱作为工作台。 -# 此外,他似乎是在一个繁忙的城市街道上,周围有行驶的车辆和建筑物,这增加了场景的荒谬性。 +./build/bin/main -m models/chatglm4v-ggml.bin --image examples/03-Confusing-Pictures.jpg -p "这张图片有什么不寻常的地方" --temp 0 +# 这张图片中不寻常的地方在于,男子正在一辆黄色出租车后面熨衣服。通常情况下,熨衣是在家中或洗衣店进行的,而不是在车辆上。此外,出租车在行驶中,男子却能够稳定地熨衣,这增加了场景的荒诞感。 ``` diff --git a/chatglm.cpp b/chatglm.cpp index a9ed601..a8fc4e9 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -538,11 +538,9 @@ static ggml_tensor *apply_rotary_emb_basic(ModelContext *mctx, ggml_tensor *laye // tensor a (activation) is of shape [s, #h, d] // tensor b (position_ids) is of shape [s] ggml_context *ctx = mctx->ctx_b.get(); -#ifdef GGML_USE_CUDA - if (!ggml_is_contiguous(layer)) { + if (ggml_cpu_has_cuda() && !ggml_is_contiguous(layer)) { layer = ggml_cont(ctx, layer); } -#endif const int head_size = layer->ne[0]; layer = ggml_rope_ext_inplace(ctx, layer, position_ids, nullptr, head_size, (int)rope_type, 0, rope_theta, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); // [s, #h, d] @@ -568,18 +566,20 @@ static ggml_tensor *apply_rotary_emb_glm(ModelContext *mctx, ggml_tensor *layer, ggml_tensor *a1_rope = a1; ggml_tensor *a2_rope = a2; -#ifdef GGML_USE_CUDA - a1_rope = ggml_cont(ctx, a1_rope); - a2_rope = ggml_cont(ctx, a2_rope); -#endif + + if (ggml_cpu_has_cuda()) { + a1_rope = ggml_cont(ctx, a1_rope); + a2_rope = ggml_cont(ctx, a2_rope); + } a1_rope = ggml_rope_inplace(ctx, a1_rope, b1, rope_dim, (int)RopeType::NEOX); // [s, #h, d/2] a2_rope = ggml_rope_inplace(ctx, a2_rope, b2, rope_dim, (int)RopeType::NEOX); // [s, #h, d/2] -#ifdef GGML_USE_CUDA - a1_rope = ggml_cpy(ctx, a1_rope, a1); - a2_rope = ggml_cpy(ctx, a2_rope, a2); -#endif + if (ggml_cpu_has_cuda()) { + a1_rope = ggml_cpy(ctx, a1_rope, a1); + a2_rope = ggml_cpy(ctx, a2_rope, a2); + } + ggml_build_forward_expand(mctx->gf, a1_rope); ggml_build_forward_expand(mctx->gf, a2_rope); @@ -599,15 +599,15 @@ static ggml_tensor *apply_rotary_emb_glm2(ModelContext *mctx, ggml_tensor *layer ggml_view_3d(ctx, layer, rope_dim, layer->ne[1], layer->ne[2], layer->nb[1], layer->nb[2], 0); ggml_tensor *half_layer = half_layer_view; -#ifdef GGML_USE_CUDA - half_layer = ggml_cont(ctx, half_layer); -#endif + if (ggml_cpu_has_cuda()) { + half_layer = ggml_cont(ctx, half_layer); + } ggml_tensor *roped_half_layer = ggml_rope_ext_inplace(ctx, half_layer, position_ids, nullptr, rope_dim, (int)RopeType::GPTJ, 0, rope_theta, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); // [s, #h, d] -#ifdef GGML_USE_CUDA - roped_half_layer = ggml_cpy(ctx, roped_half_layer, half_layer_view); -#endif + if (ggml_cpu_has_cuda()) { + roped_half_layer = ggml_cpy(ctx, roped_half_layer, half_layer_view); + } ggml_build_forward_expand(mctx->gf, roped_half_layer); return layer; @@ -677,6 +677,7 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta key_layer = ggml_permute(ctx, key_layer, 0, 2, 1, 3); // [#kvh, s, d] value_layer = ggml_permute(ctx, value_layer, 1, 2, 0, 3); // [#kvh, d, s] + ggml_tensor *context_layer; if (k_cache && v_cache) { // store key & value to cache ggml_tensor *k_cache_view = @@ -695,46 +696,47 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta value_layer = ggml_view_3d(ctx, v_cache, num_virtual_tokens + n_past + qlen, head_size, num_key_value_heads, v_cache->nb[1], v_cache->nb[2], 0); // [#kvh, d, kvs] - } else { - key_layer = ggml_cont(ctx, key_layer); - value_layer = ggml_cont(ctx, value_layer); - } - // attention - query_layer = ggml_scale_inplace(ctx, query_layer, 1.f / std::sqrt(head_size)); - ggml_tensor *attn_scores = ggml_mul_mat(ctx, key_layer, query_layer); // [#kvh, (#h/#kvh) * s, kvs] + // attention + query_layer = ggml_scale_inplace(ctx, query_layer, 1.f / std::sqrt(head_size)); + ggml_tensor *attn_scores = ggml_mul_mat(ctx, key_layer, query_layer); // [#kvh, (#h/#kvh) * s, kvs] - if (n_past == 0) { - // build attention mask for context input - if (num_shared_q_heads > 1) { - attn_scores = ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, qlen, - num_attention_heads); // [#h, s, kvs] - } + if (n_past == 0) { + // build attention mask for context input + if (num_shared_q_heads > 1) { + attn_scores = ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, qlen, + num_attention_heads); // [#h, s, kvs] + } - if (attn_mask_type == AttentionMaskType::BIDIRECTIONAL) { - // pass - } else if (attn_mask_type == AttentionMaskType::CAUSAL) { - attn_scores = ggml_diag_mask_inf_inplace(ctx, attn_scores, num_virtual_tokens + n_past); - } else { - attn_scores = ggml_add_inplace(ctx, attn_scores, attention_mask); + if (attention_mask) { + attn_scores = ggml_add_inplace(ctx, attn_scores, attention_mask); + } + + if (num_shared_q_heads > 1) { + attn_scores = + ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, num_shared_q_heads * qlen, + num_key_value_heads); // [#kvh, (#h/#kvh) * s, kvs] + } } + ggml_tensor *attn_probs = ggml_soft_max_inplace(ctx, attn_scores); // [#kvh, (#h/#kvh) * s, kvs] + + context_layer = ggml_mul_mat(ctx, value_layer, attn_probs); // [#kvh, (#h/#kvh) * s, d] if (num_shared_q_heads > 1) { - attn_scores = - ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, num_shared_q_heads * qlen, - num_key_value_heads); // [#kvh, (#h/#kvh) * s, kvs] + context_layer = ggml_reshape_3d(ctx, context_layer, head_size, qlen, + num_attention_heads); // [#h, s, d] } + context_layer = ggml_cont(ctx, ggml_permute(ctx, context_layer, 0, 2, 1, 3)); // [s, #h, d] + } else { + // qkv must be correctly padded + key_layer = ggml_cast(ctx, key_layer, GGML_TYPE_F16); // [#kvh, s, d] + value_layer = ggml_cast(ctx, ggml_permute(ctx, value_layer, 1, 0, 2, 3), GGML_TYPE_F16); // [#kvh, s, d] + context_layer = ggml_flash_attn_ext(ctx, query_layer, key_layer, value_layer, attention_mask, + 1.f / std::sqrt(head_size), 0); + ggml_flash_attn_ext_set_prec(context_layer, GGML_PREC_F32); } - ggml_tensor *attn_probs = ggml_soft_max_inplace(ctx, attn_scores); // [#kvh, (#h/#kvh) * s, kvs] - - ggml_tensor *context_layer = ggml_mul_mat(ctx, value_layer, attn_probs); // [#kvh, (#h/#kvh) * s, d] - if (num_shared_q_heads > 1) { - context_layer = ggml_reshape_3d(ctx, context_layer, head_size, qlen, - num_attention_heads); // [#h, s, d] - } - context_layer = ggml_cont(ctx, ggml_permute(ctx, context_layer, 0, 2, 1, 3)); // [s, #h, d] - context_layer = ggml_reshape_2d(ctx, context_layer, hidden_size, qlen); // [s, #h * d] + context_layer = ggml_reshape_2d(ctx, context_layer, hidden_size, qlen); // [s, #h * d] ggml_tensor *attn_output = dense.forward(mctx, context_layer); return attn_output; @@ -1341,6 +1343,19 @@ void ChatGLM2Model::set_graph_inputs(ggml_cgraph *gf, const std::vector &in std::vector position_ids_buffer(position_ids->ne[0]); std::iota(position_ids_buffer.begin(), position_ids_buffer.end(), n_past); ggml_backend_tensor_set(position_ids, position_ids_buffer.data(), 0, position_ids_buffer.size() * sizeof(int)); + + ggml_tensor *attention_mask = ggml_graph_get_tensor(gf, "attention_mask"); + if (attention_mask) { + const int kvlen = attention_mask->ne[0]; + const int qlen = attention_mask->ne[1]; + std::vector mask_buf(qlen * kvlen); + for (int i = 0; i < qlen; i++) { + for (int j = 0; j < kvlen; j++) { + mask_buf[i * kvlen + j] = (i < j + qlen - kvlen) ? -INFINITY : 0.f; + } + } + ggml_backend_tensor_set(attention_mask, mask_buf.data(), 0, ggml_nbytes(attention_mask)); + } } StateDict ChatGLM2ForCausalLM::state_dict() const { @@ -1827,14 +1842,14 @@ EVA2CLIPTransformer::EVA2CLIPTransformer(ModelContext *mctx, const VisionModelCo for (int layer_id = 0; layer_id < config.num_hidden_layers; layer_id++) { layers.emplace_back(mctx, config.dtype, config.hidden_size, config.num_attention_heads, config.num_attention_heads, config.intermediate_size, config.num_positions, config.norm_eps, - config.hidden_act, true, true, false, RopeType::DISABLED, -1, - AttentionMaskType::BIDIRECTIONAL, 0, false); + config.hidden_act, true, true, false, RopeType::DISABLED, -1, 0, false); } } -ggml_tensor *EVA2CLIPTransformer::forward(ModelContext *mctx, ggml_tensor *hidden_states) const { +ggml_tensor *EVA2CLIPTransformer::forward(ModelContext *mctx, ggml_tensor *hidden_states, + ggml_tensor *attention_mask) const { for (const auto &layer : layers) { - hidden_states = layer.forward(mctx, hidden_states, nullptr, nullptr, 0); + hidden_states = layer.forward(mctx, hidden_states, attention_mask, nullptr, 0); } return hidden_states; } @@ -1843,17 +1858,29 @@ ggml_tensor *EVA2CLIPModel::forward(ModelContext *mctx, ggml_tensor *input) cons ggml_context *ctx = mctx->ctx_b.get(); ggml_tensor *hidden_states = patch_embedding.forward(mctx, input); - hidden_states = transformer.forward(mctx, hidden_states); // [s, hd] - const int grid_size = std::round(std::sqrt(hidden_states->ne[1] - 1)); + // padding for flash attn + const int pad_to_multiple_of = ggml_cpu_has_cuda() ? 256 : GGML_KQ_MASK_PAD; + const int pad_size = GGML_PAD(hidden_states->ne[1], pad_to_multiple_of) - hidden_states->ne[1]; + if (pad_size) { + hidden_states = ggml_pad(ctx, hidden_states, 0, pad_size, 0, 0); + } + + ggml_tensor *encoder_attention_mask = + ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_states->ne[1], hidden_states->ne[1]); + ggml_set_input(encoder_attention_mask); + ggml_set_name(encoder_attention_mask, "encoder_attention_mask"); + + encoder_attention_mask = ggml_cast(ctx, encoder_attention_mask, GGML_TYPE_F16); + hidden_states = transformer.forward(mctx, hidden_states, encoder_attention_mask); // [s, hd] + + const int grid_size = std::round(std::sqrt(hidden_states->ne[1] - pad_size - 1)); hidden_states = ggml_view_3d(ctx, hidden_states, hidden_states->ne[0], grid_size, grid_size, hidden_states->nb[1], grid_size * hidden_states->nb[1], hidden_states->nb[1]); // [g, g, hd] - // TODO: must use this cont? - hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 2, 0, 1, 3)); // [hd, g, g] - hidden_states = conv.forward(mctx, hidden_states); // [hd, g/2, g/2] + hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 2, 0, 1, 3)); // [hd, g, g] + hidden_states = conv.forward(mctx, hidden_states); // [hd, g/2, g/2] hidden_states = ggml_reshape_2d(ctx, hidden_states, hidden_states->ne[0] * hidden_states->ne[1], - hidden_states->ne[2]); // [hd, s] - // TODO: this cont? + hidden_states->ne[2]); // [hd, s] hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 1, 0, 2, 3)); // [s, hd] hidden_states = linear_proj.forward(mctx, hidden_states); @@ -1967,6 +1994,38 @@ void ChatGLM4VModel::set_graph_inputs(ggml_cgraph *gf, const std::vector &i // copy to tensor ggml_backend_tensor_set(image_tensor, pixels_f32.data(), 0, ggml_nbytes(image_tensor)); } + + // attention_mask + ggml_tensor *attention_mask = ggml_graph_get_tensor(gf, "attention_mask"); + if (attention_mask) { + const int kvlen = attention_mask->ne[0]; + const int qlen = attention_mask->ne[1]; + std::vector mask_buf(qlen * kvlen); + for (int i = 0; i < qlen; i++) { + for (int j = 0; j < kvlen; j++) { + mask_buf[i * kvlen + j] = (i < j + qlen - kvlen) ? -INFINITY : 0.f; + } + } + ggml_backend_tensor_set(attention_mask, mask_buf.data(), 0, ggml_nbytes(attention_mask)); + } + + // encoder_attention_mask + ggml_tensor *encoder_attention_mask = ggml_graph_get_tensor(gf, "encoder_attention_mask"); + if (encoder_attention_mask) { + const int valid_tokens = vision.patch_embedding.num_positions(); + const int M = encoder_attention_mask->ne[1]; + const int N = encoder_attention_mask->ne[0]; + std::vector encoder_mask_f32(M * N); + CHATGLM_CHECK((size_t)ggml_nelements(encoder_attention_mask) == encoder_mask_f32.size()); + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + encoder_mask_f32[i * N + j] = + (i < valid_tokens && j < valid_tokens) ? 0.f : -65504.f; // -INFINITY causes nan/inf logits + } + } + ggml_backend_tensor_set(encoder_attention_mask, encoder_mask_f32.data(), 0, + ggml_nbytes(encoder_attention_mask)); + } } int ChatGLM4VForCausalLM::count_tokens(const std::vector &input_ids, const std::optional &image) const { diff --git a/chatglm.h b/chatglm.h index 9872e82..2f24437 100644 --- a/chatglm.h +++ b/chatglm.h @@ -195,7 +195,6 @@ class ModelConfig { interleaved_qkv = true; tie_word_embeddings = true; rope_type = RopeType::CHATGLM; - attn_mask_type = AttentionMaskType::CHATGLM; } else { hidden_act = ActivationType::SILU; use_qkv_bias = true; @@ -203,7 +202,6 @@ class ModelConfig { interleaved_qkv = false; tie_word_embeddings = false; rope_type = RopeType::CHATGLM2; - attn_mask_type = AttentionMaskType::CAUSAL; } } @@ -238,11 +236,10 @@ class ModelConfig { << ", hidden_act=" << (int)self.hidden_act << ", use_qkv_bias=" << self.use_qkv_bias << ", use_dense_bias=" << self.use_dense_bias << ", interleaved_qkv=" << self.interleaved_qkv << ", tie_word_embeddings=" << self.tie_word_embeddings << ", rope_type=" << (int)self.rope_type - << ", rope_theta=" << self.rope_theta << ", attn_mask_type=" << (int)self.attn_mask_type - << ", num_virtual_tokens=" << self.num_virtual_tokens << ", max_length=" << self.max_length - << ", bos_token_id=" << self.bos_token_id << ", eos_token_id=" << self.eos_token_id - << ", pad_token_id=" << self.pad_token_id << ", sep_token_id=" << self.sep_token_id - << ", extra_eos_token_ids={"; + << ", rope_theta=" << self.rope_theta << ", num_virtual_tokens=" << self.num_virtual_tokens + << ", max_length=" << self.max_length << ", bos_token_id=" << self.bos_token_id + << ", eos_token_id=" << self.eos_token_id << ", pad_token_id=" << self.pad_token_id + << ", sep_token_id=" << self.sep_token_id << ", extra_eos_token_ids={"; for (size_t i = 0; i < self.extra_eos_token_ids.size(); i++) { os << (i > 0 ? ", " : "") << self.extra_eos_token_ids[i]; } @@ -266,7 +263,6 @@ class ModelConfig { bool tie_word_embeddings; RopeType rope_type; float rope_theta; - AttentionMaskType attn_mask_type; int num_virtual_tokens; int max_length; int bos_token_id; @@ -573,18 +569,17 @@ class BasicAttention { BasicAttention(ModelContext *mctx, ggml_type dtype, int hidden_size, int num_attention_heads, int num_key_value_heads, int max_length, bool use_qkv_bias, bool use_dense_bias, - bool interleaved_qkv, RopeType rope_type, float rope_theta, AttentionMaskType attn_mask_type, - int num_virtual_tokens, bool use_cache) + bool interleaved_qkv, RopeType rope_type, float rope_theta, int num_virtual_tokens, bool use_cache) : num_attention_heads(num_attention_heads), num_key_value_heads(num_key_value_heads), interleaved_qkv(interleaved_qkv), rope_type(rope_type), rope_theta(rope_theta), - attn_mask_type(attn_mask_type), num_virtual_tokens(num_virtual_tokens), + num_virtual_tokens(num_virtual_tokens), query_key_value(mctx, dtype, hidden_size, hidden_size + 2 * (hidden_size / num_attention_heads) * num_key_value_heads, use_qkv_bias), dense(mctx, dtype, hidden_size, hidden_size, use_dense_bias), - k_cache(use_cache ? ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F32, hidden_size / num_attention_heads, + k_cache(use_cache ? ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads, max_length + num_virtual_tokens, num_key_value_heads) : nullptr), - v_cache(use_cache ? ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F32, max_length + num_virtual_tokens, + v_cache(use_cache ? ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F16, max_length + num_virtual_tokens, hidden_size / num_attention_heads, num_key_value_heads) : nullptr) {} @@ -597,7 +592,6 @@ class BasicAttention { bool interleaved_qkv; RopeType rope_type; float rope_theta; - AttentionMaskType attn_mask_type; int num_virtual_tokens; Linear query_key_value; Linear dense; @@ -611,12 +605,11 @@ class BasicBlock { BasicBlock() = default; BasicBlock(ModelContext *mctx, ggml_type dtype, int hidden_size, int num_attention_heads, int num_key_value_heads, int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, - bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta, - AttentionMaskType attn_mask_type, int num_virtual_tokens, bool use_cache) + bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta, int num_virtual_tokens, + bool use_cache) : input_layernorm(mctx, hidden_size, norm_eps), attention(mctx, dtype, hidden_size, num_attention_heads, num_key_value_heads, max_length, use_qkv_bias, - use_dense_bias, interleaved_qkv, rope_type, rope_theta, attn_mask_type, num_virtual_tokens, - use_cache), + use_dense_bias, interleaved_qkv, rope_type, rope_theta, num_virtual_tokens, use_cache), post_attention_layernorm(mctx, hidden_size, norm_eps), mlp(mctx, dtype, hidden_size, intermediate_size, hidden_act) {} @@ -663,17 +656,7 @@ struct GLMPositionIdsAllocator { } }; -struct NoopAttentionMaskAllocator { - ggml_tensor *operator()(ggml_context *ctx, int qlen, int kvlen) const { return nullptr; } -}; - -struct BasicAttentionMaskAllocator { - ggml_tensor *operator()(ggml_context *ctx, int qlen, int kvlen) const { - return ggml_new_tensor_2d(ctx, GGML_TYPE_F32, kvlen, qlen); - } -}; - -template +template class BasicModel { public: BasicModel() = default; @@ -700,8 +683,9 @@ class BasicModel { ggml_set_input(position_ids); } - ggml_tensor *attention_mask = attn_mask_alloc_(ctx, qlen, kvlen); - if (attention_mask) { + ggml_tensor *attention_mask = nullptr; + if (n_past == 0) { + attention_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, kvlen, qlen); ggml_set_name(attention_mask, "attention_mask"); ggml_set_input(attention_mask); } @@ -765,8 +749,8 @@ class BasicModel { layers.emplace_back(mctx, config.dtype, config.hidden_size, config.num_attention_heads, config.num_key_value_heads, config.intermediate_size, config.max_length, config.norm_eps, config.hidden_act, config.use_qkv_bias, config.use_dense_bias, - config.interleaved_qkv, config.rope_type, config.rope_theta, config.attn_mask_type, - config.num_virtual_tokens, true); + config.interleaved_qkv, config.rope_type, config.rope_theta, config.num_virtual_tokens, + true); } mctx->buf_kv = unique_ggml_backend_buffer_t(ggml_backend_alloc_ctx_tensors(mctx->ctx_kv.get(), mctx->backend.get())); @@ -779,7 +763,6 @@ class BasicModel { Norm final_layernorm; private: - AttentionMaskAllocator attn_mask_alloc_; PositionIdsAllocator pos_ids_alloc_; }; @@ -1039,12 +1022,12 @@ class GLMBlock : public BasicBlock { GLMBlock(ModelContext *mctx, ggml_type dtype, int hidden_size, int num_attention_heads, int num_key_value_heads, int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, - bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta, - AttentionMaskType attn_mask_type, int num_virtual_tokens, bool use_cache) + bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta, int num_virtual_tokens, + bool use_cache) : BasicBlock(LayerNorm(mctx, hidden_size, norm_eps), BasicAttention(mctx, dtype, hidden_size, num_attention_heads, num_attention_heads, max_length, use_qkv_bias, use_dense_bias, interleaved_qkv, rope_type, rope_theta, - attn_mask_type, num_virtual_tokens, use_cache), + num_virtual_tokens, use_cache), LayerNorm(mctx, hidden_size, norm_eps), BasicMLP(mctx, dtype, hidden_size, intermediate_size, hidden_act)), alpha(std::sqrt(2.f * 28)) {} @@ -1056,7 +1039,7 @@ class GLMBlock : public BasicBlock { float alpha; }; -class ChatGLMModel : public BasicModel { +class ChatGLMModel : public BasicModel { public: ChatGLMModel() = default; @@ -1104,7 +1087,7 @@ class ChatGLM2Tokenizer : public BaseTokenizer { using GLM2Block = BasicBlock; -class ChatGLM2Model : public BasicModel { +class ChatGLM2Model : public BasicModel { public: ChatGLM2Model() = default; @@ -1289,11 +1272,10 @@ class EVA2CLIPBlock : public BasicBlock { EVA2CLIPBlock(ModelContext *mctx, ggml_type dtype, int hidden_size, int num_attention_heads, int num_key_value_heads, int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, - RopeType rope_type, float rope_theta, AttentionMaskType attn_mask_type, int num_virtual_tokens, - bool use_cache) + RopeType rope_type, float rope_theta, int num_virtual_tokens, bool use_cache) : BasicBlock(mctx, dtype, hidden_size, num_attention_heads, num_key_value_heads, intermediate_size, max_length, norm_eps, hidden_act, use_qkv_bias, use_dense_bias, interleaved_qkv, rope_type, rope_theta, - attn_mask_type, num_virtual_tokens, use_cache) {} + num_virtual_tokens, use_cache) {} ggml_tensor *forward(ModelContext *mctx, ggml_tensor *hidden_states, ggml_tensor *attention_mask, ggml_tensor *position_ids, int n_past) const; @@ -1305,7 +1287,7 @@ class EVA2CLIPTransformer { EVA2CLIPTransformer(ModelContext *mctx, const VisionModelConfig &config); - ggml_tensor *forward(ModelContext *mctx, ggml_tensor *hidden_states) const; + ggml_tensor *forward(ModelContext *mctx, ggml_tensor *hidden_states, ggml_tensor *attention_mask) const; public: std::vector layers; diff --git a/chatglm_test.cpp b/chatglm_test.cpp index ce9e278..5f371f4 100644 --- a/chatglm_test.cpp +++ b/chatglm_test.cpp @@ -773,7 +773,7 @@ TEST_F(ChatGLMTest, GLM4Model) { TEST_F(ChatGLMTest, GLM4VModelText) { fs::path data_path = fs::path(__FILE__).parent_path() / "tests/data/glm4v_model_text.data"; - VisionModelConfig vision(GGML_TYPE_F32, ActivationType::GELU, 32, 28, 3, 56, 1e-6, 2, 1, 17, 7, 8); + VisionModelConfig vision(GGML_TYPE_F32, ActivationType::GELU, 128, 28, 3, 56, 1e-6, 2, 1, 17, 7, 8); ModelConfig config(ModelType::CHATGLM4V, GGML_TYPE_F32, /*vocab_size=*/8, /*hidden_size=*/32, /*num_attention_heads=*/8, /*num_key_value_heads=*/2, /*num_hidden_layers=*/1, @@ -804,7 +804,7 @@ TEST_F(ChatGLMTest, GLM4VModelText) { TEST_F(ChatGLMTest, GLM4VModel) { fs::path data_path = fs::path(__FILE__).parent_path() / "tests/data/glm4v_model.data"; - VisionModelConfig vision(GGML_TYPE_F32, ActivationType::GELU, 32, 28, 3, 56, 1e-6, 2, 1, 17, 7, 8); + VisionModelConfig vision(GGML_TYPE_F32, ActivationType::GELU, 128, 28, 3, 56, 1e-6, 2, 1, 17, 7, 8); ModelConfig config(ModelType::CHATGLM4V, GGML_TYPE_F32, /*vocab_size=*/8, /*hidden_size=*/32, /*num_attention_heads=*/8, /*num_key_value_heads=*/2, /*num_hidden_layers=*/1, @@ -1337,41 +1337,32 @@ TEST(Pipeline, ChatGLM3) { gen_config.do_sample = false; std::vector messages{ {ChatMessage::ROLE_SYSTEM, system_ci}, - {ChatMessage::ROLE_USER, "找出100以内的所有质数"}, + {ChatMessage::ROLE_USER, "求出100以内的所有质数"}, }; { ChatMessage output = pipeline.chat(messages, gen_config); EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); EXPECT_EQ(output.content, - R"(质数是只能被1和它本身整除的正整数。我们可以通过简单的算法来找出100以内的所有质数。 + R"(质数是只能被1和自身整除的正整数。我们可以通过遍历1到100的所有数字来找出100以内的所有质数。 -这里我们将使用一个简单的线性筛法来找出100以内的所有质数。 - -线性筛法的基本思想是: -1. 创建一个列表,其中包含1到n的所有整数。 -2. 从列表中删除所有可以被2整除的数。 -3. 然后从剩余的数中删除所有可以被3整除的数。 -4. 重复上述步骤,直到列表中的数少于100为止。 - -让我们开始计算。)"); +下面是找出100以内的所有质数的Python代码:)"); EXPECT_EQ(output.tool_calls.at(0).code.input, R"(```python -def sieve_of_eratosthenes(n): - # Create a boolean array "prime[0..n]" and initialize all entries as true. - # A value in prime[i] will finally be false if i is Not a prime, else true bool val. - prime = [True for _ in range(n+1)] - p = 2 - while p**2 <= n: - # If prime[p] is not changed, then it is a prime - if prime[p]: - # Update all multiples of p - for i in range(p**2, n+1, p): - prime[i] = False - p += 1 - - # Return the list of prime numbers - return [p for p in range(2, n+1) if prime[p]] - -primes_upto_100 = sieve_of_eratosthenes(100) +def is_prime(n): + """Check if a number is prime.""" + if n <= 1: + return False + if n <= 3: + return True + if n % 2 == 0 or n % 3 == 0: + return False + i = 5 + while i * i <= n: + if n % i == 0 or n % (i + 2) == 0: + return False + i += 6 + return True + +primes_upto_100 = [i for i in range(2, 101) if is_prime(i)] primes_upto_100 ```)"); messages.emplace_back(std::move(output)); @@ -1382,9 +1373,12 @@ primes_upto_100 { ChatMessage output = pipeline.chat(messages, gen_config); EXPECT_EQ(output.role, ChatMessage::ROLE_ASSISTANT); - EXPECT_EQ( - output.content, - R"(100以内的所有质数是:2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97。)"); + EXPECT_EQ(output.content, + R"(100以内的所有质数为: + +$$ +2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97 +$$)"); } } } @@ -1593,12 +1587,11 @@ TEST(Pipeline, ChatGLM4V) { gen_config.do_sample = false; fs::path image_path = fs::path(__FILE__).parent_path() / "examples/03-Confusing-Pictures.jpg"; Image image = Image::open(image_path.string()); - std::vector messages{{ChatMessage::ROLE_USER, "描述这张图片", std::move(image)}}; + std::vector messages{{ChatMessage::ROLE_USER, "这张图片有什么不寻常的地方", std::move(image)}}; ChatMessage output = pipeline.chat(messages, gen_config); EXPECT_EQ(output.content, - "这张图片的幽默之处在于场景的荒谬性。看到一个人在出租车后面熨衣服是非常不寻常的,因为这不是出租车通常" - "被用来做的事情。熨衣板和衣服悬挂在车尾,男子站在上面,似乎在专心熨烫衣物。出租车在行驶中,而男子却稳" - "稳地站在上面熨衣服,这样的场景给人一种不稳定的感觉,增加了喜剧效果。"); + "这张图片中不寻常的地方在于,男子正在一辆黄色出租车后面熨衣服。通常情况下,熨衣是在家中或洗衣店进行的" + ",而不是在车辆上。此外,出租车在行驶中,男子却能够稳定地熨衣,这增加了场景的荒诞感。"); } } diff --git a/tests/data/glm4v_model.data b/tests/data/glm4v_model.data index df2133e..ae0fa3f 100644 Binary files a/tests/data/glm4v_model.data and b/tests/data/glm4v_model.data differ diff --git a/tests/data/glm4v_model_text.data b/tests/data/glm4v_model_text.data index 7990ca0..9b89e86 100644 Binary files a/tests/data/glm4v_model_text.data and b/tests/data/glm4v_model_text.data differ diff --git a/tests/test_chatglm_cpp.py b/tests/test_chatglm_cpp.py index f0b90e2..29000af 100644 --- a/tests/test_chatglm_cpp.py +++ b/tests/test_chatglm_cpp.py @@ -98,9 +98,9 @@ def test_chatglm4v_pipeline(): ) check_pipeline( model_path=CHATGLM4V_MODEL_PATH, - prompt="这张图片有什么不寻常之处", + prompt="这张图片有什么不寻常的地方", image=image, - target="这张图片中不寻常的是,一个男人站在一辆黄色SUV的后备箱上,正在使用一个铁板熨烫衣物。通常情况下,熨衣是在室内进行的,使用的是家用电熨斗,而不是在户外使用汽车后备箱作为工作台。此外,他似乎是在一个繁忙的城市街道上,周围有行驶的车辆和建筑物,这增加了场景的异想天开性。", + target="这张图片中不寻常的地方在于,一个男人站在一辆黄色SUV的后备箱上,正在使用一个铁板熨烫衣物。通常情况下,熨衣是在室内进行的,使用的是家用熨斗和熨衣板。然而,这个男人却在车外,后备箱充当了临时的工作台。他似乎是在为出租车内的乘客熨烫衣物,这样的场景在现实生活中是比较少见的。", ) @@ -182,7 +182,7 @@ def test_openai_api_vision(): { "role": "user", "content": [ - {"type": "text", "text": "这张图片有什么不寻常之处"}, + {"type": "text", "text": "这张图片有什么不寻常的地方"}, { "type": "image_url", "image_url": { @@ -200,7 +200,7 @@ def test_openai_api_vision(): assert response_message["role"] == "assistant" assert ( response_message["content"] - == "这张图片中不寻常的是,一个男人站在一辆黄色SUV的后备箱上,正在使用一个铁板熨烫衣物。通常情况下,熨衣是在室内进行的,使用的是家用电熨斗,而不是在户外使用汽车后备箱作为工作台。此外,他似乎是在一个繁忙的城市街道上,周围有行驶的车辆和建筑物,这增加了场景的异想天开性。" + == "这张图片中不寻常的地方在于,一个男人站在一辆黄色SUV的后备箱上,正在使用一个铁板熨烫衣物。通常情况下,熨衣是在室内进行的,使用的是家用熨斗和熨衣板。然而,这个男人却在车外,后备箱充当了临时的工作台。他似乎是在为出租车内的乘客熨烫衣物,这样的场景在现实生活中是比较少见的。" ) # request with base64 image @@ -210,7 +210,7 @@ def test_openai_api_vision(): { "role": "user", "content": [ - {"type": "text", "text": "这张图片有什么不寻常之处"}, + {"type": "text", "text": "这张图片有什么不寻常的地方"}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, ], } @@ -223,5 +223,5 @@ def test_openai_api_vision(): assert response_message["role"] == "assistant" assert ( response_message["content"] - == "这张图片中不寻常的是,一个男人站在一辆黄色SUV的后备箱上,正在使用一个铁板熨烫衣物。通常情况下,熨衣是在室内进行的,使用的是家用电熨斗,而不是在户外使用汽车后备箱作为工作台。此外,他似乎是在一个繁忙的城市街道上,周围有行驶的车辆和建筑物,这增加了场景的异想天开性。" + == "这张图片中不寻常的地方在于,一个男人站在一辆黄色SUV的后备箱上,正在使用一个铁板熨烫衣物。通常情况下,熨衣是在室内进行的,使用的是家用熨斗和熨衣板。然而,这个男人却在车外,后备箱充当了临时的工作台。他似乎是在为出租车内的乘客熨烫衣物,这样的场景在现实生活中是比较少见的。" ) diff --git a/tests/test_convert.py b/tests/test_convert.py index b74a495..cf7cb3f 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -611,7 +611,7 @@ def make_data_glm4v_model(): config.torch_dtype = torch.float32 config.vision_config.update( num_hidden_layers=1, - hidden_size=32, + hidden_size=128, patch_size=7, num_heads=2, intermediate_size=56,