From 84221b0126c346d21ae422b7abdcb8dfc592914f Mon Sep 17 00:00:00 2001 From: vince62s Date: Mon, 11 Dec 2023 20:19:43 +0100 Subject: [PATCH 1/3] MoE for mixtral 8x7b --- onmt/decoders/transformer.py | 57 ++++++++++++++++++++++++++++-------- onmt/model_builder.py | 2 +- onmt/modules/bnb_linear.py | 22 +++++++++++++- onmt/opts.py | 15 ++++++++++ 4 files changed, 82 insertions(+), 14 deletions(-) diff --git a/onmt/decoders/transformer.py b/onmt/decoders/transformer.py index 05797d36c3..af40109efe 100644 --- a/onmt/decoders/transformer.py +++ b/onmt/decoders/transformer.py @@ -9,6 +9,7 @@ from onmt.modules import MultiHeadedAttention, AverageAttention from onmt.modules.position_ffn import PositionwiseFeedForward from onmt.modules.position_ffn import ActivationFunction +from onmt.modules.moe import MoE from onmt.utils.misc import sequence_mask try: @@ -43,6 +44,8 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + num_experts=0, + num_experts_per_tok=2, ): """ Args: @@ -109,18 +112,34 @@ def __init__( d_model, dropout=attention_dropout, aan_useffn=aan_useffn ) - self.feed_forward = PositionwiseFeedForward( - d_model, - d_ff, - dropout, - pos_ffn_activation_fn, - add_ffnbias, - parallel_residual, - layer_norm, - norm_eps, - use_ckpting=use_ckpting, - parallel_gpu=parallel_gpu, - ) + if num_experts > 0: + self.feed_forward = MoE( + num_experts, + num_experts_per_tok, + d_model, + d_ff, + dropout, + pos_ffn_activation_fn, + add_ffnbias, + parallel_residual, + layer_norm, + norm_eps, + use_ckpting=use_ckpting, + parallel_gpu=parallel_gpu, + ) + else: + self.feed_forward = PositionwiseFeedForward( + d_model, + d_ff, + dropout, + pos_ffn_activation_fn, + add_ffnbias, + parallel_residual, + layer_norm, + norm_eps, + use_ckpting=use_ckpting, + parallel_gpu=parallel_gpu, + ) self.parallel_residual = parallel_residual self.shared_layer_norm = shared_layer_norm if layer_norm == "standard": @@ -260,6 +279,8 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + num_experts=0, + num_experts_per_tok=2, ): """ Args: @@ -289,6 +310,8 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, ) self.context_attn = MultiHeadedAttention( heads, @@ -448,6 +471,8 @@ def from_opt(cls, opt, embeddings): else 1, sliding_window=opt.sliding_window, rotary_interleave=opt.rotary_interleave, + num_experts=opt.num_experts, + num_experts_per_tok=opt.num_experts_per_tok, ) def init_state(self, src, enc_out, enc_final_hs): @@ -556,6 +581,8 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + num_experts=0, + num_experts_per_tok=2, ): super(TransformerDecoder, self).__init__( d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps @@ -587,6 +614,8 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, ) for i in range(num_layers) ] @@ -823,6 +852,8 @@ def __init__( parallel_gpu=1, sliding_window=0, rotary_interleave=True, + num_experts=0, + num_experts_per_tok=2, ): super(TransformerLMDecoder, self).__init__( d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps @@ -853,6 +884,8 @@ def __init__( parallel_gpu=parallel_gpu, sliding_window=sliding_window, rotary_interleave=rotary_interleave, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, ) for i in range(num_layers) ] diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 37831c50d1..d2de64ef33 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -313,7 +313,7 @@ def build_base_model(model_opt, vocabs): ] if hasattr(model_opt, "quant_layers") and len(nonlora_to_quant) > 0: - if model_opt.quant_type in ["bnb_8bit", "bnb_FP4", "bnb_NF4"]: + if model_opt.quant_type in ["bnb_8bit", "bnb_FP4", "bnb_NF4", "bnb_sparse"]: logger.info( "%s compression of layer %s" % (model_opt.quant_type, nonlora_to_quant) ) diff --git a/onmt/modules/bnb_linear.py b/onmt/modules/bnb_linear.py index 342d470d2e..15ec7008dc 100644 --- a/onmt/modules/bnb_linear.py +++ b/onmt/modules/bnb_linear.py @@ -7,7 +7,14 @@ try: os.environ["BITSANDBYTES_NOWELCOME"] = "1" from bitsandbytes import MatmulLtState - from bitsandbytes.nn import Linear4bit, Linear8bitLt, Params4bit, Int8Params + from bitsandbytes.nn import ( + Linear4bit, + Linear8bitLt, + Params4bit, + Int8Params, + LinearSparse, + ParamsSparse, + ) except ImportError: raise ImportError("Install bitsandbytes to use 4/8bit compression") @@ -64,4 +71,17 @@ def replace_bnb_linear( quant_type=q_type[-3:].lower(), ) model._modules[name].compute_dtype = compute_dtype + elif q_type in ["bnb_sparse"]: + model._modules[name] = nn.utils.skip_init( + LinearSparse, + module.in_features, + module.out_features, + module.bias is not None, + device=torch.device("cpu"), + ) + model._modules[name].weight = ParamsSparse( + model._modules[name].weight.data, + requires_grad=False, + sparsity_level=0.98, + ) return model diff --git a/onmt/opts.py b/onmt/opts.py index 861e96172e..05d17607bf 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -894,6 +894,20 @@ def model_opts(parser): default=2048, help="Size of hidden transformer feed-forward", ) + group.add( + "--num_experts", + "-num_experts", + type=int, + default=0, + help="Number of experts", + ) + group.add( + "--num_experts_per_tok", + "-num_experts_per_tok", + type=int, + default=2, + help="Number of experts per token", + ) group.add( "--aan_useffn", "-aan_useffn", @@ -1570,6 +1584,7 @@ def _add_quant_opts(parser): "bnb_8bit", "bnb_FP4", "bnb_NF4", + "bnb_sparse", "llm_awq", "aawq_gemm", "aawq_gemv", From 54641186689cbca4aafa5935cdb0b9cc2f815077 Mon Sep 17 00:00:00 2001 From: vince62s Date: Mon, 11 Dec 2023 20:25:45 +0100 Subject: [PATCH 2/3] add missing file --- onmt/modules/moe.py | 63 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 onmt/modules/moe.py diff --git a/onmt/modules/moe.py b/onmt/modules/moe.py new file mode 100644 index 0000000000..2e1c959636 --- /dev/null +++ b/onmt/modules/moe.py @@ -0,0 +1,63 @@ +"""MoE mixture of experts".""" +import torch +import torch.nn as nn +from onmt.modules.position_ffn import PositionwiseFeedForward + + +class MoE(nn.Module): + def __init__( + self, + num_experts, + num_experts_per_tok, + d_model, + d_ff, + dropout, + pos_ffn_activation_fn, + add_ffnbias, + parallel_residual, + layer_norm, + norm_eps, + use_ckpting=[], + parallel_gpu=1, + ): + super().__init__() + self.experts = nn.ModuleList( + [ + PositionwiseFeedForward( + d_model, + d_ff, + dropout, + pos_ffn_activation_fn, + add_ffnbias, + parallel_residual, + layer_norm, + norm_eps, + use_ckpting=use_ckpting, + parallel_gpu=parallel_gpu, + ) + for i in range(num_experts) + ] + ) + self.gate = nn.Linear(d_model, num_experts, bias=False) + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, x): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + + scores = self.gate(x) + expert_weights, expert_indices = torch.topk( + scores, self.num_experts_per_tok, dim=-1 + ) + expert_weights = expert_weights.softmax(dim=-1) + flat_expert_indices = expert_indices.view(-1) + + x = x.repeat_interleave(self.num_experts_per_tok, dim=0) + y = torch.empty_like(x) + for i, expert in enumerate(self.experts): + if torch.any(flat_expert_indices == i): + y[flat_expert_indices == i] = expert(x[flat_expert_indices == i]) + y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum( + dim=1 + ) + return y.view(*orig_shape) From 9f79f4e05642548f8dc42cd204ea839e0b70b9dc Mon Sep 17 00:00:00 2001 From: vince62s Date: Tue, 19 Dec 2023 15:41:21 +0100 Subject: [PATCH 3/3] removing bnb_sparse for now --- onmt/model_builder.py | 2 +- onmt/modules/bnb_linear.py | 15 --------------- onmt/opts.py | 1 - 3 files changed, 1 insertion(+), 17 deletions(-) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index d2de64ef33..37831c50d1 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -313,7 +313,7 @@ def build_base_model(model_opt, vocabs): ] if hasattr(model_opt, "quant_layers") and len(nonlora_to_quant) > 0: - if model_opt.quant_type in ["bnb_8bit", "bnb_FP4", "bnb_NF4", "bnb_sparse"]: + if model_opt.quant_type in ["bnb_8bit", "bnb_FP4", "bnb_NF4"]: logger.info( "%s compression of layer %s" % (model_opt.quant_type, nonlora_to_quant) ) diff --git a/onmt/modules/bnb_linear.py b/onmt/modules/bnb_linear.py index 15ec7008dc..bcf61fdb31 100644 --- a/onmt/modules/bnb_linear.py +++ b/onmt/modules/bnb_linear.py @@ -12,8 +12,6 @@ Linear8bitLt, Params4bit, Int8Params, - LinearSparse, - ParamsSparse, ) except ImportError: raise ImportError("Install bitsandbytes to use 4/8bit compression") @@ -71,17 +69,4 @@ def replace_bnb_linear( quant_type=q_type[-3:].lower(), ) model._modules[name].compute_dtype = compute_dtype - elif q_type in ["bnb_sparse"]: - model._modules[name] = nn.utils.skip_init( - LinearSparse, - module.in_features, - module.out_features, - module.bias is not None, - device=torch.device("cpu"), - ) - model._modules[name].weight = ParamsSparse( - model._modules[name].weight.data, - requires_grad=False, - sparsity_level=0.98, - ) return model diff --git a/onmt/opts.py b/onmt/opts.py index 05d17607bf..ccfdca6766 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -1584,7 +1584,6 @@ def _add_quant_opts(parser): "bnb_8bit", "bnb_FP4", "bnb_NF4", - "bnb_sparse", "llm_awq", "aawq_gemm", "aawq_gemv",