Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
finish loramoe support
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Aug 8, 2024
1 parent 513194d commit edc2ac9
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 21 deletions.
26 changes: 25 additions & 1 deletion mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
LoraMoeConfig,
MixLoraConfig,
lora_config_factory,
moe_layer_dict,
moe_layer_factory,
router_loss_factory,
)
Expand Down Expand Up @@ -190,6 +191,9 @@ def init_lora_layer_weight(
transformer_layer.mlp_.moes_[lora_config.adapter_name] = moe_layer_factory(
llm_config, lora_config
)
moe_initializer = moe_layer_dict[
lora_config.routing_strategy_
].adapter_initializer
else:
model_prefix_name = "base_model.model.model"
moe_layer_name_list = []
Expand All @@ -208,6 +212,26 @@ def init_lora_layer_weight(
if moe_initializer is not None:
# init for gating mechanisms
moe_initializer(llm_config, lora_config, proj_weight)
if (
hasattr(proj_weight, "_moe_gates")
and lora_config.adapter_name in proj_weight._moe_gates
):
gate_weight = (
lora_weights.get(f"{module_name}.moe_gate.weight", None)
if lora_weights is not None
else None
)
if gate_weight is None:
torch.nn.init.normal_(
proj_weight._moe_gates[lora_config.adapter_name].weight,
mean=0.0,
std=lora_config.router_init_range_,
)
else:
with torch.no_grad():
proj_weight._moe_gates[
lora_config.adapter_name
].weight.copy_(gate_weight)

for expert_idx in range(lora_config.num_experts_):
if lora_weights is None:
Expand Down Expand Up @@ -559,7 +583,7 @@ def get_adapter_weight_dict(self, adapter_name: str) -> Dict[str, torch.Tensor]:
hasattr(proj_weight, "_moe_gates")
and adapter_name in proj_weight._moe_gates
):
lora_weight_dict[f"{module_name}.mlp.moe_gate.weight"] = (
lora_weight_dict[f"{module_name}.moe_gate.weight"] = (
proj_weight._moe_gates[adapter_name].weight
)

Expand Down
2 changes: 2 additions & 0 deletions mlora/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

# MixLoRA MoEs
from .lora_moes import (
LoraMoe,
MixtralRouterLoss,
MixtralSparseMoe,
SwitchRouterLoss,
Expand Down Expand Up @@ -83,6 +84,7 @@
"MixtralSparseMoe",
"SwitchRouterLoss",
"SwitchSparseMoe",
"LoraMoe",
"router_loss_dict",
"moe_layer_dict",
"router_loss_factory",
Expand Down
31 changes: 20 additions & 11 deletions mlora/modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def export(self) -> Dict[str, any]:
return config


available_routing_strategies = ["mixtral", "switch"]
available_routing_strategies = ["mixlora", "mixlora-switch"]


@dataclass
Expand Down Expand Up @@ -240,9 +240,9 @@ def check(self) -> "MixLoraConfig":
assert self.act_fn_ is None or (
isinstance(self.act_fn_, str) and self.act_fn_ in ACT2FN
)
if self.routing_strategy_ == "mixtral":
if self.routing_strategy_ == "mixlora":
assert isinstance(self.top_k_, int) and self.top_k_ > 0
elif self.routing_strategy_ == "switch":
elif self.routing_strategy_ == "mixlora-switch":
assert (
isinstance(self.router_z_loss_coef_, float)
and self.router_z_loss_coef_ >= 0
Expand Down Expand Up @@ -270,11 +270,11 @@ def from_config(config: Dict[str, any]) -> "MixLoraConfig":
# silu for mixtral or gelu_new for switch transformers
# left blank to automatically use the original act_fn of FFN
lora_config.act_fn_ = config.get("act_fn", None)
if lora_config.routing_strategy_ == "mixtral":
if lora_config.routing_strategy_ == "mixlora":
lora_config.router_init_range_ = config.get("router_init_range", 0.02)
lora_config.jitter_noise_ = config.get("jitter_noise", 0.0)
lora_config.top_k_ = config.get("top_k", 2)
elif lora_config.routing_strategy_ == "switch":
elif lora_config.routing_strategy_ == "mixlora-switch":
lora_config.router_init_range_ = config.get("router_init_range", 1.0)
lora_config.jitter_noise_ = config.get("jitter_noise", 0.01)
lora_config.router_z_loss_coef_ = config.get(
Expand All @@ -300,9 +300,9 @@ def export(self) -> Dict[str, any]:
config["num_experts"] = self.num_experts_
if self.act_fn_ is not None:
config["act_fn"] = self.act_fn_
if self.routing_strategy_ == "mixtral":
if self.routing_strategy_ == "mixlora":
config["top_k"] = self.top_k_
elif self.routing_strategy_ == "switch":
elif self.routing_strategy_ == "mixlora-switch":
config["expert_capacity"] = self.expert_capacity_
config["sparse_step"] = self.sparse_step_

Expand All @@ -322,20 +322,27 @@ class LoraMoeConfig(LoraConfig):
blc_alpha_: float = None
blc_weight_: float = None
num_experts_: int = None
router_init_range_: float = None
routing_strategy_: str = "loramoe"

def check(self) -> "LoraMoeConfig":
super().check()
assert isinstance(self.blc_alpha_, float) and self.blc_alpha_ >= 0.0
assert isinstance(self.blc_weight_, float) and self.blc_weight_ >= 0.0
assert isinstance(self.num_experts_, int) and self.num_experts_ > 0
assert (
isinstance(self.router_init_range_, float) and self.router_init_range_ >= 0
)

return self

@staticmethod
def from_config(config: Dict[str, any]) -> "LoraMoeConfig":
return LoraMoeConfig(
blc_alpha_=config.get("blc_alpha", 0.0),
blc_weight_=config.get("blc_weight", 0.0),
num_experts_=config["num_experts"],
router_init_range_=config.get("router_init_range", 0.02),
**LoraConfig.from_config(config).__dict__,
)

Expand All @@ -355,11 +362,13 @@ def expert_config(self, expert_idx: int) -> LoraConfig:


def lora_config_factory(config: Dict[str, any]) -> LoraConfig:
if (
"peft_type" in config and config["peft_type"] == "MIXLORA"
) or "routing_strategy" in config:
if ("peft_type" in config and config["peft_type"] == "MIXLORA") or (
config.get("routing_strategy", "") in ["mixlora", "mixlora-switch"]
):
return MixLoraConfig.from_config(config).check()
if "peft_type" in config and config["peft_type"] == "LORAMOE":
if ("peft_type" in config and config["peft_type"] == "LORAMOE") or (
config.get("routing_strategy", "") == "loramoe"
):
return LoraMoeConfig.from_config(config).check()
else:
return LoraConfig.from_config(config).check()
13 changes: 8 additions & 5 deletions mlora/modules/lora_moes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,13 @@ def adapter_initializer(
linear.selective_hook_[adapter_config.adapter_name] = LoraMoe.selective_hook

def forward(self, mlp: LLMFeedForward, hidden_states: torch.Tensor) -> Tuple:
return mlp._selective_forward(hidden_states, self.adapter_name_, moe_layer=self)
return (
mlp._selective_forward(hidden_states, self.adapter_name_, moe_layer=self),
None,
)


router_loss_dict = {"mixtral": MixtralRouterLoss, "switch": SwitchRouterLoss}
router_loss_dict = {"mixlora": MixtralRouterLoss, "mixlora-switch": SwitchRouterLoss}


def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module:
Expand All @@ -85,13 +88,13 @@ def router_loss_factory(config: MixLoraConfig) -> torch.nn.Module:


moe_layer_dict = {
"mixtral": MixtralSparseMoe,
"switch": SwitchSparseMoe,
"mixlora": MixtralSparseMoe,
"mixlora-switch": SwitchSparseMoe,
"loramoe": LoraMoe,
}


def moe_layer_factory(args: LLMModelConfig, config: MixLoraConfig) -> torch.nn.Module:
if config.routing_strategy_ not in router_loss_dict:
if config.routing_strategy_ not in moe_layer_dict:
raise ValueError(f"Unknown routing strategy {config.routing_strategy_}")
return moe_layer_dict[config.routing_strategy_](args, config)
9 changes: 9 additions & 0 deletions mlora/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ def prepare(self, train_params: Dict[str, torch.Tensor]):
# preparing optimizer
paramas_count = sum(t.numel() for t in train_params.values() if t.requires_grad)
logging.info(f"{self.adapter_name} total trainable params: {paramas_count}")
paramas_count = sum(
t.numel()
for n, t in train_params.items()
if "moe_gate" not in n and t.requires_grad
)
if paramas_count > 0:
logging.info(
f"{self.adapter_name} total trainable params (except gates): {paramas_count}"
)
grouped_parameters = self._optimizer_grouped_parameters(train_params)
if self.optimizer_type == "sgd":
self.optimizer_ = torch.optim.SGD(
Expand Down
2 changes: 1 addition & 1 deletion templates/mixlora.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"down_proj": true,
"up_proj": true
},
"routing_strategy": "mixtral",
"routing_strategy": "mixlora",
"num_experts": 8,
"top_k": 2,
"group_by_length": false
Expand Down
2 changes: 1 addition & 1 deletion templates/mixlora_glm.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"dense_h_to_4h": true,
"dense_4h_to_h": true
},
"routing_strategy": "mixtral",
"routing_strategy": "mixlora",
"num_experts": 8,
"top_k": 2,
"group_by_length": false
Expand Down
2 changes: 1 addition & 1 deletion templates/mixlora_phi.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"fc1": true,
"fc2": true
},
"routing_strategy": "mixtral",
"routing_strategy": "mixlora",
"num_experts": 8,
"top_k": 2,
"group_by_length": false
Expand Down
2 changes: 1 addition & 1 deletion templates/mixlora_phi3.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"gate_up_proj": true,
"down_proj": true
},
"routing_strategy": "mixtral",
"routing_strategy": "mixlora",
"num_experts": 8,
"top_k": 2,
"group_by_length": false
Expand Down

0 comments on commit edc2ac9

Please sign in to comment.