diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index b93be935..3bb0b30b 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -270,15 +270,15 @@ def _collate(x): contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk) task = task[0] split = split[0] - visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] - visuals = self.flatten(visuals) + batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N] + flattened_visuals = self.flatten(batched_visuals) # [B*N] ############### for debugging ################### # TODO: remove this block - if len(visuals) > 1: - for i in range(len(visuals)): - path = f"./logs/llava/{i}.png" - visuals[i].save(path) - pass + # if len(visuals) > 1: + # for i in range(len(visuals)): + # path = f"./logs/llava/{i}.png" + # visuals[i].save(path) + # pass ################################################# # we assume all gen kwargs in the batch are the same # this is safe to assume because the `grouper` object ensures it. @@ -300,8 +300,8 @@ def _collate(x): self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio") eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}") # encode, pad, and truncate contexts for this batch - if visuals: - image_tensor = process_images(visuals, self._image_processor, self._config) + if flattened_visuals: + image_tensor = process_images(flattened_visuals, self._image_processor, self._config) if type(image_tensor) is list: image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] else: @@ -313,7 +313,7 @@ def _collate(x): question_input = [] - for visual, context in zip(visuals, contexts): + for visual, context in zip(batched_visuals, contexts): if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context: """ Three senarios: @@ -323,6 +323,8 @@ def _collate(x): """ image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN] image_tokens = " ".join(image_tokens) + if isinstance(context, list): + context = "".join(context) question = image_tokens + "\n" + context else: question = context @@ -336,7 +338,7 @@ def _collate(x): # The above for loop has bugs. When there is no visuals, e.g. pure text, # there will be no for loop execute resulting in an empty question_input (because no visuals) # Scenario 1 won't even be execute - if len(visuals) == 0: + if len(flattened_visuals) == 0: for context in contexts: question = context conv = conv_templates[self.conv_template].copy() @@ -347,7 +349,7 @@ def _collate(x): # input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) # preconfigure gen_kwargs with defaults - gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] + gen_kwargs["image_sizes"] = [flattened_visuals[idx].size for idx in range(len(flattened_visuals)) if "max_new_tokens" not in gen_kwargs: gen_kwargs["max_new_tokens"] = 1024 if "temperature" not in gen_kwargs: