Skip to content

Commit

Permalink
Changes in llava_hf.py. Corrected the response split by role and adde…
Browse files Browse the repository at this point in the history
…d the ability to specify an EOS token (#153)

* get model ans w/o split by ASSISTANT

* specified_eot_token_id for gen problems
  • Loading branch information
Dannoopsy authored Jul 20, 2024
1 parent 1703b02 commit 44f3015
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions lmms_eval/models/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
device_map: str = "",
chat_template: Optional[str] = None,
use_cache: bool = True,
specified_eot_token_id: Optional[int] = None,
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
self.batch_size_per_gpu = int(batch_size)
self.chat_template = chat_template
self.use_cache = use_cache
self.specified_eot_token_id = specified_eot_token_id
if accelerator.num_processes > 1 and device_map == "":
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
Expand Down Expand Up @@ -316,18 +318,13 @@ def _collate(x):
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.specified_eot_token_id,
)
cont = cont[:, inputs["input_ids"].shape[-1]:]
except Exception as e:
eval_logger.error(f"Error {e} in generating")
cont = ""
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
if "1.5" in self.pretrained:
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()
elif "mistral" in self.pretrained:
text_outputs = text_outputs.split("[/INST]")[-1].strip()
else:
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()

if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.debug(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")

Expand Down

0 comments on commit 44f3015

Please sign in to comment.