Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
pufanyi committed Mar 31, 2024
1 parent e4cb4e6 commit 0e26684
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lmms_eval/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def _collate(x):
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
task = task[0]
split = split[0]
batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N]
flattened_visuals = self.flatten(batched_visuals) # [B*N]
batched_visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id] # [B, N]
flattened_visuals = self.flatten(batched_visuals) # [B*N]
############### for debugging ###################
# TODO: remove this block
# if len(visuals) > 1:
Expand Down Expand Up @@ -349,7 +349,7 @@ def _collate(x):

# input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
# preconfigure gen_kwargs with defaults
gen_kwargs["image_sizes"] = [flattened_visuals[idx].size for idx in range(len(flattened_visuals))
gen_kwargs["image_sizes"] = [flattened_visuals[idx].size for idx in range(len(flattened_visuals))]
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
Expand Down

0 comments on commit 0e26684

Please sign in to comment.