Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dougollerenshaw committed Sep 18, 2024
1 parent 83f27b8 commit b3118e4
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 36 deletions.
46 changes: 13 additions & 33 deletions codeaide/utils/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,40 +41,20 @@ def parse_response(response):
if not response or not response.content:
return None, None, None, None, None, None

content = response.content[0].text

def extract_json_field(field_name, content, is_code=False):
pattern = rf'"{field_name}"\s*:\s*"((?:\\.|[^"\\])*)"'
match = re.search(pattern, content, re.DOTALL)
if match:
field_content = match.group(1)
if is_code:
# For code, replace escaped newlines with actual newlines, but only within strings
field_content = re.sub(r'(?<!\\)\\n', '\n', field_content)
field_content = re.sub(r'\\(?=["\'])', '', field_content)
else:
# For non-code fields, unescape all content
field_content = field_content.encode().decode('unicode_escape')
return field_content
return None

def extract_json_array(field_name, content):
pattern = rf'"{field_name}"\s*:\s*(\[[^\]]*\])'
match = re.search(pattern, content)
if match:
return json.loads(match.group(1))
return []

text = extract_json_field('text', content)
code = extract_json_field('code', content, is_code=True)
code_version = extract_json_field('code_version', content)
version_description = extract_json_field('version_description', content)
requirements = extract_json_array('requirements', content)

questions_match = re.search(r'"questions"\s*:\s*(\[(?:\s*"(?:\\.|[^"\\])*"\s*,?\s*)*\])', content)
questions = json.loads(questions_match.group(1)) if questions_match else []
try:
content = json.loads(response.content[0].text)

text = content.get('text')
code = content.get('code')
code_version = content.get('code_version')
version_description = content.get('version_description')
requirements = content.get('requirements', [])
questions = content.get('questions', [])

return text, questions, code, code_version, version_description, requirements
return text, questions, code, code_version, version_description, requirements
except json.JSONDecodeError:
print("Error: Received malformed JSON from the API")
return None, None, None, None, None, None

def test_api_connection():
try:
Expand Down
11 changes: 9 additions & 2 deletions codeaide/utils/file_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
import shutil

class FileHandler:
def __init__(self):
self.base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def __init__(self, base_dir=None):
if base_dir is None:
self.base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
else:
self.base_dir = base_dir
self.output_dir = os.path.join(self.base_dir, "generated_code")
self.versions_dict = {}
self._ensure_output_dir_exists()

def _ensure_output_dir_exists(self):
os.makedirs(self.output_dir, exist_ok=True)

def clear_output_dir(self):
print(f"Clearing output directory: {self.output_dir}")
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
anthropic==0.34.2
python-decouple==3.8
virtualenv==20.16.2
pyyaml
pyyaml
pytest
Empty file added tests/__init__.py
Empty file.
Empty file added tests/utils/__init__.py
Empty file.
89 changes: 89 additions & 0 deletions tests/utils/test_api_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import json
from codeaide.utils.api_utils import parse_response
from collections import namedtuple

# Mock Response object
Response = namedtuple('Response', ['content'])
TextBlock = namedtuple('TextBlock', ['text'])

def test_parse_response_empty():
result = parse_response(None)
assert result == (None, None, None, None, None, None)

def test_parse_response_no_content():
response = Response(content=[])
result = parse_response(response)
assert result == (None, None, None, None, None, None)

def test_parse_response_valid():
content = {
"text": "Sample text",
"code": "print('Hello, World!')",
"code_version": "1.0",
"version_description": "Initial version",
"requirements": ["pytest"],
"questions": ["What does this code do?"]
}
response = Response(content=[TextBlock(text=json.dumps(content))])
text, questions, code, code_version, version_description, requirements = parse_response(response)

assert text == "Sample text"
assert questions == ["What does this code do?"]
assert code == "print('Hello, World!')"
assert code_version == "1.0"
assert version_description == "Initial version"
assert requirements == ["pytest"]

def test_parse_response_missing_fields():
content = {
"text": "Sample text",
"code": "print('Hello, World!')"
}
response = Response(content=[TextBlock(text=json.dumps(content))])
text, questions, code, code_version, version_description, requirements = parse_response(response)

assert text == "Sample text"
assert questions == []
assert code == "print('Hello, World!')"
assert code_version is None
assert version_description is None
assert requirements == []

def test_parse_response_complex_code():
content = {
"text": "Complex code example",
"code": 'def hello():\n print("Hello, World!")',
"code_version": "1.1",
"version_description": "Added function",
"requirements": [],
"questions": []
}
response = Response(content=[TextBlock(text=json.dumps(content))])
text, questions, code, code_version, version_description, requirements = parse_response(response)

assert text == "Complex code example"
assert code == 'def hello():\n print("Hello, World!")'
assert code_version == "1.1"
assert version_description == "Added function"

def test_parse_response_escaped_quotes():
content = {
"text": 'Text with "quotes"',
"code": 'print("Hello, \\"World!\\"")\nprint(\'Single quotes\')',
"code_version": "1.2",
"version_description": "Added escaped quotes",
"requirements": [],
"questions": []
}
response = Response(content=[TextBlock(text=json.dumps(content))])
text, questions, code, code_version, version_description, requirements = parse_response(response)

assert text == 'Text with "quotes"'
assert code == 'print("Hello, \\"World!\\"")\nprint(\'Single quotes\')'
assert code_version == "1.2"
assert version_description == "Added escaped quotes"

def test_parse_response_malformed_json():
response = Response(content=[TextBlock(text="This is not JSON")])
result = parse_response(response)
assert result == (None, None, None, None, None, None)
79 changes: 79 additions & 0 deletions tests/utils/test_file_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest
import os
import tempfile
from codeaide.utils.file_handler import FileHandler

@pytest.fixture
def file_handler():
with tempfile.TemporaryDirectory() as temp_dir:
handler = FileHandler(base_dir=temp_dir)
yield handler

def test_clear_output_dir(file_handler):
# Create a file in the output directory
test_file = os.path.join(file_handler.output_dir, "test.txt")
with open(test_file, "w") as f:
f.write("test")

file_handler.clear_output_dir()

assert os.path.exists(file_handler.output_dir)
assert len(os.listdir(file_handler.output_dir)) == 0

def test_save_code(file_handler):
code = "print('Hello, World!')"
version = "1.0"
description = "Initial version"
requirements = ["pytest"]

code_path = file_handler.save_code(code, version, description, requirements)

assert os.path.exists(code_path)
with open(code_path, "r") as f:
assert f.read() == code

assert version in file_handler.versions_dict
assert file_handler.versions_dict[version]['version_description'] == description
assert file_handler.versions_dict[version]['requirements'] == requirements

def test_save_requirements(file_handler):
requirements = ["pytest", "requests"]
version = "1.0"

req_path = file_handler.save_requirements(requirements, version)

assert os.path.exists(req_path)
with open(req_path, "r") as f:
assert f.read().splitlines() == requirements

def test_get_versions_dict(file_handler):
file_handler.save_code("code1", "1.0", "Version 1")
file_handler.save_code("code2", "2.0", "Version 2")

versions_dict = file_handler.get_versions_dict()

assert "1.0" in versions_dict
assert "2.0" in versions_dict

def test_get_code(file_handler):
original_code = "print('Test')"
file_handler.save_code(original_code, "1.0", "Test version")

retrieved_code = file_handler.get_code("1.0")

assert retrieved_code == original_code

def test_get_requirements(file_handler):
original_requirements = ["pytest", "requests"]
file_handler.save_code("code", "1.0", "Test", original_requirements)

retrieved_requirements = file_handler.get_requirements("1.0")

assert retrieved_requirements == original_requirements

def test_nonexistent_version(file_handler):
with pytest.raises(FileNotFoundError):
file_handler.get_code("nonexistent")

with pytest.raises(FileNotFoundError):
file_handler.get_requirements("nonexistent")

0 comments on commit b3118e4

Please sign in to comment.