Skip to content

Commit

Permalink
feat: rigging pr decorator for robopage prs
Browse files Browse the repository at this point in the history
  • Loading branch information
GangGreenTemperTatum committed Nov 25, 2024
1 parent 801a553 commit a0ba2a7
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 0 deletions.
145 changes: 145 additions & 0 deletions .github/scripts/rigging_pr_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
import base64
import os
import typing as t

from pydantic import ConfigDict, StringConstraints

import rigging as rg
from rigging import logger
from rigging.generator import GenerateParams, Generator, register_generator

logger.enable("rigging")

MAX_TOKENS = 8000
TRUNCATION_WARNING = "\n\n**Note**: Due to the large size of this diff, some content has been truncated."
str_strip = t.Annotated[str, StringConstraints(strip_whitespace=True)]


class PRDiffData(rg.Model):
"""XML model for PR diff data"""

content: str_strip = rg.element()

@classmethod
def xml_example(cls) -> str:
return """<diff><content>example diff content</content></diff>"""


class PRDecorator(Generator):
"""Generator for creating PR descriptions"""

model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)

api_key: str = ""
max_tokens: int = MAX_TOKENS

def __init__(self, model: str, params: rg.GenerateParams) -> None:
api_key = params.extra.get("api_key")
if not api_key:
raise ValueError("api_key is required in params.extra")

super().__init__(model=model, params=params, api_key=api_key)
self.api_key = api_key
self.max_tokens = params.max_tokens or MAX_TOKENS

async def generate_messages(
self,
messages: t.Sequence[t.Sequence[rg.Message]],
params: t.Sequence[GenerateParams],
) -> t.Sequence[rg.GeneratedMessage]:
responses = []
for message_seq, p in zip(messages, params):
base_generator = rg.get_generator(self.model, params=p)
llm_response = await base_generator.generate_messages([message_seq], [p])
responses.extend(llm_response)
return responses


register_generator("pr_decorator", PRDecorator)


async def generate_pr_description(diff_text: str) -> str:
"""Generate a PR description from the diff text"""
diff_tokens = len(diff_text) // 4
if diff_tokens >= MAX_TOKENS:
char_limit = (MAX_TOKENS * 4) - len(TRUNCATION_WARNING)
diff_text = diff_text[:char_limit] + TRUNCATION_WARNING

diff_data = PRDiffData(content=diff_text)
params = rg.GenerateParams(
extra={
"api_key": os.environ["OPENAI_API_KEY"],
"diff_text": diff_text,
},
temperature=0.7,
max_tokens=500,
)

generator = rg.get_generator("pr_decorator!gpt-4-turbo-preview", params=params)
prompt = f"""You are a helpful AI that generates clear and concise PR descriptions.
Analyze the provided diff between {PRDiffData.xml_example()} tags and create a summary using exactly this format:
### PR Summary
#### Overview of Changes
<overview paragraph>
#### Key Modifications
1. **<modification title>**: <description>
2. **<modification title>**: <description>
3. **<modification title>**: <description>
(continue as needed)
#### Potential Impact
- <impact point 1>
- <impact point 2>
- <impact point 3>
(continue as needed)
Here is the PR diff to analyze:
{diff_data.to_xml()}"""

chat = await generator.chat(prompt).run()
return chat.last.content.strip()


async def main():
"""Main function for CI environment"""
if not os.environ.get("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY environment variable must be set")

try:
diff_text = os.environ.get("GIT_DIFF", "")
if not diff_text:
raise ValueError("No diff found in GIT_DIFF environment variable")

try:
diff_text = base64.b64decode(diff_text).decode("utf-8")
except Exception:
padding = 4 - (len(diff_text) % 4)
if padding != 4:
diff_text += "=" * padding
diff_text = base64.b64decode(diff_text).decode("utf-8")

logger.debug(f"Processing diff of length: {len(diff_text)}")
description = await generate_pr_description(diff_text)

with open(os.environ["GITHUB_OUTPUT"], "a") as f:
f.write("content<<EOF\n")
f.write(description)
f.write("\nEOF\n")
f.write(f"debug_diff_length={len(diff_text)}\n")
f.write(f"debug_description_length={len(description)}\n")
debug_preview = description[:500]
f.write("debug_preview<<EOF\n")
f.write(debug_preview)
f.write("\nEOF\n")

except Exception as e:
logger.error(f"Error in main: {e}")
raise


if __name__ == "__main__":
asyncio.run(main())
64 changes: 64 additions & 0 deletions .github/workflows/rigging_pr_description.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: Update PR Description with Rigging

on:
pull_request:
types: [opened, synchronize]

jobs:
update-description:
runs-on: ubuntu-latest
permissions:
pull-requests: write
contents: read

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
fetch-depth: 0

# Get the diff first
- name: Get Diff
id: diff
run: |
git fetch origin ${{ github.base_ref }}
MERGE_BASE=$(git merge-base HEAD origin/${{ github.base_ref }})
# Encode the diff as base64 to preserve all characters
DIFF=$(git diff $MERGE_BASE..HEAD | base64 -w 0)
echo "diff=$DIFF" >> $GITHUB_OUTPUT
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b #v5.0.3
with:
python-version: "3.11"

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip cache purge
pip install rigging[all]
# Generate the description using the diff
- name: Generate PR Description
id: description
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
PR_NUMBER: ${{ github.event.pull_request.number }}
GIT_DIFF: ${{ steps.diff.outputs.diff }}
run: |
python .github/scripts/rigging_pr_decorator.py
# Update the PR description
- name: Update PR Description
uses: nefrob/pr-description@4dcc9f3ad5ec06b2a197c5f8f93db5e69d2fdca7 #v1.2.0
with:
content: |
## AI-Generated Summary
${{ steps.description.outputs.content }}
---
This summary was generated with ❤️ by [rigging](https://rigging.dreadnode.io/)
regex: ".*"
regexFlags: s
token: ${{ secrets.GITHUB_TOKEN }}

0 comments on commit a0ba2a7

Please sign in to comment.