Skip to content

Commit

Permalink
protections
Browse files Browse the repository at this point in the history
  • Loading branch information
mrT23 committed Apr 12, 2024
1 parent 654f88b commit a4680de
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 23 deletions.
3 changes: 3 additions & 0 deletions pr_agent/agent/pr_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ async def handle_request(self, pr_url, request, notify=None) -> bool:
args = update_settings_from_args(args)

action = action.lstrip("/").lower()
if action not in command2class:
get_logger().debug(f"Unknown command: {action}")
return False
with get_logger().contextualize(command=action):
get_logger().info("PR-Agent request handler started", analytics=True)
if action == "reflect_and_review":
Expand Down
62 changes: 39 additions & 23 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens, ModelType
from pr_agent.algo.utils import get_max_tokens, clip_tokens, ModelType
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider
from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo
Expand Down Expand Up @@ -87,22 +87,34 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
# if we are over the limit, start pruning
get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, "
f"pruning diff.")
patches_compressed, modified_file_names, deleted_file_names, added_file_names = \
patches_compressed, modified_file_names, deleted_file_names, added_file_names, total_tokens_new = \
pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks)

# Insert additional information about added, modified, and deleted files if there is enough space
max_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD
curr_token = total_tokens_new # == token_handler.count_tokens(final_diff)+token_handler.prompt_tokens
final_diff = "\n".join(patches_compressed)
if added_file_names:
delta_tokens = 10
if added_file_names and (max_tokens - curr_token) > delta_tokens:
added_list_str = ADDED_FILES_ + "\n".join(added_file_names)
final_diff = final_diff + "\n\n" + added_list_str
if modified_file_names:
added_list_str = clip_tokens(added_list_str, max_tokens - curr_token)
if added_list_str:
final_diff = final_diff + "\n\n" + added_list_str
curr_token += token_handler.count_tokens(added_list_str) + 2
if modified_file_names and (max_tokens - curr_token) > delta_tokens:
modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names)
final_diff = final_diff + "\n\n" + modified_list_str
if deleted_file_names:
modified_list_str = clip_tokens(modified_list_str, max_tokens - curr_token)
if modified_list_str:
final_diff = final_diff + "\n\n" + modified_list_str
curr_token += token_handler.count_tokens(modified_list_str) + 2
if deleted_file_names and (max_tokens - curr_token) > delta_tokens:
deleted_list_str = DELETED_FILES_ + "\n".join(deleted_file_names)
final_diff = final_diff + "\n\n" + deleted_list_str
deleted_list_str = clip_tokens(deleted_list_str, max_tokens - curr_token)
if deleted_list_str:
final_diff = final_diff + "\n\n" + deleted_list_str
try:
get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, "
f"deleted_list_str: {deleted_list_str}")
f"deleted_list_str: {deleted_list_str}")
except Exception as e:
pass
return final_diff
Expand Down Expand Up @@ -149,7 +161,7 @@ def pr_generate_extended_diff(pr_languages: list,


def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list]:
convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, int]:
"""
Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of
tokens used.
Expand Down Expand Up @@ -195,10 +207,11 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
patch = handle_patch_deletions(patch, original_file_content_str,
new_file_content_str, file.filename, file.edit_type)
if patch is None:
if not deleted_files_list:
total_tokens += token_handler.count_tokens(DELETED_FILES_)
deleted_files_list.append(file.filename)
total_tokens += token_handler.count_tokens(file.filename) + 1
# if not deleted_files_list:
# total_tokens += token_handler.count_tokens(DELETED_FILES_)
if file.filename not in deleted_files_list:
deleted_files_list.append(file.filename)
# total_tokens += token_handler.count_tokens(file.filename) + 1
continue

if convert_hunks_to_line_numbers:
Expand All @@ -219,14 +232,17 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
if get_settings().config.verbosity_level >= 2:
get_logger().warning(f"Patch too large, minimizing it, {file.filename}")
if file.edit_type == EDIT_TYPE.ADDED:
if not added_files_list:
total_tokens += token_handler.count_tokens(ADDED_FILES_)
added_files_list.append(file.filename)
# if not added_files_list:
# total_tokens += token_handler.count_tokens(ADDED_FILES_)
if file.filename not in added_files_list:
added_files_list.append(file.filename)
# total_tokens += token_handler.count_tokens(file.filename) + 1
else:
if not modified_files_list:
total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
modified_files_list.append(file.filename)
total_tokens += token_handler.count_tokens(file.filename) + 1
# if not modified_files_list:
# total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_)
if file.filename not in modified_files_list:
modified_files_list.append(file.filename)
# total_tokens += token_handler.count_tokens(file.filename) + 1
continue

if patch:
Expand All @@ -239,7 +255,7 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}")

return patches, modified_files_list, deleted_files_list, added_files_list
return patches, modified_files_list, deleted_files_list, added_files_list, total_tokens


async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
Expand Down Expand Up @@ -382,4 +398,4 @@ def get_pr_multi_diffs(git_provider: GitProvider,
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)

return final_diff_list
return final_diff_list

0 comments on commit a4680de

Please sign in to comment.