diff --git a/mlora/models/modeling_phi3.py b/mlora/models/modeling_phi3.py index e1a5dc4b..0fb2d855 100644 --- a/mlora/models/modeling_phi3.py +++ b/mlora/models/modeling_phi3.py @@ -7,14 +7,12 @@ import torch.nn as nn import torch.nn.functional as F from transformers.activations import ACT2FN -from transformers.models.phi3.modeling_phi3 import ( - apply_rotary_pos_emb, - repeat_kv, -) +from transformers.models.phi3.modeling_phi3 import apply_rotary_pos_emb, repeat_kv from transformers.utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, ) + from mlora.backends import backend from mlora.common import ( CHECKPOINT_CLASSES, @@ -32,6 +30,7 @@ ) from mlora.common.mix_lora import _mixtral_slice_tensor from mlora.utils import copy_parameters + from .modeling_gemma2 import Gemma2RotaryEmbedding @@ -60,7 +59,6 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: v = data.to(torch.float32).pow(2).mean(-1, keepdim=True) data = data * torch.rsqrt(v + self.norm_eps_) - print("Phi3 RMSNorm passed.") return (self.weight_ * data).to(input_dtype) @@ -73,7 +71,6 @@ def __init__(self, embedding: torch.Tensor, pad_token: int): def forward(self, tokens: torch.Tensor) -> torch.Tensor: data = F.embedding(tokens, self.token_embedding_, padding_idx=self.padding_idx_) # normalizer = torch.tensor(self.normalizer_, dtype=data.dtype) - print("Phi3Embedding passed.") return data @@ -174,8 +171,6 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, -1) - print("Phi3 Attention passed.") - return self.o_proj_(attn_output, input_args) @@ -335,7 +330,6 @@ def forward( hidden_states, router_logits = self.mlp_.forward(hidden_states, input_args) hidden_states = residual + hidden_states - print("phi3 Decoder passed.") return hidden_states, *router_logits @@ -392,7 +386,6 @@ def _batch_forward( gate, up_states = up_states.chunk(2, dim=-1) up_states = up_states * self.act_(gate) - print("Phi3MLP._batch_forward() passed.") return self.down_(up_states, input_args) def _lora_forward( @@ -415,18 +408,18 @@ def _lora_forward( act_result = act_fn(gate) * up if lora_name in self.down_.loras_: - print("Phi3MLP._lora_forward(1) passed.") return self.down_.loras_[lora_name].forward( self.down_.base_layer_.forward(act_result), act_result ) else: - print("Phi3MLP._lora_forward(2) passed.") return self.down_.base_layer_.forward(act_result) def _mixlora_forward( self, moe_name, act_fn, expert_mask, hidden_states, input_dtype ): - common_gate_up = self.gate_up_.base_layer_.forward(hidden_states.to(input_dtype)).to(hidden_states.dtype) + common_gate_up = self.gate_up_.base_layer_.forward( + hidden_states.to(input_dtype) + ).to(hidden_states.dtype) final_expert_states = [] for expert_idx in range(expert_mask.shape[0]): @@ -440,7 +433,9 @@ def _mixlora_forward( ) else: lora_data = None - gate_up_states = _mixtral_slice_tensor(common_gate_up, top_x, input_dtype) + gate_up_states = _mixtral_slice_tensor( + common_gate_up, top_x, input_dtype + ) gate, up_states = gate_up_states.chunk(2, dim=-1) act_result = up_states * self.act_(gate) @@ -455,7 +450,6 @@ def _mixlora_forward( else: final_expert_states.append(self.down_.base_layer_.forward(act_result)) - print("Phi3MLP._mixlora_forward() passed.") return final_expert_states @@ -647,5 +641,4 @@ def from_pretrained( ) model.layers_.append(decoder) - print("phi3FormPretrained passed ") return model