Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Testing Github workflow] Updating workflows and makefile #214

Merged
merged 6 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .github/workflows/quality.yml → .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Quality
name: Tests

on:
push:
Expand All @@ -11,8 +11,8 @@ on:

jobs:

check_code_quality:
name: Check code quality
tests:
name: Run tests and quality checks
runs-on: ubuntu-latest
steps:
- name: Checkout code
Expand All @@ -24,8 +24,11 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[quality]"
python -m pip install ".[quality,tests]"
- name: Code quality
run: |
make quality
- name: Run tests
run: |
make test

2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ quality:
isort --check-only $(check_dirs) setup.py
flake8 --max-line-length 119 $(check_dirs) setup.py

test:
pytest -sv tests/

# Evaluation

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def deps_list(*pkgs):


extras = {}
extras["tests"] = deps_list("pytest", "parameterized")
extras["tests"] = deps_list("pytest", "parameterized", "math-verify")
extras["torch"] = deps_list("torch")
extras["quality"] = deps_list("ruff", "isort", "flake8")
extras["train"] = deps_list("flash_attn")
Expand Down
2 changes: 1 addition & 1 deletion src/open_r1/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def format_reward(completions, **kwargs):


def reasoning_steps_reward(completions, **kwargs):
"""Reward function that checks for clear step-by-step reasoning.
r"""Reward function that checks for clear step-by-step reasoning.
Regex pattern:
Step \d+: - matches "Step 1:", "Step 2:", etc.
^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
Expand Down
4 changes: 3 additions & 1 deletion tests/test_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def test_positive_max_penalty_raises_value_error(self):

def test_zero_max_penalty_returns_zero(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=0.0)
self.assertEqual(reward_fn, 0)
completions = [[{"content": "test test test"}]]
rewards = reward_fn(completions)
self.assertEqual(rewards, [0.0])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was failing, because it was comparing int to function. I guess this was the intention

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove this test


def test_no_repetition(self):
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
Expand Down