Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ etl pr cmd #3646

Merged
merged 4 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions apps/pr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"""

import hashlib
import os
import re
import uuid
from typing import Optional, cast
Expand All @@ -55,6 +56,7 @@
from structlog import get_logger

from apps.pr.categories import PR_CATEGORIES, PR_CATEGORIES_CHOICES
from apps.utils.gpt import OpenAIWrapper
from etl.config import GITHUB_TOKEN
from etl.paths import BASE_DIR

Expand Down Expand Up @@ -131,6 +133,12 @@
is_flag=True,
help="By default, staging server site (not admin) will be publicly accessible. Use --private to have it private instead. This does not apply when using --direct mode.",
)
@click.option(
"--no-llm",
"-n",
is_flag=True,
help="We briefly use LLMs to simplify the title and use it in the branch name. Disable this by using -n flag.",
)
def cli(
title: str,
category: Optional[str],
Expand All @@ -139,6 +147,7 @@ def cli(
base_branch: str,
direct: bool,
private: bool,
no_llm: bool,
# base_branch: Optional[str] = None,
) -> None:
# Check that the user has set up a GitHub token.
Expand Down Expand Up @@ -166,6 +175,8 @@ def cli(
work_branch=work_branch,
direct=direct,
pr_title=pr_title,
remote_branches=remote_branches,
no_llm=no_llm,
)

# Check branches main & work make sense!
Expand Down Expand Up @@ -247,13 +258,13 @@ def init_repo():
return repo, remote_branches


def ensure_work_branch(repo, work_branch, direct, pr_title):
def ensure_work_branch(repo, work_branch, direct, pr_title, remote_branches, no_llm):
"""Get name of new branch if not provided."""
# If no name for new branch is given
if work_branch is None:
if not direct:
# Generate name for new branch
work_branch = bake_branch_name(repo, pr_title)
work_branch = bake_branch_name(repo, pr_title, no_llm, remote_branches)
else:
# If not explicitly given, the new branch will be the current branch.
work_branch = repo.active_branch.name
Expand Down Expand Up @@ -353,24 +364,32 @@ def _generate_pr_title(title: str, category: str, scope: str | None) -> Optional
return title


def bake_branch_name(repo, pr_title):
def bake_branch_name(repo, pr_title, no_llm, remote_branches):
# Get user
git_config = repo.config_reader()
user = git_config.get_value("user", "name")
# git_config = repo.config_reader()
# user = git_config.get_value("user", "name").lower()

# Get category
category = pr_title.category

# Get input title (without emoji, scope, etc.)
title = _extract_relevant_title_for_branch_name(pr_title.title)
title = _extract_relevant_title_for_branch_name(pr_title.title, not no_llm)

# Bake complete PR branch name
name = f"{user}-{category}-{title}"

# name = f"{user}-{category}-{title}"
name = f"{category}-{title}"

# If branch name collision
# if name in remote_branches:
# log.info("Generating a hash for this branch name to prevent name collisions.")
# name = f"{name}-{user}"
if name in remote_branches:
log.info("Generating a hash for this branch name to prevent name collisions.")
name = f"{name}-{generate_short_hash()}"
return name


def _extract_relevant_title_for_branch_name(text_in: str) -> str:
def _extract_relevant_title_for_branch_name(text_in: str, use_llm) -> str:
"""
Process the input string by:
1. Removing all symbols, keeping only letters and numbers.
Expand All @@ -384,19 +403,19 @@ def _extract_relevant_title_for_branch_name(text_in: str) -> str:
Returns:
str: The processed string.
"""
# Remove all symbols, keep only letters and numbers
if use_llm:
if "OPENAI_API_KEY" in os.environ:
text_in = summarize_title_llm(text_in)

cleaned_text = re.sub(r"[^a-zA-Z0-9\s]", "", text_in)

# Split into tokens/words
tokens = cleaned_text.split()
# Keep only the first 3 tokens
tokens = tokens[:3]
# Combine tokens with '-'
name = "-".join(tokens).lower()

# Add hash to prevent collisions
hash_txt = generate_short_hash()
name = f"{name}-{hash_txt}"

return name


Expand All @@ -410,3 +429,11 @@ def generate_short_hash() -> str:
random_data = uuid.uuid4().hex # Generate random data
random_hash = hashlib.sha256(random_data.encode()).hexdigest() # Create hash
return random_hash[:6] # Return the first 6 characters


def summarize_title_llm(title) -> str:
sys_prompt = "You are given a title of a pull request. I need a 2-3 keyword summary, separated by a space. These words will be used to create a branch name."
api = OpenAIWrapper()
log.info("Querying GPT!")
response = api.query_gpt_fast(title, sys_prompt, model="gpt-4o-mini")
return response
15 changes: 15 additions & 0 deletions apps/utils/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,21 @@ def query_gpt(
else:
raise ValueError("message_content is expected to be a string!")

def query_gpt_fast(self, user_prompt: str, system_prompt: str, model: str = MODEL_DEFAULT) -> str:
"""Query Chat GPT to get message content from the chat completion."""
query = GPTQuery(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
)
response = self.query_gpt(query=query, model=model)

if isinstance(response, GPTResponse):
return response.message_content
else:
raise ValueError("message_content is expected to be a string!")


def get_number_tokens(text: str, model_name: str) -> int:
"""Get number of tokens of text.
Expand Down