From d274643cda896d3b5212c66c7c82e56a710c4285 Mon Sep 17 00:00:00 2001 From: Yixuan Qiao Date: Sat, 1 Feb 2025 23:22:23 -0800 Subject: [PATCH] Replace with SiLU --- vllm/model_executor/models/qwen2_5_vl.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index be21fe2ef37b6..a6792e36ecf13 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -113,7 +113,7 @@ def __init__( hidden_size, intermediate_size, bias = True, - act_layer: Type[nn.Module] = QuickGELU, + act_layer: Type[nn.Module] = F.silu, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -143,7 +143,7 @@ def __init__( prefix=f"{prefix}.down_proj" ) - self.act = act_layer() + self.act = act_layer def forward(self, x): gate_out, _ = self.gate_proj(x) @@ -276,11 +276,10 @@ def forward( if rotary_pos_emb is not None: q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - if self.attn_backend == _Backend.FLASH_ATTN: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + from vllm_flash_attn.flash_attn_interface import ( + flash_attn_varlen_func) + # from flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -338,7 +337,7 @@ def __init__( num_heads: int, hidden_size: int, intermediate_size: int, - act_layer: Type[nn.Module] = QuickGELU, + act_layer: Type[nn.Module] = F.silu, norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -421,7 +420,7 @@ def __init__( bias=True, quant_config=quant_config, prefix=f"{prefix}.mlp.0"), - nn.GELU(), + nn.SiLU(), RowParallelLinear(self.hidden_size, d_model, bias=True, @@ -1096,7 +1095,6 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of image pixel values. " f"Got type: {type(pixel_values)}") - return Qwen2_5_VLImagePixelInputs(type="pixel_values", pixel_values=pixel_values, image_grid_thw=image_grid_thw)