From e5bee757c432aa77902a6b8c13f94f37f1174502 Mon Sep 17 00:00:00 2001 From: dougollerenshaw Date: Mon, 23 Sep 2024 15:15:29 -0700 Subject: [PATCH 1/7] Started making changes to make model selectable --- codeaide/logic/chat_handler.py | 38 ++++++++++++------ codeaide/utils/api_utils.py | 73 ++++++++++++++++++++++++---------- codeaide/utils/constants.py | 28 +++++++++++-- 3 files changed, 101 insertions(+), 38 deletions(-) diff --git a/codeaide/logic/chat_handler.py b/codeaide/logic/chat_handler.py index 78429e0..cd4a636 100644 --- a/codeaide/logic/chat_handler.py +++ b/codeaide/logic/chat_handler.py @@ -10,7 +10,12 @@ save_api_key, MissingAPIKeyException, ) -from codeaide.utils.constants import MAX_RETRIES, MAX_TOKENS +from codeaide.utils.constants import ( + MAX_RETRIES, + AI_PROVIDERS, + DEFAULT_MODEL, + DEFAULT_PROVIDER, +) from codeaide.utils.cost_tracker import CostTracker from codeaide.utils.environment_manager import EnvironmentManager from codeaide.utils.file_handler import FileHandler @@ -37,7 +42,8 @@ def __init__(self): self.latest_version = "0.0" self.api_client = None self.api_key_set = False - self.current_service = "anthropic" # Default service + self.current_provider = DEFAULT_PROVIDER + self.current_model = DEFAULT_MODEL def check_api_key(self): """ @@ -50,26 +56,26 @@ def check_api_key(self): tuple: A tuple containing a boolean indicating if the API key is valid and a message. """ if self.api_client is None: - self.api_client = get_api_client(self.current_service) + self.api_client = get_api_client(self.current_provider, self.current_model) if self.api_client: self.api_key_set = True return True, None else: self.api_key_set = False - return False, self.get_api_key_instructions(self.current_service) + return False, self.get_api_key_instructions(self.current_provider) - def get_api_key_instructions(self, service): + def get_api_key_instructions(self, provider): """ - Get instructions for setting up the API key for a given service. + Get instructions for setting up the API key for a given provider. Args: - service (str): The name of the service. + provider (str): The name of the provider. Returns: str: Instructions for setting up the API key. """ - if service == "anthropic": + if provider == "anthropic": return ( "It looks like you haven't set up your Anthropic API key yet. " "Here's how to get started:\n\n" @@ -83,7 +89,7 @@ def get_api_key_instructions(self, service): "Please paste your Anthropic API key now:" ) else: - return f"Please enter your API key for {service.capitalize()}:" + return f"Please enter your API key for {provider.capitalize()}:" def validate_api_key(self, api_key): """ @@ -122,9 +128,11 @@ def handle_api_key_input(self, api_key): cleaned_key = api_key.strip().strip("'\"") # Remove quotes and whitespace is_valid, error_message = self.validate_api_key(cleaned_key) if is_valid: - if save_api_key(self.current_service, cleaned_key): + if save_api_key(self.current_provider, cleaned_key): # Try to get a new API client with the new key - self.api_client = get_api_client(self.current_service) + self.api_client = get_api_client( + self.current_provider, self.current_model + ) if self.api_client: self.api_key_set = True return True, "API key saved and verified successfully." @@ -155,7 +163,7 @@ def process_input(self, user_input): if not self.check_and_set_api_key(): return { "type": "api_key_required", - "message": self.get_api_key_instructions(self.current_service), + "message": self.get_api_key_instructions(self.current_provider), } self.add_user_input_to_history(user_input) @@ -174,6 +182,7 @@ def process_input(self, user_input): try: return self.process_ai_response(response) except ValueError as e: + print(f"ValueError: {str(e)}\n") if not self.is_last_attempt(attempt): self.add_error_prompt_to_history(str(e)) else: @@ -226,7 +235,10 @@ def get_ai_response(self): Returns: dict: The response from the AI API, or None if the request failed. """ - return send_api_request(self.api_client, self.conversation_history, MAX_TOKENS) + max_tokens = AI_PROVIDERS[self.current_provider]["models"][self.current_model][ + "max_tokens" + ] + return send_api_request(self.api_client, self.conversation_history, max_tokens) def is_last_attempt(self, attempt): """ diff --git a/codeaide/utils/api_utils.py b/codeaide/utils/api_utils.py index bb269f1..5a63e17 100644 --- a/codeaide/utils/api_utils.py +++ b/codeaide/utils/api_utils.py @@ -5,7 +5,12 @@ from anthropic import APIError from decouple import config, AutoConfig -from codeaide.utils.constants import AI_MODEL, MAX_TOKENS, SYSTEM_PROMPT +from codeaide.utils.constants import ( + AI_PROVIDERS, + DEFAULT_MODEL, + DEFAULT_PROVIDER, + SYSTEM_PROMPT, +) class MissingAPIKeyException(Exception): @@ -16,7 +21,7 @@ def __init__(self, service): ) -def get_api_client(service="anthropic"): +def get_api_client(provider=DEFAULT_PROVIDER, model=DEFAULT_MODEL): try: # Force a reload of the configuration auto_config = AutoConfig( @@ -24,16 +29,21 @@ def get_api_client(service="anthropic"): os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) ) - api_key = auto_config(f"{service.upper()}_API_KEY", default=None) + api_key_name = AI_PROVIDERS[provider]["api_key_name"] + api_key = auto_config(api_key_name, default=None) if api_key is None or api_key.strip() == "": return None # Return None if API key is missing or empty - if service.lower() == "anthropic": + if provider.lower() == "anthropic": return anthropic.Anthropic(api_key=api_key) + elif provider.lower() == "openai": + # You'll need to import the OpenAI client and implement this part + # For now, we'll just raise an error + raise NotImplementedError("OpenAI client not yet implemented") else: - raise ValueError(f"Unsupported service: {service}") + raise ValueError(f"Unsupported provider: {provider}") except Exception as e: - print(f"Error initializing {service.capitalize()} API client: {str(e)}") + print(f"Error initializing {provider.capitalize()} API client: {str(e)}") return None @@ -70,11 +80,15 @@ def save_api_key(service, api_key): return False -def send_api_request(client, conversation_history, max_tokens=MAX_TOKENS): +def send_api_request(api_client, conversation_history, max_tokens): system_prompt = SYSTEM_PROMPT + + print(f"Sending API request with max_tokens: {max_tokens}") + print(f"Conversation history: {conversation_history}\n") + try: - response = client.messages.create( - model=AI_MODEL, + response = api_client.messages.create( + model=DEFAULT_MODEL, max_tokens=max_tokens, messages=conversation_history, system=system_prompt, @@ -89,22 +103,37 @@ def send_api_request(client, conversation_history, max_tokens=MAX_TOKENS): def parse_response(response): if not response or not response.content: - return None, None, None, None, None, None + raise ValueError("Empty or invalid response received") + + print(f"Received response: {response}\n") + + # Extract the JSON string + json_str = response.content[0].text + + # Escape newlines within the "code" field + json_str = re.sub( + r'("code"\s*:\s*")(.+?)(")', + lambda m: m.group(1) + m.group(2).replace("\n", "\\n") + m.group(3), + json_str, + flags=re.DOTALL, + ) try: - content = json.loads(response.content[0].text) + # Parse the outer structure + outer_json = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError( + f"Failed to parse JSON: {str(e)}\nProblematic JSON string: {json_str}" + ) - text = content.get("text") - code = content.get("code") - code_version = content.get("code_version") - version_description = content.get("version_description") - requirements = content.get("requirements", []) - questions = content.get("questions", []) + text = outer_json.get("text") + code = outer_json.get("code") + code_version = outer_json.get("code_version") + version_description = outer_json.get("version_description") + requirements = outer_json.get("requirements", []) + questions = outer_json.get("questions", []) - return text, questions, code, code_version, version_description, requirements - except json.JSONDecodeError: - print("Error: Received malformed JSON from the API") - return None, None, None, None, None, None + return text, questions, code, code_version, version_description, requirements def check_api_connection(): @@ -113,7 +142,7 @@ def check_api_connection(): return False, "API key is missing or invalid" try: response = client.messages.create( - model=AI_MODEL, + model=DEFAULT_MODEL, max_tokens=100, messages=[{"role": "user", "content": "Hi Claude, are we communicating?"}], ) diff --git a/codeaide/utils/constants.py b/codeaide/utils/constants.py index adc1ad2..583867b 100644 --- a/codeaide/utils/constants.py +++ b/codeaide/utils/constants.py @@ -1,7 +1,29 @@ # API Configuration -MAX_TOKENS = 8192 # This is the maximum token limit for the API -AI_MODEL = "claude-3-5-sonnet-20240620" -MAX_RETRIES = 3 # Maximum number of retries for API requests (in case of errors or responses that can't be parsed) +AI_PROVIDERS = { + "anthropic": { + "api_key_name": "ANTHROPIC_API_KEY", + "models": { + "claude-3-opus-20240229": {"max_tokens": 8192}, + "claude-3-5-sonnet-20240620": {"max_tokens": 8192}, + "claude-3-haiku-20240307": {"max_tokens": 4096}, + }, + }, + "openai": { + "api_key_name": "OPENAI_API_KEY", + "models": { + "gpt-4": {"max_tokens": 8192}, + "gpt-4-32k": {"max_tokens": 32768}, + "gpt-3.5-turbo": {"max_tokens": 4096}, + }, + }, +} + +# Default model (we'll keep this for backwards compatibility) +DEFAULT_MODEL = "claude-3-haiku-20240307" # "claude-3-5-sonnet-20240620" +DEFAULT_PROVIDER = "anthropic" + +# Other existing constants remain unchanged +MAX_RETRIES = 3 # UI Configuration CHAT_WINDOW_WIDTH = 800 From 2a887e1ec78888bb0140f0cb13b5dc7636dca741 Mon Sep 17 00:00:00 2001 From: dougollerenshaw Date: Wed, 25 Sep 2024 09:18:43 -0700 Subject: [PATCH 2/7] More changes to get openai api working --- codeaide/logic/chat_handler.py | 66 +++++++++++++++++---- codeaide/ui/chat_window.py | 105 +++++++++++++++++++++++++++++---- codeaide/utils/api_utils.py | 87 +++++++++++++++------------ codeaide/utils/constants.py | 8 ++- codeaide/utils/cost_tracker.py | 26 +------- requirements.txt | 2 + 6 files changed, 207 insertions(+), 87 deletions(-) diff --git a/codeaide/logic/chat_handler.py b/codeaide/logic/chat_handler.py index cd4a636..ea479d3 100644 --- a/codeaide/logic/chat_handler.py +++ b/codeaide/logic/chat_handler.py @@ -44,6 +44,9 @@ def __init__(self): self.api_key_set = False self.current_provider = DEFAULT_PROVIDER self.current_model = DEFAULT_MODEL + self.max_tokens = AI_PROVIDERS[self.current_provider]["models"][ + self.current_model + ]["max_tokens"] def check_api_key(self): """ @@ -235,10 +238,13 @@ def get_ai_response(self): Returns: dict: The response from the AI API, or None if the request failed. """ - max_tokens = AI_PROVIDERS[self.current_provider]["models"][self.current_model][ - "max_tokens" - ] - return send_api_request(self.api_client, self.conversation_history, max_tokens) + return send_api_request( + self.api_client, + self.conversation_history, + self.max_tokens, + self.current_model, + self.current_provider, + ) def is_last_attempt(self, attempt): """ @@ -265,9 +271,13 @@ def process_ai_response(self, response): Raises: ValueError: If the response cannot be parsed or the version is invalid. """ - parsed_response = parse_response(response) - if parsed_response[0] is None: - raise ValueError("Failed to parse JSON response") + try: + parsed_response = parse_response(response, provider=self.current_provider) + except (ValueError, json.JSONDecodeError) as e: + error_message = ( + f"Failed to parse AI response: {str(e)}\nRaw response: {response}" + ) + raise ValueError(error_message) ( text, @@ -304,9 +314,16 @@ def update_conversation_history(self, response): Returns: None """ - self.conversation_history.append( - {"role": "assistant", "content": response.content[0].text} - ) + if self.current_provider.lower() == "anthropic": + self.conversation_history.append( + {"role": "assistant", "content": response.content[0].text} + ) + elif self.current_provider.lower() == "openai": + self.conversation_history.append( + {"role": "assistant", "content": response.choices[0].message.content} + ) + else: + raise ValueError(f"Unsupported provider: {provider}") def create_questions_response(self, text, questions): """ @@ -382,7 +399,7 @@ def add_error_prompt_to_history(self, error_message): Returns: None """ - error_prompt = f"\n\nThere was an error in your response: {error_message}. Please ensure you're using proper JSON formatting and incrementing the version number correctly. The latest version was {self.latest_version}, so the new version must be higher than this." + error_prompt = f"\n\nThere was an error in your last response: {error_message}. Please ensure you're using proper JSON formatting to avoid this error and others like it." self.conversation_history[-1]["content"] += error_prompt def handle_unexpected_error(self, e): @@ -468,3 +485,30 @@ def is_task_in_progress(self): bool: True if there's an ongoing task, False otherwise. """ return bool(self.conversation_history) + + def set_model(self, provider, model): + if provider not in AI_PROVIDERS: + print(f"Invalid provider: {provider}") + return False + if model not in AI_PROVIDERS[provider]["models"]: + print(f"Invalid model {model} for provider {provider}") + return False + + self.current_provider = provider + self.current_model = model + self.max_tokens = AI_PROVIDERS[self.current_provider]["models"][ + self.current_model + ]["max_tokens"] + self.api_client = get_api_client(self.current_provider, self.current_model) + self.api_key_set = self.api_client is not None + return self.api_key_set + + def clear_conversation_history(self): + self.conversation_history = [] + # We maintain the latest version across model changes + + def get_latest_version(self): + return self.latest_version + + def set_latest_version(self, version): + self.latest_version = version diff --git a/codeaide/ui/chat_window.py b/codeaide/ui/chat_window.py index a5e7eb0..97e707a 100644 --- a/codeaide/ui/chat_window.py +++ b/codeaide/ui/chat_window.py @@ -15,8 +15,9 @@ QTextEdit, QVBoxLayout, QWidget, + QComboBox, + QLabel, ) - from codeaide.ui.code_popup import CodePopup from codeaide.ui.example_selection_dialog import show_example_dialog from codeaide.utils import general_utils @@ -31,6 +32,9 @@ INITIAL_MESSAGE, USER_FONT, USER_MESSAGE_COLOR, + AI_PROVIDERS, + DEFAULT_PROVIDER, + DEFAULT_MODEL, ) @@ -61,6 +65,35 @@ def setup_ui(self): main_layout.setSpacing(5) main_layout.setContentsMargins(8, 8, 8, 8) + # Create a widget for the dropdowns + dropdown_widget = QWidget() + dropdown_layout = QHBoxLayout(dropdown_widget) + dropdown_layout.setContentsMargins(0, 0, 0, 0) + dropdown_layout.setSpacing(5) # Minimal spacing between items + + # Provider dropdown + self.provider_dropdown = QComboBox() + self.provider_dropdown.addItems(AI_PROVIDERS.keys()) + self.provider_dropdown.setCurrentText(DEFAULT_PROVIDER) + self.provider_dropdown.currentTextChanged.connect(self.update_model_dropdown) + dropdown_layout.addWidget(QLabel("Provider:")) + dropdown_layout.addWidget(self.provider_dropdown) + + # Model dropdown + self.model_dropdown = QComboBox() + self.update_model_dropdown(DEFAULT_PROVIDER) + self.model_dropdown.setCurrentText(DEFAULT_MODEL) + self.model_dropdown.currentTextChanged.connect(self.update_chat_handler) + dropdown_layout.addWidget(QLabel("Model:")) + dropdown_layout.addWidget(self.model_dropdown) + + # Add stretch to push everything to the left + dropdown_layout.addStretch(1) + + # Add the dropdown widget to the main layout + main_layout.addWidget(dropdown_widget) + + # Chat display self.chat_display = QTextEdit(self) self.chat_display.setReadOnly(True) self.chat_display.setStyleSheet( @@ -68,14 +101,10 @@ def setup_ui(self): ) main_layout.addWidget(self.chat_display, stretch=3) + # Input text area self.input_text = QTextEdit(self) self.input_text.setStyleSheet( - f""" - background-color: {CHAT_WINDOW_BG}; - color: {CHAT_WINDOW_FG}; - border: 1px solid #ccc; - padding: 5px; - """ + f"background-color: {CHAT_WINDOW_BG}; color: {CHAT_WINDOW_FG}; border: 1px solid #ccc; padding: 5px;" ) self.input_text.setAcceptRichText(False) # Add this line self.input_text.setFont(general_utils.set_font(USER_FONT)) @@ -84,18 +113,17 @@ def setup_ui(self): self.input_text.installEventFilter(self) main_layout.addWidget(self.input_text, stretch=1) + # Buttons button_layout = QHBoxLayout() - button_layout.setSpacing(5) - - self.submit_button = QPushButton("Submit") + self.submit_button = QPushButton("Submit", self) self.submit_button.clicked.connect(self.on_submit) button_layout.addWidget(self.submit_button) - self.example_button = QPushButton("Use Example") + self.example_button = QPushButton("Load Example", self) self.example_button.clicked.connect(self.load_example) button_layout.addWidget(self.example_button) - self.exit_button = QPushButton("Exit") + self.exit_button = QPushButton("Exit", self) self.exit_button.clicked.connect(self.on_exit) button_layout.addWidget(self.exit_button) @@ -263,3 +291,56 @@ def closeEvent(self, event): def sigint_handler(self, *args): QApplication.quit() + + def update_model_dropdown(self, provider): + self.model_dropdown.clear() + models = AI_PROVIDERS[provider]["models"].keys() + self.model_dropdown.addItems(models) + + # Set the current item to the first model in the list + if models: + self.model_dropdown.setCurrentText(list(models)[0]) + else: + print(f"No models available for provider {provider}") + + def update_chat_handler(self): + provider = self.provider_dropdown.currentText() + model = self.model_dropdown.currentText() + + # Check if a valid model is selected + if not model: + print(f"No valid model selected for provider {provider}") + return + + current_version = self.chat_handler.get_latest_version() + success = self.chat_handler.set_model(provider, model) + if not success: + self.add_to_chat( + "System", + f"Failed to set model {model} for provider {provider}. Please check your API key.", + ) + return + + self.chat_handler.clear_conversation_history() + self.chat_handler.set_latest_version( + current_version + ) # Maintain the version number + + # Add a message about switching models and the current version + self.add_to_chat( + "System", + f""" +{'='*50} +Switched to {provider} - {model} +Starting a new conversation with this model. +Current code version: {current_version} +Any new code will be versioned starting from {self.increment_version(current_version)} +{'='*50} +""", + ) + + self.check_api_key() + + def increment_version(self, version): + major, minor = map(int, version.split(".")) + return f"{major}.{minor + 1}" diff --git a/codeaide/utils/api_utils.py b/codeaide/utils/api_utils.py index 5a63e17..b212a48 100644 --- a/codeaide/utils/api_utils.py +++ b/codeaide/utils/api_utils.py @@ -2,8 +2,10 @@ import json import re import anthropic -from anthropic import APIError +import openai from decouple import config, AutoConfig +import hjson +from google.auth.exceptions import DefaultCredentialsError from codeaide.utils.constants import ( AI_PROVIDERS, @@ -23,23 +25,19 @@ def __init__(self, service): def get_api_client(provider=DEFAULT_PROVIDER, model=DEFAULT_MODEL): try: - # Force a reload of the configuration - auto_config = AutoConfig( - search_path=os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - ) - ) api_key_name = AI_PROVIDERS[provider]["api_key_name"] - api_key = auto_config(api_key_name, default=None) + api_key = config(api_key_name, default=None) + print(f"Attempting to get API key for {provider} with key name: {api_key_name}") + print(f"API key found: {'Yes' if api_key else 'No'}") + if api_key is None or api_key.strip() == "": - return None # Return None if API key is missing or empty + print(f"API key for {provider} is missing or empty") + return None if provider.lower() == "anthropic": return anthropic.Anthropic(api_key=api_key) elif provider.lower() == "openai": - # You'll need to import the OpenAI client and implement this part - # For now, we'll just raise an error - raise NotImplementedError("OpenAI client not yet implemented") + return openai.OpenAI(api_key=api_key) else: raise ValueError(f"Unsupported provider: {provider}") except Exception as e: @@ -80,50 +78,63 @@ def save_api_key(service, api_key): return False -def send_api_request(api_client, conversation_history, max_tokens): +def send_api_request(api_client, conversation_history, max_tokens, model, provider): system_prompt = SYSTEM_PROMPT - print(f"Sending API request with max_tokens: {max_tokens}") + print(f"Sending API request with model: {model} and max_tokens: {max_tokens}") print(f"Conversation history: {conversation_history}\n") try: - response = api_client.messages.create( - model=DEFAULT_MODEL, - max_tokens=max_tokens, - messages=conversation_history, - system=system_prompt, - ) - if not response.content: - return None + if provider.lower() == "anthropic": + response = api_client.messages.create( + model=model, + max_tokens=max_tokens, + messages=conversation_history, + system=system_prompt, + ) + elif provider.lower() == "openai": + messages = [ + {"role": "system", "content": system_prompt} + ] + conversation_history + response = api_client.chat.completions.create( + model=model, + messages=messages, + max_tokens=max_tokens, + ) + else: + raise NotImplementedError(f"API request for {provider} not implemented") + + print(f"Received response from {provider}") + print(f"Response object: {response}") return response except Exception as e: - print(f"Error in API request: {str(e)}") + print(f"Error in API request to {provider}: {str(e)}") return None -def parse_response(response): - if not response or not response.content: +def parse_response(response, provider): + if not response: raise ValueError("Empty or invalid response received") print(f"Received response: {response}\n") - # Extract the JSON string - json_str = response.content[0].text + if provider.lower() == "anthropic": + json_str = response.content[0].text + elif provider.lower() == "openai": + json_str = response.choices[0].message.content + else: + raise ValueError(f"Unsupported provider: {provider}") - # Escape newlines within the "code" field - json_str = re.sub( - r'("code"\s*:\s*")(.+?)(")', - lambda m: m.group(1) + m.group(2).replace("\n", "\\n") + m.group(3), - json_str, - flags=re.DOTALL, - ) + # Remove the triple backticks and language identifier + if json_str.startswith("```json"): + json_str = json_str[7:-3].strip() try: - # Parse the outer structure - outer_json = json.loads(json_str) - except json.JSONDecodeError as e: + # Parse the outer structure using hjson + outer_json = hjson.loads(json_str) + except hjson.HjsonDecodeError as e: raise ValueError( - f"Failed to parse JSON: {str(e)}\nProblematic JSON string: {json_str}" + f"Failed to parse response: {str(e)}\nProblematic string: {json_str}" ) text = outer_json.get("text") diff --git a/codeaide/utils/constants.py b/codeaide/utils/constants.py index 583867b..8d93d92 100644 --- a/codeaide/utils/constants.py +++ b/codeaide/utils/constants.py @@ -3,7 +3,7 @@ "anthropic": { "api_key_name": "ANTHROPIC_API_KEY", "models": { - "claude-3-opus-20240229": {"max_tokens": 8192}, + "claude-3-opus-20240229": {"max_tokens": 4096}, "claude-3-5-sonnet-20240620": {"max_tokens": 8192}, "claude-3-haiku-20240307": {"max_tokens": 4096}, }, @@ -11,8 +11,10 @@ "openai": { "api_key_name": "OPENAI_API_KEY", "models": { - "gpt-4": {"max_tokens": 8192}, - "gpt-4-32k": {"max_tokens": 32768}, + "chatgpt-4o-latest": {"max_tokens": 16384}, + "gpt-4o-mini": {"max_tokens": 16384}, + "o1-preview": {"max_tokens": 32768}, + "gpt-4-turbo": {"max_tokens": 4096}, "gpt-3.5-turbo": {"max_tokens": 4096}, }, }, diff --git a/codeaide/utils/cost_tracker.py b/codeaide/utils/cost_tracker.py index 9a1ead3..851a6b9 100644 --- a/codeaide/utils/cost_tracker.py +++ b/codeaide/utils/cost_tracker.py @@ -7,30 +7,10 @@ def __init__(self): self.cost_per_1k_tokens = 0.03 # Update this with actual pricing def log_request(self, response): - # The new API doesn't provide direct access to prompt tokens - # We'll estimate based on the response tokens - completion_tokens = response.usage.output_tokens - # Estimate prompt tokens (this is not accurate, but it's a rough estimate) - estimated_prompt_tokens = completion_tokens // 2 - total_tokens = estimated_prompt_tokens + completion_tokens - estimated_cost = (total_tokens / 1000) * self.cost_per_1k_tokens - - self.cost_log.append( - { - "timestamp": datetime.now(), - "estimated_prompt_tokens": estimated_prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - "estimated_cost": estimated_cost, - } - ) + pass def get_total_cost(self): - return sum(entry["estimated_cost"] for entry in self.cost_log) + return 0 def print_summary(self): - total_cost = self.get_total_cost() - total_tokens = sum(entry["total_tokens"] for entry in self.cost_log) - print(f"\nTotal estimated cost: ${total_cost:.4f}") - print(f"Total tokens used: {total_tokens}") - print(f"Number of API calls: {len(self.cost_log)}") + print("Cost summary not implemented") diff --git a/requirements.txt b/requirements.txt index 04262bb..97c2ed8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ anthropic==0.34.2 python-decouple==3.8 virtualenv==20.16.2 +google-generativeai +hjson pyyaml pytest black From 8bf42a00c57b7f7a22c515646b9a5e185ca92bf3 Mon Sep 17 00:00:00 2001 From: dougollerenshaw Date: Wed, 25 Sep 2024 09:30:59 -0700 Subject: [PATCH 3/7] Added to readme --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1d85e0f..310fe68 100644 --- a/README.md +++ b/README.md @@ -2,16 +2,20 @@ CodeAIde is an AI-powered coding assistant that helps developers write, test, and optimize code through natural language interactions. By leveraging the power of large language models, CodeAIde aims to streamline the coding process and boost productivity. +This is designed to be a simple, intuitive tool for writing, running, and refining simple Python scripts. It is not meant to be a full IDE or code editing environment and isn't a replacement for a tool like Cursor or Github Copilot. Instead, it is intended to be a simple tool for quickly writing code and getting it working without the need to worry about setting up environments, installing dependencies, etc. + ## Features - Natural language code generation +- Support for OpenAI and Anthropic APIs - Interactive clarification process for precise code output +- Version control for generated code - Local code execution and testing - Cost tracking for API usage (not yet implemented) ## Examples -Here are some example videos demonstrating use. Example prompts can be accessed by clicking "Use Example" and selecting from avaialable examples. +Here are some example videos demonstrating use. Example prompts can be accessed by clicking "Use Example" and selecting from available examples. First, a simple matplotlib plot with followup requests to modify aesthetics. From 06a2fedbbfaf51c6780faf900967b3fe00c20120 Mon Sep 17 00:00:00 2001 From: dougollerenshaw Date: Wed, 25 Sep 2024 09:37:29 -0700 Subject: [PATCH 4/7] Rearranged constants order --- codeaide/utils/constants.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/codeaide/utils/constants.py b/codeaide/utils/constants.py index 8d93d92..28b8abb 100644 --- a/codeaide/utils/constants.py +++ b/codeaide/utils/constants.py @@ -11,17 +11,16 @@ "openai": { "api_key_name": "OPENAI_API_KEY", "models": { + "gpt-3.5-turbo": {"max_tokens": 4096}, + "gpt-4-turbo": {"max_tokens": 4096}, "chatgpt-4o-latest": {"max_tokens": 16384}, "gpt-4o-mini": {"max_tokens": 16384}, - "o1-preview": {"max_tokens": 32768}, - "gpt-4-turbo": {"max_tokens": 4096}, - "gpt-3.5-turbo": {"max_tokens": 4096}, }, }, } -# Default model (we'll keep this for backwards compatibility) -DEFAULT_MODEL = "claude-3-haiku-20240307" # "claude-3-5-sonnet-20240620" +# Default model +DEFAULT_MODEL = "claude-3-5-sonnet-20240620" DEFAULT_PROVIDER = "anthropic" # Other existing constants remain unchanged From 30fc52ace3f3f94a6e9576fee705a4eec0f3da24 Mon Sep 17 00:00:00 2001 From: dougollerenshaw Date: Wed, 25 Sep 2024 09:39:34 -0700 Subject: [PATCH 5/7] Changed prompt in check_api_connection --- codeaide/utils/api_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeaide/utils/api_utils.py b/codeaide/utils/api_utils.py index b212a48..4cc47d4 100644 --- a/codeaide/utils/api_utils.py +++ b/codeaide/utils/api_utils.py @@ -155,7 +155,7 @@ def check_api_connection(): response = client.messages.create( model=DEFAULT_MODEL, max_tokens=100, - messages=[{"role": "user", "content": "Hi Claude, are we communicating?"}], + messages=[{"role": "user", "content": "Are we communicating?"}], ) return True, response.content[0].text.strip() except Exception as e: From d5cda59f7d160ade52e7b40cce52e03f2f52f165 Mon Sep 17 00:00:00 2001 From: dougollerenshaw Date: Wed, 25 Sep 2024 09:41:33 -0700 Subject: [PATCH 6/7] Updated requirements to include openai --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 97c2ed8..5f6d1ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ anthropic==0.34.2 python-decouple==3.8 virtualenv==20.16.2 -google-generativeai +openai hjson pyyaml pytest From 87d2d51911b1567e483e7c1298ed539c6fec1b99 Mon Sep 17 00:00:00 2001 From: dougollerenshaw Date: Wed, 25 Sep 2024 16:32:40 -0700 Subject: [PATCH 7/7] Lots of test updates --- codeaide/utils/api_utils.py | 23 +- pytest.ini | 4 +- tests/integration/test_api_integration.py | 130 +++++++ tests/utils/test_api_utils.py | 418 +++++++++++++++++++--- 4 files changed, 525 insertions(+), 50 deletions(-) create mode 100644 tests/integration/test_api_integration.py diff --git a/codeaide/utils/api_utils.py b/codeaide/utils/api_utils.py index 4cc47d4..cec3bdf 100644 --- a/codeaide/utils/api_utils.py +++ b/codeaide/utils/api_utils.py @@ -5,7 +5,7 @@ import openai from decouple import config, AutoConfig import hjson -from google.auth.exceptions import DefaultCredentialsError +from anthropic import APIError from codeaide.utils.constants import ( AI_PROVIDERS, @@ -79,8 +79,6 @@ def save_api_key(service, api_key): def send_api_request(api_client, conversation_history, max_tokens, model, provider): - system_prompt = SYSTEM_PROMPT - print(f"Sending API request with model: {model} and max_tokens: {max_tokens}") print(f"Conversation history: {conversation_history}\n") @@ -90,17 +88,21 @@ def send_api_request(api_client, conversation_history, max_tokens, model, provid model=model, max_tokens=max_tokens, messages=conversation_history, - system=system_prompt, + system=SYSTEM_PROMPT, ) + if not response.content: + return None elif provider.lower() == "openai": messages = [ - {"role": "system", "content": system_prompt} + {"role": "system", "content": SYSTEM_PROMPT} ] + conversation_history response = api_client.chat.completions.create( model=model, messages=messages, max_tokens=max_tokens, ) + if not response.choices: + return None else: raise NotImplementedError(f"API request for {provider} not implemented") @@ -119,15 +121,21 @@ def parse_response(response, provider): print(f"Received response: {response}\n") if provider.lower() == "anthropic": + if not response.content: + raise ValueError("Empty or invalid response received") json_str = response.content[0].text elif provider.lower() == "openai": + if not response.choices: + raise ValueError("Empty or invalid response received") json_str = response.choices[0].message.content else: raise ValueError(f"Unsupported provider: {provider}") - # Remove the triple backticks and language identifier + # Remove the triple backticks and language identifier if present if json_str.startswith("```json"): json_str = json_str[7:-3].strip() + elif json_str.startswith("```"): + json_str = json_str[3:-3].strip() try: # Parse the outer structure using hjson @@ -137,6 +145,9 @@ def parse_response(response, provider): f"Failed to parse response: {str(e)}\nProblematic string: {json_str}" ) + if not isinstance(outer_json, dict): + raise ValueError("Parsed response is not a valid JSON object") + text = outer_json.get("text") code = outer_json.get("code") code_version = outer_json.get("code_version") diff --git a/pytest.ini b/pytest.ini index 10cc899..21e3e0d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,6 @@ markers = send_api_request: marks tests related to sending API requests parse_response: marks tests related to parsing API responses - api_connection: marks tests related to API connection \ No newline at end of file + api_connection: marks tests related to API connection + integration: mark a test as an integration test that uses actual API keys +addopts = -m "not integration" \ No newline at end of file diff --git a/tests/integration/test_api_integration.py b/tests/integration/test_api_integration.py new file mode 100644 index 0000000..f3953a1 --- /dev/null +++ b/tests/integration/test_api_integration.py @@ -0,0 +1,130 @@ +""" +This file contains integration tests for the API functionality of the CodeAide application. + +These tests verify the correct operation of API clients, request sending, and response parsing +for both Anthropic and OpenAI APIs. They ensure that the application can successfully +communicate with these external services and handle their responses appropriately. + +IMPORTANT: These tests use live API calls and will incur charges on your API accounts. +They are designed to complement the existing unit tests and are not part of the +continuous integration pipeline. These tests should be run manually and infrequently, +primarily to verify that the API functionality is working as expected with the live APIs. + +To run these tests: +1. Ensure you have the necessary API keys set in your environment variables: + - ANTHROPIC_API_KEY for Anthropic tests + - OPENAI_API_KEY for OpenAI tests +2. Install pytest if not already installed: `pip install pytest` +3. Navigate to the project root directory +4. Run the tests using the command: `pytest -m integration` + +Note: These tests are marked with the 'integration' marker and are specifically run +using the `-m integration` flag. This allows them to be easily separated from other +tests and run independently when needed. + +Caution: Due to the use of live API calls, these tests should not be run frequently +or as part of automated CI/CD processes to avoid unnecessary API charges. +""" + +import pytest +from codeaide.utils.api_utils import get_api_client, send_api_request, parse_response +from codeaide.utils.constants import SYSTEM_PROMPT + +ANTHROPIC_MODEL = "claude-3-haiku-20240307" +OPENAI_MODEL = "gpt-3.5-turbo" + +MINIMAL_PROMPT = """ +Please respond with a JSON object containing the following fields: +- text: A brief description of the code. +- code: A piece of code that prints "Hello, World!". +- code_version: The version of the code. +- version_description: A brief description of the version. +- requirements: An empty list. +- questions: An empty list. +""" + + +@pytest.mark.integration +def test_anthropic_api(): + """ + Integration test for the Anthropic API. + + This test: + 1. Initializes the Anthropic API client + 2. Sends a request to the API with a minimal prompt + 3. Verifies that a non-empty response is received + 4. Checks that the response contains the word "Hello" + 5. Attempts to parse the response + 6. Verifies that all expected fields are present in the parsed response + """ + api_client = get_api_client(provider="anthropic") + assert api_client is not None, "Anthropic API client initialization failed" + + conversation_history = [{"role": "user", "content": MINIMAL_PROMPT}] + response = send_api_request( + api_client, conversation_history, 100, ANTHROPIC_MODEL, "anthropic" + ) + assert response is not None, "Anthropic API request failed" + assert "Hello" in response.content[0].text, "Unexpected response from Anthropic API" + + # Test parse_response function + parsed_response = parse_response(response, "anthropic") + assert parsed_response is not None, "Failed to parse response from Anthropic API" + ( + text, + questions, + code, + code_version, + version_description, + requirements, + ) = parsed_response + assert text is not None, "Parsed text is None" + assert isinstance(questions, list), "Parsed questions is not a list" + assert code is not None, "Parsed code is None" + assert code_version is not None, "Parsed code_version is None" + assert version_description is not None, "Parsed version_description is None" + assert isinstance(requirements, list), "Parsed requirements is not a list" + + +@pytest.mark.integration +def test_openai_api(): + """ + Integration test for the OpenAI API. + + This test: + 1. Initializes the OpenAI API client + 2. Sends a request to the API with a minimal prompt + 3. Verifies that a non-empty response is received + 4. Checks that the response contains the word "Hello" + 5. Attempts to parse the response + 6. Verifies that all expected fields are present in the parsed response + """ + api_client = get_api_client(provider="openai") + assert api_client is not None, "OpenAI API client initialization failed" + + conversation_history = [{"role": "user", "content": MINIMAL_PROMPT}] + response = send_api_request( + api_client, conversation_history, 100, OPENAI_MODEL, "openai" + ) + assert response is not None, "OpenAI API request failed" + assert ( + "Hello" in response.choices[0].message.content + ), "Unexpected response from OpenAI API" + + # Test parse_response function + parsed_response = parse_response(response, "openai") + assert parsed_response is not None, "Failed to parse response from OpenAI API" + ( + text, + questions, + code, + code_version, + version_description, + requirements, + ) = parsed_response + assert text is not None, "Parsed text is None" + assert isinstance(questions, list), "Parsed questions is not a list" + assert code is not None, "Parsed code is None" + assert code_version is not None, "Parsed code_version is None" + assert version_description is not None, "Parsed version_description is None" + assert isinstance(requirements, list), "Parsed requirements is not a list" diff --git a/tests/utils/test_api_utils.py b/tests/utils/test_api_utils.py index ba288f5..1bbbca5 100644 --- a/tests/utils/test_api_utils.py +++ b/tests/utils/test_api_utils.py @@ -13,7 +13,12 @@ get_api_client, MissingAPIKeyException, ) -from codeaide.utils.constants import AI_MODEL, MAX_TOKENS, SYSTEM_PROMPT +from codeaide.utils.constants import ( + DEFAULT_MODEL, + DEFAULT_PROVIDER, + SYSTEM_PROMPT, + AI_PROVIDERS, +) # Mock Response object Response = namedtuple("Response", ["content"]) @@ -25,9 +30,21 @@ pytest.mark.api_connection, ] +# Get the max_tokens value from the AI_PROVIDERS dictionary +MAX_TOKENS = AI_PROVIDERS[DEFAULT_PROVIDER]["models"][DEFAULT_MODEL]["max_tokens"] + @pytest.fixture def mock_anthropic_client(): + """ + A pytest fixture that mocks the Anthropic API client. + + This fixture patches the 'anthropic.Anthropic' class and returns a mock client. + The mock client includes a 'messages' attribute, which is also a mock object. + + Returns: + Mock: A mock object representing the Anthropic API client. + """ with patch("anthropic.Anthropic") as mock_anthropic: mock_client = Mock() mock_messages = Mock() @@ -36,67 +53,222 @@ def mock_anthropic_client(): yield mock_client +@pytest.fixture +def mock_openai_client(): + """ + A pytest fixture that mocks the OpenAI API client. + + This fixture patches the 'openai.OpenAI' class and returns a mock client. + The mock client can be used to simulate OpenAI API responses in tests + without making actual API calls. + + Returns: + Mock: A mock object representing the OpenAI API client. + """ + with patch("openai.OpenAI") as mock_openai: + mock_client = Mock() + mock_openai.return_value = mock_client + yield mock_client + + class TestGetApiClient: - def test_get_api_client_success(self, monkeypatch): - monkeypatch.setenv("ANTHROPIC_API_KEY", "test_key") - client = get_api_client() - assert client is not None - assert hasattr( - client, "messages" - ) # Check for a common attribute of Anthropic client + """ + A test class for the get_api_client function in the api_utils module. + + This class contains test methods to verify the behavior of the get_api_client function + under various scenarios, such as missing API keys, successful client creation, + and handling of unsupported services. + + The @patch decorators used in this class serve to mock the 'config' and 'AutoConfig' + functions from the codeaide.utils.api_utils module. This allows us to control the + behavior of these functions during testing, simulating different environments and + configurations without actually modifying the system or making real API calls. + Attributes: + None + + Methods: + Various test methods to cover different scenarios for get_api_client function. + """ + + @patch("codeaide.utils.api_utils.config") @patch("codeaide.utils.api_utils.AutoConfig") - def test_get_api_client_missing_key(self, mock_auto_config, monkeypatch): - mock_config = Mock() + def test_get_api_client_missing_key( + self, mock_auto_config, mock_config, monkeypatch + ): + """ + Test the behavior of get_api_client when the API key is missing. + + This test ensures that the get_api_client function returns None when the + ANTHROPIC_API_KEY is not set in the environment variables. + + Args: + mock_auto_config (MagicMock): A mock object for the AutoConfig class. + mock_config (MagicMock): A mock object for the config function. + monkeypatch (pytest.MonkeyPatch): Pytest fixture for modifying the test environment. + + The test performs the following steps: + 1. Mocks the config function to return None, simulating a missing API key. + 2. Sets up the mock AutoConfig to use the mocked config function. + 3. Removes the ANTHROPIC_API_KEY from the environment variables. + 4. Calls get_api_client with the "anthropic" provider. + 5. Asserts that the returned client is None, as expected when the API key is missing. + """ mock_config.return_value = None mock_auto_config.return_value = mock_config monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) - client = get_api_client() + client = get_api_client(provider="anthropic") assert client is None + def test_get_api_client_success(self, monkeypatch): + """ + Test the successful creation of an API client for Anthropic. + + This test verifies that the get_api_client function correctly creates and returns + an Anthropic API client when a valid API key is provided in the environment. + + Args: + monkeypatch (pytest.MonkeyPatch): Pytest fixture for modifying the test environment. + + The test performs the following steps: + 1. Sets the ANTHROPIC_API_KEY environment variable to a test value. + 2. Calls get_api_client with the "anthropic" provider. + 3. Asserts that the returned client is not None. + 4. Verifies that the client has a 'messages' attribute, which is expected for Anthropic clients. + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test_key") + client = get_api_client(provider="anthropic") + assert client is not None + assert hasattr(client, "messages") + def test_get_api_client_empty_key(self, monkeypatch): + """ + Test the behavior of get_api_client when the API key is empty. + + This test ensures that the get_api_client function returns None when the + ANTHROPIC_API_KEY is set to an empty string in the environment variables. + + Args: + monkeypatch (pytest.MonkeyPatch): Pytest fixture for modifying the test environment. + + The test performs the following steps: + 1. Sets the ANTHROPIC_API_KEY environment variable to an empty string. + 2. Calls get_api_client with the "anthropic" provider. + 3. Asserts that the returned client is None, as expected when the API key is empty. + """ monkeypatch.setenv("ANTHROPIC_API_KEY", "") - client = get_api_client() + client = get_api_client(provider="anthropic") assert client is None @patch("codeaide.utils.api_utils.AutoConfig") def test_get_api_client_unsupported_service(self, mock_auto_config): + """ + Test the behavior of get_api_client when an unsupported service is provided. + + This test ensures that the get_api_client function returns None when an + unsupported service provider is specified. + + Args: + mock_auto_config (MagicMock): A mock object for the AutoConfig class. + + The test performs the following steps: + 1. Mocks the AutoConfig to return a dummy API key. + 2. Calls get_api_client with an unsupported service provider. + 3. Asserts that the returned result is None, as expected for unsupported services. + """ mock_config = Mock() mock_config.return_value = "dummy_key" mock_auto_config.return_value = mock_config - result = get_api_client("unsupported_service") + result = get_api_client(provider="unsupported_service") assert result is None class TestSendAPIRequest: - def test_send_api_request_success(self, mock_anthropic_client): + """ + A test class for the send_api_request function. + + This class contains test methods to verify the behavior of the send_api_request function + under various scenarios, including successful API calls, empty responses, and API errors. + It tests the function's interaction with both OpenAI and Anthropic APIs. + + Test methods: + - test_send_api_request_success_openai: Verifies successful OpenAI API requests. + - test_send_api_request_empty_response: Checks handling of empty responses from Anthropic API. + - test_send_api_request_api_error: Tests error handling for API errors. + + Each test method uses mocking to simulate API responses and errors, ensuring + that the send_api_request function behaves correctly in different scenarios. + """ + + @patch("openai.OpenAI") + def test_send_api_request_success_openai(self, mock_openai): + """ + Test that send_api_request successfully sends a request to OpenAI API + and returns a non-None response. + """ + conversation_history = [{"role": "user", "content": "Hello, GPT!"}] + mock_client = Mock() + mock_response = Mock() + mock_response.choices = [ + Mock(message=Mock(content="Hello! How can I assist you today?")) + ] + mock_client.chat.completions.create.return_value = mock_response + mock_openai.return_value = mock_client + + result = send_api_request( + mock_client, conversation_history, MAX_TOKENS, DEFAULT_MODEL, "openai" + ) + + mock_client.chat.completions.create.assert_called_once_with( + model=DEFAULT_MODEL, + max_tokens=MAX_TOKENS, + messages=[{"role": "system", "content": SYSTEM_PROMPT}] + + conversation_history, + ) + assert result is not None + + @patch("anthropic.Anthropic") + def test_send_api_request_empty_response(self, mock_anthropic): + """ + Test that send_api_request returns None when receiving an empty response + from the Anthropic API. + """ conversation_history = [{"role": "user", "content": "Hello, Claude!"}] + mock_client = Mock() mock_response = Mock() - mock_response.content = [Mock(text="Hello! How can I assist you today?")] - mock_anthropic_client.messages.create.return_value = mock_response + mock_response.content = [] # Empty content + mock_client.messages.create.return_value = mock_response + mock_anthropic.return_value = mock_client - result = send_api_request(mock_anthropic_client, conversation_history) + result = send_api_request( + mock_client, conversation_history, MAX_TOKENS, DEFAULT_MODEL, "anthropic" + ) - mock_anthropic_client.messages.create.assert_called_once_with( - model=AI_MODEL, + mock_client.messages.create.assert_called_once_with( + model=DEFAULT_MODEL, max_tokens=MAX_TOKENS, messages=conversation_history, system=SYSTEM_PROMPT, ) - assert result == mock_response + assert result is None, "Expected None for empty response content" - def test_send_api_request_empty_response(self, mock_anthropic_client): - conversation_history = [{"role": "user", "content": "Hello, Claude!"}] - mock_response = Mock() - mock_response.content = [] - mock_anthropic_client.messages.create.return_value = mock_response + def test_send_api_request_api_error(self, mock_anthropic_client): + """ + Test that send_api_request handles API errors correctly. - result = send_api_request(mock_anthropic_client, conversation_history) + This test simulates an APIError being raised by the Anthropic client + and verifies that the function returns None in this case. - assert result is None + Args: + mock_anthropic_client (Mock): A mocked Anthropic client object. - def test_send_api_request_api_error(self, mock_anthropic_client): + The test: + 1. Sets up a conversation history. + 2. Configures the mock client to raise an APIError. + 3. Calls send_api_request with the mocked client. + 4. Asserts that the function returns None when an APIError occurs. + """ conversation_history = [{"role": "user", "content": "Hello, Claude!"}] mock_request = Mock() mock_anthropic_client.messages.create.side_effect = APIError( @@ -105,11 +277,35 @@ def test_send_api_request_api_error(self, mock_anthropic_client): body={"error": {"message": "API Error"}}, ) - result = send_api_request(mock_anthropic_client, conversation_history) + result = send_api_request( + mock_anthropic_client, + conversation_history, + MAX_TOKENS, + DEFAULT_MODEL, + "anthropic", + ) assert result is None def test_send_api_request_custom_max_tokens(self, mock_anthropic_client): + """ + Test the send_api_request function with a custom max_tokens value. + + This test verifies that: + 1. The function correctly uses a custom max_tokens value. + 2. The Anthropic client is called with the correct parameters. + 3. The function returns the expected mock response. + + Args: + mock_anthropic_client (Mock): A mocked Anthropic client object. + + The test: + 1. Sets up a conversation history and custom max_tokens value. + 2. Creates a mock response from the Anthropic API. + 3. Calls send_api_request with the custom parameters. + 4. Asserts that the Anthropic client was called with the correct arguments. + 5. Verifies that the function returns the expected mock response. + """ conversation_history = [{"role": "user", "content": "Hello, Claude!"}] custom_max_tokens = 500 mock_response = Mock() @@ -117,11 +313,15 @@ def test_send_api_request_custom_max_tokens(self, mock_anthropic_client): mock_anthropic_client.messages.create.return_value = mock_response result = send_api_request( - mock_anthropic_client, conversation_history, max_tokens=custom_max_tokens + mock_anthropic_client, + conversation_history, + custom_max_tokens, + DEFAULT_MODEL, + "anthropic", ) mock_anthropic_client.messages.create.assert_called_once_with( - model=AI_MODEL, + model=DEFAULT_MODEL, max_tokens=custom_max_tokens, messages=conversation_history, system=SYSTEM_PROMPT, @@ -130,16 +330,69 @@ def test_send_api_request_custom_max_tokens(self, mock_anthropic_client): class TestParseResponse: + """ + A test class for the parse_response function in the api_utils module. + + This class contains various test methods to ensure the correct behavior + of the parse_response function under different scenarios, including: + - Handling of empty or invalid responses + - Parsing of valid responses from different AI providers (Anthropic and OpenAI) + - Correct extraction of fields from the parsed JSON + - Handling of responses with missing fields + + Each test method in this class focuses on a specific aspect of the + parse_response function's behavior, helping to ensure its robustness + and correctness across various input conditions. + """ + def test_parse_response_empty(self): - result = parse_response(None) - assert result == (None, None, None, None, None, None) + """ + Test that parse_response raises a ValueError when given an empty response. + + This test verifies that the parse_response function correctly handles + the case of an empty (None) response for the Anthropic provider. + + It checks that: + 1. A ValueError is raised when parse_response is called with None. + 2. The error message matches the expected string. + + This helps ensure that the function fails gracefully and provides + appropriate error information when given invalid input. + """ + with pytest.raises(ValueError, match="Empty or invalid response received"): + parse_response(None, "anthropic") def test_parse_response_no_content(self): - response = Response(content=[]) - result = parse_response(response) - assert result == (None, None, None, None, None, None) + """ + Test that parse_response raises a ValueError when given an Anthropic + response with no content. + """ + response = Mock(content=[]) + with pytest.raises(ValueError, match="Empty or invalid response received"): + parse_response(response, "anthropic") + + def test_parse_response_no_choices(self): + """ + Test that parse_response raises a ValueError when given an OpenAI + response with no choices. + """ + response = Mock(choices=[]) + with pytest.raises(ValueError, match="Empty or invalid response received"): + parse_response(response, "openai") def test_parse_response_valid(self): + """ + Test that parse_response correctly handles a valid Anthropic response. + + This test verifies that the parse_response function correctly parses + a valid JSON response from the Anthropic API. It checks that: + 1. The function correctly extracts all fields from the JSON. + 2. The extracted values match the expected values. + 3. The function handles various data types (strings, lists) correctly. + + This test helps ensure that the parse_response function can accurately + process and return the structured data from a well-formed API response. + """ content = { "text": "Sample text", "code": "print('Hello, World!')", @@ -156,7 +409,7 @@ def test_parse_response_valid(self): code_version, version_description, requirements, - ) = parse_response(response) + ) = parse_response(response, "anthropic") assert text == "Sample text" assert questions == ["What does this code do?"] @@ -166,6 +419,17 @@ def test_parse_response_valid(self): assert requirements == ["pytest"] def test_parse_response_missing_fields(self): + """ + Test that parse_response correctly handles a response with missing fields. + + This test verifies that the parse_response function: + 1. Correctly extracts the fields that are present in the response. + 2. Sets default values (None or empty list) for missing fields. + 3. Doesn't raise an exception when optional fields are missing. + + It helps ensure that the function is robust and can handle incomplete + responses without breaking. + """ content = {"text": "Sample text", "code": "print('Hello, World!')"} response = Response(content=[TextBlock(text=json.dumps(content))]) ( @@ -175,7 +439,7 @@ def test_parse_response_missing_fields(self): code_version, version_description, requirements, - ) = parse_response(response) + ) = parse_response(response, "anthropic") assert text == "Sample text" assert questions == [] @@ -185,6 +449,19 @@ def test_parse_response_missing_fields(self): assert requirements == [] def test_parse_response_complex_code(self): + """ + Test parse_response function with a complex code example. + + This test verifies that the parse_response function correctly handles + a response containing a more complex code structure. It checks that: + 1. The function correctly extracts all fields from the response. + 2. The extracted code maintains its structure and indentation. + 3. Version information and descriptions are correctly parsed. + 4. Empty lists for requirements and questions are handled properly. + + This test ensures that the parse_response function can handle + responses with multi-line code snippets and various metadata fields. + """ content = { "text": "Complex code example", "code": 'def hello():\n print("Hello, World!")', @@ -201,14 +478,29 @@ def test_parse_response_complex_code(self): code_version, version_description, requirements, - ) = parse_response(response) + ) = parse_response(response, "anthropic") assert text == "Complex code example" assert code == 'def hello():\n print("Hello, World!")' assert code_version == "1.1" assert version_description == "Added function" + assert questions == [] + assert requirements == [] def test_parse_response_escaped_quotes(self): + """ + Test parse_response function with escaped quotes in the content. + + This test verifies that the parse_response function correctly handles + a response containing escaped quotes in various fields. It checks that: + 1. The function correctly extracts all fields from the response. + 2. The extracted text and code maintain their escaped quotes. + 3. Version information is correctly parsed. + 4. Empty lists for requirements and questions are handled properly. + + This test ensures that the parse_response function can handle + responses with complex string content, including escaped quotes. + """ content = { "text": 'Text with "quotes"', "code": 'print("Hello, \\"World!\\"")\nprint(\'Single quotes\')', @@ -225,7 +517,7 @@ def test_parse_response_escaped_quotes(self): code_version, version_description, requirements, - ) = parse_response(response) + ) = parse_response(response, "anthropic") assert text == 'Text with "quotes"' assert code == 'print("Hello, \\"World!\\"")\nprint(\'Single quotes\')' @@ -233,14 +525,42 @@ def test_parse_response_escaped_quotes(self): assert version_description == "Added escaped quotes" def test_parse_response_malformed_json(self): + """ + Test parse_response function with malformed JSON input. + + This test verifies that the parse_response function correctly handles + a response containing invalid JSON. It checks that: + 1. The function raises a ValueError when given non-JSON content. + 2. The error message specifically mentions that the parsed response + is not a valid JSON object. + + This test ensures that the parse_response function fails gracefully + and provides meaningful error messages when given invalid input. + """ response = Response(content=[TextBlock(text="This is not JSON")]) - result = parse_response(response) - assert result == (None, None, None, None, None, None) + with pytest.raises( + ValueError, match="Parsed response is not a valid JSON object" + ): + parse_response(response, "anthropic") class TestAPIConnection: + """ + Test suite for the API connection functionality. + + This class contains tests to verify the behavior of the check_api_connection function + under various scenarios, including successful connections, connection failures, + and missing API keys. + """ + @patch("codeaide.utils.api_utils.get_api_client") def test_check_api_connection_success(self, mock_get_api_client): + """ + Test successful API connection. + + This test verifies that the check_api_connection function returns a successful + result when the API client is properly initialized and responds correctly. + """ mock_client = Mock() mock_response = Mock() mock_response.content = [Mock(text="Yes, we are communicating.")] @@ -254,6 +574,12 @@ def test_check_api_connection_success(self, mock_get_api_client): @patch("codeaide.utils.api_utils.get_api_client") def test_check_api_connection_failure(self, mock_get_api_client): + """ + Test API connection failure. + + This test ensures that the check_api_connection function handles connection + failures gracefully and returns an appropriate error message. + """ mock_client = Mock() mock_client.messages.create.side_effect = Exception("Connection failed") mock_get_api_client.return_value = mock_client @@ -265,9 +591,15 @@ def test_check_api_connection_failure(self, mock_get_api_client): @patch("codeaide.utils.api_utils.get_api_client") def test_check_api_connection_missing_key(self, mock_get_api_client): + """ + Test API connection with missing API key. + + This test verifies that the check_api_connection function correctly handles + the scenario where the API key is missing or invalid. + """ mock_get_api_client.return_value = None result = check_api_connection() assert result[0] == False - assert "API key is missing or invalid" in result[1] + assert result[1] == "API key is missing or invalid"