From 27038e9993cc35d3d61051fe40f641df40e7dd44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 8 Apr 2025 14:52:40 +0000 Subject: [PATCH] log correct/incorrect length --- trl/trainer/grpo_trainer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7ac95ba3d8..baa5046f09 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -959,6 +959,19 @@ def _generate_and_score_completions( self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) + # Get the length of the completions for "accuracy_reward" function + if "accuracy_reward" in reward_func_names: + accuracy_reward_idx = reward_func_names.index("accuracy_reward") + accuracy_reward = rewards_per_func[:, accuracy_reward_idx] + + length_correct = agg_completion_mask[accuracy_reward == 1.0].float().nanmean() + if not torch.isnan(length_correct): + self._metrics[mode]["length_correct"].append(length_correct.item()) + + length_incorrect = agg_completion_mask[accuracy_reward == 0.0].float().nanmean() + if not torch.isnan(length_incorrect): + self._metrics[mode]["length_incorrect"].append(length_incorrect.item()) + # Log prompt and completion texts self._textual_logs["prompt"].extend(gather_object(prompts_text)) self._textual_logs["completion"].extend(gather_object(completions_text))