From f8da983c17233ef895581623ea4efd172c34c322 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Thu, 14 Nov 2024 11:16:26 +0800 Subject: [PATCH 1/3] add qwen2vl --- .../examples/qwen2_vl/multi_image_infer.py | 61 +++++++++++++++++-- .../examples/qwen2_vl/single_image_infer.py | 58 ++++++++++++++++-- paddlemix/examples/qwen2_vl/video_infer.py | 58 +++++++++++++++++- 3 files changed, 165 insertions(+), 12 deletions(-) diff --git a/paddlemix/examples/qwen2_vl/multi_image_infer.py b/paddlemix/examples/qwen2_vl/multi_image_infer.py index c0ef506d1..164ad73c8 100644 --- a/paddlemix/examples/qwen2_vl/multi_image_infer.py +++ b/paddlemix/examples/qwen2_vl/multi_image_infer.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime + +import paddle from paddlenlp.transformers import Qwen2Tokenizer from paddlemix.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration @@ -21,7 +24,12 @@ process_vision_info, ) -MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" +benchmark = True +warm_up = 3 + +# MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" +MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct" + model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype="bfloat16") image_processor = Qwen2VLImageProcessor() @@ -37,8 +45,14 @@ { "role": "user", "content": [ - {"type": "image", "image": "paddlemix/demo_images/examples_image1.jpg"}, - {"type": "image", "image": "paddlemix/demo_images/examples_image2.jpg"}, + { + "type": "image", + "image": "/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/demo_images/examples_image1.jpg", + }, + { + "type": "image", + "image": "/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/demo_images/examples_image2.jpg", + }, {"type": "text", "text": "Identify the similarities between these images."}, ], } @@ -59,7 +73,44 @@ return_tensors="pd", ) -# Inference: Generation of the output -generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle + +if warm_up > 0: + for _ in range(warm_up): + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle +if benchmark: + repeat_times = 10 + sumtime = 0.0 + for i in range(repeat_times): + paddle.device.synchronize() + starttime = datetime.datetime.now() + + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle + + paddle.device.synchronize() + endtime = datetime.datetime.now() + + duringtime = endtime - starttime + duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + sumtime += duringtime + print(f"Multi {MODEL_NAME} end to end time : ", duringtime, "ms") + + paddle.device.cuda.empty_cache() + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") + + print(f"Multi {MODEL_NAME} ave end to end time : ", sumtime / repeat_times, "ms") + + paddle.device.cuda.empty_cache() + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") + cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3) + print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") +else: + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle + + output_text = processor.batch_decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) print("output_text:\n", output_text[0]) diff --git a/paddlemix/examples/qwen2_vl/single_image_infer.py b/paddlemix/examples/qwen2_vl/single_image_infer.py index f98aad649..cf7f44d5d 100644 --- a/paddlemix/examples/qwen2_vl/single_image_infer.py +++ b/paddlemix/examples/qwen2_vl/single_image_infer.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime + +import paddle from paddlenlp.transformers import Qwen2Tokenizer from paddlemix.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration @@ -21,7 +24,11 @@ process_vision_info, ) +benchmark = True +warm_up = 3 + MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" +# MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct" model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype="bfloat16") image_processor = Qwen2VLImageProcessor() @@ -32,14 +39,14 @@ # max_pixels = 1280*28*28 # 1003520 # processor = Qwen2VLProcessor(image_processor, tokenizer, min_pixels=min_pixels, max_pixels=max_pixels) - messages = [ { "role": "user", "content": [ { "type": "image", - "image": "paddlemix/demo_images/examples_image1.jpg", + # "image": "paddlemix/demo_images/examples_image1.jpg", + "image": "/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/demo_images/examples_image1.jpg", }, {"type": "text", "text": "Describe this image."}, ], @@ -61,7 +68,50 @@ return_tensors="pd", ) -# Inference: Generation of the output -generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle +if warm_up > 0: + for _ in range(warm_up): + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle +if benchmark: + repeat_times = 10 + sumtime = 0.0 + for i in range(repeat_times): + paddle.device.synchronize() + starttime = datetime.datetime.now() + + paddle.device.synchronize() + import nvtx + + generate_nvtx = nvtx.start_range(message="generate", color="green") + + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle + + paddle.device.synchronize() + nvtx.end_range(generate_nvtx) + + paddle.device.synchronize() + endtime = datetime.datetime.now() + + duringtime = endtime - starttime + duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + sumtime += duringtime + print(f"Single {MODEL_NAME} end to end time : ", duringtime, "ms") + + paddle.device.cuda.empty_cache() + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") + + print(f"Single {MODEL_NAME} ave end to end time : ", sumtime / repeat_times, "ms") + + paddle.device.cuda.empty_cache() + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") + cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3) + print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") +else: + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle + output_text = processor.batch_decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) print("output_text:\n", output_text[0]) diff --git a/paddlemix/examples/qwen2_vl/video_infer.py b/paddlemix/examples/qwen2_vl/video_infer.py index 7b28a7ff8..5c7e54460 100644 --- a/paddlemix/examples/qwen2_vl/video_infer.py +++ b/paddlemix/examples/qwen2_vl/video_infer.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime + +import paddle from paddlenlp.transformers import Qwen2Tokenizer from paddlemix.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration @@ -21,6 +24,10 @@ process_vision_info, ) +benchmark = True +warm_up = 3 + +# MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct" model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype="bfloat16") @@ -37,7 +44,7 @@ "content": [ { "type": "video", - "video": "paddlemix/demo_images/red-panda.mp4", + "video": "/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/demo_images/red-panda.mp4", "max_pixels": 360 * 420, "fps": 1.0, }, @@ -58,8 +65,53 @@ padding=True, return_tensors="pd", ) -# Inference: Generation of the output -generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle + +if warm_up > 0: + for _ in range(warm_up): + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle +if benchmark: + repeat_times = 10 + sumtime = 0.0 + for i in range(repeat_times): + paddle.device.synchronize() + starttime = datetime.datetime.now() + + paddle.device.synchronize() + import nvtx + + generate_nvtx = nvtx.start_range(message="generate", color="green") + + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle + + paddle.device.synchronize() + nvtx.end_range(generate_nvtx) + + paddle.device.synchronize() + endtime = datetime.datetime.now() + + duringtime = endtime - starttime + duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0 + sumtime += duringtime + print(f"Video {MODEL_NAME} end to end time : ", duringtime, "ms") + + paddle.device.cuda.empty_cache() + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") + + print(f"Video {MODEL_NAME} ave end to end time : ", sumtime / repeat_times, "ms") + + paddle.device.cuda.empty_cache() + inference_global_mem = paddle.device.cuda.memory_reserved() / (1024**3) + print(f"Inference used CUDA memory : {inference_global_mem:.3f} GiB") + cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3) + print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") +else: + # Inference: Generation of the output + generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle + + # print("generated_ids:\n", generated_ids) output_text = processor.batch_decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) print("output_text:\n", output_text[0]) From acf476e1c52c5b01ef65074568e0011b25babdc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Tue, 19 Nov 2024 13:56:21 +0800 Subject: [PATCH 2/3] test --- .../examples/qwen2_vl/single_image_infer.py | 40 +++++++++-- .../models/qwen2_vl/modeling_qwen2_vl.py | 69 ++++++++++++++++--- 2 files changed, 96 insertions(+), 13 deletions(-) diff --git a/paddlemix/examples/qwen2_vl/single_image_infer.py b/paddlemix/examples/qwen2_vl/single_image_infer.py index cf7f44d5d..248d8ae0d 100644 --- a/paddlemix/examples/qwen2_vl/single_image_infer.py +++ b/paddlemix/examples/qwen2_vl/single_image_infer.py @@ -23,9 +23,29 @@ Qwen2VLProcessor, process_vision_info, ) +import argparse +def parse_args(): + parser = argparse.ArgumentParser( + description=" Use PaddleMIX to accelerate the Stable Diffusion3 image generation model." + ) + parser.add_argument( + "--benchmark", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=False, + help="if set to True, measure inference performance", + ) + parser.add_argument( + "--inference_optimize", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=False, + help="If set to True, all optimizations except Triton are enabled.", + ) + return parser.parse_args() + +args = parse_args() + + -benchmark = True -warm_up = 3 MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" # MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct" @@ -68,11 +88,22 @@ return_tensors="pd", ) -if warm_up > 0: +# pipe.transformer = paddle.incubate.jit.inference( +# pipe.transformer, +# save_model_dir="./tmp/sd3", +# enable_new_ir=True, +# cache_static_model=True, +# # V100环境下,需设置exp_enable_use_cutlass=False, +# exp_enable_use_cutlass=True, +# delete_pass_lists=["add_norm_fuse_pass"], +# ) + + +if args.benchmark: + warm_up = 3 for _ in range(warm_up): # Inference: Generation of the output generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle -if benchmark: repeat_times = 10 sumtime = 0.0 for i in range(repeat_times): @@ -110,6 +141,7 @@ cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3) print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") else: + # breakpoint() # Inference: Generation of the output generated_ids = model.generate(**inputs, max_new_tokens=128) # already trimmed in paddle diff --git a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py index 84ef694b6..bfd719896 100644 --- a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py +++ b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py @@ -38,7 +38,7 @@ from ...activations import ACT2FN from .bert_padding import index_first_axis, pad_input, unpad_input from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig - +from paddlemix.triton_ops.triton_ops import rms_norm logger = logging.get_logger(__name__) @@ -282,7 +282,9 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: stride=self.proj._stride) hidden_states = hidden_states.to(target_dtype).reshape([-1, self.embed_dim]) else: - hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).reshape([-1, self.embed_dim]) + # NOTE(changwenbin): AttributeError: 'Variable' object has no attribute 'to' + # hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).reshape([-1, self.embed_dim]) + hidden_states = self.proj(paddle.cast(hidden_states,dtype=target_dtype)).reshape([-1, self.embed_dim]) return hidden_states @@ -485,6 +487,7 @@ def forward(self, hidden_states): if self.weight.dtype in [paddle.float16, paddle.bfloat16]: hidden_states = paddle.cast(hidden_states, self.weight.dtype) return hidden_states * self.weight + # hidden_states = rms_norm(hidden_states,weight=self.weight, epsilon=self.variance_epsilon) # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP @@ -774,7 +777,7 @@ def _flash_attention_forward( key_states, value_states, dropout, - causal=causal, # no softmax_scale= + causal=True, # no softmax_scale= )[0] return attn_output @@ -834,7 +837,15 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.config = config + # @paddle.incubate.jit.inference( + # enable_new_ir=False, + # cache_static_model=False, + # save_model_dir="/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/examples/qwen2_vl/tmp/qwen2vl_model_decoder", + # exp_enable_use_cutlass=False, + # skip_prune_program=True, + # # delete_pass_lists=["fc_fuse_pass"], + # ) def forward( self, hidden_states: paddle.Tensor, @@ -844,7 +855,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[paddle.Tensor] = None, - **kwargs, + # **kwargs, ): """ Args: @@ -867,7 +878,13 @@ def forward( residual = hidden_states + # Note:(changwenbin) use triton_rmsnorm hidden_states = self.input_layernorm(hidden_states) + # print(self.input_layernorm.weight) + # print(self.input_layernorm.bias) + # exit(0) + # from paddlemix.triton_ops.triton_ops import rms_norm + # hidden_states = rms_norm(hidden_states,weight=self.input_layernorm.weight, epsilon=self.config.rms_norm_eps) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( @@ -978,7 +995,14 @@ def rot_pos_emb(self, grid_thw): rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_axis=1) return rotary_pos_emb - + + # @paddle.incubate.jit.inference( + # enable_new_ir=False, + # cache_static_model=False, + # save_model_dir="/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/examples/qwen2_vl/tmp/qwen2vl_vision", + # exp_enable_use_cutlass=False, + # # delete_pass_lists=["fc_fuse_pass"], + # ) def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor) -> paddle.Tensor: hidden_states = self.patch_embed(hidden_states) @@ -1042,7 +1066,15 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values # Convert bool attention_mask to float attention mask, which will be added to attention_scores later expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) return expanded_attn_mask - + @paddle.incubate.jit.inference( + enable_new_ir=True, + cache_static_model=False, + save_model_dir="/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/examples/qwen2_vl/tmp/qwen2vl_model", + exp_enable_use_cutlass=False, + # skip_prune_program=True, + switch_ir_optim=False, + # delete_pass_lists=["fc_fuse_pass"], + ) def forward( self, input_ids: paddle.Tensor = None, @@ -1454,7 +1486,10 @@ def forward( output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + # import datetime + # paddle.device.synchronize() + # starttime_visual = datetime.datetime.now() + if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: @@ -1471,7 +1506,16 @@ def forward( inputs_embeds[video_mask] = video_embeds if attention_mask is not None: attention_mask = attention_mask + # paddle.device.synchronize() + # endtime_visual = datetime.datetime.now() + + # duringtime_visual = endtime_visual - starttime_visual + # duringtime_visual = duringtime_visual.seconds * 1000 + duringtime_visual.microseconds / 1000.0 + # print("duringtime_visual",duringtime_visual) + # paddle.device.synchronize() + # starttime_text = datetime.datetime.now() + outputs = self.model( input_ids=None, position_ids=position_ids, @@ -1483,6 +1527,13 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) + # paddle.device.synchronize() + # endtime_text = datetime.datetime.now() + + # duringtime_text = endtime_text - starttime_text + # duringtime_text = duringtime_text.seconds * 1000 + duringtime_text.microseconds / 1000.0 + # print("duringtime_text",duringtime_text) + hidden_states = outputs[0] logits = self.lm_head(hidden_states) @@ -1503,7 +1554,7 @@ def forward( loss = loss / label_sum if not return_dict: - output = (logits,) + outputs[1:] + output = (logits,) + tuple(outputs[1:]) return (loss,) + output if loss is not None else output # return logits + 28 layers k and v From 270a4c852aa5371548c92aba0c93f375c8a8f481 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cchang-wenbin=E2=80=9D?= <1286094601@qq.com> Date: Tue, 19 Nov 2024 18:42:46 +0800 Subject: [PATCH 3/3] tast --- .../models/qwen2_vl/modeling_qwen2_vl.py | 71 +++---------------- 1 file changed, 10 insertions(+), 61 deletions(-) diff --git a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py index bfd719896..dead6e71c 100644 --- a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py +++ b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py @@ -38,7 +38,7 @@ from ...activations import ACT2FN from .bert_padding import index_first_axis, pad_input, unpad_input from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig -from paddlemix.triton_ops.triton_ops import rms_norm + logger = logging.get_logger(__name__) @@ -282,9 +282,7 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: stride=self.proj._stride) hidden_states = hidden_states.to(target_dtype).reshape([-1, self.embed_dim]) else: - # NOTE(changwenbin): AttributeError: 'Variable' object has no attribute 'to' - # hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).reshape([-1, self.embed_dim]) - hidden_states = self.proj(paddle.cast(hidden_states,dtype=target_dtype)).reshape([-1, self.embed_dim]) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).reshape([-1, self.embed_dim]) return hidden_states @@ -487,7 +485,6 @@ def forward(self, hidden_states): if self.weight.dtype in [paddle.float16, paddle.bfloat16]: hidden_states = paddle.cast(hidden_states, self.weight.dtype) return hidden_states * self.weight - # hidden_states = rms_norm(hidden_states,weight=self.weight, epsilon=self.variance_epsilon) # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP @@ -777,7 +774,7 @@ def _flash_attention_forward( key_states, value_states, dropout, - causal=True, # no softmax_scale= + causal=causal, # no softmax_scale= )[0] return attn_output @@ -837,15 +834,7 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int): self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.config = config - # @paddle.incubate.jit.inference( - # enable_new_ir=False, - # cache_static_model=False, - # save_model_dir="/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/examples/qwen2_vl/tmp/qwen2vl_model_decoder", - # exp_enable_use_cutlass=False, - # skip_prune_program=True, - # # delete_pass_lists=["fc_fuse_pass"], - # ) + def forward( self, hidden_states: paddle.Tensor, @@ -855,7 +844,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[paddle.Tensor] = None, - # **kwargs, + **kwargs, ): """ Args: @@ -878,13 +867,7 @@ def forward( residual = hidden_states - # Note:(changwenbin) use triton_rmsnorm hidden_states = self.input_layernorm(hidden_states) - # print(self.input_layernorm.weight) - # print(self.input_layernorm.bias) - # exit(0) - # from paddlemix.triton_ops.triton_ops import rms_norm - # hidden_states = rms_norm(hidden_states,weight=self.input_layernorm.weight, epsilon=self.config.rms_norm_eps) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( @@ -995,14 +978,7 @@ def rot_pos_emb(self, grid_thw): rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_axis=1) return rotary_pos_emb - - # @paddle.incubate.jit.inference( - # enable_new_ir=False, - # cache_static_model=False, - # save_model_dir="/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/examples/qwen2_vl/tmp/qwen2vl_vision", - # exp_enable_use_cutlass=False, - # # delete_pass_lists=["fc_fuse_pass"], - # ) + def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor) -> paddle.Tensor: hidden_states = self.patch_embed(hidden_states) @@ -1066,15 +1042,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values # Convert bool attention_mask to float attention mask, which will be added to attention_scores later expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) return expanded_attn_mask - @paddle.incubate.jit.inference( - enable_new_ir=True, - cache_static_model=False, - save_model_dir="/root/paddlejob/workspace/env_run/output/changwenbin/PaddleMIX/paddlemix/examples/qwen2_vl/tmp/qwen2vl_model", - exp_enable_use_cutlass=False, - # skip_prune_program=True, - switch_ir_optim=False, - # delete_pass_lists=["fc_fuse_pass"], - ) + def forward( self, input_ids: paddle.Tensor = None, @@ -1486,10 +1454,7 @@ def forward( output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # fmt:skip return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # import datetime - # paddle.device.synchronize() - # starttime_visual = datetime.datetime.now() - + if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: @@ -1506,16 +1471,7 @@ def forward( inputs_embeds[video_mask] = video_embeds if attention_mask is not None: attention_mask = attention_mask - # paddle.device.synchronize() - # endtime_visual = datetime.datetime.now() - - # duringtime_visual = endtime_visual - starttime_visual - # duringtime_visual = duringtime_visual.seconds * 1000 + duringtime_visual.microseconds / 1000.0 - # print("duringtime_visual",duringtime_visual) - # paddle.device.synchronize() - # starttime_text = datetime.datetime.now() - outputs = self.model( input_ids=None, position_ids=position_ids, @@ -1527,13 +1483,6 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - # paddle.device.synchronize() - # endtime_text = datetime.datetime.now() - - # duringtime_text = endtime_text - starttime_text - # duringtime_text = duringtime_text.seconds * 1000 + duringtime_text.microseconds / 1000.0 - # print("duringtime_text",duringtime_text) - hidden_states = outputs[0] logits = self.lm_head(hidden_states) @@ -1554,7 +1503,7 @@ def forward( loss = loss / label_sum if not return_dict: - output = (logits,) + tuple(outputs[1:]) + output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output # return logits + 28 layers k and v @@ -1661,4 +1610,4 @@ def prepare_inputs_for_generation( "rope_deltas": rope_deltas, # [[-3504]] } ) - return model_inputs + return model_inputs \ No newline at end of file