From 36a1ea01448e56242222b68201207a7219d72b4b Mon Sep 17 00:00:00 2001 From: Kye Date: Wed, 24 Jan 2024 15:52:16 -0500 Subject: [PATCH] [FEATS][Depth] [add_norm] --- example.py | 2 +- switch_transformers/model.py | 74 +++++++++++++++++++++--------------- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/example.py b/example.py index a9482bc..89b6a0e 100644 --- a/example.py +++ b/example.py @@ -1,5 +1,5 @@ import torch -from switch_transformers import SwitchTransformer +from switch_transformers.model import SwitchTransformer # Generate a random tensor of shape (1, 10) with values between 0 and 100 x = torch.randint(0, 100, (1, 10)) diff --git a/switch_transformers/model.py b/switch_transformers/model.py index d98a604..8757ff4 100644 --- a/switch_transformers/model.py +++ b/switch_transformers/model.py @@ -236,6 +236,9 @@ def __init__( self.ffn = SwitchMoE( dim, dim * mult, dim, num_experts, *args, **kwargs ) + + self.add_norm = nn.LayerNorm(dim) + def forward(self, x: Tensor): """ @@ -251,12 +254,31 @@ def forward(self, x: Tensor): resi = x x, _, _ = self.attn(x) x = x + resi + x = self.add_norm(x) + add_normed = x + + ##### MoE ##### x, _ = self.ffn(x) - x = x + resi + x = x + add_normed + x = self.add_norm(x) return x class SwitchTransformer(nn.Module): + """ + SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts. + + Args: + num_tokens (int): The number of tokens in the input vocabulary. + dim (int): The dimensionality of the token embeddings and hidden states. + heads (int): The number of attention heads. + dim_head (int, optional): The dimensionality of each attention head. Defaults to 64. + mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4. + dropout (float, optional): The dropout rate. Defaults to 0.1. + num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ def __init__( self, num_tokens: int, @@ -266,23 +288,10 @@ def __init__( mult: int = 4, dropout: float = 0.1, num_experts: int = 3, + depth: int = 4, *args, **kwargs, ): - """ - SwitchTransformer is a PyTorch module that implements a transformer model with switchable experts. - - Args: - num_tokens (int): The number of tokens in the input vocabulary. - dim (int): The dimensionality of the token embeddings and hidden states. - heads (int): The number of attention heads. - dim_head (int, optional): The dimensionality of each attention head. Defaults to 64. - mult (int, optional): The multiplier for the hidden dimension in the feed-forward network. Defaults to 4. - dropout (float, optional): The dropout rate. Defaults to 0.1. - num_experts (int, optional): The number of experts in the switchable experts mechanism. Defaults to 3. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - """ super().__init__() self.num_tokens = num_tokens self.dim = dim @@ -291,19 +300,24 @@ def __init__( self.mult = mult self.dropout = dropout self.num_experts = num_experts + self.depth = depth self.embedding = nn.Embedding(num_tokens, dim) - - self.block = SwitchTransformerBlock( - dim, - heads, - dim_head, - mult, - dropout, - num_experts, - *args, - **kwargs, - ) + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append( + SwitchTransformerBlock( + dim, + heads, + dim_head, + mult, + dropout, + num_experts, + *args, + **kwargs, + ) + ) self.to_out = nn.Sequential( nn.Softmax(dim=-1), @@ -323,11 +337,11 @@ def forward(self, x: Tensor) -> Tensor: """ # Embed tokens through embedding layer x = self.embedding(x) - # Pass through the transformer block with MoE - x = self.block(x) + + # Pass through the transformer block with MoE, it's in modulelist + for layer in self.layers: + x = layer(x) # Project to output tokens x = self.to_out(x) return x - -