Skip to content

Commit

Permalink
fix visuals
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Mar 31, 2024
1 parent 826e5fe commit e4cb4e6
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ def _collate(x):
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
visuals = self.flatten(visuals)
batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N]
flattened_visuals = self.flatten(batched_visuals) # [B*N]
############### for debugging ###################
# TODO: remove this block
if len(visuals) > 1:
for i in range(len(visuals)):
path = f"./logs/llava/{i}.png"
visuals[i].save(path)
pass
# if len(visuals) > 1:
# for i in range(len(visuals)):
# path = f"./logs/llava/{i}.png"
# visuals[i].save(path)
# pass
#################################################
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
Expand All @@ -300,8 +300,8 @@ def _collate(x):
self._config.image_aspect_ratio = gen_kwargs.pop("image_aspect_ratio")
eval_logger.info(f"Setting image aspect ratio: {self._config.image_aspect_ratio}")
# encode, pad, and truncate contexts for this batch
if visuals:
image_tensor = process_images(visuals, self._image_processor, self._config)
if flattened_visuals:
image_tensor = process_images(flattened_visuals, self._image_processor, self._config)
if type(image_tensor) is list:
image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]
else:
Expand All @@ -313,7 +313,7 @@ def _collate(x):

question_input = []

for visual, context in zip(visuals, contexts):
for visual, context in zip(batched_visuals, contexts):
if image_tensor is not None and len(image_tensor) != 0 and DEFAULT_IMAGE_TOKEN not in context:
"""
Three senarios:
Expand All @@ -323,6 +323,8 @@ def _collate(x):
"""
image_tokens = [DEFAULT_IMAGE_TOKEN] * len(visual) if isinstance(visual, list) else [DEFAULT_IMAGE_TOKEN]
image_tokens = " ".join(image_tokens)
if isinstance(context, list):
context = "".join(context)
question = image_tokens + "\n" + context
else:
question = context
Expand All @@ -336,7 +338,7 @@ def _collate(x):
# The above for loop has bugs. When there is no visuals, e.g. pure text,
# there will be no for loop execute resulting in an empty question_input (because no visuals)
# Scenario 1 won't even be execute
if len(visuals) == 0:
if len(flattened_visuals) == 0:
for context in contexts:
question = context
conv = conv_templates[self.conv_template].copy()
Expand All @@ -347,7 +349,7 @@ def _collate(x):

# input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
# preconfigure gen_kwargs with defaults
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
gen_kwargs["image_sizes"] = [flattened_visuals[idx].size for idx in range(len(flattened_visuals))
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
Expand Down

0 comments on commit e4cb4e6

Please sign in to comment.