diff --git a/README.md b/README.md index 7b845f44..a2f8e720 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,9 @@ 🏠 [Homepage](https://lmms-lab.github.io/) | 🎉 [Blog](https://lmms-lab.github.io/lmms-eval-blog/lmms-eval-0.1/) | 📚 [Documentation](docs/README.md) | 🤗 [Huggingface Datasets](https://huggingface.co/lmms-lab) | Discord_Thread [discord/lmms-eval](https://discord.gg/ebAMGSsS) -In today's world, we're on a thrilling quest for Artificial General Intelligence (AGI), driven by a passion that reminds us of the excitement surrounding the 1960s moon landing. At the heart of this adventure are the incredible large language models (LLMs) and large multimodal models (LMMs). These models are like brilliant minds that can understand, learn, and interact with a vast array of human tasks, marking a significant leap toward our goal. +In today's world, we're on an exciting journey toward creating Artificial General Intelligence (AGI), much like the enthusiasm of the 1960s moon landing. This journey is powered by advanced large language models (LLMs) and large multimodal models (LMMs), which are complex systems capable of understanding, learning, and performing a wide variety of human tasks. These advancements bring us closer to achieving AGI. -To truly understand how capable these models are, we've started to create and use a wide variety of evaluation benchmarks. These benchmarks help us map out a detailed chart of abilities, showing us how close we are to achieving true AGI. However, this journey is not without its challenges. The sheer number of benchmarks and datasets we need to look at is overwhelming. They're all over the place - tucked away in someone's Google Drive, scattered across Dropbox, and hidden in the corners of various school and research lab websites. It's like embarking on a treasure hunt where the maps are spread far and wide. +To gauge how advanced these models are, we use a variety of evaluation benchmarks. These benchmarks are tools that help us understand the capabilities of these models, showing us how close we are to achieving AGI. However, finding and using these benchmarks is a big challenge. The necessary benchmarks and datasets are spread out and hidden in various places like Google Drive, Dropbox, and different school and research lab websites. It feels like we're on a treasure hunt, but the maps are scattered everywhere. In the field of language models, there has been a valuable precedent set by the work of [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). They offer integrated data and model interfaces, enabling rapid evaluation of language models and serving as the backend support framework for the [open-llm-leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard), and has gradually become the underlying ecosystem of the era of foundation models. diff --git a/llava_repr_requirements.txt b/llava_repr_requirements.txt index 07afb686..e3f0f527 100644 --- a/llava_repr_requirements.txt +++ b/llava_repr_requirements.txt @@ -27,6 +27,23 @@ shortuuid==1.0.12 sqlitedict==2.1.0 tenacity==8.2.3 torch==2.0.1 +openai>=1.0.0 +pycocoevalcap tokenizers==0.15.2 tqdm==4.66.2 -transformers==4.37.2 \ No newline at end of file +tqdm-multiprocess +transformers==4.37.2 +zstandard +pillow +pyyaml +sympy +mpmath +Jinja2 +openpyxl +Levenshtein +hf_transfer +tenacity +wandb>=0.16.0 +transformers-stream-generator +tiktoken +pre-commit \ No newline at end of file diff --git a/lmms_eval/tasks/olympiadbench/cn_utils.py b/lmms_eval/tasks/olympiadbench/cn_utils.py new file mode 100644 index 00000000..34e5ce4d --- /dev/null +++ b/lmms_eval/tasks/olympiadbench/cn_utils.py @@ -0,0 +1,69 @@ +import os +import json +import datetime +from lmms_eval.tasks.olympiadbench.olympiadbench_evals import OlympiadBenchEvaluator +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file + +import logging +eval_logger = logging.getLogger("lmms-eval") +dir_name = os.path.dirname(os.path.abspath(__file__)) + +olympiadbench_evaluator = OlympiadBenchEvaluator() + +def olympiadbench_doc_to_visual(doc): + return [image.convert("RGB") for image in doc["images"]] + +def olympiadbench_doc_to_text(doc): + question = doc["question"] + subject = doc["subfield"] + mul_ans = doc["is_multiple_answer"] + if mul_ans is None: + mul_ans = False + ans_type = doc["answer_type"] + if ans_type == "Need_human_evaluate": + ans_type = "proof based" + + pre_prompt = f"以下是中国{subject}竞赛中的解答题。\n" + + post_prompt = "" + if not mul_ans: + post_prompt += f"答案类型为{ans_type}。\n" + else: + post_prompt += f"题目有多个答案,答案类型均为{ans_type}。\n" + post_prompt += "请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的变量和公式请使用LaTeX格式表示。请在最后以" + if not mul_ans: + post_prompt += '"所以最终答案是\\boxed{答案}。"\n' + else: + post_prompt += '"所以最终答案是\\boxed{用英⽂逗号连接的多个答案}。"\n' + + final_question = pre_prompt + question + '\n' + post_prompt + return final_question + +def olympiadbench_process_results(doc, results): + precision = doc["error"] + is_proving = "TP" in doc["source"] + if precision is None: + precision = 0 + prediction = results[0].strip() + + if is_proving: + return { + "submission": prediction + } + else: + prediction = prediction.split("所以最终答案是")[-1] + prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。") + accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision) + accuracy = int(accuracy) + return { + "exact_match": accuracy + } + +def olympiadbench_aggregate_results(results, args): + now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S") + submission_file_name = f"olympiadbench-test-cn-submission-{now_date_time}.json" + path = generate_submission_file(submission_file_name, args) + with open(path, "w") as f: + json.dump(results, f, ensure_ascii=False) + print(f"Submission file saved to {path}") + \ No newline at end of file diff --git a/lmms_eval/tasks/olympiadbench/en_utils.py b/lmms_eval/tasks/olympiadbench/en_utils.py new file mode 100644 index 00000000..a21ee159 --- /dev/null +++ b/lmms_eval/tasks/olympiadbench/en_utils.py @@ -0,0 +1,69 @@ +import os +import json +import datetime +from lmms_eval.tasks.olympiadbench.olympiadbench_evals import OlympiadBenchEvaluator +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file + +import logging +eval_logger = logging.getLogger("lmms-eval") +dir_name = os.path.dirname(os.path.abspath(__file__)) + +olympiadbench_evaluator = OlympiadBenchEvaluator() + +def olympiadbench_doc_to_visual(doc): + return [image.convert("RGB") for image in doc["images"]] + +def olympiadbench_doc_to_text(doc): + question = doc["question"] + subject = doc["subfield"] + mul_ans = doc["is_multiple_answer"] + if mul_ans is None: + mul_ans = False + ans_type = doc["answer_type"] + if ans_type == "Need_human_evaluate": + ans_type = "proof based" + + pre_prompt = f"The following is a question from an International {subject} competition.\n" + + post_prompt = "" + if not mul_ans: + post_prompt += f"The answer of the question should be {ans_type}.\n" + else: + post_prompt += f"The question has multiple answers, each of them should be {ans_type}.\n" + post_prompt += "Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with " + if not mul_ans: + post_prompt += '"So the final answer is \\boxed{answer}."\n' + else: + post_prompt += 'So the final answer is \\boxed{multiple answers connected with commas}.\n' + + final_question = pre_prompt + question + '\n' + post_prompt + return final_question + +def olympiadbench_process_results(doc, results): + precision = doc["error"] + is_proving = "TP" in doc["source"] + if precision is None: + precision = 0 + prediction = results[0].strip() + + if is_proving: + return { + "submission": prediction + } + else: + prediction = prediction.split("final answer is")[-1] + prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。") + accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision) + accuracy = int(accuracy) + return { + "exact_match": accuracy + } + +def olympiadbench_aggregate_results(results, args): + now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S") + submission_file_name = f"olympiadbench-test-en-submission-{now_date_time}.json" + path = generate_submission_file(submission_file_name, args) + with open(path, "w") as f: + json.dump(results, f, ensure_ascii=False) + print(f"Submission file saved to {path}") + \ No newline at end of file diff --git a/lmms_eval/tasks/olympiadbench/olympiadbench.yaml b/lmms_eval/tasks/olympiadbench/olympiadbench.yaml new file mode 100644 index 00000000..1580b158 --- /dev/null +++ b/lmms_eval/tasks/olympiadbench/olympiadbench.yaml @@ -0,0 +1,6 @@ +group: olympiadbench +task: +- olympiadbench_test_en +- olympiadbench_test_cn +metadata: + - version: 0.0 diff --git a/lmms_eval/tasks/olympiadbench/olympiadbench_evals.py b/lmms_eval/tasks/olympiadbench/olympiadbench_evals.py new file mode 100644 index 00000000..dd40f611 --- /dev/null +++ b/lmms_eval/tasks/olympiadbench/olympiadbench_evals.py @@ -0,0 +1,355 @@ +import re +import sympy as sp +from sympy import simplify, Eq, sympify, Pow +from sympy.parsing.latex import parse_latex +import math + +# how to use +# scorer = OlympiadBenchEvaluator() +# exp1 = "10^{10^{10^{10}}}" +# exp2 = "10^{10}" +# precision = 1e-4 +# res = scorer.judge(exp1, exp2, precision) + +class OlympiadBenchEvaluator: + def __init__(self): + # Map of special symbols to their replacements + self.special_signal_map = { + "\\left": "", + "\\right": "", + "∶": ":", + ",": ",", + "$": "", + "\\approx": "=", + "\\simeq": "=", + "\\sim": "=", + "^\\prime": "'", + "^{\\prime}": "'", + "^\\circ": "", + "%": "", + } + self.pi = parse_latex("\\pi") + self.precision = 1e-8 # Default precision for comparison + + def split_by_comma(self, expr: str): + # Splits expressions by commas outside of brackets + in_bracket_num = 0 + splitted_expr = [] + start_idx = 0 + for i, char in enumerate(expr): + if char in ["(", "["]: + in_bracket_num += 1 + elif char in [")", "]"]: + in_bracket_num -= 1 + elif char == "," and in_bracket_num == 0: + splitted_expr.append(expr[start_idx:i].strip()) + start_idx = i + 1 + + if start_idx < len(expr): + splitted_expr.append(expr[start_idx:].strip()) + + return splitted_expr + + def trans_plus_minus_sign(self, expr_list: list): + # Translates plus-minus signs into separate expressions + new_expr_list = [] + for expr in expr_list: + if "\\pm" in expr: + new_expr_list.append(expr.replace("\\pm", "+")) + new_expr_list.append(expr.replace("\\pm", "-")) + else: + new_expr_list.append(expr) + + return new_expr_list + + def judge(self, expression1, expression2, precision=1e-8): + # Judge if two expressions are equal (expression1 is considered as the Ground Truth) + # Default precision is a list for supporting multiple expressions + precision = precision if isinstance(precision, list) else [precision] + + try: + expression1, expression2 = self.preprocess(expression1, expression2) + except: + return False + if expression1 == expression2: + # print("Exactly equal") + return True + + # Remove Chinese characters from the string, as answers like "yes" or "no" in Chinese have been considered + expression1 = re.sub(r'[\u4e00-\u9fff]+', '', expression1) + expression2 = re.sub(r'[\u4e00-\u9fff]+', '', expression2) + + expression1 = self.split_by_comma(expression1) + expression2 = self.split_by_comma(expression2) + + temp_list1 = self.trans_plus_minus_sign(expression1) + temp_list2 = self.trans_plus_minus_sign(expression2) + + # Set up a list for allowed errors + if len(precision) <= 1: + precision = precision * len(temp_list1) + + if len(temp_list1) != len(temp_list2): + return False + + # Check if elements in both lists can be paired and are equal + idx = -1 + while len(temp_list1) != 0: + idx = (idx + 1) % len(temp_list1) + + item1 = temp_list1[idx] + self.precision = precision[idx] + + for item2 in temp_list2: + if self.is_equal(item1, item2): + temp_list1.remove(item1) + temp_list2.remove(item2) + precision.remove(self.precision) + break + else: + # If no match was found, return False + return False + + # If all elements are matched, return True + return True + + def is_interval(self, expr): + # Checks if an expression is an interval + return expr.startswith(("(", "[")) and expr.endswith((")", "]")) + + def sympy_sub_pi(self, expression_sympy): + # Replaces the symbol for pi in sympy expressions with its numerical value + return expression_sympy.subs(self.pi, math.pi) + + def is_equal(self, expression1, expression2): + # Default first expression is ground truth. Check if expressions are equal in different aspects + if expression1 == expression2 and expression1 != "" and expression2 != "": + # print("Equivalent natively") + return True + + # First check if both are intervals + if self.is_interval(expression1) and self.is_interval(expression2): + try: + if self.interval_equal(expression1, expression2): + # print("Interval equivalent") + return True + except: + return False + + # Then check for numerical equality + try: + if self.numerical_equal(expression1, expression2): + # print("Numerically equivalent") + return True + except: + pass + + # Then check if expressions are mathematically equal + try: + if self.expression_equal(expression1, expression2) and not ("=" in expression1 and "=" in expression2): + # print("Expression equivalent") + return True + except: + pass + + # Lastly, check for equation equality + try: + if self.equation_equal(expression1, expression2): + # print("Equation equivalent") + return True + except: + pass + + return False + + def numerical_equal(self, expression1: str, expression2: str, include_percentage: bool = True): + # Check if two numerical values are equal within an allowed error range + # Includes possible percentage cases + reference = float(expression1) + prediction = float(expression2) + + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + + for item in gt_result: + if abs(item - prediction) <= self.precision * 1.01: + return True + return False + + + def expression_equal(self, exp1, exp2): + # Check if two expressions are mathematically equivalent + # Extract expression and use sympy for equivalence checking + def extract_expression(expression): + if "=" in expression: + expression = expression.split("=")[1] + return expression.strip() + + exp1 = extract_expression(exp1) + exp2 = extract_expression(exp2) + + expr1_sym = sympify(parse_latex(exp1)) + expr2_sym = sympify(parse_latex(exp2)) + + if expr1_sym == expr2_sym: + return True + else: + expr1_sym = self.sympy_sub_pi(expr1_sym) + expr2_sym = self.sympy_sub_pi(expr2_sym) + + if (expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol)) or (not expr1_sym.has(sp.Symbol) and expr2_sym.has(sp.Symbol)): + return False + elif not expr1_sym.has(sp.Symbol) and not expr2_sym.has(sp.Symbol): + try: + if not (self.can_compute_power(expr1_sym) and self.can_compute_power(expr2_sym)): + print(f"These two numbers cannot be calculated by the current computer for: \"{str(expr1_sym)}\" and \"{str(expr2_sym)}\"") + return False + + if abs(expr1_sym.evalf() - expr2_sym.evalf()) <= self.precision * 1.01: + return True + else: + return False + except: + return False + else: + try: + simplified_expr = simplify(expr1_sym - expr2_sym) + + num_value = simplified_expr.evalf() + + return abs(num_value) < 1e-3 + except: + return False + + def equation_equal(self, expression1, expression2): + # Check if two equations are mathematically equivalent + # Simplify equations and use sympy for equivalence checking + def simplify_equation(latex_eq): + lhs, rhs = latex_eq.split('=') + + lhs_expr = parse_latex(lhs) + rhs_expr = parse_latex(rhs) + + equation = Eq(lhs_expr, rhs_expr) + + simplified_eq = simplify(equation.lhs - equation.rhs) + + return simplified_eq + + expr1_sym = simplify_equation(expression1) + expr2_sym = simplify_equation(expression2) + + division_result_1 = simplify(expr1_sym / expr2_sym) + division_result_2 = simplify(expr2_sym / expr1_sym) + + if (division_result_1.is_Integer and division_result_1 != 0) or (division_result_2.is_Integer and division_result_2 != 0): + return True + else: + return False + + def interval_equal(self, expression1, expression2): + # Check if two intervals are mathematically equivalent + def compare_two_interval(inter1, inter2): + if inter1[0] != inter2[0] or inter1[-1] != inter2[-1]: + return False + + inter1 = inter1.strip('[]()') + inter2 = inter2.strip('[]()') + + items_1 = inter1.split(',') + items_2 = inter2.split(',') + + for item_1, item_2 in zip(items_1, items_2): + if not self.expression_equal(item_1, item_2): + return False + return True + + interval1 = expression1 + interval2 = expression2 + + if interval1 == interval2: + return True + else: + inter_list1 = interval1.split("\\cup") + inter_list2 = interval2.split("\\cup") + + if len(inter_list1) != len(inter_list2): + return False + else: + for inter1, inter2 in zip(inter_list1, inter_list2): + if not compare_two_interval(inter1, inter2): + return False + return True + + def preprocess(self, expression1, expression2): + # Preprocess expressions to extract and replace special symbols + def extract_boxed_content(latex_str): + boxed_matches = re.finditer(r'\\boxed{', latex_str) + results = "" + + for match in boxed_matches: + start_index = match.end() + end_index = start_index + stack = 1 + + while stack > 0 and end_index < len(latex_str): + if latex_str[end_index] == '{': + stack += 1 + elif latex_str[end_index] == '}': + stack -= 1 + end_index += 1 + + if stack == 0: + content = latex_str[start_index:end_index - 1] + results += content + "," + else: + raise ValueError("Mismatched braces in LaTeX string.") + + if results == "": + last_line_ans = latex_str.strip().split("\n")[-1] + dollar_pattern = r"\$(.*?)\$" + answers = re.findall(dollar_pattern, last_line_ans) + + if answers: + for ans in answers: + results += ans + "," + else: + results = latex_str + + return results + + def sepcial_symbol_replace(expression): + if "\\in " in expression: + expression = expression.split("\\in ")[1] + + for signal in self.special_signal_map: + expression = expression.replace(signal, self.special_signal_map[signal]) + + expression = expression.strip("\n$,.:;^_=+`!@#$%^&*~,。") + + pattern = r'\\(?:mathrm|mathbf)\{~?([^}]*)\}' + expression = re.sub(pattern, r'\1', expression) + + return expression + + exp1, exp2 = extract_boxed_content(expression1), extract_boxed_content(expression2) + exp1, exp2 = sepcial_symbol_replace(exp1), sepcial_symbol_replace(exp2) + + return exp1, exp2 + + def can_compute_power(self, expr): + # Checks if a power expression can be computed + if isinstance(expr, Pow): + base, exp = expr.as_base_exp() + if base.is_number and exp.is_number: + MAX_EXP = 1000 # Adjust based on computing environment + if abs(exp.evalf()) > MAX_EXP: + return False + else: + return True + else: + return False + else: + return True # Not a power expression, can compute \ No newline at end of file diff --git a/lmms_eval/tasks/olympiadbench/olympiadbench_test_cn.yaml b/lmms_eval/tasks/olympiadbench/olympiadbench_test_cn.yaml new file mode 100644 index 00000000..574d0c19 --- /dev/null +++ b/lmms_eval/tasks/olympiadbench/olympiadbench_test_cn.yaml @@ -0,0 +1,25 @@ +dataset_path: lmms-lab/OlympiadBench +dataset_kwargs: + token: True +task : "olympiadbench_test_cn" +test_split: test_cn +output_type: generate_until +doc_to_visual: !function cn_utils.olympiadbench_doc_to_visual +doc_to_text: !function cn_utils.olympiadbench_doc_to_text +doc_to_target: "answer" +generation_kwargs: + until: + - "ASSISTANT:" + max_new_tokens: 1024 + temperature: 0 + top_p: 0 + num_beams: 1 + do_sample: false +process_results: !function cn_utils.olympiadbench_process_results +metric_list: + - metric: submission + aggregation: !function cn_utils.olympiadbench_aggregate_results + higher_is_better: true + - metric: exact_match + aggregation: mean + higher_is_better: true \ No newline at end of file diff --git a/lmms_eval/tasks/olympiadbench/olympiadbench_test_en.yaml b/lmms_eval/tasks/olympiadbench/olympiadbench_test_en.yaml new file mode 100644 index 00000000..6d293fb7 --- /dev/null +++ b/lmms_eval/tasks/olympiadbench/olympiadbench_test_en.yaml @@ -0,0 +1,25 @@ +dataset_path: lmms-lab/OlympiadBench +dataset_kwargs: + token: True +task : "olympiadbench_test_en" +test_split: test_en +output_type: generate_until +doc_to_visual: !function en_utils.olympiadbench_doc_to_visual +doc_to_text: !function en_utils.olympiadbench_doc_to_text +doc_to_target: "answer" +generation_kwargs: + until: + - "ASSISTANT:" + max_new_tokens: 1024 + temperature: 0 + top_p: 0 + num_beams: 1 + do_sample: false +process_results: !function en_utils.olympiadbench_process_results +metric_list: + - metric: submission + aggregation: !function en_utils.olympiadbench_aggregate_results + higher_is_better: true + - metric: exact_match + aggregation: mean + higher_is_better: true \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ba37a96c..c50c4e76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "openai>=1.0.0", "pycocoevalcap", "tqdm-multiprocess", - "transformers==4.37.2", + "transformers", "zstandard", "pillow", "pyyaml",