Skip to content

Commit

Permalink
Adds git fixture to pytest example
Browse files Browse the repository at this point in the history
So that one can keep track of versions with data generated.
  • Loading branch information
skrawcz committed Dec 29, 2024
1 parent 91a6b33 commit 7cc3599
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 4 deletions.
83 changes: 83 additions & 0 deletions examples/pytest/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,39 @@ def test_my_agent(input, expected_output):
assert actual_output == expected_output
# assert some other property of the output...
```

### pytest fixtures

Another useful construct to know are pytest fixtures. A "fixture" is a function that is
used to provide a fixed baseline upon which tests can reliably and repeatedly execute.
They are used to set up preconditions for a test, such as creating test data, initializing
objects, or establishing database connections, etc. To use one in pytest, you just need to
declare a function and annotate it:

```python
import pytest

@pytest.fixture(scope="module")
def database_connection():
"""Fixture that creates a DB connection"""
db_client = SomeDBClient()
yield db_client
print("\nStopped client:\n")
```

Then to use it, one just needs to "declare" it as a function parameter for a test.

```python
def test_my_function(database_connection):
"""pytest will inject the result of the 'database_connection' function into `database_connection` here in this test function"""
...

def test_my_other_function(database_connection):
"""pytest will inject the result of the 'database_connection' function into `database_connection` here in this test function"""
...
```


What we've shown above will fail on the first assertion failure. But what if we want to evaluate all the outputs before making a pass / fail decision?

### What kind of "asserts" do we want?
Expand Down Expand Up @@ -318,11 +351,61 @@ def test_an_actions_stability():
assert len(variances) == 0, "Outputs vary across iterations:\n" + variances_str
```

# Capturing versions of your prompts to go with the datasets you generate via pytest
As you start to iterate and generate datasets (that's what happens if you log the output of the dataframe), one needs to tie together the version of the code that generated the data set with the data set itself. This is useful for debugging, and for ensuring that you can reproduce results. One way to do this is to capture the version of the code that generated the data set in the data set itself. This can be done by using the `gitpython` library to capture the git commit hash of the code that generated the data set,
i.e. the prompts + business logic. If you treat prompts as code, then here's how you might do it:

1. Use git to commit changes.
2. Then create a pytest fixture that captures the git commit hash of the current state of the git repo.
3. When you log the results of your tests, log the git commit hash as well as a column.
4. When you then load / look at the data set, you can see the git commit hash that generated the data set to tie it back to the code that generated it.

```python
import pytest
import subprocess

@pytest.fixture
def git_info():
"""Fixture that returns the git commit, branch, latest_tag.
Note if there are uncommitted changes, the commit will have '-dirty' appended.
"""
try:
commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip().decode('utf-8')
dirty = subprocess.check_output(['git', 'status', '--porcelain']).strip() != b''
commit = f"{commit}{'-dirty' if dirty else ''}"
except subprocess.CalledProcessError:
commit = None
try:
latest_tag = subprocess.check_output(['git', 'describe', '--tags', '--abbrev=0']).strip().decode('utf-8')
except subprocess.CalledProcessError:
latest_tag = None
try:
branch = subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD']).strip().decode('utf-8')
except subprocess.CalledProcessError:
branch = None
return {'commit': commit, 'latest_tag': latest_tag, "branch": branch}
```

Then to use it - we'd add the fixture to the function that saves the results:

```python
def test_print_results(module_results_df, git_info):
"""Function that uses pytest-harvest and our custom git fixture that goes at the end of the module to evaluate & save the results."""
...
# add the git information
module_results_df["git_commit"] = git_info["commit"]
module_results_df["git_latest_tag"] = git_info["latest_tag"]
# save results
module_results_df.to_csv("results.csv")
```

# An example
Here in this directory we have:

- `some_actions.py` - a file that defines an augmented LLM application (it's not a full agent) with some actions. See image below - note the hypotheses action runs multiple in parallel.
- `test_some_actions.py` - a file that defines some tests for the actions in `some_actions.py`.
- `conftest.py` - a file that defines some fixtures & pytest configuration for the tests in `test_some_actions.py`.

![toy example](diagnosis.png)

Expand Down
40 changes: 40 additions & 0 deletions examples/pytest/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import subprocess

import pytest


Expand All @@ -23,3 +25,41 @@ def result_collector():
collector = ResultCollector()
yield collector
print("\nCollected Results:\n", collector)


@pytest.fixture
def git_info():
"""Fixture that returns the git commit, branch, latest_tag.
Note if there are uncommitted changes, the commit will have '-dirty' appended.
"""
try:
commit = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8")
dirty = subprocess.check_output(["git", "status", "--porcelain"]).strip() != b""
commit = f"{commit}{'-dirty' if dirty else ''}"
except subprocess.CalledProcessError:
commit = None
try:
latest_tag = (
subprocess.check_output(["git", "describe", "--tags", "--abbrev=0"])
.strip()
.decode("utf-8")
)
except subprocess.CalledProcessError:
latest_tag = None
try:
branch = (
subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
.strip()
.decode("utf-8")
)
except subprocess.CalledProcessError:
branch = None
return {"commit": commit, "latest_tag": latest_tag, "branch": branch}


def pytest_configure(config):
"""Code to stop custom mark warnings"""
config.addinivalue_line(
"markers", "file_name: mark test to run using Burr file-based parameterization."
)
11 changes: 7 additions & 4 deletions examples/pytest/test_some_actions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""This module shows example tests for testing actions and agents."""
import pytest
import some_actions

from burr.core import state
from burr.tracking import LocalTrackingClient

from examples.pytest import some_actions


def test_example1(result_collector):
"""Example test that uses a custom fixture."""
Expand Down Expand Up @@ -157,10 +156,12 @@ def test_run_hypothesis_burr_fixture_e2e_with_tracker(input_state, expected_stat
assert agent_state["final_diagnosis"] != ""


def test_print_results(module_results_df):
"""This is an example using pytest-harvest to return results to a central location.
def test_print_results(module_results_df, git_info):
"""This is an example using pytest-harvest and our custom git fixture to return results to a central location.
You could use other plugins, or create your own fixtures (e.g. see conftest.py for a simpler custom fixture).
"""
module_results_df["git_commit"] = git_info["commit"]
module_results_df["git_branch"] = git_info["branch"]
print(module_results_df.columns)
print(module_results_df.head())
# compute statistics
Expand All @@ -173,6 +174,8 @@ def test_print_results(module_results_df):
tests_of_interest[
[
"test_function",
"git_branch",
"git_commit",
"duration_ms",
"status",
"input",
Expand Down

0 comments on commit 7cc3599

Please sign in to comment.