Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
format and remove redundant codes
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jul 29, 2024
1 parent cb0ebe0 commit 15cc636
Showing 1 changed file with 9 additions and 16 deletions.
25 changes: 9 additions & 16 deletions mlora/models/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.models.phi3.modeling_phi3 import (
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.models.phi3.modeling_phi3 import apply_rotary_pos_emb, repeat_kv
from transformers.utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
)

from mlora.backends import backend
from mlora.common import (
CHECKPOINT_CLASSES,
Expand All @@ -32,6 +30,7 @@
)
from mlora.common.mix_lora import _mixtral_slice_tensor
from mlora.utils import copy_parameters

from .modeling_gemma2 import Gemma2RotaryEmbedding


Expand Down Expand Up @@ -60,7 +59,6 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
v = data.to(torch.float32).pow(2).mean(-1, keepdim=True)
data = data * torch.rsqrt(v + self.norm_eps_)

print("Phi3 RMSNorm passed.")
return (self.weight_ * data).to(input_dtype)


Expand All @@ -73,7 +71,6 @@ def __init__(self, embedding: torch.Tensor, pad_token: int):
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
data = F.embedding(tokens, self.token_embedding_, padding_idx=self.padding_idx_)
# normalizer = torch.tensor(self.normalizer_, dtype=data.dtype)
print("Phi3Embedding passed.")
return data


Expand Down Expand Up @@ -174,8 +171,6 @@ def forward(
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)

print("Phi3 Attention passed.")

return self.o_proj_(attn_output, input_args)


Expand Down Expand Up @@ -335,7 +330,6 @@ def forward(
hidden_states, router_logits = self.mlp_.forward(hidden_states, input_args)
hidden_states = residual + hidden_states

print("phi3 Decoder passed.")
return hidden_states, *router_logits


Expand Down Expand Up @@ -392,7 +386,6 @@ def _batch_forward(
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.act_(gate)

print("Phi3MLP._batch_forward() passed.")
return self.down_(up_states, input_args)

def _lora_forward(
Expand All @@ -415,18 +408,18 @@ def _lora_forward(

act_result = act_fn(gate) * up
if lora_name in self.down_.loras_:
print("Phi3MLP._lora_forward(1) passed.")
return self.down_.loras_[lora_name].forward(
self.down_.base_layer_.forward(act_result), act_result
)
else:
print("Phi3MLP._lora_forward(2) passed.")
return self.down_.base_layer_.forward(act_result)

def _mixlora_forward(
self, moe_name, act_fn, expert_mask, hidden_states, input_dtype
):
common_gate_up = self.gate_up_.base_layer_.forward(hidden_states.to(input_dtype)).to(hidden_states.dtype)
common_gate_up = self.gate_up_.base_layer_.forward(
hidden_states.to(input_dtype)
).to(hidden_states.dtype)

final_expert_states = []
for expert_idx in range(expert_mask.shape[0]):
Expand All @@ -440,7 +433,9 @@ def _mixlora_forward(
)
else:
lora_data = None
gate_up_states = _mixtral_slice_tensor(common_gate_up, top_x, input_dtype)
gate_up_states = _mixtral_slice_tensor(
common_gate_up, top_x, input_dtype
)

gate, up_states = gate_up_states.chunk(2, dim=-1)
act_result = up_states * self.act_(gate)
Expand All @@ -455,7 +450,6 @@ def _mixlora_forward(
else:
final_expert_states.append(self.down_.base_layer_.forward(act_result))

print("Phi3MLP._mixlora_forward() passed.")
return final_expert_states


Expand Down Expand Up @@ -647,5 +641,4 @@ def from_pretrained(
)
model.layers_.append(decoder)

print("phi3FormPretrained passed ")
return model

0 comments on commit 15cc636

Please sign in to comment.