From 04139ade599eedd493ce8effcda7ceabb57f2fb5 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Fri, 20 Dec 2024 04:04:21 -0800 Subject: [PATCH] [V1] Fix profiling for models with merged input processor (#11370) Signed-off-by: ywang96 --- vllm/v1/worker/gpu_model_runner.py | 44 ++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cb89246db0cc9..ace62d8978bea 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -635,17 +635,6 @@ def profile_run(self) -> None: ) dummy_mm_data = dummy_request_data.multi_modal_data - # Compute MM hashes (if enabled) - mm_hashes = None - if self.use_hash: - mm_hashes = self.mm_hasher.hash_dummy_mm_data(dummy_mm_data) - - dummy_mm_kwargs = self.mm_input_mapper_client.process_inputs( - mm_data=dummy_mm_data, - mm_hashes=mm_hashes, - mm_processor_kwargs=None, - precomputed_mm_inputs=None) - # NOTE: Currently model is profiled with a single non-text # modality even when it supports multiple. max_tokens_per_mm_item = max( @@ -660,8 +649,39 @@ def profile_run(self) -> None: # (e.g, multiple images) for a single request, therefore here we # always replicate first item by max_num_mm_items times since in V1 # they are scheduled to be processed separately. + + # Case when models have a merged processor, their dummy data is + # already batched `MultiModalKwargs`, therefore we need to "unbatch" + # and take the first item in each batched tensor. + # TODO (ywang96): This is somewhat hacky. Refactor this to be + # consistent with the other case. + if isinstance(dummy_mm_data, MultiModalKwargs): + dummy_mm_kwargs = { + k: v[0].unsqueeze(0) + for k, v in dummy_mm_data.items() + } + + # Case when models have dummy data explicitly defined as + # `MultiModalDataDict`, so they need to be processed through input + # mapper. + else: + # Compute MM hashes (if enabled) + mm_hashes = None + if self.use_hash: + mm_hashes = self.mm_hasher.hash_dummy_mm_data( + dummy_mm_data) + + mm_kwargs_list = self.mm_input_mapper_client.process_inputs( + mm_data=dummy_mm_data, + mm_hashes=mm_hashes, + mm_processor_kwargs=None, + precomputed_mm_inputs=None) + + # Take the first `MultiModalKwargs` + dummy_mm_kwargs = mm_kwargs_list[0] + batched_dummy_mm_inputs = MultiModalKwargs.batch( - [dummy_mm_kwargs[0]] * max_num_mm_items) + [dummy_mm_kwargs] * max_num_mm_items) batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs( batched_dummy_mm_inputs, device=self.device)