Skip to content

Commit

Permalink
fix internlm1 ci
Browse files Browse the repository at this point in the history
  • Loading branch information
yingtongxiong committed Jan 20, 2025
1 parent 9e65623 commit 295fba7
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class InternLM1Decoder(nn.Module):
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization.
multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization.
enable_qkv_fusion(bool): Whether to fuse Wq,Wk,Wv computation. True by default.
"""

def __init__(
Expand All @@ -91,6 +92,7 @@ def __init__(
rope_base: int = 10000,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
enable_qkv_fusion: bool = True,
):
super().__init__()
self.checkpoint = checkpoint
Expand All @@ -116,7 +118,7 @@ def __init__(
device=device,
dtype=dtype,
qk_interleaved=qk_interleaved,
enable_qkv_fusion=True,
enable_qkv_fusion=enable_qkv_fusion,
)

# Compatible with the name of internlm1 Wqkv linear layer
Expand Down Expand Up @@ -261,6 +263,7 @@ class InternLM1(BaseModel):
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization.
multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization.
enable_qkv_fusion(bool): Whether to fuse Wq,Wk,Wv computation. True by default.
"""

def __init__(
Expand Down Expand Up @@ -293,6 +296,7 @@ def __init__(
rope_base: int = 10000,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
enable_qkv_fusion: bool = True,
):
super().__init__()

Expand Down Expand Up @@ -329,6 +333,7 @@ def __init__(
qk_interleaved=qk_interleaved,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
enable_qkv_fusion=enable_qkv_fusion,
)
for lid in range(num_layers)
]
Expand Down

0 comments on commit 295fba7

Please sign in to comment.