From 353607a82f3db298feea76c60b2a8ba47324f52e Mon Sep 17 00:00:00 2001 From: jzhang38 Date: Mon, 15 Apr 2024 12:54:03 +0800 Subject: [PATCH] fix bs>1 bug --- lmms_eval/models/llava_sglang.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/lmms_eval/models/llava_sglang.py b/lmms_eval/models/llava_sglang.py index df53a02d..d22f18be 100644 --- a/lmms_eval/models/llava_sglang.py +++ b/lmms_eval/models/llava_sglang.py @@ -96,10 +96,8 @@ def _collate(x): num_iters = len(requests) // self.parallel if len(requests) % self.parallel == 0 else len(requests) // self.parallel + 1 pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding") for chunk in chunks: - contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) - task = task[0] - split = split[0] - batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N] + contexts, all_gen_kwargs, doc_to_visuals, doc_id, tasks, splits = zip(*chunk) + batched_visuals = [doc_to_visual(self.task_dict[task][split][ids]) for ids,task,split,doc_to_visual in zip(doc_id, tasks, splits, doc_to_visuals)] # [B, N] # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. gen_kwargs = all_gen_kwargs[0]