Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MoE for mixtral 8x7b #2535

Merged
merged 3 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from onmt.modules import MultiHeadedAttention, AverageAttention
from onmt.modules.position_ffn import PositionwiseFeedForward
from onmt.modules.position_ffn import ActivationFunction
from onmt.modules.moe import MoE
from onmt.utils.misc import sequence_mask

try:
Expand Down Expand Up @@ -43,6 +44,8 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
num_experts=0,
num_experts_per_tok=2,
):
"""
Args:
Expand Down Expand Up @@ -109,18 +112,34 @@ def __init__(
d_model, dropout=attention_dropout, aan_useffn=aan_useffn
)

self.feed_forward = PositionwiseFeedForward(
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
if num_experts > 0:
self.feed_forward = MoE(
num_experts,
num_experts_per_tok,
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
else:
self.feed_forward = PositionwiseFeedForward(
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
self.parallel_residual = parallel_residual
self.shared_layer_norm = shared_layer_norm
if layer_norm == "standard":
Expand Down Expand Up @@ -260,6 +279,8 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
num_experts=0,
num_experts_per_tok=2,
):
"""
Args:
Expand Down Expand Up @@ -289,6 +310,8 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
self.context_attn = MultiHeadedAttention(
heads,
Expand Down Expand Up @@ -448,6 +471,8 @@ def from_opt(cls, opt, embeddings):
else 1,
sliding_window=opt.sliding_window,
rotary_interleave=opt.rotary_interleave,
num_experts=opt.num_experts,
num_experts_per_tok=opt.num_experts_per_tok,
)

def init_state(self, src, enc_out, enc_final_hs):
Expand Down Expand Up @@ -556,6 +581,8 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
num_experts=0,
num_experts_per_tok=2,
):
super(TransformerDecoder, self).__init__(
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
Expand Down Expand Up @@ -587,6 +614,8 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -823,6 +852,8 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
num_experts=0,
num_experts_per_tok=2,
):
super(TransformerLMDecoder, self).__init__(
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
Expand Down Expand Up @@ -853,6 +884,8 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
for i in range(num_layers)
]
Expand Down
7 changes: 6 additions & 1 deletion onmt/modules/bnb_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
try:
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from bitsandbytes import MatmulLtState
from bitsandbytes.nn import Linear4bit, Linear8bitLt, Params4bit, Int8Params
from bitsandbytes.nn import (
Linear4bit,
Linear8bitLt,
Params4bit,
Int8Params,
)
except ImportError:
raise ImportError("Install bitsandbytes to use 4/8bit compression")

Expand Down
63 changes: 63 additions & 0 deletions onmt/modules/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""MoE mixture of experts"."""
import torch
import torch.nn as nn
from onmt.modules.position_ffn import PositionwiseFeedForward


class MoE(nn.Module):
def __init__(
self,
num_experts,
num_experts_per_tok,
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=[],
parallel_gpu=1,
):
super().__init__()
self.experts = nn.ModuleList(
[
PositionwiseFeedForward(
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
for i in range(num_experts)
]
)
self.gate = nn.Linear(d_model, num_experts, bias=False)
self.num_experts_per_tok = num_experts_per_tok

def forward(self, x):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])

scores = self.gate(x)
expert_weights, expert_indices = torch.topk(
scores, self.num_experts_per_tok, dim=-1
)
expert_weights = expert_weights.softmax(dim=-1)
flat_expert_indices = expert_indices.view(-1)

x = x.repeat_interleave(self.num_experts_per_tok, dim=0)
y = torch.empty_like(x)
for i, expert in enumerate(self.experts):
if torch.any(flat_expert_indices == i):
y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(
dim=1
)
return y.view(*orig_shape)
14 changes: 14 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,20 @@ def model_opts(parser):
default=2048,
help="Size of hidden transformer feed-forward",
)
group.add(
"--num_experts",
"-num_experts",
type=int,
default=0,
help="Number of experts",
)
group.add(
"--num_experts_per_tok",
"-num_experts_per_tok",
type=int,
default=2,
help="Number of experts per token",
)
group.add(
"--aan_useffn",
"-aan_useffn",
Expand Down
Loading