Skip to content

Commit

Permalink
[FEATS][Depth] [add_norm]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 24, 2024
1 parent ce86f81 commit 36a1ea0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 31 deletions.
2 changes: 1 addition & 1 deletion example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from switch_transformers import SwitchTransformer
from switch_transformers.model import SwitchTransformer

# Generate a random tensor of shape (1, 10) with values between 0 and 100
x = torch.randint(0, 100, (1, 10))
Expand Down
74 changes: 44 additions & 30 deletions switch_transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ def __init__(
self.ffn = SwitchMoE(
dim, dim * mult, dim, num_experts, *args, **kwargs
)

self.add_norm = nn.LayerNorm(dim)


def forward(self, x: Tensor):
"""
Expand All @@ -251,12 +254,31 @@ def forward(self, x: Tensor):
resi = x
x, _, _ = self.attn(x)
x = x + resi
x = self.add_norm(x)
add_normed = x

##### MoE #####
x, _ = self.ffn(x)
x = x + resi
x = x + add_normed
x = self.add_norm(x)
return x


class SwitchTransformer(nn.Module):
"""
SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts.
Args:
num_tokens (int): The number of tokens in the input vocabulary.
dim (int): The dimensionality of the token embeddings and hidden states.
heads (int): The number of attention heads.
dim_head (int, optional): The dimensionality of each attention head. Defaults to 64.
mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.
dropout (float, optional): The dropout rate. Defaults to 0.1.
num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
num_tokens: int,
Expand All @@ -266,23 +288,10 @@ def __init__(
mult: int = 4,
dropout: float = 0.1,
num_experts: int = 3,
depth: int = 4,
*args,
**kwargs,
):
"""
SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts.
Args:
num_tokens (int): The number of tokens in the input vocabulary.
dim (int): The dimensionality of the token embeddings and hidden states.
heads (int): The number of attention heads.
dim_head (int, optional): The dimensionality of each attention head. Defaults to 64.
mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4.
dropout (float, optional): The dropout rate. Defaults to 0.1.
num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
super().__init__()
self.num_tokens = num_tokens
self.dim = dim
Expand All @@ -291,19 +300,24 @@ def __init__(
self.mult = mult
self.dropout = dropout
self.num_experts = num_experts
self.depth = depth

self.embedding = nn.Embedding(num_tokens, dim)

self.block = SwitchTransformerBlock(
dim,
heads,
dim_head,
mult,
dropout,
num_experts,
*args,
**kwargs,
)
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(
SwitchTransformerBlock(
dim,
heads,
dim_head,
mult,
dropout,
num_experts,
*args,
**kwargs,
)
)

self.to_out = nn.Sequential(
nn.Softmax(dim=-1),
Expand All @@ -323,11 +337,11 @@ def forward(self, x: Tensor) -> Tensor:
"""
# Embed tokens through embedding layer
x = self.embedding(x)
# Pass through the transformer block with MoE
x = self.block(x)

# Pass through the transformer block with MoE, it's in modulelist
for layer in self.layers:
x = layer(x)

# Project to output tokens
x = self.to_out(x)
return x


0 comments on commit 36a1ea0

Please sign in to comment.