From edc2ac995763abefcf46f297253ffbab4b0f8281 Mon Sep 17 00:00:00 2001 From: Mike Lee Date: Thu, 8 Aug 2024 17:50:06 +0800 Subject: [PATCH] finish loramoe support --- mlora/model.py | 26 +++++++++++++++++++++++++- mlora/modules/__init__.py | 2 ++ mlora/modules/config.py | 31 ++++++++++++++++++++----------- mlora/modules/lora_moes.py | 13 ++++++++----- mlora/trainer.py | 9 +++++++++ templates/mixlora.json | 2 +- templates/mixlora_glm.json | 2 +- templates/mixlora_phi.json | 2 +- templates/mixlora_phi3.json | 2 +- 9 files changed, 68 insertions(+), 21 deletions(-) diff --git a/mlora/model.py b/mlora/model.py index 2a25ff8e..5f90072b 100644 --- a/mlora/model.py +++ b/mlora/model.py @@ -26,6 +26,7 @@ LoraMoeConfig, MixLoraConfig, lora_config_factory, + moe_layer_dict, moe_layer_factory, router_loss_factory, ) @@ -190,6 +191,9 @@ def init_lora_layer_weight( transformer_layer.mlp_.moes_[lora_config.adapter_name] = moe_layer_factory( llm_config, lora_config ) + moe_initializer = moe_layer_dict[ + lora_config.routing_strategy_ + ].adapter_initializer else: model_prefix_name = "base_model.model.model" moe_layer_name_list = [] @@ -208,6 +212,26 @@ def init_lora_layer_weight( if moe_initializer is not None: # init for gating mechanisms moe_initializer(llm_config, lora_config, proj_weight) + if ( + hasattr(proj_weight, "_moe_gates") + and lora_config.adapter_name in proj_weight._moe_gates + ): + gate_weight = ( + lora_weights.get(f"{module_name}.moe_gate.weight", None) + if lora_weights is not None + else None + ) + if gate_weight is None: + torch.nn.init.normal_( + proj_weight._moe_gates[lora_config.adapter_name].weight, + mean=0.0, + std=lora_config.router_init_range_, + ) + else: + with torch.no_grad(): + proj_weight._moe_gates[ + lora_config.adapter_name + ].weight.copy_(gate_weight) for expert_idx in range(lora_config.num_experts_): if lora_weights is None: @@ -559,7 +583,7 @@ def get_adapter_weight_dict(self, adapter_name: str) -> Dict[str, torch.Tensor]: hasattr(proj_weight, "_moe_gates") and adapter_name in proj_weight._moe_gates ): - lora_weight_dict[f"{module_name}.mlp.moe_gate.weight"] = ( + lora_weight_dict[f"{module_name}.moe_gate.weight"] = ( proj_weight._moe_gates[adapter_name].weight ) diff --git a/mlora/modules/__init__.py b/mlora/modules/__init__.py index a06697c1..df9926ad 100644 --- a/mlora/modules/__init__.py +++ b/mlora/modules/__init__.py @@ -50,6 +50,7 @@ # MixLoRA MoEs from .lora_moes import ( + LoraMoe, MixtralRouterLoss, MixtralSparseMoe, SwitchRouterLoss, @@ -83,6 +84,7 @@ "MixtralSparseMoe", "SwitchRouterLoss", "SwitchSparseMoe", + "LoraMoe", "router_loss_dict", "moe_layer_dict", "router_loss_factory", diff --git a/mlora/modules/config.py b/mlora/modules/config.py index ac65e477..5b69b86f 100644 --- a/mlora/modules/config.py +++ b/mlora/modules/config.py @@ -196,7 +196,7 @@ def export(self) -> Dict[str, any]: return config -available_routing_strategies = ["mixtral", "switch"] +available_routing_strategies = ["mixlora", "mixlora-switch"] @dataclass @@ -240,9 +240,9 @@ def check(self) -> "MixLoraConfig": assert self.act_fn_ is None or ( isinstance(self.act_fn_, str) and self.act_fn_ in ACT2FN ) - if self.routing_strategy_ == "mixtral": + if self.routing_strategy_ == "mixlora": assert isinstance(self.top_k_, int) and self.top_k_ > 0 - elif self.routing_strategy_ == "switch": + elif self.routing_strategy_ == "mixlora-switch": assert ( isinstance(self.router_z_loss_coef_, float) and self.router_z_loss_coef_ >= 0 @@ -270,11 +270,11 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig": # silu for mixtral or gelu_new for switch transformers # left blank to automatically use the original act_fn of FFN lora_config.act_fn_ = config.get("act_fn", None) - if lora_config.routing_strategy_ == "mixtral": + if lora_config.routing_strategy_ == "mixlora": lora_config.router_init_range_ = config.get("router_init_range", 0.02) lora_config.jitter_noise_ = config.get("jitter_noise", 0.0) lora_config.top_k_ = config.get("top_k", 2) - elif lora_config.routing_strategy_ == "switch": + elif lora_config.routing_strategy_ == "mixlora-switch": lora_config.router_init_range_ = config.get("router_init_range", 1.0) lora_config.jitter_noise_ = config.get("jitter_noise", 0.01) lora_config.router_z_loss_coef_ = config.get( @@ -300,9 +300,9 @@ def export(self) -> Dict[str, any]: config["num_experts"] = self.num_experts_ if self.act_fn_ is not None: config["act_fn"] = self.act_fn_ - if self.routing_strategy_ == "mixtral": + if self.routing_strategy_ == "mixlora": config["top_k"] = self.top_k_ - elif self.routing_strategy_ == "switch": + elif self.routing_strategy_ == "mixlora-switch": config["expert_capacity"] = self.expert_capacity_ config["sparse_step"] = self.sparse_step_ @@ -322,6 +322,7 @@ class LoraMoeConfig(LoraConfig): blc_alpha_: float = None blc_weight_: float = None num_experts_: int = None + router_init_range_: float = None routing_strategy_: str = "loramoe" def check(self) -> "LoraMoeConfig": @@ -329,6 +330,11 @@ def check(self) -> "LoraMoeConfig": assert isinstance(self.blc_alpha_, float) and self.blc_alpha_ >= 0.0 assert isinstance(self.blc_weight_, float) and self.blc_weight_ >= 0.0 assert isinstance(self.num_experts_, int) and self.num_experts_ > 0 + assert ( + isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0 + ) + + return self @staticmethod def from_config(config: Dict[str, any]) -> "LoraMoeConfig": @@ -336,6 +342,7 @@ def from_config(config: Dict[str, any]) -> "LoraMoeConfig": blc_alpha_=config.get("blc_alpha", 0.0), blc_weight_=config.get("blc_weight", 0.0), num_experts_=config["num_experts"], + router_init_range_=config.get("router_init_range", 0.02), **LoraConfig.from_config(config).__dict__, ) @@ -355,11 +362,13 @@ def expert_config(self, expert_idx: int) -> LoraConfig: def lora_config_factory(config: Dict[str, any]) -> LoraConfig: - if ( - "peft_type" in config and config["peft_type"] == "MIXLORA" - ) or "routing_strategy" in config: + if ("peft_type" in config and config["peft_type"] == "MIXLORA") or ( + config.get("routing_strategy", "") in ["mixlora", "mixlora-switch"] + ): return MixLoraConfig.from_config(config).check() - if "peft_type" in config and config["peft_type"] == "LORAMOE": + if ("peft_type" in config and config["peft_type"] == "LORAMOE") or ( + config.get("routing_strategy", "") == "loramoe" + ): return LoraMoeConfig.from_config(config).check() else: return LoraConfig.from_config(config).check() diff --git a/mlora/modules/lora_moes.py b/mlora/modules/lora_moes.py index b2ef6ab1..ca8f35dd 100644 --- a/mlora/modules/lora_moes.py +++ b/mlora/modules/lora_moes.py @@ -69,10 +69,13 @@ def adapter_initializer( linear.selective_hook_[adapter_config.adapter_name] = LoraMoe.selective_hook def forward(self, mlp: LLMFeedForward, hidden_states: torch.Tensor) -> Tuple: - return mlp._selective_forward(hidden_states, self.adapter_name_, moe_layer=self) + return ( + mlp._selective_forward(hidden_states, self.adapter_name_, moe_layer=self), + None, + ) -router_loss_dict = {"mixtral": MixtralRouterLoss, "switch": SwitchRouterLoss} +router_loss_dict = {"mixlora": MixtralRouterLoss, "mixlora-switch": SwitchRouterLoss} def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module: @@ -85,13 +88,13 @@ def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module: moe_layer_dict = { - "mixtral": MixtralSparseMoe, - "switch": SwitchSparseMoe, + "mixlora": MixtralSparseMoe, + "mixlora-switch": SwitchSparseMoe, "loramoe": LoraMoe, } def moe_layer_factory(args: LLMModelConfig, config: MixLoraConfig) -> torch.nn.Module: - if config.routing_strategy_ not in router_loss_dict: + if config.routing_strategy_ not in moe_layer_dict: raise ValueError(f"Unknown routing strategy {config.routing_strategy_}") return moe_layer_dict[config.routing_strategy_](args, config) diff --git a/mlora/trainer.py b/mlora/trainer.py index d138ec0d..bf0d69a5 100644 --- a/mlora/trainer.py +++ b/mlora/trainer.py @@ -163,6 +163,15 @@ def prepare(self, train_params: Dict[str, torch.Tensor]): # preparing optimizer paramas_count = sum(t.numel() for t in train_params.values() if t.requires_grad) logging.info(f"{self.adapter_name} total trainable params: {paramas_count}") + paramas_count = sum( + t.numel() + for n, t in train_params.items() + if "moe_gate" not in n and t.requires_grad + ) + if paramas_count > 0: + logging.info( + f"{self.adapter_name} total trainable params (except gates): {paramas_count}" + ) grouped_parameters = self._optimizer_grouped_parameters(train_params) if self.optimizer_type == "sgd": self.optimizer_ = torch.optim.SGD( diff --git a/templates/mixlora.json b/templates/mixlora.json index 3b0d17f3..24a3bd36 100644 --- a/templates/mixlora.json +++ b/templates/mixlora.json @@ -28,7 +28,7 @@ "down_proj": true, "up_proj": true }, - "routing_strategy": "mixtral", + "routing_strategy": "mixlora", "num_experts": 8, "top_k": 2, "group_by_length": false diff --git a/templates/mixlora_glm.json b/templates/mixlora_glm.json index 8863de79..d1d93b17 100644 --- a/templates/mixlora_glm.json +++ b/templates/mixlora_glm.json @@ -25,7 +25,7 @@ "dense_h_to_4h": true, "dense_4h_to_h": true }, - "routing_strategy": "mixtral", + "routing_strategy": "mixlora", "num_experts": 8, "top_k": 2, "group_by_length": false diff --git a/templates/mixlora_phi.json b/templates/mixlora_phi.json index fa6bffe3..2eba4ec9 100644 --- a/templates/mixlora_phi.json +++ b/templates/mixlora_phi.json @@ -27,7 +27,7 @@ "fc1": true, "fc2": true }, - "routing_strategy": "mixtral", + "routing_strategy": "mixlora", "num_experts": 8, "top_k": 2, "group_by_length": false diff --git a/templates/mixlora_phi3.json b/templates/mixlora_phi3.json index 5f2361a1..1ba9ff94 100644 --- a/templates/mixlora_phi3.json +++ b/templates/mixlora_phi3.json @@ -25,7 +25,7 @@ "gate_up_proj": true, "down_proj": true }, - "routing_strategy": "mixtral", + "routing_strategy": "mixlora", "num_experts": 8, "top_k": 2, "group_by_length": false