Skip to content

Commit

Permalink
[V1] Fix multimodal profiling for Molmo (vllm-project#11325)
Browse files Browse the repository at this point in the history
Signed-off-by: ywang96 <[email protected]>
Co-authored-by: ywang96 <[email protected]>
  • Loading branch information
ywang96 and ywang96 authored Dec 19, 2024
1 parent 6c7f881 commit 7379b3d
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 4 deletions.
5 changes: 5 additions & 0 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,11 @@ def image_input_mapper_for_molmo(
data: object,
):
if isinstance(data, list):
assert len(data) == 1, "Molmo supports only one image per prompt."
data = data[0]

# Remove unused dummy PIL image
data.pop('raw_mm_data', None)
return MultiModalKwargs(data)


Expand Down Expand Up @@ -974,6 +978,7 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
dummy_imgdata = {
"images": out["images"],
"image_input_idx": out["image_input_idx"],
"raw_mm_data": dummy_image,
}
if "image_masks" in out:
dummy_imgdata["image_masks"] = out["image_masks"]
Expand Down
19 changes: 17 additions & 2 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,31 @@ class MMHasher:
def __init__(self):
pass

def hash_mm_data(
def hash_dummy_mm_data(
self,
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
"""Hash user-defined dummy multimodal data used for profiling."""

if mm_data is None:
return None

image_inputs = mm_data['image']

# This is a temporary workaround for models (e.g, Molmo) that
# process multimodal data in the input processor (therefore
# image_inputs is MultiModalKwargs instead of raw input format).
# `raw_mm_data` with the original input format is expected
# in this case.
if isinstance(image_inputs, dict):
assert "raw_mm_data" in image_inputs and isinstance(
image_inputs["raw_mm_data"], PIL.Image.Image)
image_inputs = image_inputs.pop("raw_mm_data")

return self.hash_images(image_inputs)

def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]:
def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]:
"""Hash multimodal data in the user input prompt if they exist."""

if "multi_modal_data" not in prompt:
return None

Expand All @@ -171,6 +185,7 @@ def hash_prompt(self, prompt: PromptType) -> Optional[List[str]]:
return self.hash_images(image_inputs)

def hash_images(self, image_inputs) -> Optional[List[str]]:
"""Hash PIL image objects to strings."""
if not isinstance(image_inputs, list):
image_inputs = [image_inputs]
assert len(image_inputs) > 0
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def process_inputs(
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_prompt(prompt)
mm_hashes = self.mm_hasher.hash_prompt_mm_data(prompt)

# Process inputs.
preprocessed_inputs = self.input_preprocessor.preprocess(
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def profile_run(self) -> None:
# Compute MM hashes (if enabled)
mm_hashes = None
if self.use_hash:
mm_hashes = self.mm_hasher.hash_mm_data(dummy_mm_data)
mm_hashes = self.mm_hasher.hash_dummy_mm_data(dummy_mm_data)

dummy_mm_kwargs = self.mm_input_mapper_client.process_inputs(
mm_data=dummy_mm_data,
Expand Down

0 comments on commit 7379b3d

Please sign in to comment.