Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Feb 19, 2024
1 parent 9d28530 commit 65e6904
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions examples/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def forward_once(self, x, expert_weights, top_experts): # TODO: sparse
with torch.no_grad():
topo = self.topology(x, padded_bins)

x = self.mlp(x, topo) # TODO: exp_pg=1 and num_experts=2 means the experts will get same data.
x = self.mlp(x, topo)

# Un-route the data for the MoE output.
x = ops.padded_scatter(
Expand Down Expand Up @@ -429,7 +429,6 @@ def __init__(
)

if self.tp_pg.size() == 1:
# transpose self.w1.module.weight
self.w1.module.weight.data = self.w1.module.weight.data.T.contiguous()

# TODO @nouamane: jit
Expand All @@ -438,7 +437,6 @@ def __init__(
self.dsd = partial(wp.dsd_nn, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.dsd

def forward(self, x, topo):
# Compute the MLP.
self.w1.scale_gradients(), self.w2.scale_gradients()
x = self.sdd(x.contiguous(), self.w1.module.weight, topo)
activation_fn_out = act_fn(x, self.act)
Expand Down

0 comments on commit 65e6904

Please sign in to comment.