From 90b09473820dfd90767fd569b03645e56ff58c5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hynek=20Kydl=C3=AD=C4=8Dek?= Date: Sun, 26 Jan 2025 18:35:48 +0100 Subject: [PATCH] Reward verification and evaluation fixes (#55) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bump up deps, fix aime24 evals, make grpo more strict * minor fixes * 🤨 fmt * Update src/open_r1/grpo.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --------- Co-authored-by: Hynek Kydlicek Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- setup.py | 2 +- slurm/evaluate.slurm | 1 + src/open_r1/evaluate.py | 26 ++++++++++++++++++++++---- src/open_r1/grpo.py | 38 +++++++++++++++++++++++++++++++------- 4 files changed, 55 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 0e59cc02..21ba5b4a 100644 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ "isort>=5.12.0", "liger_kernel==0.5.2", "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]", - "math-verify", # Used for math verification in grpo + "math-verify>=0.3.2", # Used for math verification in grpo "packaging>=23.0", "parameterized>=0.9.0", "pytest", diff --git a/slurm/evaluate.slurm b/slurm/evaluate.slurm index 421a96cb..315cc80a 100644 --- a/slurm/evaluate.slurm +++ b/slurm/evaluate.slurm @@ -43,6 +43,7 @@ lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \ --custom-tasks src/open_r1/evaluate.py \ --use-chat-template \ --system-prompt="Please reason step by step, and put your final answer within \boxed{}." \ + --save-details --output-dir $OUTPUT_DIR diff --git a/src/open_r1/evaluate.py b/src/open_r1/evaluate.py index ef3089ff..3febbe43 100644 --- a/src/open_r1/evaluate.py +++ b/src/open_r1/evaluate.py @@ -24,7 +24,7 @@ from lighteval.utils.language import Language -metric = multilingual_extractive_match_metric( +latex_gold_metric = multilingual_extractive_match_metric( language=Language.ENGLISH, fallback_mode="first_match", precision=5, @@ -33,6 +33,15 @@ aggregation_function=max, ) +expr_gold_metric = multilingual_extractive_match_metric( + language=Language.ENGLISH, + fallback_mode="first_match", + precision=5, + gold_extraction_target=(ExprExtractionConfig(),), + pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), + aggregation_function=max, +) + def prompt_fn(line, task_name: str = None): """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically""" @@ -44,11 +53,20 @@ def prompt_fn(line, task_name: str = None): ) +def aime_prompt_fn(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["problem"], + choices=[line["answer"]], + gold_index=0, + ) + + # Define tasks aime24 = LightevalTaskConfig( name="aime24", suite=["custom"], - prompt_function=prompt_fn, + prompt_function=aime_prompt_fn, hf_repo="HuggingFaceH4/aime_2024", hf_subset="default", hf_avail_splits=["train"], @@ -56,7 +74,7 @@ def prompt_fn(line, task_name: str = None): few_shots_split=None, few_shots_select=None, generation_size=32768, - metric=[metric], + metric=[expr_gold_metric], version=1, ) math_500 = LightevalTaskConfig( @@ -70,7 +88,7 @@ def prompt_fn(line, task_name: str = None): few_shots_split=None, few_shots_select=None, generation_size=32768, - metric=[metric], + metric=[latex_gold_metric], version=1, ) diff --git a/src/open_r1/grpo.py b/src/open_r1/grpo.py index 7dea36be..3c904516 100644 --- a/src/open_r1/grpo.py +++ b/src/open_r1/grpo.py @@ -17,7 +17,8 @@ from datasets import load_dataset -from math_verify import parse, verify +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config @@ -42,13 +43,36 @@ def accuracy_reward(completions, solution, **kwargs): contents = [completion[0]["content"] for completion in completions] rewards = [] for content, sol in zip(contents, solution): - try: - answer = parse(content) - reward = float(verify(answer, parse(sol))) - except Exception: # if it fails for any reason, return 0.0 - reward = 0.0 + gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed=True, + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + # Reward 1 if the content is the same as the ground truth, 0 otherwise + reward = float(verify(answer_parsed, gold_parsed)) + else: + # If the gold solution is not parseable, we reward 1 to skip this example + reward = 1.0 + print("Failed to parse gold solution: ", sol) rewards.append(reward) - # Reward 1 if the content is the same as the ground truth, 0 otherwise + return rewards