diff --git a/.gitignore b/.gitignore index 7ec216e0..950d4183 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ temp # IPython profile_default/ ipython_config.py -logs/ \ No newline at end of file +logs/ +scripts/ \ No newline at end of file diff --git a/lmms_eval/api/instance.py b/lmms_eval/api/instance.py index 13bab803..7324dae4 100644 --- a/lmms_eval/api/instance.py +++ b/lmms_eval/api/instance.py @@ -5,7 +5,6 @@ @dataclass class Instance: request_type: Literal["loglikelihood", "loglikelihood_rolling", "generate_until"] - doc: dict arguments: tuple idx: int metadata: Tuple[str, int, int] = field(default_factory=lambda: (None, None, None)) # TODO: better typehints here @@ -16,6 +15,7 @@ class Instance: task_name: str = None doc_id: str = None repeats: str = None + doc: dict = None def __post_init__(self) -> None: # unpack metadata field diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index a9a3ddc0..b38c36d3 100644 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -1,11 +1,13 @@ import abc from dataclasses import dataclass, field, asdict +import itertools import os import re import ast import logging import random +from tqdm import tqdm import datasets import numpy as np @@ -338,38 +340,38 @@ def build_all_requests(self, limit=None, rank=None, world_size=None) -> None: eval_logger.info(f"Building contexts for task on rank {rank}...") instances = [] - for doc_id, doc in utils.create_iterator(enumerate(docs), rank, world_size, limit): + doc_id_iterator = utils.create_iterator([i for i in range(len(docs))], rank, world_size, limit) + doc_id_iterator, doc_id_iterator_counting = itertools.tee(doc_id_iterator) + total_docs = sum(1 for _ in doc_id_iterator_counting) + pbar = tqdm(total=total_docs, desc="Building context") + for doc_id in doc_id_iterator: # sample fewshot context #TODO: need to offset doc_id by rank now! - fewshot_ctx = self.fewshot_context( - doc, - 0 if self.config.num_fewshot is None else self.config.num_fewshot, - ) + fewshot_ctx = self.fewshot_context(doc_id, 0 if self.config.num_fewshot is None else self.config.num_fewshot, self.config.training_split if self.has_training_docs() else split) # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute - inst = self.construct_requests(doc=doc, ctx=fewshot_ctx, metadata=(self.config["task"], doc_id, self.config.repeats), split=split) + inst = self.construct_requests(doc_id=doc_id, ctx=fewshot_ctx, metadata=(self.config["task"], doc_id, self.config.repeats), split=split) if not isinstance(inst, list): inst = [inst] instances.extend(inst) + pbar.update(1) self._instances = instances assert len(self._instances) != 0, "task.build_requests() did not find any docs!" @abc.abstractmethod - def construct_requests(self, doc, ctx, **kwargs): + def construct_requests(self, doc_id, ctx, **kwargs): """Uses RequestFactory to construct Requests and returns an iterable of Requests which will be sent to the LMM. - :param doc: - The document as returned from training_docs, validation_docs, or test_docs. + :param doc_id: int + The index of a document within `self.test_docs()` or `self.validation_docs()`, + whichever is the main split used. :param ctx: str The context string, generated by fewshot_context. This includes the natural language description, as well as the few shot examples, and the question part of the document for `doc`. - :param doc_idx: int - The index of a document within `self.test_docs()` or `self.validation_docs()`, - whichever is the main split used. :param repeats: int TODO: update this docstring The number of times each instance in a dataset is inferred on. Defaults to 1, @@ -421,18 +423,21 @@ def count_words(cls, doc): @utils.positional_deprecated def fewshot_context( self, - doc, + doc_id, num_fewshot, + split, rnd=random.Random(1234), description=None, ): """Returns a fewshot context string that is made up of a prepended description (if provided), the `num_fewshot` number of examples, and an appended prompt example. - :param doc: str - The document as returned from training_docs, validation_docs, or test_docs. + :param doc_id: int + The document id as returned from training_docs, validation_docs, or test_docs. :param num_fewshot: int The number of fewshot examples to provide in the returned context string. + :param split: str + The split of the document to retrieve from the dataset :param rnd: random.Random The pseudo-random number generator used to randomly sample examples. WARNING: This is currently a required arg although it's optionalized with a default `None`. @@ -444,6 +449,7 @@ def fewshot_context( assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" description = description if description else "" + doc = self.dataset[split][doc_id] if num_fewshot == 0: labeled_examples = "" @@ -676,18 +682,18 @@ def fewshot_docs(self): return super().fewshot_docs() @utils.positional_deprecated - def fewshot_context(self, doc, num_fewshot): + def fewshot_context(self, doc_id, num_fewshot, split): """Returns a fewshot context string that is made up of a prepended description (if provided), the `num_fewshot` number of examples, and an appended prompt example. - :param doc: str - The document as returned from training_docs, validation_docs, or test_docs. + :param doc_id: str + The document id as returned from training_docs, validation_docs, or test_docs. :param num_fewshot: int The number of fewshot examples to provide in the returned context string. :returns: str The fewshot context. """ - + doc = self.dataset[split][doc_id] if num_fewshot == 0: # always prepend the (possibly empty) task description labeled_examples = self.config.description @@ -840,26 +846,28 @@ def doc_to_choice(self, doc: Any) -> List[str]: else: raise TypeError - def construct_requests(self, doc: dict, ctx: str, **kwargs) -> Union[List[Instance], Instance]: + def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Instance], Instance]: + split = kwargs.get("split") + kwargs.pop("split") if self.OUTPUT_TYPE == "loglikelihood": - arguments = (ctx, self.doc_to_target(doc), self.doc_to_visual, kwargs.get("metadata")[1], self.config.task, kwargs.get("split")) + arguments = (ctx, self.doc_to_target, self.doc_to_visual, doc_id, self.config.task, split) elif self.OUTPUT_TYPE == "loglikelihood_rolling": - arguments = (self.doc_to_target(doc),) + arguments = (self.doc_to_target,) elif self.OUTPUT_TYPE == "multiple_choice": + doc = self.dataset[split][doc_id] choices = self.doc_to_choice(doc) target_delimiter = self.config.target_delimiter if self.multiple_input: # If there are multiple inputs, choices are placed in the ctx cont = self.doc_to_target(doc) - arguments = [(ctx, f"{target_delimiter}{cont}", self.doc_to_visual, kwargs.get("metadata")[1], self.config.task, kwargs.get("split")) for ctx in choices] + arguments = [(ctx, f"{target_delimiter}{cont}", self.doc_to_visual, doc_id, self.config.task, split) for ctx in choices] else: # Otherwise they are placed in the continuation - arguments = [(ctx, f"{target_delimiter}{cont}", self.doc_to_visual, kwargs.get("metadata")[1], self.config.task, kwargs.get("split")) for cont in choices] - kwargs.pop("split") + arguments = [(ctx, f"{target_delimiter}{cont}", self.doc_to_visual, doc_id, self.config.task, split) for cont in choices] request_list = [ Instance( request_type="loglikelihood", - doc=doc, + # doc=doc, arguments=arg, idx=i, **kwargs, @@ -878,7 +886,7 @@ def construct_requests(self, doc: dict, ctx: str, **kwargs) -> Union[List[Instan [ Instance( request_type="loglikelihood", - doc=doc, + # doc=doc, arguments=("", "{}".format(choice)), idx=i, **kwargs, @@ -889,9 +897,8 @@ def construct_requests(self, doc: dict, ctx: str, **kwargs) -> Union[List[Instan return request_list elif self.OUTPUT_TYPE == "generate_until": - arguments = (ctx, self.config.generation_kwargs, self.doc_to_visual, kwargs.get("metadata")[1], self.config.task, kwargs.get("split")) - kwargs.pop("split") - return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs) + arguments = (ctx, self.config.generation_kwargs, self.doc_to_visual, doc_id, self.config.task, split) + return Instance(request_type=self.OUTPUT_TYPE, arguments=arguments, idx=0, **kwargs) def process_results(self, doc, results): if callable(self.config.process_results): diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py index 6662a3e7..eefaff86 100644 --- a/lmms_eval/evaluator.py +++ b/lmms_eval/evaluator.py @@ -325,7 +325,7 @@ def evaluate( "doc_id": doc_id, "doc": {k: v for k, v in doc.items() if "image" not in k}, # do not include image "target": target, - "arguments": [req.args[:2] for req in requests], # do not include image + "arguments": [tuple(a for a in req.args if isinstance(a, (int, str))) for req in requests], # do not include image "resps": [req.resps for req in requests], "filtered_resps": [req.filtered_resps[key] for req in requests], } diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 2033f093..e4eecd02 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -145,8 +145,9 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: res = [] pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") - for contexts, continuation, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: + for contexts, doc_to_target, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: # encode, pad, and truncate contexts for this batch + continuation = doc_to_target(self.task_dict[task][split][doc_id]) visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] visuals = self.flatten(visuals) if visuals: diff --git a/lmms_eval/tasks/llava-bench-coco/llava-bench-coco.yaml b/lmms_eval/tasks/llava-bench-coco/llava-bench-coco.yaml index eb4973f8..4ca8f35a 100644 --- a/lmms_eval/tasks/llava-bench-coco/llava-bench-coco.yaml +++ b/lmms_eval/tasks/llava-bench-coco/llava-bench-coco.yaml @@ -17,17 +17,17 @@ generation_kwargs: num_beams: 1 process_results: !function utils.llava_process_results metric_list: + - metric: gpt_eval_llava_all + aggregation: !function utils.llava_all_aggregation + higher_is_better: true - metric: gpt_eval_llava_conv - aggregation: !function utils.llava_aggregation + aggregation: !function utils.llava_conv_aggregation higher_is_better: true - metric: gpt_eval_llava_detail - aggregation: !function utils.llava_aggregation + aggregation: !function utils.llava_detail_aggregation higher_is_better: true - metric: gpt_eval_llava_complex - aggregation: !function utils.llava_aggregation - higher_is_better: true - - metric: gpt_eval_llava_all - aggregation: !function utils.llava_aggregation + aggregation: !function utils.llava_complex_aggregation higher_is_better: true metadata: version: 0.0 diff --git a/lmms_eval/tasks/llava-bench-coco/utils.py b/lmms_eval/tasks/llava-bench-coco/utils.py index ee2a7614..19075868 100644 --- a/lmms_eval/tasks/llava-bench-coco/utils.py +++ b/lmms_eval/tasks/llava-bench-coco/utils.py @@ -29,17 +29,28 @@ GPT_EVAL_MODEL_NAME = config["metadata"]["gpt_eval_model_name"] -def get_eval(content: str, max_tokens: int): +def get_eval(content: str, max_tokens: int, retries: int = 3): headers = { "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json", } - messages = [{"role": "system", "content": "You are a helpful and precise assistant for checking the quality of the answer."}, {"role": "user", "content": content}] - - payload = {"model": GPT_EVAL_MODEL_NAME, "messages": messages, "temperature": 0.2, "max_tokens": max_tokens} + messages = [ + { + "role": "system", + "content": "You are a helpful and precise assistant for checking the quality of the answer.", + }, + {"role": "user", "content": content}, + ] + + payload = { + "model": GPT_EVAL_MODEL_NAME, + "messages": messages, + "temperature": 0.2, + "max_tokens": max_tokens, + } - while True: + for attempt in range(retries): try: response = requests.post(API_URL, headers=headers, json=payload) response.raise_for_status() @@ -48,12 +59,15 @@ def get_eval(content: str, max_tokens: int): content = response_data["choices"][0]["message"]["content"].strip() if content != "": return content, response_data["model"] + break # If successful, break out of the loop except Exception as e: - eval_logger.info(f"Error in response : {response.json()['error']['message']}") - if "Rate limit" in str(e): - eval_logger.info("Sleeping due to rate limit...") + eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}") + if attempt < retries - 1: # If we have retries left, sleep and then continue to next attempt time.sleep(NUM_SECONDS_TO_SLEEP) + else: # If this was the last attempt, log and return empty + eval_logger.error(f"All {retries} attempts failed. Last error message: {str(e)}") + return "", "" return "", "" @@ -65,11 +79,11 @@ def parse_score(review): if len(sp) == 2: return [float(sp[0]), float(sp[1])] else: - print("error", review) + eval_logger.debug("error", review) return [-1, -1] except Exception as e: - print(e) - print("error", review) + eval_logger.debug(e) + eval_logger.debug("error", review) return [-1, -1] @@ -90,32 +104,73 @@ def llava_process_results(doc, result): Returns: a dictionary with key: metric name (in this case coco_bleu), value: metric value """ - question = doc["question"] - ans1 = doc["answer"] - ans2 = result - context = doc["caption"] - category = "llava_bench_" + doc["category"] - rule = rule_dict[category] - prompt = rule["prompt"] - role = rule["role"] - content = f"[Context]\n{context}\n\n" f"[Question]\n{question}\n\n" f"[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n" f"[{role} 2]\n{ans2}\n\n[End of {role} 2]\n\n" f"[System]\n{prompt}\n\n" - - review, model_name = get_eval(content, 1024) - scores = parse_score(review) - metric = f"gpt_eval_llava_{doc['category']}" - review_dict = {"question": question, "ans1": ans1, "ans2": ans2, "context": context, "category": category, "review": review, "scores": scores, "eval_model": model_name} + try: + question = doc.get("question", "") + ans1 = doc.get("gpt_answer", "") + ans2 = result[0] if result else "" + captions = doc.get("caption", []) + context = "\n".join(captions) if isinstance(captions, list) else captions + category = "llava_bench_" + doc.get("category", "") + rule = rule_dict.get(category, {}) + prompt = rule.get("prompt", "") + role = rule.get("role", "user") + content = f"[Context]\n{context}\n\n" f"[Question]\n{question}\n\n" f"[{role} 1]\n{ans1}\n\n[End of {role} 1]\n\n" f"[{role} 2]\n{ans2}\n\n[End of {role} 2]\n\n" f"[System]\n{prompt}\n\n" + + review, model_name = get_eval(content, 1024) + scores = parse_score(review) + except Exception as e: + eval_logger.error(f"Error for Question ID: {doc.get('question_id', 'Unknown')}: {e}") + review = "Failed to Get a Proper Review." + model_name = "Failed Request" + scores = [-1, -1] + + metric = f"gpt_eval_llava_{doc.get('category', 'unknown')}" + review_dict = { + "question": question, + "ans1": ans1, + "ans2": ans2, + "context": context, + "category": category, + "review": review, + "scores": scores, + "eval_model": model_name, + } + # return {"gpt_eval_llava_all": review_dict} return {metric: review_dict, "gpt_eval_llava_all": review_dict} -def llava_aggregation(results): - scores = [] - category = results[0]["category"] - for result in results: - scores.append(result["scores"]) +def llava_conv_aggregation(results): + return llava_aggregation(results, "conv") + + +def llava_complex_aggregation(results): + return llava_aggregation(results, "complex") - stats = np.asarray(scores).mean(0).tolist() - stats = [round(x, 3) for x in stats] - print(category, round(stats[1] / stats[0] * 100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) - print("=========================") - return round(stats[1] / stats[0] * 100, 1) + +def llava_detail_aggregation(results): + return llava_aggregation(results, "detail") + + +def llava_all_aggregation(results): + return llava_aggregation(results, "all") + + +def llava_aggregation(results, category): + try: + scores = [] + for result in results: + scores.append(result["scores"]) + + stats = np.asarray(scores).mean(0).tolist() + stats = [round(x, 3) for x in stats] + # gpt4_score_percentage = stats[0] * 10 + # model_score_percentage = stats[1] * 10 + # eval_logger.info(f"Category: {category}") + # eval_logger.info(f"GPT4 Score: {gpt4_score_percentage:.1f}%") + # eval_logger.info(f"Model Score: {model_score_percentage:.1f}%") + # eval_logger.info("=========================") + return round(stats[1] / stats[0] * 100, 1) + except Exception as e: + eval_logger.error(f"Error in llava_aggregation: {e}") + return None diff --git a/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml b/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml index aa7b4fe4..1a0a2b1e 100644 --- a/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml +++ b/lmms_eval/tasks/llava-in-the-wild/llava-in-the-wild.yaml @@ -19,7 +19,16 @@ generation_kwargs: process_results: !function utils.llava_process_results metric_list: - metric: gpt_eval_llava_all - aggregation: !function utils.llava_aggregation + aggregation: !function utils.llava_all_aggregation + higher_is_better: true + - metric: gpt_eval_llava_conv + aggregation: !function utils.llava_conv_aggregation + higher_is_better: true + - metric: gpt_eval_llava_detail + aggregation: !function utils.llava_detail_aggregation + higher_is_better: true + - metric: gpt_eval_llava_complex + aggregation: !function utils.llava_complex_aggregation higher_is_better: true # - metric: gpt_eval_llava_conv # aggregation: !function utils.llava_aggregation diff --git a/lmms_eval/tasks/llava-in-the-wild/utils.py b/lmms_eval/tasks/llava-in-the-wild/utils.py index 903270a2..0c6445f3 100644 --- a/lmms_eval/tasks/llava-in-the-wild/utils.py +++ b/lmms_eval/tasks/llava-in-the-wild/utils.py @@ -124,7 +124,7 @@ def llava_process_results(doc, result): model_name = "Failed Request" scores = [-1, -1] - # metric = f"gpt_eval_llava_{doc.get('category', 'unknown')}" + metric = f"gpt_eval_llava_{doc.get('category', 'unknown')}" review_dict = { "question": question, "ans1": ans1, @@ -136,26 +136,40 @@ def llava_process_results(doc, result): "eval_model": model_name, } - return {"gpt_eval_llava_all": review_dict} - # return {metric: review_dict, "gpt_eval_llava_all": review_dict} + # return {"gpt_eval_llava_all": review_dict} + return {metric: review_dict, "gpt_eval_llava_all": review_dict} -def llava_aggregation(results): - return 0 +def llava_conv_aggregation(results): + return llava_aggregation(results, "conv") + + +def llava_complex_aggregation(results): + return llava_aggregation(results, "complex") + + +def llava_detail_aggregation(results): + return llava_aggregation(results, "detail") + + +def llava_all_aggregation(results): + return llava_aggregation(results, "all") + + +def llava_aggregation(results, category): try: scores = [] - category = results[0]["category"] for result in results: scores.append(result["scores"]) stats = np.asarray(scores).mean(0).tolist() stats = [round(x, 3) for x in stats] - gpt4_score_percentage = stats[0] * 10 - model_score_percentage = stats[1] * 10 - eval_logger.info(f"Category: {category}") - eval_logger.info(f"GPT4 Score: {gpt4_score_percentage:.1f}%") - eval_logger.info(f"Model Score: {model_score_percentage:.1f}%") - eval_logger.info("=========================") + # gpt4_score_percentage = stats[0] * 10 + # model_score_percentage = stats[1] * 10 + # eval_logger.info(f"Category: {category}") + # eval_logger.info(f"GPT4 Score: {gpt4_score_percentage:.1f}%") + # eval_logger.info(f"Model Score: {model_score_percentage:.1f}%") + # eval_logger.info("=========================") return round(stats[1] / stats[0] * 100, 1) except Exception as e: eval_logger.error(f"Error in llava_aggregation: {e}") diff --git a/lmms_eval/tasks/nocaps/utils.py b/lmms_eval/tasks/nocaps/utils.py index 4c996e27..00a38b13 100644 --- a/lmms_eval/tasks/nocaps/utils.py +++ b/lmms_eval/tasks/nocaps/utils.py @@ -134,7 +134,7 @@ def nocaps_test_process_result(doc, result): Returns: a dictionary with key: metric name (in this case nocaps_passthrough), value: metric value """ - return {"nocaps_passthrough": {"pred": result, "image_id": doc["image_id"]}} + return {"nocaps_passthrough": {"pred": result[0], "image_id": doc["image_id"]}} def nocaps_test_aggregation_result(results): @@ -144,7 +144,7 @@ def nocaps_test_aggregation_result(results): if not os.path.exists("./captions_nocaps_test_alg_results.json"): eval_logger.info("Storing prediction that can be submitted to the server ...") - with open("./captions_nocaps_val_alg_results.json", "w") as f: + with open("./captions_nocaps_test_alg_results.json", "w") as f: json.dump(stored_results, f, indent=4) eval_logger.info("Your test result has been stored. Make sure you also have the val result stored to submit to the server on https://codalab.lisn.upsaclay.fr/competitions/7404#participate.") diff --git a/lmms_eval/tasks/scienceqa_img/scienceqa.yaml b/lmms_eval/tasks/scienceqa_img/scienceqa.yaml index c3ff1611..224195da 100644 --- a/lmms_eval/tasks/scienceqa_img/scienceqa.yaml +++ b/lmms_eval/tasks/scienceqa_img/scienceqa.yaml @@ -8,8 +8,11 @@ doc_to_visual: !function utils.sqa_doc_to_visual doc_to_text: !function utils.sqa_doc_to_text doc_to_target: !function utils.sqa_doc_to_target generation_kwargs: - max_new_tokens: 16 + - "ASSISTANT:" image_aspect_ratio: original + max_new_tokens: 16 + temperature: 0 + do_sample: False metric_list: - metric: exact_match aggregation: mean diff --git a/lmms_eval/tasks/seedbench/seedbench.yaml b/lmms_eval/tasks/seedbench/seedbench.yaml index a573f54b..371d0ba3 100644 --- a/lmms_eval/tasks/seedbench/seedbench.yaml +++ b/lmms_eval/tasks/seedbench/seedbench.yaml @@ -10,6 +10,7 @@ doc_to_target: "answer" generation_kwargs: until: - "ASSISTANT:" + image_aspect_ratio: original # The return value of process_results will be used by metrics process_results: !function utils.seed_process_result # Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results diff --git a/lmms_eval/tasks/seedbench/utils.py b/lmms_eval/tasks/seedbench/utils.py index 37943f04..490af986 100644 --- a/lmms_eval/tasks/seedbench/utils.py +++ b/lmms_eval/tasks/seedbench/utils.py @@ -7,10 +7,10 @@ def seed_doc_to_visual(doc): def seed_doc_to_text(doc): question = doc["question"] - question += "\n" + f"A.{doc['choice_a']}\n" - question += "\n" + f"B.{doc['choice_b']}\n" - question += "\n" + f"C.{doc['choice_c']}\n" - question += "\n" + f"D.{doc['choice_d']}" + question += "\n" + f"A. {doc['choice_a']}\n" + question += f"B. {doc['choice_b']}\n" + question += f"C. {doc['choice_c']}\n" + question += f"D. {doc['choice_d']}" return f"{question}\nAnswer with the option's letter from the given choices directly." @@ -37,7 +37,7 @@ def seed_aggregation_result_all(results): stored_results = [] for result in results: stored_results.append({"question_id": result["question_id"], "prediction": result["pred"]}) - with open("./seed_submission.json", "r") as f: + with open("./seed_submission.json", "w") as f: json.dump(stored_results, f, indent=4) print("Storing files for seed_submission ...")