Skip to content

Commit

Permalink
Merge pull request #34 from dougollerenshaw/make_model_selectable
Browse files Browse the repository at this point in the history
Make model provider selectable
  • Loading branch information
dougollerenshaw authored Sep 25, 2024
2 parents 5df88e5 + 87d2d51 commit fcb7d3c
Show file tree
Hide file tree
Showing 10 changed files with 805 additions and 144 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@

CodeAIde is an AI-powered coding assistant that helps developers write, test, and optimize code through natural language interactions. By leveraging the power of large language models, CodeAIde aims to streamline the coding process and boost productivity.

This is designed to be a simple, intuitive tool for writing, running, and refining simple Python scripts. It is not meant to be a full IDE or code editing environment and isn't a replacement for a tool like Cursor or Github Copilot. Instead, it is intended to be a simple tool for quickly writing code and getting it working without the need to worry about setting up environments, installing dependencies, etc.

## Features

- Natural language code generation
- Support for OpenAI and Anthropic APIs
- Interactive clarification process for precise code output
- Version control for generated code
- Local code execution and testing
- Cost tracking for API usage (not yet implemented)

## Examples

Here are some example videos demonstrating use. Example prompts can be accessed by clicking "Use Example" and selecting from avaialable examples.
Here are some example videos demonstrating use. Example prompts can be accessed by clicking "Use Example" and selecting from available examples.

First, a simple matplotlib plot with followup requests to modify aesthetics.

Expand Down
96 changes: 76 additions & 20 deletions codeaide/logic/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
save_api_key,
MissingAPIKeyException,
)
from codeaide.utils.constants import MAX_RETRIES, MAX_TOKENS
from codeaide.utils.constants import (
MAX_RETRIES,
AI_PROVIDERS,
DEFAULT_MODEL,
DEFAULT_PROVIDER,
)
from codeaide.utils.cost_tracker import CostTracker
from codeaide.utils.environment_manager import EnvironmentManager
from codeaide.utils.file_handler import FileHandler
Expand All @@ -37,7 +42,11 @@ def __init__(self):
self.latest_version = "0.0"
self.api_client = None
self.api_key_set = False
self.current_service = "anthropic" # Default service
self.current_provider = DEFAULT_PROVIDER
self.current_model = DEFAULT_MODEL
self.max_tokens = AI_PROVIDERS[self.current_provider]["models"][
self.current_model
]["max_tokens"]

def check_api_key(self):
"""
Expand All @@ -50,26 +59,26 @@ def check_api_key(self):
tuple: A tuple containing a boolean indicating if the API key is valid and a message.
"""
if self.api_client is None:
self.api_client = get_api_client(self.current_service)
self.api_client = get_api_client(self.current_provider, self.current_model)

if self.api_client:
self.api_key_set = True
return True, None
else:
self.api_key_set = False
return False, self.get_api_key_instructions(self.current_service)
return False, self.get_api_key_instructions(self.current_provider)

def get_api_key_instructions(self, service):
def get_api_key_instructions(self, provider):
"""
Get instructions for setting up the API key for a given service.
Get instructions for setting up the API key for a given provider.
Args:
service (str): The name of the service.
provider (str): The name of the provider.
Returns:
str: Instructions for setting up the API key.
"""
if service == "anthropic":
if provider == "anthropic":
return (
"It looks like you haven't set up your Anthropic API key yet. "
"Here's how to get started:\n\n"
Expand All @@ -83,7 +92,7 @@ def get_api_key_instructions(self, service):
"Please paste your Anthropic API key now:"
)
else:
return f"Please enter your API key for {service.capitalize()}:"
return f"Please enter your API key for {provider.capitalize()}:"

def validate_api_key(self, api_key):
"""
Expand Down Expand Up @@ -122,9 +131,11 @@ def handle_api_key_input(self, api_key):
cleaned_key = api_key.strip().strip("'\"") # Remove quotes and whitespace
is_valid, error_message = self.validate_api_key(cleaned_key)
if is_valid:
if save_api_key(self.current_service, cleaned_key):
if save_api_key(self.current_provider, cleaned_key):
# Try to get a new API client with the new key
self.api_client = get_api_client(self.current_service)
self.api_client = get_api_client(
self.current_provider, self.current_model
)
if self.api_client:
self.api_key_set = True
return True, "API key saved and verified successfully."
Expand Down Expand Up @@ -155,7 +166,7 @@ def process_input(self, user_input):
if not self.check_and_set_api_key():
return {
"type": "api_key_required",
"message": self.get_api_key_instructions(self.current_service),
"message": self.get_api_key_instructions(self.current_provider),
}

self.add_user_input_to_history(user_input)
Expand All @@ -174,6 +185,7 @@ def process_input(self, user_input):
try:
return self.process_ai_response(response)
except ValueError as e:
print(f"ValueError: {str(e)}\n")
if not self.is_last_attempt(attempt):
self.add_error_prompt_to_history(str(e))
else:
Expand Down Expand Up @@ -226,7 +238,13 @@ def get_ai_response(self):
Returns:
dict: The response from the AI API, or None if the request failed.
"""
return send_api_request(self.api_client, self.conversation_history, MAX_TOKENS)
return send_api_request(
self.api_client,
self.conversation_history,
self.max_tokens,
self.current_model,
self.current_provider,
)

def is_last_attempt(self, attempt):
"""
Expand All @@ -253,9 +271,13 @@ def process_ai_response(self, response):
Raises:
ValueError: If the response cannot be parsed or the version is invalid.
"""
parsed_response = parse_response(response)
if parsed_response[0] is None:
raise ValueError("Failed to parse JSON response")
try:
parsed_response = parse_response(response, provider=self.current_provider)
except (ValueError, json.JSONDecodeError) as e:
error_message = (
f"Failed to parse AI response: {str(e)}\nRaw response: {response}"
)
raise ValueError(error_message)

(
text,
Expand Down Expand Up @@ -292,9 +314,16 @@ def update_conversation_history(self, response):
Returns:
None
"""
self.conversation_history.append(
{"role": "assistant", "content": response.content[0].text}
)
if self.current_provider.lower() == "anthropic":
self.conversation_history.append(
{"role": "assistant", "content": response.content[0].text}
)
elif self.current_provider.lower() == "openai":
self.conversation_history.append(
{"role": "assistant", "content": response.choices[0].message.content}
)
else:
raise ValueError(f"Unsupported provider: {provider}")

def create_questions_response(self, text, questions):
"""
Expand Down Expand Up @@ -370,7 +399,7 @@ def add_error_prompt_to_history(self, error_message):
Returns:
None
"""
error_prompt = f"\n\nThere was an error in your response: {error_message}. Please ensure you're using proper JSON formatting and incrementing the version number correctly. The latest version was {self.latest_version}, so the new version must be higher than this."
error_prompt = f"\n\nThere was an error in your last response: {error_message}. Please ensure you're using proper JSON formatting to avoid this error and others like it."
self.conversation_history[-1]["content"] += error_prompt

def handle_unexpected_error(self, e):
Expand Down Expand Up @@ -456,3 +485,30 @@ def is_task_in_progress(self):
bool: True if there's an ongoing task, False otherwise.
"""
return bool(self.conversation_history)

def set_model(self, provider, model):
if provider not in AI_PROVIDERS:
print(f"Invalid provider: {provider}")
return False
if model not in AI_PROVIDERS[provider]["models"]:
print(f"Invalid model {model} for provider {provider}")
return False

self.current_provider = provider
self.current_model = model
self.max_tokens = AI_PROVIDERS[self.current_provider]["models"][
self.current_model
]["max_tokens"]
self.api_client = get_api_client(self.current_provider, self.current_model)
self.api_key_set = self.api_client is not None
return self.api_key_set

def clear_conversation_history(self):
self.conversation_history = []
# We maintain the latest version across model changes

def get_latest_version(self):
return self.latest_version

def set_latest_version(self, version):
self.latest_version = version
105 changes: 93 additions & 12 deletions codeaide/ui/chat_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
QTextEdit,
QVBoxLayout,
QWidget,
QComboBox,
QLabel,
)

from codeaide.ui.code_popup import CodePopup
from codeaide.ui.example_selection_dialog import show_example_dialog
from codeaide.utils import general_utils
Expand All @@ -31,6 +32,9 @@
INITIAL_MESSAGE,
USER_FONT,
USER_MESSAGE_COLOR,
AI_PROVIDERS,
DEFAULT_PROVIDER,
DEFAULT_MODEL,
)


Expand Down Expand Up @@ -61,21 +65,46 @@ def setup_ui(self):
main_layout.setSpacing(5)
main_layout.setContentsMargins(8, 8, 8, 8)

# Create a widget for the dropdowns
dropdown_widget = QWidget()
dropdown_layout = QHBoxLayout(dropdown_widget)
dropdown_layout.setContentsMargins(0, 0, 0, 0)
dropdown_layout.setSpacing(5) # Minimal spacing between items

# Provider dropdown
self.provider_dropdown = QComboBox()
self.provider_dropdown.addItems(AI_PROVIDERS.keys())
self.provider_dropdown.setCurrentText(DEFAULT_PROVIDER)
self.provider_dropdown.currentTextChanged.connect(self.update_model_dropdown)
dropdown_layout.addWidget(QLabel("Provider:"))
dropdown_layout.addWidget(self.provider_dropdown)

# Model dropdown
self.model_dropdown = QComboBox()
self.update_model_dropdown(DEFAULT_PROVIDER)
self.model_dropdown.setCurrentText(DEFAULT_MODEL)
self.model_dropdown.currentTextChanged.connect(self.update_chat_handler)
dropdown_layout.addWidget(QLabel("Model:"))
dropdown_layout.addWidget(self.model_dropdown)

# Add stretch to push everything to the left
dropdown_layout.addStretch(1)

# Add the dropdown widget to the main layout
main_layout.addWidget(dropdown_widget)

# Chat display
self.chat_display = QTextEdit(self)
self.chat_display.setReadOnly(True)
self.chat_display.setStyleSheet(
f"background-color: {CHAT_WINDOW_BG}; color: {CHAT_WINDOW_FG}; border: 1px solid #ccc; padding: 5px;"
)
main_layout.addWidget(self.chat_display, stretch=3)

# Input text area
self.input_text = QTextEdit(self)
self.input_text.setStyleSheet(
f"""
background-color: {CHAT_WINDOW_BG};
color: {CHAT_WINDOW_FG};
border: 1px solid #ccc;
padding: 5px;
"""
f"background-color: {CHAT_WINDOW_BG}; color: {CHAT_WINDOW_FG}; border: 1px solid #ccc; padding: 5px;"
)
self.input_text.setAcceptRichText(False) # Add this line
self.input_text.setFont(general_utils.set_font(USER_FONT))
Expand All @@ -84,18 +113,17 @@ def setup_ui(self):
self.input_text.installEventFilter(self)
main_layout.addWidget(self.input_text, stretch=1)

# Buttons
button_layout = QHBoxLayout()
button_layout.setSpacing(5)

self.submit_button = QPushButton("Submit")
self.submit_button = QPushButton("Submit", self)
self.submit_button.clicked.connect(self.on_submit)
button_layout.addWidget(self.submit_button)

self.example_button = QPushButton("Use Example")
self.example_button = QPushButton("Load Example", self)
self.example_button.clicked.connect(self.load_example)
button_layout.addWidget(self.example_button)

self.exit_button = QPushButton("Exit")
self.exit_button = QPushButton("Exit", self)
self.exit_button.clicked.connect(self.on_exit)
button_layout.addWidget(self.exit_button)

Expand Down Expand Up @@ -263,3 +291,56 @@ def closeEvent(self, event):

def sigint_handler(self, *args):
QApplication.quit()

def update_model_dropdown(self, provider):
self.model_dropdown.clear()
models = AI_PROVIDERS[provider]["models"].keys()
self.model_dropdown.addItems(models)

# Set the current item to the first model in the list
if models:
self.model_dropdown.setCurrentText(list(models)[0])
else:
print(f"No models available for provider {provider}")

def update_chat_handler(self):
provider = self.provider_dropdown.currentText()
model = self.model_dropdown.currentText()

# Check if a valid model is selected
if not model:
print(f"No valid model selected for provider {provider}")
return

current_version = self.chat_handler.get_latest_version()
success = self.chat_handler.set_model(provider, model)
if not success:
self.add_to_chat(
"System",
f"Failed to set model {model} for provider {provider}. Please check your API key.",
)
return

self.chat_handler.clear_conversation_history()
self.chat_handler.set_latest_version(
current_version
) # Maintain the version number

# Add a message about switching models and the current version
self.add_to_chat(
"System",
f"""
{'='*50}
Switched to {provider} - {model}
Starting a new conversation with this model.
Current code version: {current_version}
Any new code will be versioned starting from {self.increment_version(current_version)}
{'='*50}
""",
)

self.check_api_key()

def increment_version(self, version):
major, minor = map(int, version.split("."))
return f"{major}.{minor + 1}"
Loading

0 comments on commit fcb7d3c

Please sign in to comment.