Skip to content

Commit

Permalink
fix bs>1 bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang38 committed Apr 15, 2024
1 parent ae99553 commit 353607a
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions lmms_eval/models/llava_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,8 @@ def _collate(x):
num_iters = len(requests) // self.parallel if len(requests) % self.parallel == 0 else len(requests) // self.parallel + 1
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
for chunk in chunks:
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N]
contexts, all_gen_kwargs, doc_to_visuals, doc_id, tasks, splits = zip(*chunk)
batched_visuals = [doc_to_visual(self.task_dict[task][split][ids]) for ids,task,split,doc_to_visual in zip(doc_id, tasks, splits, doc_to_visuals)] # [B, N]
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
Expand Down

0 comments on commit 353607a

Please sign in to comment.