Skip to content

Commit

Permalink
Enhance LlamaVision model with video loading improvements and configu…
Browse files Browse the repository at this point in the history
…ration updates
  • Loading branch information
pufanyi committed Dec 1, 2024
1 parent 5e23413 commit 4fddfff
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 22 deletions.
31 changes: 11 additions & 20 deletions lmms_eval/models/llama_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model
from lmms_eval.models.model_utils.load_video import read_video_pyav_pil

warnings.filterwarnings("ignore")

Expand All @@ -25,33 +26,19 @@

@register_model("llama_vision")
class LlamaVision(lmms):
"""
Llava Model for Hugging Face Transformers: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava
Adapted from the InstructBLIP model in lmms_eval/models/instructblip.py
Example usage:
accelerate launch --num_processes=8 --main_process_port 12345 -m lmms_eval \
--model llava_hf \
--model_args pretrained=llava-hf/llava-1.5-7b-hf \
--tasks seedbench \
--batch_size 1 \
--output_path ./logs/ \
--log_samples
"""

def __init__(
self,
pretrained: str = "meta-llama/Llama-3.2-11B-Vision",
revision: str = "main",
device: str = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: int = 1,
trust_remote_code: Optional[bool] = False,
trust_remote_code: Optional[bool] = True,
attn_implementation: Optional[str] = None,
device_map: str = "",
max_frames_num: Optional[int] = 32,
fps: Optional[int] = None,
max_image_size: Optional[int] = None,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -68,7 +55,9 @@ def __init__(
if isinstance(dtype, str) and dtype != "auto":
dtype = getattr(torch, dtype)

self.fps = fps
self.max_frames_num = max_frames_num
self.max_image_size = max_image_size
self._model = MllamaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)
self.model.eval()
self.processor = AutoProcessor.from_pretrained(pretrained)
Expand Down Expand Up @@ -193,9 +182,11 @@ def generate_until(self, requests: List[Instance]) -> List[str]:

for visual in visuals:
if isinstance(visual, str):
frames = self.load_video(visual, self.max_frames_num)
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
images.extend([to_pil_image(frame) for frame in frames])
frames = read_video_pyav_pil(visual, num_frm=self.max_frames_num, fps=self.fps, max_image_size=self.max_image_size)
images.extend(frames)
# frames = self.load_video(visual, self.max_frames_num)
# frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
# images.extend([to_pil_image(frame) for frame in frames])
elif isinstance(visual, PIL.Image.Image):
images.append(visual)

Expand Down
5 changes: 3 additions & 2 deletions lmms_eval/tasks/mix_evals/video2text/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,9 @@ def apply(self, resps, docs):
# response.raise_for_status()

# content =["choices"][0]["message"]["content"].strip()
content = response.choices[0].message.content.strip()
if content != "":
content = response.choices[0].message.content
if content:
content = content.strip()
match = re.search(r"r'\b([A-Z])\.?\b'", content)
if match:
result = ord(match.group(1)) - ord("A")
Expand Down

0 comments on commit 4fddfff

Please sign in to comment.