Skip to content

Commit

Permalink
Merge pull request #246 from EvolvingLMMs-Lab/fix/srt_videos
Browse files Browse the repository at this point in the history
[Fix] Strict video to be single processing
  • Loading branch information
Luodian authored Sep 13, 2024
2 parents 07eee00 + da2f73c commit bfff5ab
Showing 1 changed file with 78 additions and 27 deletions.
105 changes: 78 additions & 27 deletions lmms_eval/models/srt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from accelerate import Accelerator, DistributedType
from decord import VideoReader, cpu
from loguru import logger as eval_logger
from openai import AsyncOpenAI
from openai import AsyncOpenAI, OpenAI
from PIL import Image
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
Expand Down Expand Up @@ -93,7 +93,10 @@ def __init__(
other_args=other_args,
)
self.base_url += "/v1"
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
if self.modality == "video":
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
else:
self.client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
self.num_processes = num_processes
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
if accelerator.num_processes > 1:
Expand All @@ -120,28 +123,14 @@ def encode_image(self, image: Image):

# Function to encode the video
def encode_video(self, video_path, for_get_frames_num):
try:
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()
except:
import av

container = av.open(video_path)

frames = []
# https://github.com/PyAV-Org/PyAV/issues/1269
# https://www.cnblogs.com/beyond-tester/p/17641872.html
# context = CodecContext.create("libvpx-vp9", "r")
for packet in container.demux(video=0):
for frame in packet.decode():
frames.append(frame)
total_frames = len(frames)
sampled_frm = min(total_frames, for_get_frames_num)
indices = np.linspace(0, total_frames - 1, sampled_frm, dtype=int)
frames = [frames[i] for i in indices]
if type(video_path) == str:
vr = VideoReader(video_path, ctx=cpu(0))
else:
vr = VideoReader(video_path[0], ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, for_get_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()

base64_frames = []
for frame in frames:
Expand All @@ -167,7 +156,7 @@ async def generate(self, request):
visuals = self.flatten(visuals)
imgs = [] # multiple images or frames for video
for visual in visuals:
if self.modality == "image":
if self.modality == "image" or self.modality == "multi-images":
img = self.encode_image(visual)
imgs.append(img)
elif self.modality == "video":
Expand All @@ -190,7 +179,7 @@ async def generate(self, request):
# put the images in the first place
content = []
for img in imgs:
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}, "modalities": self.modality})

content.append({"type": "text", "text": contexts})
messages.append({"role": "user", "content": content})
Expand All @@ -217,6 +206,62 @@ async def generate(self, request):

return response_text

def generate_sync(self, request):
contexts, gen_kwargs, doc_to_visual, doc_id, task, split = request.args
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
visuals = self.flatten(visuals)
imgs = [] # multiple images or frames for video
for visual in visuals:
if self.modality == "image" or self.modality == "multi-images":
img = self.encode_image(visual)
imgs.append(img)
elif self.modality == "video":
try:
frames = self.encode_video(visual, self.max_frames_num)
imgs.extend(frames)
except Exception as e:
eval_logger.error(f"Exception : {e} \n When loading video {visual}")
imgs = None
break

# Handling video decode error
# If we can't even load using pyav, then we will skip
if imgs is None:
resps = ""
return resps

messages = []

# put the images in the first place
content = []
for img in imgs:
content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}, "modalities": self.modality})

content.append({"type": "text", "text": contexts})
messages.append({"role": "user", "content": content})

if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024

if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0

for attempt in range(5):
try:
response = self.client.chat.completions.create(model=self.model_version, messages=messages, temperature=gen_kwargs["temperature"], max_tokens=gen_kwargs["max_new_tokens"], timeout=self.timeout)
response_text = response.choices[0].message.content.strip()
break # If successful, break out of the loop

except Exception as e:
eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.")
if attempt < 4:
time.sleep(NUM_SECONDS_TO_SLEEP)
else: # If this was the last attempt, log and return empty string
eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.")
response_text = ""

return response_text

def generate_until(self, requests) -> List[str]:
res = []
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
Expand All @@ -234,7 +279,13 @@ async def _process(request):
res.append(result)
pbar.update(1)

asyncio.run(run(requests))
if self.modality == "video":
for req in requests:
response = self.generate_sync(req)
res.append(response)
pbar.update(1)
else:
asyncio.run(run(requests))
kill_child_process(self.process.pid)

return res
Expand Down

0 comments on commit bfff5ab

Please sign in to comment.