Skip to content

Commit

Permalink
[CODE QUALITY]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 22, 2024
1 parent ea25ce9 commit 425d5f4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion swarms_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Particle,
TransformerParticleSwarmOptimization,
)
from swarms_torch.structs import * # noqa
from swarms_torch.structs import * # noqa

__all__ = [
"ParticleSwarmOptimization",
Expand Down
2 changes: 1 addition & 1 deletion swarms_torch/structs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
"SwitchMoE",
"GatingMechanism",
"SimpleMoE",
]
]
19 changes: 15 additions & 4 deletions swarms_torch/structs/mixtral_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F


class SwiGLU(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(SwiGLU, self).__init__()
Expand All @@ -11,6 +12,7 @@ def __init__(self, input_dim, hidden_dim, output_dim):
def forward(self, x):
return self.fc2(F.silu(self.fc1(x)))


class TopKGate(nn.Module):
def __init__(self, model_dim, num_experts, top_k):
super(TopKGate, self).__init__()
Expand All @@ -20,22 +22,31 @@ def __init__(self, model_dim, num_experts, top_k):
def forward(self, x):
gate_logits = self.w_gate(x)
top_logits, top_indices = torch.topk(gate_logits, self.top_k, dim=-1)
top_k_logits = torch.full_like(gate_logits, float('-inf'))
top_k_logits = torch.full_like(gate_logits, float("-inf"))
top_k_logits.scatter_(1, top_indices, top_logits)
return F.softmax(top_k_logits, dim=-1)


class MoE(nn.Module):
def __init__(self, model_dim, hidden_dim, num_experts, top_k):
super(MoE, self).__init__()
self.experts = nn.ModuleList([SwiGLU(model_dim, hidden_dim, model_dim) for _ in range(num_experts)])
self.experts = nn.ModuleList(
[
SwiGLU(model_dim, hidden_dim, model_dim)
for _ in range(num_experts)
]
)
self.gate = TopKGate(model_dim, num_experts, top_k)

def forward(self, x):
gate_scores = self.gate(x)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)
expert_outputs = torch.stack(
[expert(x) for expert in self.experts], dim=2
)
weighted_expert_outputs = gate_scores.unsqueeze(-1) * expert_outputs
return weighted_expert_outputs.sum(dim=2)


# Model architecture parameters
model_dim = 4096
n_layers = 32
Expand All @@ -56,4 +67,4 @@ def forward(self, x):
# Forward pass through the MoE layer
output = moe_layer(x)

print(output)
print(output)

0 comments on commit 425d5f4

Please sign in to comment.