-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauto_test_generation.py
75 lines (55 loc) · 2.46 KB
/
auto_test_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import sys
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.document_loaders import PythonLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
template = """
I will provide you with a list of python code, including tests. I will also provide you with a class which I want you
to generate a test file for. When generating tests, use the provided code as a reference
and try to keep in mind the style and structure of the existing tests.
Here is the code:
<code>
{context}
</code>
Class to test: {class_to_test}
File in which we will store the class: {test_class}
Now, generate a unit test for the class {class_to_test} using the provided code as a reference.
* This code will be written directly to file, so include all necessary imports and setup.
* This code should be a complete unit test for the class {class_to_test}.
* This code should be written in Python.
* This code will be written directly to file, so only return code - do not return any extra information, no ```python or ``` at all
* Consider the file in which we will story the class, make sure that your imports are relative to that file.
"""
def format_docs(docs: list[dict]):
return "\n\n".join(["{0}:\n\n{1}".format(d.metadata["source"], d.page_content) for d in docs])
prompt = PromptTemplate.from_template(template)
llm = ChatGoogleGenerativeAI(model="gemini-pro")
def generate_tests(class_to_test, test_class):
retriever = DirectoryLoader(".", glob="*.py", loader_cls=PythonLoader,
recursive=True, show_progress=True).load()
docs = []
for doc in retriever:
if "venv" in doc.metadata["source"]:
continue
docs.append(doc)
chain = (
prompt
| llm
| StrOutputParser()
)
print("Generating tests for class: ", class_to_test)
open("all_files.txt", "w").write(format_docs(docs))
result = chain.invoke(
input={"test_class": test_class,
"class_to_test": class_to_test,
"context": format_docs(docs)
}
)
open(test_class, "w").write(result.removeprefix("```python\n").removesuffix("\n```"))
if __name__ == "__main__":
class_to_test = sys.argv[1]
test_class = sys.argv[2]
generate_tests(class_to_test, test_class)