Skip to content

Commit

Permalink
[Model] Add models (EvolvingLMMs-Lab#47)
Browse files Browse the repository at this point in the history
* Refactor logging and model initialization

* Fix wandb_logger.online() method call

* Add error handling during evaluation

* Add wait time and error handling in get_chat_response function

* Update wait_time in get_chat_response function

* Refactor code for improved readability and maintainability

* Refactor doc_to_visual function to handle multiple images in ICON-QA tasks

* Refactor logging_utils.py and utils.py

This commit refactors the `logging_utils.py` and `utils.py` files. It removes unused imports, adjusts code formatting, and updates the `get_chat_response` function to increase the `wait_time` parameter from 5 to 10.

* Refactor code for wandb logging and generation in OtterHD class

* Refactor prepare_report_by_task method in logging_utils.py

* Update generation parameters in OtterHD model

* Update generation parameters in OtterHD model

* Squashed commit of the following:

commit 21dea7b
Author: kcz358 <[email protected]>
Date:   Tue Feb 13 18:50:37 2024 +0800

    Fix seedbench choices bugs (EvolvingLMMs-Lab#45)

commit 12144a6
Author: XinrunDu <[email protected]>
Date:   Tue Feb 13 18:50:23 2024 +0800

    add stvqa and multidocvqa (EvolvingLMMs-Lab#46)

commit aca1e6d
Author: XinrunDu <[email protected]>
Date:   Sun Feb 11 00:54:39 2024 +0800

    add cmmmu (EvolvingLMMs-Lab#44)

    Co-authored-by: ygjin11 <[email protected]>

commit 0925443
Author: kcz358 <[email protected]>
Date:   Sun Feb 11 00:54:23 2024 +0800

    [Feat] Add qwen loglikelihood (EvolvingLMMs-Lab#43)

    * Add qwen loglikelihood

    * Revise the pyproject dependency. Move tiktoken out from optional-dependencies

    * Add ferret-bench

    * Add seedbench 2, test on llava

commit 16f1cf2
Author: JvThunder <[email protected]>
Date:   Wed Feb 7 00:08:22 2024 +0800

    Joshua/vizwizvqa refactor (EvolvingLMMs-Lab#42)

    * refactor vizwizvqa task

    * Merge commit '9bbbad51a77051fcf676438f81e81f723c1b438b'

    * Fix exact_match accuracy calculation in vizwiz_vqa_process_results

    * Update vizwiz_vqa tasks

    ---------

    Co-authored-by: Fanyi Pu <[email protected]>
  • Loading branch information
Luodian authored Feb 13, 2024
1 parent 21dea7b commit fdcc63f
Show file tree
Hide file tree
Showing 14 changed files with 320 additions and 262 deletions.
28 changes: 16 additions & 12 deletions lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,21 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
is_main_process = False

for args in args_list:
if is_main_process:
wandb_logger = WandbLogger(args)
results = cli_evaluate_single(args)

accelerator.wait_for_everyone()
if is_main_process:
wandb_logger.log_eval_result(results)
if wandb_logger.online():
wandb_logger.write_to_report(results)
wandb_logger.finish()
results_list.append(results)
try:
if is_main_process and args.wandb_args:
wandb_logger = WandbLogger(args)
results = cli_evaluate_single(args)

accelerator.wait_for_everyone()
if is_main_process and args.wandb_args and results is not None:
wandb_logger.log_eval_result(results)
if wandb_logger.online():
wandb_logger.write_to_report()
wandb_logger.finish()
results_list.append(results)
except Exception as e:
eval_logger.error(f"Error during evaluation: {e}")
results_list.append(None)

for args, results in zip(args_list, results_list):
# cli_evaluate will return none if the process is not the main process (rank 0)
Expand All @@ -209,7 +213,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
elif args.tasks == "list":
eval_logger.info("Available Tasks:\n - {}".format(f"\n - ".join(sorted(ALL_TASKS))))
sys.exit()
elif args.tasks == "list_tasks_num":
elif args.tasks == "list_with_num":
log_message = (
"\n" + "=" * 70 + "\n" + "\n\tYou are trying to check all the numbers in each task." + "\n\tThis action will download the complete dataset." + "\n\tIf the results are not clear initially, call this again." + "\n\n" + "=" * 70
)
Expand Down
66 changes: 26 additions & 40 deletions lmms_eval/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ def init_run(self):
# initialize a W&B run
self.run = wandb.init(**self.wandb_args)

# call wr inside the init_run method to avoid multiple times logging
import wandb.apis.reports as wr

self.wr = wr

def log_eval_result(self, results):
# Log configs to wandb
configs = self.get_config(results)
Expand Down Expand Up @@ -217,80 +212,71 @@ def sanitize_results_dict(self):
return wandb_summary, _results

def prepare_report_by_task(self, results):
task_names = list(results.get("results", {}).keys())
import wandb.apis.reports as wr

task_names = list(self.results.get("results", {}).keys())
blocks = []
for task_name in task_names:
blocks.append(self.wr.H2(task_name))
blocks.append(wr.H2(task_name))
panels = []
for metric_name, metric_value in results.items():
if task_name in metric_name:
panels.append(
self.wr.ScalarChart(
wr.ScalarChart(
title=f"{metric_name}",
metric=f"{metric_name}",
font_size="large",
)
)
_results = {
"results": {f"{task_name}": results.get("results").get(task_name)},
"versions": {f"{task_name}": results.get("versions").get(task_name)},
"n-shot": {f"{task_name}": results.get("n-shot").get(task_name)},
"results": {f"{task_name}": self.results.get("results").get(task_name)},
"versions": {f"{task_name}": self.results.get("versions").get(task_name)},
"n-shot": {f"{task_name}": self.results.get("n-shot").get(task_name)},
}
results_md = utils.make_table(_results)
blocks.extend([self.wr.MarkdownBlock(results_md), self.wr.PanelGrid(panels=panels)])
# blocks.extend([
# self.wr.WeaveBlockSummaryTable(
# project=self.run.project,
# entity=self.run.entity,
# table_name=f"{task_name}_eval_results",
# ),
# self.wr.PanelGrid(
# runsets=[
# self.wr.Runset(
# project=self.run.project, entity=self.run.entity,
# ).set_filters_with_python_expr(f'Name == "{str(self.run.name)}"'),
# ]
# ),
# ])
blocks.extend([wr.MarkdownBlock(results_md), wr.PanelGrid(panels=panels)])
# TODO: Add results table

return blocks

def write_to_report(self, results):
wandb_project = self.run.project
wandb_entity = self.run.entity
report = self.wr.Report(
def write_to_report(self):
import wandb.apis.reports as wr

report = wr.Report(
project=self.run.project,
entity=self.run.entity,
title=f"({datetime.now().strftime('%Y-%m-%d %H:%M:%S')}) {self.run.id} - Evaluation report",
description=f"Evaluation run by: {self.run.entity} logged to {self.run.url}",
)

results_md = utils.make_table(results)
task_blocks = self.prepare_report_by_task(self.wandb_results)

blocks = (
[
self.wr.TableOfContents(),
self.wr.H1("Complete Evaluation Results"),
self.wr.WeaveBlockSummaryTable(
wr.TableOfContents(),
wr.H1("Complete Evaluation Results"),
wr.WeaveBlockSummaryTable(
project=self.run.project,
entity=self.run.entity,
table_name=f"evaluation/eval_results",
table_name="evaluation/eval_results",
),
self.wr.PanelGrid(
wr.PanelGrid(
runsets=[
self.wr.Runset(
wr.Runset(
project=self.run.project,
entity=self.run.entity,
).set_filters_with_python_expr(f'Name == "{str(self.run.name)}"'),
]
),
self.wr.H1("Evaluation Results By Task"),
wr.H1("Evaluation Results By Task"),
]
+ task_blocks
+ [
self.wr.H1("Evaluation Config"),
self.wr.CodeBlock(json.dumps(self.results["config"], indent=5).split("\n"), language="json"),
wr.H1("Evaluation Config"),
wr.CodeBlock(
json.dumps(self.results["config"], indent=5).split("\n"),
language="json",
),
# TODO: Add appendix
]
)
Expand Down
90 changes: 19 additions & 71 deletions lmms_eval/models/model_utils/qwen/qwen_generate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def get_ltor_masks_and_position_ids(
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(att_mask_batch, 1, seq_length, seq_length)

# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
Expand All @@ -66,7 +64,6 @@ def get_ltor_masks_and_position_ids(
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):

# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
Expand Down Expand Up @@ -134,9 +131,7 @@ def make_context(
nl_tokens = tokenizer.encode("\n")

def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(
role, allowed_special=set(tokenizer.IMAGE_ST)
) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set(tokenizer.IMAGE_ST)) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))

system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
Expand All @@ -148,22 +143,16 @@ def _tokenize_str(role, content):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
if turn_response is not None:
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response
)
response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens

next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
else:
next_context_tokens = nl_tokens + query_tokens + nl_tokens
prev_chat = f"\n{im_start}{query_text}{im_end}\n"

current_context_size = (
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
Expand All @@ -172,16 +161,7 @@ def _tokenize_str(role, content):

context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
)
context_tokens += nl_tokens + im_start_tokens + _tokenize_str("user", query)[1] + im_end_tokens + nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"

elif chat_format == "raw":
Expand All @@ -202,7 +182,7 @@ def _decode_default(
raw_text_len: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace',
errors: str = "replace",
):
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
if verbose:
Expand All @@ -227,16 +207,7 @@ def _decode_default(


def _decode_chatml(
tokens: List[int],
*,
stop_words: List[str],
eod_token_ids: List[int],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace'
tokens: List[int], *, stop_words: List[str], eod_token_ids: List[int], tokenizer: PreTrainedTokenizer, raw_text_len: int, context_length: int, verbose: bool = False, return_end_reason: bool = False, errors: str = "replace"
):
end_reason = f"Gen length {len(tokens)}"
eod_token_idx = context_length
Expand Down Expand Up @@ -270,7 +241,7 @@ def decode_tokens(
chat_format: str,
verbose: bool = False,
return_end_reason: bool = False,
errors: str="replace",
errors: str = "replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
Expand Down Expand Up @@ -315,42 +286,19 @@ class StopWordsLogitsProcessor(LogitsProcessor):
"""

def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):

if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
raise ValueError(
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
)
raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
raise ValueError(
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
)
if any(
any(
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
for token_id in stop_word_ids
)
for stop_word_ids in stop_words_ids
):
raise ValueError(
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
)

self.stop_words_ids = list(
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
)
)
raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.")
if any(any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids) for stop_word_ids in stop_words_ids):
raise ValueError(f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}.")

self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids))
self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids:
assert (
len(stop_token_seq) > 0
), "Stop words token sequences {} cannot have an empty list".format(
stop_words_ids
)

def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format(stop_words_ids)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples):
if should_stop:
Expand Down Expand Up @@ -416,4 +364,4 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):

def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
return (1 - boolean) * val1 + boolean * val2
4 changes: 3 additions & 1 deletion lmms_eval/models/otterhd.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def _collate(x):
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
generation_output = self.model.generate(**model_inputs, max_new_tokens=gen_kwargs["max_new_tokens"], pad_token_id=self.tokenizer.eos_token_id)
generation_output = self.model.generate(
**model_inputs, temperature=gen_kwargs["temperature"], max_new_tokens=gen_kwargs["max_new_tokens"], top_p=gen_kwargs["top_p"], num_beams=gen_kwargs["num_beams"], pad_token_id=self.tokenizer.eos_token_id
)
generation_texts = self.processor.batch_decode(generation_output, skip_special_tokens=True)
response = [gen_text.split("\x04")[1].strip(" ").strip("\n") for gen_text in generation_texts]
res.extend(response)
Expand Down
Loading

0 comments on commit fdcc63f

Please sign in to comment.