diff --git a/lmms_eval/models/srt_api.py b/lmms_eval/models/srt_api.py index 82320fcf..9752bd96 100755 --- a/lmms_eval/models/srt_api.py +++ b/lmms_eval/models/srt_api.py @@ -12,7 +12,7 @@ from accelerate import Accelerator, DistributedType from decord import VideoReader, cpu from loguru import logger as eval_logger -from openai import AsyncOpenAI +from openai import AsyncOpenAI, OpenAI from PIL import Image from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( @@ -93,7 +93,10 @@ def __init__( other_args=other_args, ) self.base_url += "/v1" - self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) + if self.modality == "video": + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) + else: + self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) self.num_processes = num_processes # assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue." if accelerator.num_processes > 1: @@ -120,28 +123,14 @@ def encode_image(self, image: Image): # Function to encode the video def encode_video(self, video_path, for_get_frames_num): - try: - vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) - total_frame_num = len(vr) - uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int) - frame_idx = uniform_sampled_frames.tolist() - frames = vr.get_batch(frame_idx).asnumpy() - except: - import av - - container = av.open(video_path) - - frames = [] - # https://github.com/PyAV-Org/PyAV/issues/1269 - # https://www.cnblogs.com/beyond-tester/p/17641872.html - # context = CodecContext.create("libvpx-vp9", "r") - for packet in container.demux(video=0): - for frame in packet.decode(): - frames.append(frame) - total_frames = len(frames) - sampled_frm = min(total_frames, for_get_frames_num) - indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int) - frames = [frames[i] for i in indices] + if type(video_path) == str: + vr = VideoReader(video_path, ctx=cpu(0)) + else: + vr = VideoReader(video_path[0], ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() base64_frames = [] for frame in frames: @@ -167,7 +156,7 @@ async def generate(self, request): visuals = self.flatten(visuals) imgs = [] # multiple images or frames for video for visual in visuals: - if self.modality == "image": + if self.modality == "image" or self.modality == "multi-images": img = self.encode_image(visual) imgs.append(img) elif self.modality == "video": @@ -190,7 +179,7 @@ async def generate(self, request): # put the images in the first place content = [] for img in imgs: - content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}) + content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}, "modalities": self.modality}) content.append({"type": "text", "text": contexts}) messages.append({"role": "user", "content": content}) @@ -217,6 +206,62 @@ async def generate(self, request): return response_text + def generate_sync(self, request): + contexts, gen_kwargs, doc_to_visual, doc_id, task, split = request.args + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] + visuals = self.flatten(visuals) + imgs = [] # multiple images or frames for video + for visual in visuals: + if self.modality == "image" or self.modality == "multi-images": + img = self.encode_image(visual) + imgs.append(img) + elif self.modality == "video": + try: + frames = self.encode_video(visual, self.max_frames_num) + imgs.extend(frames) + except Exception as e: + eval_logger.error(f"Exception : {e} \n When loading video {visual}") + imgs = None + break + + # Handling video decode error + # If we can't even load using pyav, then we will skip + if imgs is None: + resps = "" + return resps + + messages = [] + + # put the images in the first place + content = [] + for img in imgs: + content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}, "modalities": self.modality}) + + content.append({"type": "text", "text": contexts}) + messages.append({"role": "user", "content": content}) + + if "max_new_tokens" not in gen_kwargs: + gen_kwargs["max_new_tokens"] = 1024 + + if "temperature" not in gen_kwargs: + gen_kwargs["temperature"] = 0 + + for attempt in range(5): + try: + response = self.client.chat.completions.create(model=self.model_version, messages=messages, temperature=gen_kwargs["temperature"], max_tokens=gen_kwargs["max_new_tokens"], timeout=self.timeout) + response_text = response.choices[0].message.content.strip() + break # If successful, break out of the loop + + except Exception as e: + eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.") + if attempt < 4: + time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty string + eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.") + response_text = "" + + return response_text + def generate_until(self, requests) -> List[str]: res = [] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") @@ -234,7 +279,13 @@ async def _process(request): res.append(result) pbar.update(1) - asyncio.run(run(requests)) + if self.modality == "video": + for req in requests: + response = self.generate_sync(req) + res.append(response) + pbar.update(1) + else: + asyncio.run(run(requests)) kill_child_process(self.process.pid) return res