Skip to content

Commit

Permalink
Apply flash attention on vision encoder (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Jul 31, 2024
1 parent 606eb1b commit 60c89b7
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 150 deletions.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,8 @@ python3 chatglm_cpp/convert.py -i THUDM/glm-4-9b-chat -t q4_0 -o models/chatglm4
You may use `-vt <vision_type>` to set quantization type for the vision encoder. It is recommended to run GLM4V on GPU since vision encoding runs too slow on CPU even with 4-bit quantization.
```sh
python3 chatglm_cpp/convert.py -i THUDM/glm-4v-9b -t q4_0 -vt q4_0 -o models/chatglm4v-ggml.bin
./build/bin/main -m models/chatglm4v-ggml.bin --image examples/03-Confusing-Pictures.jpg -p "这张图片有什么不寻常之处" --temp 0
# 这张图片中不寻常的是,一个男人站在一辆黄色SUV的后备箱上,正在使用一个铁板熨烫衣物。
# 通常情况下,熨衣是在室内进行的,使用的是家用电熨斗,而不是在户外使用汽车后备箱作为工作台。
# 此外,他似乎是在一个繁忙的城市街道上,周围有行驶的车辆和建筑物,这增加了场景的荒谬性。
./build/bin/main -m models/chatglm4v-ggml.bin --image examples/03-Confusing-Pictures.jpg -p "这张图片有什么不寻常的地方" --temp 0
# 这张图片中不寻常的地方在于,男子正在一辆黄色出租车后面熨衣服。通常情况下,熨衣是在家中或洗衣店进行的,而不是在车辆上。此外,出租车在行驶中,男子却能够稳定地熨衣,这增加了场景的荒诞感。
```

</details>
Expand Down
177 changes: 118 additions & 59 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,9 @@ static ggml_tensor *apply_rotary_emb_basic(ModelContext *mctx, ggml_tensor *laye
// tensor a (activation) is of shape [s, #h, d]
// tensor b (position_ids) is of shape [s]
ggml_context *ctx = mctx->ctx_b.get();
#ifdef GGML_USE_CUDA
if (!ggml_is_contiguous(layer)) {
if (ggml_cpu_has_cuda() && !ggml_is_contiguous(layer)) {
layer = ggml_cont(ctx, layer);
}
#endif
const int head_size = layer->ne[0];
layer = ggml_rope_ext_inplace(ctx, layer, position_ids, nullptr, head_size, (int)rope_type, 0, rope_theta, 1.0f,
0.0f, 1.0f, 0.0f, 0.0f); // [s, #h, d]
Expand All @@ -568,18 +566,20 @@ static ggml_tensor *apply_rotary_emb_glm(ModelContext *mctx, ggml_tensor *layer,

ggml_tensor *a1_rope = a1;
ggml_tensor *a2_rope = a2;
#ifdef GGML_USE_CUDA
a1_rope = ggml_cont(ctx, a1_rope);
a2_rope = ggml_cont(ctx, a2_rope);
#endif

if (ggml_cpu_has_cuda()) {
a1_rope = ggml_cont(ctx, a1_rope);
a2_rope = ggml_cont(ctx, a2_rope);
}

a1_rope = ggml_rope_inplace(ctx, a1_rope, b1, rope_dim, (int)RopeType::NEOX); // [s, #h, d/2]
a2_rope = ggml_rope_inplace(ctx, a2_rope, b2, rope_dim, (int)RopeType::NEOX); // [s, #h, d/2]

#ifdef GGML_USE_CUDA
a1_rope = ggml_cpy(ctx, a1_rope, a1);
a2_rope = ggml_cpy(ctx, a2_rope, a2);
#endif
if (ggml_cpu_has_cuda()) {
a1_rope = ggml_cpy(ctx, a1_rope, a1);
a2_rope = ggml_cpy(ctx, a2_rope, a2);
}

ggml_build_forward_expand(mctx->gf, a1_rope);
ggml_build_forward_expand(mctx->gf, a2_rope);

Expand All @@ -599,15 +599,15 @@ static ggml_tensor *apply_rotary_emb_glm2(ModelContext *mctx, ggml_tensor *layer
ggml_view_3d(ctx, layer, rope_dim, layer->ne[1], layer->ne[2], layer->nb[1], layer->nb[2], 0);

ggml_tensor *half_layer = half_layer_view;
#ifdef GGML_USE_CUDA
half_layer = ggml_cont(ctx, half_layer);
#endif
if (ggml_cpu_has_cuda()) {
half_layer = ggml_cont(ctx, half_layer);
}
ggml_tensor *roped_half_layer =
ggml_rope_ext_inplace(ctx, half_layer, position_ids, nullptr, rope_dim, (int)RopeType::GPTJ, 0, rope_theta,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f); // [s, #h, d]
#ifdef GGML_USE_CUDA
roped_half_layer = ggml_cpy(ctx, roped_half_layer, half_layer_view);
#endif
if (ggml_cpu_has_cuda()) {
roped_half_layer = ggml_cpy(ctx, roped_half_layer, half_layer_view);
}
ggml_build_forward_expand(mctx->gf, roped_half_layer);

return layer;
Expand Down Expand Up @@ -677,6 +677,7 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta
key_layer = ggml_permute(ctx, key_layer, 0, 2, 1, 3); // [#kvh, s, d]
value_layer = ggml_permute(ctx, value_layer, 1, 2, 0, 3); // [#kvh, d, s]

ggml_tensor *context_layer;
if (k_cache && v_cache) {
// store key & value to cache
ggml_tensor *k_cache_view =
Expand All @@ -695,46 +696,47 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta
value_layer = ggml_view_3d(ctx, v_cache, num_virtual_tokens + n_past + qlen, head_size, num_key_value_heads,
v_cache->nb[1], v_cache->nb[2],
0); // [#kvh, d, kvs]
} else {
key_layer = ggml_cont(ctx, key_layer);
value_layer = ggml_cont(ctx, value_layer);
}

// attention
query_layer = ggml_scale_inplace(ctx, query_layer, 1.f / std::sqrt(head_size));
ggml_tensor *attn_scores = ggml_mul_mat(ctx, key_layer, query_layer); // [#kvh, (#h/#kvh) * s, kvs]
// attention
query_layer = ggml_scale_inplace(ctx, query_layer, 1.f / std::sqrt(head_size));
ggml_tensor *attn_scores = ggml_mul_mat(ctx, key_layer, query_layer); // [#kvh, (#h/#kvh) * s, kvs]

if (n_past == 0) {
// build attention mask for context input
if (num_shared_q_heads > 1) {
attn_scores = ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, qlen,
num_attention_heads); // [#h, s, kvs]
}
if (n_past == 0) {
// build attention mask for context input
if (num_shared_q_heads > 1) {
attn_scores = ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, qlen,
num_attention_heads); // [#h, s, kvs]
}

if (attn_mask_type == AttentionMaskType::BIDIRECTIONAL) {
// pass
} else if (attn_mask_type == AttentionMaskType::CAUSAL) {
attn_scores = ggml_diag_mask_inf_inplace(ctx, attn_scores, num_virtual_tokens + n_past);
} else {
attn_scores = ggml_add_inplace(ctx, attn_scores, attention_mask);
if (attention_mask) {
attn_scores = ggml_add_inplace(ctx, attn_scores, attention_mask);
}

if (num_shared_q_heads > 1) {
attn_scores =
ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, num_shared_q_heads * qlen,
num_key_value_heads); // [#kvh, (#h/#kvh) * s, kvs]
}
}

ggml_tensor *attn_probs = ggml_soft_max_inplace(ctx, attn_scores); // [#kvh, (#h/#kvh) * s, kvs]

context_layer = ggml_mul_mat(ctx, value_layer, attn_probs); // [#kvh, (#h/#kvh) * s, d]
if (num_shared_q_heads > 1) {
attn_scores =
ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, num_shared_q_heads * qlen,
num_key_value_heads); // [#kvh, (#h/#kvh) * s, kvs]
context_layer = ggml_reshape_3d(ctx, context_layer, head_size, qlen,
num_attention_heads); // [#h, s, d]
}
context_layer = ggml_cont(ctx, ggml_permute(ctx, context_layer, 0, 2, 1, 3)); // [s, #h, d]
} else {
// qkv must be correctly padded
key_layer = ggml_cast(ctx, key_layer, GGML_TYPE_F16); // [#kvh, s, d]
value_layer = ggml_cast(ctx, ggml_permute(ctx, value_layer, 1, 0, 2, 3), GGML_TYPE_F16); // [#kvh, s, d]
context_layer = ggml_flash_attn_ext(ctx, query_layer, key_layer, value_layer, attention_mask,
1.f / std::sqrt(head_size), 0);
ggml_flash_attn_ext_set_prec(context_layer, GGML_PREC_F32);
}

ggml_tensor *attn_probs = ggml_soft_max_inplace(ctx, attn_scores); // [#kvh, (#h/#kvh) * s, kvs]

ggml_tensor *context_layer = ggml_mul_mat(ctx, value_layer, attn_probs); // [#kvh, (#h/#kvh) * s, d]
if (num_shared_q_heads > 1) {
context_layer = ggml_reshape_3d(ctx, context_layer, head_size, qlen,
num_attention_heads); // [#h, s, d]
}
context_layer = ggml_cont(ctx, ggml_permute(ctx, context_layer, 0, 2, 1, 3)); // [s, #h, d]
context_layer = ggml_reshape_2d(ctx, context_layer, hidden_size, qlen); // [s, #h * d]
context_layer = ggml_reshape_2d(ctx, context_layer, hidden_size, qlen); // [s, #h * d]

ggml_tensor *attn_output = dense.forward(mctx, context_layer);
return attn_output;
Expand Down Expand Up @@ -1341,6 +1343,19 @@ void ChatGLM2Model::set_graph_inputs(ggml_cgraph *gf, const std::vector<int> &in
std::vector<int> position_ids_buffer(position_ids->ne[0]);
std::iota(position_ids_buffer.begin(), position_ids_buffer.end(), n_past);
ggml_backend_tensor_set(position_ids, position_ids_buffer.data(), 0, position_ids_buffer.size() * sizeof(int));

ggml_tensor *attention_mask = ggml_graph_get_tensor(gf, "attention_mask");
if (attention_mask) {
const int kvlen = attention_mask->ne[0];
const int qlen = attention_mask->ne[1];
std::vector<float> mask_buf(qlen * kvlen);
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < kvlen; j++) {
mask_buf[i * kvlen + j] = (i < j + qlen - kvlen) ? -INFINITY : 0.f;
}
}
ggml_backend_tensor_set(attention_mask, mask_buf.data(), 0, ggml_nbytes(attention_mask));
}
}

StateDict ChatGLM2ForCausalLM::state_dict() const {
Expand Down Expand Up @@ -1827,14 +1842,14 @@ EVA2CLIPTransformer::EVA2CLIPTransformer(ModelContext *mctx, const VisionModelCo
for (int layer_id = 0; layer_id < config.num_hidden_layers; layer_id++) {
layers.emplace_back(mctx, config.dtype, config.hidden_size, config.num_attention_heads,
config.num_attention_heads, config.intermediate_size, config.num_positions, config.norm_eps,
config.hidden_act, true, true, false, RopeType::DISABLED, -1,
AttentionMaskType::BIDIRECTIONAL, 0, false);
config.hidden_act, true, true, false, RopeType::DISABLED, -1, 0, false);
}
}

ggml_tensor *EVA2CLIPTransformer::forward(ModelContext *mctx, ggml_tensor *hidden_states) const {
ggml_tensor *EVA2CLIPTransformer::forward(ModelContext *mctx, ggml_tensor *hidden_states,
ggml_tensor *attention_mask) const {
for (const auto &layer : layers) {
hidden_states = layer.forward(mctx, hidden_states, nullptr, nullptr, 0);
hidden_states = layer.forward(mctx, hidden_states, attention_mask, nullptr, 0);
}
return hidden_states;
}
Expand All @@ -1843,17 +1858,29 @@ ggml_tensor *EVA2CLIPModel::forward(ModelContext *mctx, ggml_tensor *input) cons
ggml_context *ctx = mctx->ctx_b.get();

ggml_tensor *hidden_states = patch_embedding.forward(mctx, input);
hidden_states = transformer.forward(mctx, hidden_states); // [s, hd]

const int grid_size = std::round(std::sqrt(hidden_states->ne[1] - 1));
// padding for flash attn
const int pad_to_multiple_of = ggml_cpu_has_cuda() ? 256 : GGML_KQ_MASK_PAD;
const int pad_size = GGML_PAD(hidden_states->ne[1], pad_to_multiple_of) - hidden_states->ne[1];
if (pad_size) {
hidden_states = ggml_pad(ctx, hidden_states, 0, pad_size, 0, 0);
}

ggml_tensor *encoder_attention_mask =
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hidden_states->ne[1], hidden_states->ne[1]);
ggml_set_input(encoder_attention_mask);
ggml_set_name(encoder_attention_mask, "encoder_attention_mask");

encoder_attention_mask = ggml_cast(ctx, encoder_attention_mask, GGML_TYPE_F16);
hidden_states = transformer.forward(mctx, hidden_states, encoder_attention_mask); // [s, hd]

const int grid_size = std::round(std::sqrt(hidden_states->ne[1] - pad_size - 1));
hidden_states = ggml_view_3d(ctx, hidden_states, hidden_states->ne[0], grid_size, grid_size, hidden_states->nb[1],
grid_size * hidden_states->nb[1], hidden_states->nb[1]); // [g, g, hd]
// TODO: must use this cont?
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 2, 0, 1, 3)); // [hd, g, g]
hidden_states = conv.forward(mctx, hidden_states); // [hd, g/2, g/2]
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 2, 0, 1, 3)); // [hd, g, g]
hidden_states = conv.forward(mctx, hidden_states); // [hd, g/2, g/2]
hidden_states = ggml_reshape_2d(ctx, hidden_states, hidden_states->ne[0] * hidden_states->ne[1],
hidden_states->ne[2]); // [hd, s]
// TODO: this cont?
hidden_states->ne[2]); // [hd, s]
hidden_states = ggml_cont(ctx, ggml_permute(ctx, hidden_states, 1, 0, 2, 3)); // [s, hd]

hidden_states = linear_proj.forward(mctx, hidden_states);
Expand Down Expand Up @@ -1967,6 +1994,38 @@ void ChatGLM4VModel::set_graph_inputs(ggml_cgraph *gf, const std::vector<int> &i
// copy to tensor
ggml_backend_tensor_set(image_tensor, pixels_f32.data(), 0, ggml_nbytes(image_tensor));
}

// attention_mask
ggml_tensor *attention_mask = ggml_graph_get_tensor(gf, "attention_mask");
if (attention_mask) {
const int kvlen = attention_mask->ne[0];
const int qlen = attention_mask->ne[1];
std::vector<float> mask_buf(qlen * kvlen);
for (int i = 0; i < qlen; i++) {
for (int j = 0; j < kvlen; j++) {
mask_buf[i * kvlen + j] = (i < j + qlen - kvlen) ? -INFINITY : 0.f;
}
}
ggml_backend_tensor_set(attention_mask, mask_buf.data(), 0, ggml_nbytes(attention_mask));
}

// encoder_attention_mask
ggml_tensor *encoder_attention_mask = ggml_graph_get_tensor(gf, "encoder_attention_mask");
if (encoder_attention_mask) {
const int valid_tokens = vision.patch_embedding.num_positions();
const int M = encoder_attention_mask->ne[1];
const int N = encoder_attention_mask->ne[0];
std::vector<float> encoder_mask_f32(M * N);
CHATGLM_CHECK((size_t)ggml_nelements(encoder_attention_mask) == encoder_mask_f32.size());
for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) {
encoder_mask_f32[i * N + j] =
(i < valid_tokens && j < valid_tokens) ? 0.f : -65504.f; // -INFINITY causes nan/inf logits
}
}
ggml_backend_tensor_set(encoder_attention_mask, encoder_mask_f32.data(), 0,
ggml_nbytes(encoder_attention_mask));
}
}

int ChatGLM4VForCausalLM::count_tokens(const std::vector<int> &input_ids, const std::optional<Image> &image) const {
Expand Down
Loading

0 comments on commit 60c89b7

Please sign in to comment.