Skip to content

Commit

Permalink
Add tests for patch extension and update configuration for extra line…
Browse files Browse the repository at this point in the history
…s handling

- Added unit tests in `test_extend_patch.py` and `test_pr_generate_extended_diff.py` to verify patch extension functionality with extra lines.
- Updated `pr_processing.py` to include `patch_extra_lines_before` and `patch_extra_lines_after` settings.
- Modified `configuration.toml` to adjust `patch_extra_lines_before` to 4 and `max_context_tokens` to 16000.
- Enabled extra lines in `pr_code_suggestions.py`.
- Added new model `claude-3-5-sonnet` to `__init__.py`.
  • Loading branch information
mrT23 committed Aug 11, 2024
1 parent 61bdfd3 commit e238a88
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 50 deletions.
1 change: 1 addition & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000,
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000,
'claude-3-5-sonnet': 100000,
'groq/llama3-8b-8192': 8192,
'groq/llama3-70b-8192': 8192,
'groq/mixtral-8x7b-32768': 32768,
Expand Down
7 changes: 5 additions & 2 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,13 @@ def get_pr_multi_diffs(git_provider: GitProvider,
for lang in pr_languages:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))


# try first a single run with standard diff string, with patch extension, and no deletions
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks=True)
pr_languages, token_handler, add_line_numbers_to_hunks=True,
patch_extra_lines_before=get_settings().config.patch_extra_lines_before,
patch_extra_lines_after=get_settings().config.patch_extra_lines_after)

# if we are under the limit, return the full diff
if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model):
return ["\n".join(patches_extended)] if patches_extended else []

Expand Down
4 changes: 2 additions & 2 deletions pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ max_commits_tokens = 500
max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities.
custom_model_max_tokens=-1 # for models not in the default list
#
patch_extra_lines_before = 6
patch_extra_lines_before = 4
patch_extra_lines_after = 2
secret_provider=""
cli_mode=false
Expand Down Expand Up @@ -97,7 +97,7 @@ enable_help_text=false


[pr_code_suggestions] # /improve #
max_context_tokens=10000
max_context_tokens=16000
num_code_suggestions=4
commitable_code_suggestions = false
extra_instructions = ""
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ async def _prepare_prediction(self, model: str) -> dict:
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=True)
disable_extra_lines=False)

if self.patches_diff:
get_logger().debug(f"PR diff", artifact=self.patches_diff)
Expand Down
110 changes: 65 additions & 45 deletions tests/unittest/test_extend_patch.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,6 @@

# Generated by CodiumAI


import pytest
from pr_agent.algo.git_patch_processing import extend_patch

"""
Code Analysis
Objective:
The objective of the 'extend_patch' function is to extend a given patch to include a specified number of surrounding
lines. This function takes in an original file string, a patch string, and the number of lines to extend the patch by,
and returns the extended patch string.
Inputs:
- original_file_str: a string representing the original file
- patch_str: a string representing the patch to be extended
- num_lines: an integer representing the number of lines to extend the patch by
Flow:
1. Split the original file string and patch string into separate lines
2. Initialize variables to keep track of the current hunk's start and size for both the original file and the patch
3. Iterate through each line in the patch string
4. If the line starts with '@@', extract the start and size values for both the original file and the patch, and
calculate the extended start and size values
5. Append the extended hunk header to the extended patch lines list
6. Append the specified number of lines before the hunk to the extended patch lines list
7. Append the current line to the extended patch lines list
8. If the line is not a hunk header, append it to the extended patch lines list
9. Return the extended patch string
Outputs:
- extended_patch_str: a string representing the extended patch
Additional aspects:
- The function uses regular expressions to extract the start and size values from the hunk header
- The function handles cases where the start value of a hunk is less than the number of lines to extend by by setting
the extended start value to 1
- The function handles cases where the hunk extends beyond the end of the original file by only including lines up to
the end of the original file in the extended patch
"""
from pr_agent.algo.token_handler import TokenHandler


class TestExtendPatch:
Expand All @@ -48,7 +10,8 @@ def test_happy_path(self):
patch_str = '@@ -2,2 +2,2 @@ init()\n-line2\n+new_line2\nline3'
num_lines = 1
expected_output = '@@ -1,4 +1,4 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4'
actual_output = extend_patch(original_file_str, patch_str, num_lines)
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output

# Tests that the function returns an empty string when patch_str is empty
Expand All @@ -57,14 +20,16 @@ def test_empty_patch(self):
patch_str = ''
num_lines = 1
expected_output = ''
assert extend_patch(original_file_str, patch_str, num_lines) == expected_output
assert extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) == expected_output

# Tests that the function returns the original patch when num_lines is 0
def test_zero_num_lines(self):
original_file_str = 'line1\nline2\nline3\nline4\nline5'
patch_str = '@@ -2,2 +2,2 @@ init()\n-line2\n+new_line2\nline3'
num_lines = 0
assert extend_patch(original_file_str, patch_str, num_lines) == patch_str
assert extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) == patch_str

# Tests that the function returns the original patch when patch_str contains no hunks
def test_no_hunks(self):
Expand All @@ -80,7 +45,8 @@ def test_single_hunk(self):
patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\nline3\nline4'
num_lines = 1
expected_output = '@@ -1,5 +1,5 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4\nline5'
actual_output = extend_patch(original_file_str, patch_str, num_lines)
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output

# Tests the functionality of extending a patch with multiple hunks.
Expand All @@ -89,5 +55,59 @@ def test_multiple_hunks(self):
patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\nline3\nline4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501
num_lines = 1
expected_output = '@@ -1,5 +1,5 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4\nline5\n@@ -3,3 +3,3 @@ init2()\nline3\n-line4\n+new_line4\nline5' # noqa: E501
actual_output = extend_patch(original_file_str, patch_str, num_lines)
actual_output = extend_patch(original_file_str, patch_str,
patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines)
assert actual_output == expected_output


class PRProcessingTest:
class File:
def __init__(self, base_file, patch, filename):
self.base_file = base_file
self.patch = patch
self.filename = filename

@pytest.fixture
def token_handler(self):
# Create a TokenHandler instance with dummy data
th = TokenHandler(system="System prompt", user="User prompt")
th.prompt_tokens = 100
return th

@pytest.fixture
def pr_languages(self):
# Create a list of languages with files containing base_file and patch data
return [
{
'files': [
self.File(base_file="line000\nline00\nline0\nline1\noriginal content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10",
patch="@@ -5,5 +5,5 @@\n-original content\n+modified content\nline2\nline3\nline4\nline5",
filename="file1"),
self.File(base_file="original content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10",
patch="@@ -6,5 +6,5 @@\nline6\nline7\nline8\n-line9\n+modified line9\nline10",
filename="file2")
]
}
]

def test_extend_patches_with_extra_lines(self, token_handler, pr_languages):
patches_extended_no_extra_lines, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks=False,
patch_extra_lines_before=0,
patch_extra_lines_after=0
)

# Check that with no extra lines, the patches are the same as the original patches
p0 = patches_extended_no_extra_lines[0].strip()
p1 = patches_extended_no_extra_lines[1].strip()
assert p0 == '## file1\n\n' + pr_languages[0]['files'][0].patch.strip()
assert p1 == '## file2\n\n' + pr_languages[0]['files'][1].patch.strip()

patches_extended_with_extra_lines, total_tokens, patches_extended_tokens = pr_generate_extended_diff(
pr_languages, token_handler, add_line_numbers_to_hunks=False,
patch_extra_lines_before=2,
patch_extra_lines_after=1
)

p0_extended = patches_extended_with_extra_lines[0].strip()
assert p0_extended == '## file1\n\n@@ -3,8 +3,8 @@ \nline0\nline1\n-original content\n+modified content\nline2\nline3\nline4\nline5\nline6'

0 comments on commit e238a88

Please sign in to comment.