Skip to content

Commit

Permalink
Router zloss
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Sep 8, 2024
1 parent 66d7894 commit abc0638
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
4 changes: 4 additions & 0 deletions megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class Arguments:
int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)

# Router Z-loss arguments
moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
moe_zloss_in_fp32 : bool = False

def __post_init__(self):
if self.__getattribute__('mlp_impl') == 'grouped':
grouped_gemm.assert_grouped_gemm_is_available()
Expand Down
36 changes: 35 additions & 1 deletion megablocks/layers/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@
from megablocks.layers import common
from megablocks.layers.arguments import Arguments

_ROUTER_LOGITS = []

def _save_router_logits(logits: torch.Tensor, args: Arguments):
if args.moe_zloss_weight == 0:
return
global _ROUTER_LOGITS
_ROUTER_LOGITS.append(logits)

def clear_router_zloss():
global _ROUTER_LOGITS
_ROUTER_LOGITS.clear()

def batched_router_zloss(args : Arguments):
global _ROUTER_LOGITS

if args.moe_zloss_weight == 0:
import warnings
warnings.warn("Call to batched_router_zloss, but moe_zloss_weight=0")
return 0

logits_per_router = _ROUTER_LOGITS

if args.moe_zloss_in_fp32:
logits_per_router = [logits.float() for logits in logits_per_router]

unscaled_zloss_per_router = torch.stack([
torch.logsumexp(logits, dim=1).square().mean()
for logits in logits_per_router
])

return args.moe_zloss_weight * unscaled_zloss_per_router


# NOTE: To enable end-to-end benchmarking without convergence we
# support a flag to force the router to assign tokens uniformly
Expand Down Expand Up @@ -60,7 +92,9 @@ def forward(self, x: torch.Tensor):
if self.training and self.args.moe_jitter_eps is not None:
x = x * self.jitter(x)

scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
logits = self.layer(x.view(-1, x.shape[-1]))
_save_router_logits(logits, self.args)
scores = logits.softmax(dim=-1)
expert_weights, expert_indices = self._top_k(scores)
if self.args.moe_normalize_expert_weights:
expert_weights = expert_weights / torch.norm(
Expand Down

0 comments on commit abc0638

Please sign in to comment.