Skip to content

Commit

Permalink
Add chat template
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun committed Apr 5, 2024
1 parent a6e258b commit 86732eb
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions lmms_eval/models/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class LlavaHf(lmms):
def __init__(
self,
pretrained: str = "llava-hf/llava-1.5-7b-hf",
revision=None,
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = torch.float16,
batch_size: Optional[Union[int, str]] = 1,
Expand All @@ -40,8 +41,8 @@ def __init__(
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
else:
self._device = device
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, torch_dtype=dtype)
self._image_processor = AutoProcessor.from_pretrained(pretrained)
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype)
self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision)
self._tokenizer = self._image_processor.tokenizer
self._config = self._model.config
self._model.eval()
Expand Down Expand Up @@ -180,8 +181,20 @@ def _collate(x):
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {type(until)}")
assert self.batch_size_per_gpu == 1, "Do not support batch_size_per_gpu > 1 for now"
context = contexts[0]
context = f"USER: <image>\n{context}\nASSISTANT:"
inputs = self._image_processor(images=visuals, text=context, return_tensors="pt").to(self._device, torch.float16)

# Some benchmarks like MME do not contain image tokens, so we prepend them to the prompt.
if "<image>" not in context:
context = f"<image>\n{context}"
if self.tokenizer.chat_template is not None:
messages = [{"role": "user", "content": context}]
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
text = f"USER: {context}\nASSISTANT:"

if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.info(f"Prompt for doc ID {doc_id[0]}:\n\n{text}\n")

inputs = self._image_processor(images=visuals, text=text, return_tensors="pt").to(self._device, torch.float16)

gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_new_tokens" not in gen_kwargs:
Expand All @@ -204,14 +217,18 @@ def _collate(x):
except Exception as e:
eval_logger.error(f"Error {e} in generating")
cont = ""
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]#.strip()
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
text_outputs = text_outputs.split("ASSISTANT:")[1].strip()

if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.info(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")

res.append(text_outputs)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), text_outputs)
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
res = re_ords.get_foriginal(res)
res = re_ords.get_original(res)

pbar.close()
return res

0 comments on commit 86732eb

Please sign in to comment.