Skip to content

Commit

Permalink
Reward verification and evaluation fixes (#55)
Browse files Browse the repository at this point in the history
* bump up deps, fix aime24 evals, make grpo more strict

* minor fixes

* 🤨 fmt

* Update src/open_r1/grpo.py

Co-authored-by: Quentin Gallouédec <[email protected]>

---------

Co-authored-by: Hynek Kydlicek <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Jan 26, 2025
1 parent 15df4fb commit 90b0947
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"isort>=5.12.0",
"liger_kernel==0.5.2",
"lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
"math-verify", # Used for math verification in grpo
"math-verify>=0.3.2", # Used for math verification in grpo
"packaging>=23.0",
"parameterized>=0.9.0",
"pytest",
Expand Down
1 change: 1 addition & 0 deletions slurm/evaluate.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
--custom-tasks src/open_r1/evaluate.py \
--use-chat-template \
--system-prompt="Please reason step by step, and put your final answer within \boxed{}." \
--save-details
--output-dir $OUTPUT_DIR


Expand Down
26 changes: 22 additions & 4 deletions src/open_r1/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lighteval.utils.language import Language


metric = multilingual_extractive_match_metric(
latex_gold_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
Expand All @@ -33,6 +33,15 @@
aggregation_function=max,
)

expr_gold_metric = multilingual_extractive_match_metric(
language=Language.ENGLISH,
fallback_mode="first_match",
precision=5,
gold_extraction_target=(ExprExtractionConfig(),),
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
aggregation_function=max,
)


def prompt_fn(line, task_name: str = None):
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
Expand All @@ -44,19 +53,28 @@ def prompt_fn(line, task_name: str = None):
)


def aime_prompt_fn(line, task_name: str = None):
return Doc(
task_name=task_name,
query=line["problem"],
choices=[line["answer"]],
gold_index=0,
)


# Define tasks
aime24 = LightevalTaskConfig(
name="aime24",
suite=["custom"],
prompt_function=prompt_fn,
prompt_function=aime_prompt_fn,
hf_repo="HuggingFaceH4/aime_2024",
hf_subset="default",
hf_avail_splits=["train"],
evaluation_splits=["train"],
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[metric],
metric=[expr_gold_metric],
version=1,
)
math_500 = LightevalTaskConfig(
Expand All @@ -70,7 +88,7 @@ def prompt_fn(line, task_name: str = None):
few_shots_split=None,
few_shots_select=None,
generation_size=32768,
metric=[metric],
metric=[latex_gold_metric],
version=1,
)

Expand Down
38 changes: 31 additions & 7 deletions src/open_r1/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from datasets import load_dataset

from math_verify import parse, verify
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config


Expand All @@ -42,13 +43,36 @@ def accuracy_reward(completions, solution, **kwargs):
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
try:
answer = parse(content)
reward = float(verify(answer, parse(sol)))
except Exception: # if it fails for any reason, return 0.0
reward = 0.0
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed=True,
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
# Reward 1 if the content is the same as the ground truth, 0 otherwise

return rewards


Expand Down

0 comments on commit 90b0947

Please sign in to comment.