Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement: Implement label case conversion and update label descriptions in settings files #530

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pr_agent/algo/utils.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hagai comment:
enable to write the walkthrough for this file directly here

Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,15 @@ def set_custom_labels(variables, git_provider=None):

# Set custom labels
variables["custom_labels_class"] = "class Label(str, Enum):"
counter = 0
labels_minimal_to_labels_dict = {}
for k, v in labels.items():
description = v['description'].strip('\n').replace('\n', '\\n')
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}"
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
counter += 1
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict

def get_user_labels(current_labels: List[str] = None):
"""
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/settings/pr_custom_labels.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Label(str, Enum):
{%- endif %}
class Labels(BaseModel):
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.")
labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
======
Expand Down
9 changes: 6 additions & 3 deletions pr_agent/settings/pr_description_prompts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class FileWalkthrough(BaseModel):
{%- endif %}
{%- if enable_semantic_files_types %}
Class FileDescription(BaseModel):
filename: str = Field(description="the relevant file full path")
changes_summary: str = Field(description="minimal and concise summary of the changes in the relevant file")
Expand All @@ -48,7 +49,7 @@ Class PRDescription(BaseModel):
type: List[PRType] = Field(description="one or more types that describe the PR type. Return the label value, not the name.")
description: str = Field(description="an informative and concise description of the PR. {%- if use_bullet_points %} Use bullet points.{% endif %}")
{%- if enable_custom_labels %}
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.")
labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
{%- endif %}
{%- if enable_file_walkthrough %}
main_files_walkthrough: List[FileWalkthrough] = Field(max_items=10)
Expand All @@ -69,8 +70,10 @@ type:
- ...
{%- if enable_custom_labels %}
labels:
- ...
- ...
- |
...
- |
...
{%- endif %}
description: |-
...
Expand Down
11 changes: 11 additions & 0 deletions pr_agent/tools/pr_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ async def _get_prediction(self, model: str) -> str:

environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)

Expand Down Expand Up @@ -203,6 +204,16 @@ def _prepare_labels(self) -> List[str]:
pr_types = self.data['type']
elif type(self.data['type']) == str:
pr_types = self.data['type'].split(',')

# convert lowercase labels to original case
try:
if "labels_minimal_to_labels_dict" in self.variables:
d: dict = self.variables["labels_minimal_to_labels_dict"]
for i, label_i in enumerate(pr_types):
if label_i in d:
pr_types[i] = d[label_i]
except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
return pr_types

def _prepare_pr_answer_with_markers(self) -> Tuple[str, str]:
Expand Down
11 changes: 11 additions & 0 deletions pr_agent/tools/pr_generate_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ async def _get_prediction(self, model: str) -> str:

environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)

Expand Down Expand Up @@ -170,4 +171,14 @@ def _prepare_labels(self) -> List[str]:
elif type(self.data['labels']) == str:
pr_types = self.data['labels'].split(',')

# convert lowercase labels to original case
try:
if "labels_minimal_to_labels_dict" in self.variables:
d: dict = self.variables["labels_minimal_to_labels_dict"]
for i, label_i in enumerate(pr_types):
if label_i in d:
pr_types[i] = d[label_i]
except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")

return pr_types
Loading