Skip to content

Commit

Permalink
Replace with SiLU
Browse files Browse the repository at this point in the history
  • Loading branch information
yixqiao committed Feb 2, 2025
1 parent a75217f commit d274643
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
hidden_size,
intermediate_size,
bias = True,
act_layer: Type[nn.Module] = QuickGELU,
act_layer: Type[nn.Module] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(
prefix=f"{prefix}.down_proj"
)

self.act = act_layer()
self.act = act_layer

def forward(self, x):
gate_out, _ = self.gate_proj(x)
Expand Down Expand Up @@ -276,11 +276,10 @@ def forward(
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)

if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
from vllm_flash_attn.flash_attn_interface import (
flash_attn_varlen_func)
# from flash_attn import flash_attn_varlen_func

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

Expand Down Expand Up @@ -338,7 +337,7 @@ def __init__(
num_heads: int,
hidden_size: int,
intermediate_size: int,
act_layer: Type[nn.Module] = QuickGELU,
act_layer: Type[nn.Module] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
Expand Down Expand Up @@ -421,7 +420,7 @@ def __init__(
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp.0"),
nn.GELU(),
nn.SiLU(),
RowParallelLinear(self.hidden_size,
d_model,
bias=True,
Expand Down Expand Up @@ -1096,7 +1095,6 @@ def _parse_and_validate_image_input(
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")

return Qwen2_5_VLImagePixelInputs(type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
Expand Down

0 comments on commit d274643

Please sign in to comment.