-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: rigging pr decorator for robopage prs
- Loading branch information
1 parent
801a553
commit a0ba2a7
Showing
2 changed files
with
209 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import asyncio | ||
import base64 | ||
import os | ||
import typing as t | ||
|
||
from pydantic import ConfigDict, StringConstraints | ||
|
||
import rigging as rg | ||
from rigging import logger | ||
from rigging.generator import GenerateParams, Generator, register_generator | ||
|
||
logger.enable("rigging") | ||
|
||
MAX_TOKENS = 8000 | ||
TRUNCATION_WARNING = "\n\n**Note**: Due to the large size of this diff, some content has been truncated." | ||
str_strip = t.Annotated[str, StringConstraints(strip_whitespace=True)] | ||
|
||
|
||
class PRDiffData(rg.Model): | ||
"""XML model for PR diff data""" | ||
|
||
content: str_strip = rg.element() | ||
|
||
@classmethod | ||
def xml_example(cls) -> str: | ||
return """<diff><content>example diff content</content></diff>""" | ||
|
||
|
||
class PRDecorator(Generator): | ||
"""Generator for creating PR descriptions""" | ||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) | ||
|
||
api_key: str = "" | ||
max_tokens: int = MAX_TOKENS | ||
|
||
def __init__(self, model: str, params: rg.GenerateParams) -> None: | ||
api_key = params.extra.get("api_key") | ||
if not api_key: | ||
raise ValueError("api_key is required in params.extra") | ||
|
||
super().__init__(model=model, params=params, api_key=api_key) | ||
self.api_key = api_key | ||
self.max_tokens = params.max_tokens or MAX_TOKENS | ||
|
||
async def generate_messages( | ||
self, | ||
messages: t.Sequence[t.Sequence[rg.Message]], | ||
params: t.Sequence[GenerateParams], | ||
) -> t.Sequence[rg.GeneratedMessage]: | ||
responses = [] | ||
for message_seq, p in zip(messages, params): | ||
base_generator = rg.get_generator(self.model, params=p) | ||
llm_response = await base_generator.generate_messages([message_seq], [p]) | ||
responses.extend(llm_response) | ||
return responses | ||
|
||
|
||
register_generator("pr_decorator", PRDecorator) | ||
|
||
|
||
async def generate_pr_description(diff_text: str) -> str: | ||
"""Generate a PR description from the diff text""" | ||
diff_tokens = len(diff_text) // 4 | ||
if diff_tokens >= MAX_TOKENS: | ||
char_limit = (MAX_TOKENS * 4) - len(TRUNCATION_WARNING) | ||
diff_text = diff_text[:char_limit] + TRUNCATION_WARNING | ||
|
||
diff_data = PRDiffData(content=diff_text) | ||
params = rg.GenerateParams( | ||
extra={ | ||
"api_key": os.environ["OPENAI_API_KEY"], | ||
"diff_text": diff_text, | ||
}, | ||
temperature=0.7, | ||
max_tokens=500, | ||
) | ||
|
||
generator = rg.get_generator("pr_decorator!gpt-4-turbo-preview", params=params) | ||
prompt = f"""You are a helpful AI that generates clear and concise PR descriptions. | ||
Analyze the provided diff between {PRDiffData.xml_example()} tags and create a summary using exactly this format: | ||
### PR Summary | ||
#### Overview of Changes | ||
<overview paragraph> | ||
#### Key Modifications | ||
1. **<modification title>**: <description> | ||
2. **<modification title>**: <description> | ||
3. **<modification title>**: <description> | ||
(continue as needed) | ||
#### Potential Impact | ||
- <impact point 1> | ||
- <impact point 2> | ||
- <impact point 3> | ||
(continue as needed) | ||
Here is the PR diff to analyze: | ||
{diff_data.to_xml()}""" | ||
|
||
chat = await generator.chat(prompt).run() | ||
return chat.last.content.strip() | ||
|
||
|
||
async def main(): | ||
"""Main function for CI environment""" | ||
if not os.environ.get("OPENAI_API_KEY"): | ||
raise ValueError("OPENAI_API_KEY environment variable must be set") | ||
|
||
try: | ||
diff_text = os.environ.get("GIT_DIFF", "") | ||
if not diff_text: | ||
raise ValueError("No diff found in GIT_DIFF environment variable") | ||
|
||
try: | ||
diff_text = base64.b64decode(diff_text).decode("utf-8") | ||
except Exception: | ||
padding = 4 - (len(diff_text) % 4) | ||
if padding != 4: | ||
diff_text += "=" * padding | ||
diff_text = base64.b64decode(diff_text).decode("utf-8") | ||
|
||
logger.debug(f"Processing diff of length: {len(diff_text)}") | ||
description = await generate_pr_description(diff_text) | ||
|
||
with open(os.environ["GITHUB_OUTPUT"], "a") as f: | ||
f.write("content<<EOF\n") | ||
f.write(description) | ||
f.write("\nEOF\n") | ||
f.write(f"debug_diff_length={len(diff_text)}\n") | ||
f.write(f"debug_description_length={len(description)}\n") | ||
debug_preview = description[:500] | ||
f.write("debug_preview<<EOF\n") | ||
f.write(debug_preview) | ||
f.write("\nEOF\n") | ||
|
||
except Exception as e: | ||
logger.error(f"Error in main: {e}") | ||
raise | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
name: Update PR Description with Rigging | ||
|
||
on: | ||
pull_request: | ||
types: [opened, synchronize] | ||
|
||
jobs: | ||
update-description: | ||
runs-on: ubuntu-latest | ||
permissions: | ||
pull-requests: write | ||
contents: read | ||
|
||
steps: | ||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2 | ||
with: | ||
fetch-depth: 0 | ||
|
||
# Get the diff first | ||
- name: Get Diff | ||
id: diff | ||
run: | | ||
git fetch origin ${{ github.base_ref }} | ||
MERGE_BASE=$(git merge-base HEAD origin/${{ github.base_ref }}) | ||
# Encode the diff as base64 to preserve all characters | ||
DIFF=$(git diff $MERGE_BASE..HEAD | base64 -w 0) | ||
echo "diff=$DIFF" >> $GITHUB_OUTPUT | ||
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b #v5.0.3 | ||
with: | ||
python-version: "3.11" | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip cache purge | ||
pip install rigging[all] | ||
# Generate the description using the diff | ||
- name: Generate PR Description | ||
id: description | ||
env: | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} | ||
PR_NUMBER: ${{ github.event.pull_request.number }} | ||
GIT_DIFF: ${{ steps.diff.outputs.diff }} | ||
run: | | ||
python .github/scripts/rigging_pr_decorator.py | ||
# Update the PR description | ||
- name: Update PR Description | ||
uses: nefrob/pr-description@4dcc9f3ad5ec06b2a197c5f8f93db5e69d2fdca7 #v1.2.0 | ||
with: | ||
content: | | ||
## AI-Generated Summary | ||
${{ steps.description.outputs.content }} | ||
--- | ||
This summary was generated with ❤️ by [rigging](https://rigging.dreadnode.io/) | ||
regex: ".*" | ||
regexFlags: s | ||
token: ${{ secrets.GITHUB_TOKEN }} |