Skip to content

reduce inference memory of vit #819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions lightllm/models/vit/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def forward(self, vit_embeds, layer_weight: ViTPreAndPostLayerWeight):
layer_weight.mlp1_1_bias_, vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1]), layer_weight.mlp1_1_weight_
)

# vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)
vit_embeds_gelu = gelu_fwd(vit_embeds_1)
vit_embeds_gelu = gelu_fwd(vit_embeds_1, use_custom_tensor_mananger=True)

vit_embeds_out = torch.addmm(
layer_weight.mlp1_3_bias_,
Expand Down
44 changes: 30 additions & 14 deletions lightllm/models/vit/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager


class ViTTransformerLayerInfer:
Expand Down Expand Up @@ -60,7 +61,9 @@ def tp_norm(self, input, weight):

def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
if layer_weight.norm_type == "rms_norm":
b = rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
b = rms_norm(
input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
)
else:
b = torch.nn.functional.layer_norm(
input,
Expand All @@ -73,7 +76,9 @@ def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten

def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
if layer_weight.norm_type == "rms_norm":
return rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
return rms_norm(
input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
)
else:
return torch.nn.functional.layer_norm(
input,
Expand All @@ -84,20 +89,28 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
)

def _qk_norm(self, q, k, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight)
k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight)
if self.tp_world_size_ > 1:
q_norm = self.tp_norm(q, layer_weight.q_norm_weight_.weight)
k_norm = self.tp_norm(k, layer_weight.k_norm_weight_.weight)
else:
q_norm = rms_norm(
q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
)
k_norm = rms_norm(
k, weight=layer_weight.k_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True
)
return q_norm, k_norm

def _get_qkv(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
batch_size = input.shape[0]
seq_len = input.shape[1]
qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False)
qkv = layer_weight.qkv_proj.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=True)
qkv = qkv.view(batch_size, seq_len, 3, -1, self.head_dim_)
q, k, v = qkv.unbind(2)
return q, k, v

def _context_attention_kernel(self, q, k, v) -> torch.Tensor:
out = torch.empty_like(q)
out = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
batch_size = q.shape[0]
seq_len = q.shape[1]
flash_attention_fwd(q, k, v, out)
Expand All @@ -107,30 +120,33 @@ def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor
batch_size = input.shape[0]
seq_len = input.shape[1]
o_tensor = layer_weight.o_proj.mm(
input.view(-1, self.tp_padding_head_num * self.head_dim_), use_custom_tensor_mananger=False
input.view(-1, self.tp_padding_head_num * self.head_dim_), use_custom_tensor_mananger=True
)
if layer_weight.use_ls:
o_tensor *= layer_weight.ls1
o_tensor.mul_(layer_weight.ls1)
return o_tensor.reshape((batch_size, seq_len, -1))

def _ffn(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False)
# ffn1_out = torch.nn.functional.gelu(fc1)
ffn1_out = gelu_fwd(fc1)
fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=True)
input_shape = input.shape
input = None
ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out, use_custom_tensor_mananger=False)
if layer_weight.use_ls:
ffn2_out *= layer_weight.ls2
ffn1_out = gelu_fwd(fc1, use_custom_tensor_mananger=True)
ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out, use_custom_tensor_mananger=True)
ffn1_out = None
if layer_weight.use_ls:
ffn2_out.mul_(layer_weight.ls2)
return ffn2_out.reshape(input_shape)

def _context_attention(self, input_embding, layer_weight):
input1 = self._att_norm(input_embding, layer_weight)
q, k, v = self._get_qkv(input1, layer_weight)
input1 = None
if layer_weight.qk_norm:
q, k = self._qk_norm(q, k, layer_weight)
o = self._context_attention_kernel(q, k, v)
q = None
k = None
v = None
o = self._get_o(o, layer_weight)
if self.tp_world_size_ > 1:
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
Expand Down
4 changes: 4 additions & 0 deletions lightllm/models/vit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from rpyc.utils.classic import obtain
from lightllm.common.quantization import Quantcfg
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager


logger = init_logger(__name__)
Expand Down Expand Up @@ -128,11 +129,14 @@ def _init_datatype(self):
else:
raise ValueError(f"Unsupport datatype {self.data_type}!")

@torch.no_grad()
def forward(self, pixel_values):
g_cache_manager.cache_env_in()
input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight)
for i in range(self.layers_num + self.select_layer + 1):
input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i])
input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight)
g_cache_manager.cache_env_out()
return input_embs

@torch.no_grad()
Expand Down
11 changes: 9 additions & 2 deletions lightllm/models/vit/triton_kernel/gelu_vit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import triton
import triton.language as tl
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager


@triton.jit
Expand All @@ -21,8 +22,14 @@ def gelu_kernel(output_ptr, input_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
tl.store(output_ptr + offsets, output, mask=mask)


def gelu_fwd(input):
output = torch.empty_like(input)
def gelu_fwd(input, use_custom_tensor_mananger=False):
if use_custom_tensor_mananger:
shape = input.shape
dtype = input.dtype
device = input.device
output = g_cache_manager.alloc_tensor(shape, dtype, device=device)
else:
output = torch.empty_like(input)
n_elements = input.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
gelu_kernel[grid](output, input, n_elements, BLOCK_SIZE=1024)
Expand Down
16 changes: 11 additions & 5 deletions lightllm/models/vit/triton_kernel/rms_norm_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import triton
import triton.language as tl
from torch import Tensor
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager


@triton.jit
Expand Down Expand Up @@ -32,26 +33,31 @@ def rms_norm_kernel(
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)


def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-5):
def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-5, use_custom_tensor_mananger: bool = False):
"""Rms norm."""
feat_size = weight.shape[0]
seq_len = hidden_states.numel() // hidden_states.size(-1)
input_stride = hidden_states.stride(-2)

BLOCK_N = triton.next_power_of_2(feat_size)
out = torch.empty_like(hidden_states)
if use_custom_tensor_mananger:
shape = hidden_states.shape
dtype = hidden_states.dtype
device = hidden_states.device
output = g_cache_manager.alloc_tensor(shape, dtype, device=device)
else:
output = torch.empty_like(hidden_states)

grid = (seq_len,)
rms_norm_kernel[grid](
hidden_states,
weight,
out,
output,
input_row_stride=input_stride,
eps=eps,
N_COLS=feat_size,
BLOCK_N=BLOCK_N,
num_warps=4,
num_stages=3,
)

return out
return output