Skip to content

Commit

Permalink
Merge pull request #49 from dougollerenshaw/add_gemini
Browse files Browse the repository at this point in the history
Add gemini to available models
  • Loading branch information
dougollerenshaw authored Oct 10, 2024
2 parents edf893b + a320f04 commit 537c3ea
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 39 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,7 @@ codeaide.spec
dist/

# Ignore .dmg files
*.dmg
*.dmg

# Ignore .egg-info directories
*.egg-info/
16 changes: 13 additions & 3 deletions codeaide/logic/chat_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from codeaide.utils.constants import (
MAX_RETRIES,
AI_PROVIDERS,
DEFAULT_MODEL,
DEFAULT_PROVIDER,
INITIAL_MESSAGE,
)
Expand Down Expand Up @@ -57,7 +56,9 @@ def __init__(self):
self.api_client = None
self.api_key_set = False
self.current_provider = DEFAULT_PROVIDER
self.current_model = DEFAULT_MODEL
self.current_model = list(AI_PROVIDERS[self.current_provider]["models"].keys())[
0
]
self.max_tokens = AI_PROVIDERS[self.current_provider]["models"][
self.current_model
]["max_tokens"]
Expand Down Expand Up @@ -346,8 +347,17 @@ def update_conversation_history(self, response):
self.conversation_history.append(
{"role": "assistant", "content": response.choices[0].message.content}
)
elif self.current_provider.lower() == "google":
self.conversation_history.append(
{
"role": "assistant",
"content": response.candidates[0].content.parts[0].text,
}
)
else:
raise ValueError(f"Unsupported provider: {self.current_provider}")
raise ValueError(
f"In update_conversation_history, unsupported provider: {self.current_provider}"
)
self.file_handler.save_chat_history(self.conversation_history)

def create_questions_response(self, text, questions):
Expand Down
17 changes: 10 additions & 7 deletions codeaide/ui/chat_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def setup_ui(self):

# Model dropdown
self.model_dropdown = QComboBox()
self.update_model_dropdown(DEFAULT_PROVIDER)
self.update_model_dropdown(DEFAULT_PROVIDER, add_message_to_chat=False)
self.model_dropdown.currentTextChanged.connect(self.update_chat_handler)
dropdown_layout.addWidget(QLabel("Model:"))
dropdown_layout.addWidget(self.model_dropdown)
Expand Down Expand Up @@ -335,9 +335,6 @@ def call_process_input_async(self, user_input):
f"ChatWindow: call_process_input_async called with input: {user_input[:50]}..."
)
response = self.chat_handler.process_input(user_input)
self.logger.info(
f"ChatWindow: Received response from chat handler: {str(response)[:50]}..."
)
self.handle_response(response)

def on_modify(self):
Expand Down Expand Up @@ -474,17 +471,23 @@ def force_close(self):
def sigint_handler(self, *args):
QApplication.quit()

def update_model_dropdown(self, provider):
def update_model_dropdown(self, provider, add_message_to_chat=False):
self.model_dropdown.clear()
models = AI_PROVIDERS[provider]["models"].keys()
self.model_dropdown.addItems(models)

# Set the current item to the first model in the list
# Set the current item to the first model in the list (default)
if models:
self.model_dropdown.setCurrentText(list(models)[0])
default_model = list(models)[0]
self.model_dropdown.setCurrentText(default_model)
self.logger.info(f"Set default model for {provider} to {default_model}")
else:
self.logger.info(f"No models available for provider {provider}")

# Update the chat handler with the selected model if add_message_to_chat is True
if add_message_to_chat:
self.update_chat_handler()

def update_chat_handler(self):
provider = self.provider_dropdown.currentText()
model = self.model_dropdown.currentText()
Expand Down
67 changes: 61 additions & 6 deletions codeaide/utils/api_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import anthropic
import openai
import google.generativeai as genai
from decouple import AutoConfig
import hjson
import re
from google.generativeai.types import GenerationConfig

from codeaide.utils.constants import (
AI_PROVIDERS,
DEFAULT_MODEL,
DEFAULT_PROVIDER,
SYSTEM_PROMPT,
)
Expand All @@ -23,7 +25,7 @@ def __init__(self, service):
)


def get_api_client(provider=DEFAULT_PROVIDER, model=DEFAULT_MODEL):
def get_api_client(provider=DEFAULT_PROVIDER, model=None):
try:
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Expand All @@ -47,8 +49,12 @@ def get_api_client(provider=DEFAULT_PROVIDER, model=DEFAULT_MODEL):
return anthropic.Anthropic(api_key=api_key)
elif provider.lower() == "openai":
return openai.OpenAI(api_key=api_key)
elif provider.lower() == "google":
genai.configure(api_key=api_key)
client = genai.GenerativeModel(model, system_instruction=SYSTEM_PROMPT)
return client
else:
raise ValueError(f"Unsupported provider: {provider}")
raise ValueError(f"In get_api_client, unsupported provider: {provider}")
except Exception as e:
logger.error(f"Error initializing {provider.capitalize()} API client: {str(e)}")
return None
Expand Down Expand Up @@ -112,6 +118,25 @@ def send_api_request(api_client, conversation_history, max_tokens, model, provid
)
if not response.choices:
return None
elif provider.lower() == "google":
# Convert conversation history to the format expected by Google Gemini
prompt = ""
for message in conversation_history:
role = message["role"]
content = message["content"]
prompt += f"{role.capitalize()}: {content}\n\n"

# Create a GenerationConfig object
generation_config = GenerationConfig(
max_output_tokens=max_tokens,
temperature=0.7, # You can adjust this as needed
top_p=0.95, # You can adjust this as needed
top_k=40, # You can adjust this as needed
)

response = api_client.generate_content(
contents=prompt, generation_config=generation_config
)
else:
raise NotImplementedError(f"API request for {provider} not implemented")

Expand All @@ -127,7 +152,7 @@ def parse_response(response, provider):
if not response:
raise ValueError("Empty or invalid response received")

logger.debug(f"Received response: {response}")
logger.info(f"Received response: {response}")

if provider.lower() == "anthropic":
if not response.content:
Expand All @@ -137,8 +162,10 @@ def parse_response(response, provider):
if not response.choices:
raise ValueError("Empty or invalid response received")
json_str = response.choices[0].message.content
elif provider.lower() == "google":
json_str = response.candidates[0].content.parts[0].text
else:
raise ValueError(f"Unsupported provider: {provider}")
raise ValueError(f"In parse_response, unsupported provider: {provider}")

# Remove the triple backticks and language identifier if present
if json_str.startswith("```json"):
Expand All @@ -164,16 +191,44 @@ def parse_response(response, provider):
requirements = outer_json.get("requirements", [])
questions = outer_json.get("questions", [])

# Clean the code if it exists
if code:
code = clean_code(code)

return text, questions, code, code_version, version_description, requirements


def clean_code(code):
"""
Clean the code by removing triple backticks and language identifiers.
Args:
code (str): The code string to clean.
Returns:
str: The cleaned code string.
"""
# Remove triple backticks and language identifier at the start
code = re.sub(r"^```[\w-]*\n", "", code, flags=re.MULTILINE)

# Remove triple backticks at the end
code = re.sub(r"\n```$", "", code, flags=re.MULTILINE)

# Trim any leading or trailing whitespace
code = code.strip()

return code


def check_api_connection():
client = get_api_client()
if client is None:
return False, "API key is missing or invalid"
try:
provider = DEFAULT_PROVIDER
model = list(AI_PROVIDERS[provider]["models"].keys())[0]
response = client.messages.create(
model=DEFAULT_MODEL,
model=model,
max_tokens=100,
messages=[{"role": "user", "content": "Are we communicating?"}],
)
Expand Down
18 changes: 14 additions & 4 deletions codeaide/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# API Configuration
# This dictionary defines the supported API providers and the supported models for each.
# The max_tokens argument is the max output tokens, which is generally specified in the API documentation
# The default model for each provider will be the first model in the list
AI_PROVIDERS = {
"google": {
"api_key_name": "GEMINI_API_KEY",
"models": {
"gemini-1.5-pro": {"max_tokens": 8192},
"gemini-1.5-flash": {"max_tokens": 8192},
},
},
"anthropic": {
"api_key_name": "ANTHROPIC_API_KEY",
"models": {
Expand All @@ -19,9 +29,8 @@
},
}

# Default model
DEFAULT_MODEL = "claude-3-5-sonnet-20240620"
DEFAULT_PROVIDER = "anthropic"
# This sets the default provider when the application launches
DEFAULT_PROVIDER = "google"

# Other existing constants remain unchanged
MAX_RETRIES = 3
Expand Down Expand Up @@ -104,10 +113,11 @@
* Ensure that newlines in string literals are properly contained within the string delimiters and do not break the code structure.
* Add inline comments to explain complex parts of the code or to provide additional context where necessary. However, avoid excessive commenting that may clutter the code.
* All code must be contained within a single file. If the code requires multiple classes or functions, include them all in the same code block.
* Do not include triple backticks ("```") or language identifiers in the code block.
Remember, the goal is to provide valuable, working code solutions while maintaining a balance between making reasonable assumptions and seeking clarification when truly necessary.
Format your responses as a JSON object with six keys:
* 'text': a string that contains any natural language explanations or comments that you think are helpful for the user. This should never be null or incomplete. If you mention providing a list or explanation, ensure it is fully included here. If you have no text response, provide a brief explanation of the code or the assumptions made.
* 'text': a string that contains any natural language explanations or comments that you think are helpful for the user. This should never be null or incomplete. If you mention providing a list or explanation, ensure it is fully included here. If you have no text response, provide a brief explanation of the code or the assumptions made. Use plain text, not markdown.
* 'questions': an array of strings that pose necessary follow-up questions to the user
* 'code': a string with the properly formatted, complete code block. This must include all necessary components for the code to run, including any previously implemented methods or classes. This should be null only if you have questions or text responses but no code to provide.
* 'code_version': a string that represents the version of the code. Start at 1.0 and increment for each new version of the code you provide. Use your judgement on whether to increment the minor or major component of the version. It is critical that version numbers never be reused during a chat and that the numbers always increment upward. This field should be null if you have no code to provide.
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
anthropic==0.34.2
python-decouple
google-generativeai==0.8.3
python-decouple==3.8
virtualenv==20.16.2
numpy==1.26.4
numpy==1.26.4
openai
hjson
pyyaml
Expand Down
40 changes: 40 additions & 0 deletions sandbox/prototype_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
This is a little sandbox script to test out the Gemini API.
"""

import argparse
from decouple import config
import google.generativeai as genai
from codeaide.utils.constants import SYSTEM_PROMPT

genai.configure(api_key=config("GEMINI_API_KEY"))

model = genai.GenerativeModel("gemini-1.5-pro", system_instruction=SYSTEM_PROMPT)


def generate_a_story():
response = model.generate_content("Write a story about a magic backpack.")
print(response.text)


def request_code():
response = model.generate_content(
"Write a Python function to calculate the Fibonacci sequence."
)
print(response.text)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Gemini API prototype script")
parser.add_argument(
"action",
choices=["story", "code"],
help="Action to perform: generate a story or request code",
)

args = parser.parse_args()

if args.action == "story":
generate_a_story()
elif args.action == "code":
request_code()
7 changes: 1 addition & 6 deletions tests/ui/test_chat_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from codeaide.utils.constants import (
AI_PROVIDERS,
DEFAULT_PROVIDER,
DEFAULT_MODEL,
MODEL_SWITCH_MESSAGE,
)

Expand Down Expand Up @@ -82,11 +81,7 @@ def test_model_switching(chat_window, mock_chat_handler, caplog):
test_provider = next(
provider for provider in AI_PROVIDERS.keys() if provider != DEFAULT_PROVIDER
)
test_model = next(
model
for model in AI_PROVIDERS[test_provider]["models"].keys()
if model != DEFAULT_MODEL
)
test_model = list(AI_PROVIDERS[test_provider]["models"].keys())[0]

window.provider_dropdown.setCurrentText(test_provider)
window.model_dropdown.setCurrentText(test_model)
Expand Down
Loading

0 comments on commit 537c3ea

Please sign in to comment.