-
Notifications
You must be signed in to change notification settings - Fork 1
/
poc_python_search_replace.py
198 lines (165 loc) · 7.33 KB
/
poc_python_search_replace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import autogen
import pdb
from typing import List, Tuple, Dict, Optional, Union
import re
import tempfile
import os
import pytest
import shutil
import requests
import json
from llm.openrouter import OpenRouterLLM
from agent.custom import CustomAssistantAgent
class CodeModifier:
def __init__(self, api_key: str):
self.llm = OpenRouterLLM(api_key)
self.assistant = CustomAssistantAgent(
name="assistant",
llm_config={"temperature": 0.1},
openrouter_llm=self.llm
)
self.user_proxy = autogen.UserProxyAgent(
name="user_proxy",
human_input_mode="NEVER",
code_execution_config={"work_dir": "workspace"},
max_consecutive_auto_reply=2 # Limit consecutive auto-replies
)
def apply_diff(self, file_path: str, diff_block: str) -> bool:
"""Apply a single diff block to the code file."""
try:
with open(file_path, 'r') as f:
content = f.read()
search_pattern = re.search(r'<<<<<<< SEARCH\n(.*?)\n=======\n(.*?)\n>>>>>>> REPLACE',
diff_block, re.DOTALL)
if not search_pattern:
return False
search_code = search_pattern.group(1)
replace_code = search_pattern.group(2)
if search_code in content: # Only replace if the search pattern exists
new_content = content.replace(search_code, replace_code)
with open(file_path, 'w') as f:
f.write(new_content)
return True
return False
except Exception as e:
print(f"Error applying diff: {e}")
return False
def run_tests(self, test_file: str) -> Tuple[bool, str]:
"""Run pytest and return results."""
try:
result = pytest.main(["-v", test_file])
return result == pytest.ExitCode.OK, "Tests passed" if result == pytest.ExitCode.OK else "Tests failed"
except Exception as e:
return False, str(e)
def process_code(self, problem_description: str, target_files: List[str], max_iterations: int = 3) -> bool:
"""Main workflow to process code modifications and testing."""
with tempfile.TemporaryDirectory() as temp_dir:
# Copy target files to workspace
for file_path in target_files:
if not os.path.exists(file_path):
print(f"Error: File {file_path} not found")
return False
shutil.copy(file_path, temp_dir)
iteration = 0
while iteration < max_iterations:
print(f"\nIteration {iteration + 1}/{max_iterations}")
messages = [
{
"role": "user",
"content": f"""
Problem: {problem_description}
Current code files:
{self._read_files(target_files)}
Please generate:
1. Code modifications in diff format
2. Pytest test cases
**MUST return use this format for code modifications**:
<<<<<<< SEARCH
(original code)
=======
(modified code)
>>>>>>> REPLACE
"""
}
]
# Get response from assistant
chat_response = self.user_proxy.initiate_chat(
self.assistant,
messages=messages
)
# Extract the last message from the assistant
last_message = None
for message in chat_response.chat_history: # Changed from messages to chat_history
if isinstance(message, dict) and message.get("role") == "assistant":
last_message = message.get("content", "")
if not last_message:
print("No response from assistant")
break
# Extract modifications and tests
diff_blocks = self._extract_diff_blocks(last_message)
test_code = self._extract_test_code(last_message)
print(f" diff_blocks ###### {diff_blocks} ####")
#if not diff_blocks or not test_code:
# pdb.set_trace()
# print("No valid modifications or tests found")
# break
# Apply modifications
applied_any = False
for file_path in target_files:
temp_file_path = os.path.join(temp_dir, os.path.basename(file_path))
for diff in diff_blocks:
if self.apply_diff(temp_file_path, diff):
applied_any = True
print(f"applied_any ===== {applied_any}")
#if not applied_any:
# print("No modifications were applied")
# break
# Write and run tests
test_file_path = os.path.join(temp_dir, "test_generated.py")
with open(test_file_path, 'w') as f:
f.write(test_code)
success, test_output = self.run_tests(test_file_path)
print(f"Test result: {test_output}")
if True: #success:
# Copy successful modifications back
for file_path in target_files:
shutil.copy(
os.path.join(temp_dir, os.path.basename(file_path)),
file_path
)
print("Successfully modified code and passed tests")
return True
iteration += 1
if iteration >= max_iterations:
print(f"Reached maximum iterations ({max_iterations})")
break
return False
def _read_files(self, file_paths: List[str]) -> str:
"""Read and format the content of multiple files."""
content = []
for file_path in file_paths:
with open(file_path, 'r') as f:
content.append(f"File: {file_path}\n{f.read()}")
return "\n\n".join(content)
def _extract_diff_blocks(self, response: str) -> List[str]:
"""Extract diff blocks from the response."""
pattern = r'<<<<<<< SEARCH.*?>>>>>>> REPLACE'
return re.findall(pattern, response, re.DOTALL)
def _extract_test_code(self, response: str) -> str:
"""Extract pytest code from the response."""
pattern = r'```python\s*(import pytest.*?)```'
match = re.search(pattern, response, re.DOTALL)
if match:
return match.group(1)
return ""
# Example usage
if __name__ == "__main__":
api_key = os.getenv('OPENROUTER_API_KEY')
modifier = CodeModifier(api_key=api_key)
problem = """
Fix the implementation of the calculate_average function to handle empty lists
and return the correct average of numbers.
"""
target_files = ["math_utils.py"]
success = modifier.process_code(problem, target_files, max_iterations=3)
print(f"\nFinal result: Code modification {'succeeded' if success else 'failed'}")