From 0e26684691703baff229609019d3044dc892ef65 Mon Sep 17 00:00:00 2001 From: Fanyi Pu Date: Sun, 31 Mar 2024 14:56:32 +0800 Subject: [PATCH] lint --- lmms_eval/models/llava.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 3bb0b30b..3a9c8407 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -270,8 +270,8 @@ def _collate(x): contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) task = task[0] split = split[0] - batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N] - flattened_visuals = self.flatten(batched_visuals) # [B*N] + batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N] + flattened_visuals = self.flatten(batched_visuals) # [B*N] ############### for debugging ################### # TODO: remove this block # if len(visuals) > 1: @@ -349,7 +349,7 @@ def _collate(x): # input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) # preconfigure gen_kwargs with defaults - gen_kwargs["image_sizes"] = [flattened_visuals[idx].size for idx in range(len(flattened_visuals)) + gen_kwargs["image_sizes"] = [flattened_visuals[idx].size for idx in range(len(flattened_visuals))] if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 if "temperature" not in gen_kwargs: