Skip to content

Commit

Permalink
feat: Implement label case conversion and update label descriptions i…
Browse files Browse the repository at this point in the history
…n settings files
  • Loading branch information
mrT23 committed Dec 18, 2023
1 parent 419ae51 commit b28829a
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 6 deletions.
10 changes: 8 additions & 2 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,15 @@ def set_custom_labels(variables, git_provider=None):

# Set custom labels
variables["custom_labels_class"] = "class Label(str, Enum):"
counter = 0
labels_minimal_to_labels_dict = {}
for k, v in labels.items():
description = v['description'].strip('\n').replace('\n', '\\n')
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}"
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
counter += 1
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict

def get_user_labels(current_labels: List[str] = None):
"""
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/settings/pr_custom_labels.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Label(str, Enum):
{%- endif %}
class Labels(BaseModel):
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.")
labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
======
Expand Down
9 changes: 6 additions & 3 deletions pr_agent/settings/pr_description_prompts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class FileWalkthrough(BaseModel):
{%- endif %}
{%- if enable_semantic_files_types %}
Class FileDescription(BaseModel):
filename: str = Field(description="the relevant file full path")
changes_summary: str = Field(description="minimal and concise summary of the changes in the relevant file")
Expand All @@ -48,7 +49,7 @@ Class PRDescription(BaseModel):
type: List[PRType] = Field(description="one or more types that describe the PR type. Return the label value, not the name.")
description: str = Field(description="an informative and concise description of the PR. {%- if use_bullet_points %} Use bullet points.{% endif %}")
{%- if enable_custom_labels %}
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.")
labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
{%- endif %}
{%- if enable_file_walkthrough %}
main_files_walkthrough: List[FileWalkthrough] = Field(max_items=10)
Expand All @@ -69,8 +70,10 @@ type:
- ...
{%- if enable_custom_labels %}
labels:
- ...
- ...
- |
...
- |
...
{%- endif %}
description: |-
...
Expand Down
11 changes: 11 additions & 0 deletions pr_agent/tools/pr_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ async def _get_prediction(self, model: str) -> str:

environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)

Expand Down Expand Up @@ -203,6 +204,16 @@ def _prepare_labels(self) -> List[str]:
pr_types = self.data['type']
elif type(self.data['type']) == str:
pr_types = self.data['type'].split(',')

# convert lowercase labels to original case
try:
if "labels_minimal_to_labels_dict" in self.variables:
d: dict = self.variables["labels_minimal_to_labels_dict"]
for i, label_i in enumerate(pr_types):
if label_i in d:
pr_types[i] = d[label_i]
except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
return pr_types

def _prepare_pr_answer_with_markers(self) -> Tuple[str, str]:
Expand Down
11 changes: 11 additions & 0 deletions pr_agent/tools/pr_generate_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ async def _get_prediction(self, model: str) -> str:

environment = Environment(undefined=StrictUndefined)
set_custom_labels(variables, self.git_provider)
self.variables = variables
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)

Expand Down Expand Up @@ -170,4 +171,14 @@ def _prepare_labels(self) -> List[str]:
elif type(self.data['labels']) == str:
pr_types = self.data['labels'].split(',')

# convert lowercase labels to original case
try:
if "labels_minimal_to_labels_dict" in self.variables:
d: dict = self.variables["labels_minimal_to_labels_dict"]
for i, label_i in enumerate(pr_types):
if label_i in d:
pr_types[i] = d[label_i]
except Exception as e:
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")

return pr_types

0 comments on commit b28829a

Please sign in to comment.