From a4680ded939401c60e61ab8192da989ad283ad8a Mon Sep 17 00:00:00 2001 From: mrT23 Date: Fri, 12 Apr 2024 20:32:47 +0300 Subject: [PATCH] protections --- pr_agent/agent/pr_agent.py | 3 ++ pr_agent/algo/pr_processing.py | 62 +++++++++++++++++++++------------- 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/pr_agent/agent/pr_agent.py b/pr_agent/agent/pr_agent.py index 5be0f7792..d2542cf2b 100644 --- a/pr_agent/agent/pr_agent.py +++ b/pr_agent/agent/pr_agent.py @@ -73,6 +73,9 @@ async def handle_request(self, pr_url, request, notify=None) -> bool: args = update_settings_from_args(args) action = action.lstrip("/").lower() + if action not in command2class: + get_logger().debug(f"Unknown command: {action}") + return False with get_logger().contextualize(command=action): get_logger().info("PR-Agent request handler started", analytics=True) if action == "reflect_and_review": diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index 90482a029..2c5a29575 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -9,7 +9,7 @@ from pr_agent.algo.language_handler import sort_files_by_main_languages from pr_agent.algo.file_filter import filter_ignored from pr_agent.algo.token_handler import TokenHandler -from pr_agent.algo.utils import get_max_tokens, ModelType +from pr_agent.algo.utils import get_max_tokens, clip_tokens, ModelType from pr_agent.config_loader import get_settings from pr_agent.git_providers.git_provider import GitProvider from pr_agent.algo.types import EDIT_TYPE, FilePatchInfo @@ -87,22 +87,34 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s # if we are over the limit, start pruning get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, " f"pruning diff.") - patches_compressed, modified_file_names, deleted_file_names, added_file_names = \ + patches_compressed, modified_file_names, deleted_file_names, added_file_names, total_tokens_new = \ pr_generate_compressed_diff(pr_languages, token_handler, model, add_line_numbers_to_hunks) + # Insert additional information about added, modified, and deleted files if there is enough space + max_tokens = get_max_tokens(model) - OUTPUT_BUFFER_TOKENS_HARD_THRESHOLD + curr_token = total_tokens_new # == token_handler.count_tokens(final_diff)+token_handler.prompt_tokens final_diff = "\n".join(patches_compressed) - if added_file_names: + delta_tokens = 10 + if added_file_names and (max_tokens - curr_token) > delta_tokens: added_list_str = ADDED_FILES_ + "\n".join(added_file_names) - final_diff = final_diff + "\n\n" + added_list_str - if modified_file_names: + added_list_str = clip_tokens(added_list_str, max_tokens - curr_token) + if added_list_str: + final_diff = final_diff + "\n\n" + added_list_str + curr_token += token_handler.count_tokens(added_list_str) + 2 + if modified_file_names and (max_tokens - curr_token) > delta_tokens: modified_list_str = MORE_MODIFIED_FILES_ + "\n".join(modified_file_names) - final_diff = final_diff + "\n\n" + modified_list_str - if deleted_file_names: + modified_list_str = clip_tokens(modified_list_str, max_tokens - curr_token) + if modified_list_str: + final_diff = final_diff + "\n\n" + modified_list_str + curr_token += token_handler.count_tokens(modified_list_str) + 2 + if deleted_file_names and (max_tokens - curr_token) > delta_tokens: deleted_list_str = DELETED_FILES_ + "\n".join(deleted_file_names) - final_diff = final_diff + "\n\n" + deleted_list_str + deleted_list_str = clip_tokens(deleted_list_str, max_tokens - curr_token) + if deleted_list_str: + final_diff = final_diff + "\n\n" + deleted_list_str try: get_logger().debug(f"After pruning, added_list_str: {added_list_str}, modified_list_str: {modified_list_str}, " - f"deleted_list_str: {deleted_list_str}") + f"deleted_list_str: {deleted_list_str}") except Exception as e: pass return final_diff @@ -149,7 +161,7 @@ def pr_generate_extended_diff(pr_languages: list, def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str, - convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list]: + convert_hunks_to_line_numbers: bool) -> Tuple[list, list, list, list, int]: """ Generate a compressed diff string for a pull request, using diff minimization techniques to reduce the number of tokens used. @@ -195,10 +207,11 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename, file.edit_type) if patch is None: - if not deleted_files_list: - total_tokens += token_handler.count_tokens(DELETED_FILES_) - deleted_files_list.append(file.filename) - total_tokens += token_handler.count_tokens(file.filename) + 1 + # if not deleted_files_list: + # total_tokens += token_handler.count_tokens(DELETED_FILES_) + if file.filename not in deleted_files_list: + deleted_files_list.append(file.filename) + # total_tokens += token_handler.count_tokens(file.filename) + 1 continue if convert_hunks_to_line_numbers: @@ -219,14 +232,17 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo if get_settings().config.verbosity_level >= 2: get_logger().warning(f"Patch too large, minimizing it, {file.filename}") if file.edit_type == EDIT_TYPE.ADDED: - if not added_files_list: - total_tokens += token_handler.count_tokens(ADDED_FILES_) - added_files_list.append(file.filename) + # if not added_files_list: + # total_tokens += token_handler.count_tokens(ADDED_FILES_) + if file.filename not in added_files_list: + added_files_list.append(file.filename) + # total_tokens += token_handler.count_tokens(file.filename) + 1 else: - if not modified_files_list: - total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_) - modified_files_list.append(file.filename) - total_tokens += token_handler.count_tokens(file.filename) + 1 + # if not modified_files_list: + # total_tokens += token_handler.count_tokens(MORE_MODIFIED_FILES_) + if file.filename not in modified_files_list: + modified_files_list.append(file.filename) + # total_tokens += token_handler.count_tokens(file.filename) + 1 continue if patch: @@ -239,7 +255,7 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo if get_settings().config.verbosity_level >= 2: get_logger().info(f"Tokens: {total_tokens}, last filename: {file.filename}") - return patches, modified_files_list, deleted_files_list, added_files_list + return patches, modified_files_list, deleted_files_list, added_files_list, total_tokens async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR): @@ -382,4 +398,4 @@ def get_pr_multi_diffs(git_provider: GitProvider, final_diff = "\n".join(patches) final_diff_list.append(final_diff) - return final_diff_list + return final_diff_list \ No newline at end of file