forked from CoinFabrik/geppetto
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
merge gemini handler with slack handler
- Loading branch information
Showing
11 changed files
with
211 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from urllib.request import urlopen | ||
import logging | ||
|
||
from .exceptions import InvalidThreadFormatError | ||
from .llm_api_handler import LLMHandler | ||
from dotenv import load_dotenv | ||
from typing import List, Dict | ||
import os | ||
import textwrap | ||
import google.generativeai as genai | ||
from IPython.display import display | ||
from IPython.display import Markdown | ||
|
||
load_dotenv(os.path.join("config", ".env")) | ||
|
||
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | ||
GEMINI_MODEL=os.getenv("GEMINI_MODEL", "gemini-pro") | ||
MSG_FIELD = "parts" | ||
MSG_INPUT_FIELD = "content" | ||
|
||
def to_markdown(text): | ||
text = text.replace('•', ' *') | ||
return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True)) | ||
|
||
class GeminiHandler(LLMHandler): | ||
|
||
def __init__( | ||
self, | ||
personality, | ||
): | ||
super().__init__( | ||
'Gemini', | ||
GEMINI_MODEL, | ||
genai.GenerativeModel(GEMINI_MODEL), | ||
) | ||
self.personality = personality | ||
self.system_role = "system" | ||
self.assistant_role = "model" | ||
self.user_role = "user" | ||
genai.configure(api_key=GOOGLE_API_KEY) | ||
|
||
def llm_generate_content(self, user_prompt, status_callback=None, *status_callback_args): | ||
logging.info("Sending msg to gemini: %s" % user_prompt) | ||
if len(user_prompt) >= 2 and user_prompt[0].get('role') == 'user' and user_prompt[1].get('role') == 'user': | ||
merged_prompt = { | ||
'role': 'user', | ||
'parts': [msg['parts'][0] for msg in user_prompt[:2]] | ||
} | ||
user_prompt = [merged_prompt] + user_prompt[2:] | ||
response= self.client.generate_content(user_prompt) | ||
markdown_response = to_markdown(response.text) | ||
return str(markdown_response.data) | ||
|
||
def get_prompt_from_thread(self, thread: List[Dict], assistant_tag: str, user_tag: str): | ||
prompt = super().get_prompt_from_thread(thread, assistant_tag, user_tag) | ||
for msg in prompt: | ||
if MSG_INPUT_FIELD in msg: | ||
msg[MSG_FIELD] = [msg.pop(MSG_INPUT_FIELD)] | ||
else: | ||
raise InvalidThreadFormatError("The input thread doesn't have the field %s" % MSG_INPUT_FIELD) | ||
return prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import os | ||
import sys | ||
import unittest | ||
from unittest.mock import Mock, patch | ||
|
||
script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
parent_dir = os.path.dirname(script_dir) | ||
sys.path.append(parent_dir) | ||
|
||
from geppetto.exceptions import InvalidThreadFormatError | ||
from geppetto.gemini_handler import GeminiHandler | ||
|
||
def OF(**kw): | ||
class OF: | ||
pass | ||
instance = OF() | ||
for k, v in kw.items(): | ||
setattr(instance, k, v) | ||
return instance | ||
|
||
class TestGemini(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.patcher = patch("geppetto.gemini_handler.genai") | ||
cls.mock_genai = cls.patcher.start() | ||
cls.gemini_handler = GeminiHandler(personality="Your AI personality") | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.patcher.stop() | ||
|
||
def test_personality(self): | ||
self.assertEqual(self.gemini_handler.personality, "Your AI personality") | ||
|
||
@patch("geppetto.gemini_handler.to_markdown") | ||
def test_llm_generate_content(self, mock_to_markdown): | ||
user_prompt = [ | ||
{"role": "user", "parts": ["Hello"]}, | ||
{"role": "user", "parts": ["How are you?"]} | ||
] | ||
mock_response = Mock() | ||
mock_response.text = "Mocked Gemini response" | ||
self.gemini_handler.client.generate_content.return_value = mock_response | ||
mock_to_markdown.return_value.data = "Mocked Markdown data" | ||
|
||
response = self.gemini_handler.llm_generate_content(user_prompt) | ||
|
||
self.assertEqual(response, "Mocked Markdown data") | ||
mock_to_markdown.assert_called_once_with("Mocked Gemini response") | ||
|
||
def test_get_prompt_from_thread(self): | ||
thread = [ | ||
{"role": "slack_user", "content": "Message 1"}, | ||
{"role": "geppetto", "content": "Message 2"} | ||
] | ||
|
||
ROLE_FIELD = "role" | ||
MSG_FIELD = "parts" | ||
|
||
prompt = self.gemini_handler.get_prompt_from_thread( | ||
thread, assistant_tag="geppetto", user_tag="slack_user" | ||
) | ||
|
||
self.assertIsInstance(prompt, list) | ||
|
||
for msg in prompt: | ||
self.assertIsInstance(msg, dict) | ||
self.assertIn(ROLE_FIELD, msg) | ||
self.assertIn(MSG_FIELD, msg) | ||
self.assertIsInstance(msg[MSG_FIELD], list) | ||
self.assertTrue(msg[MSG_FIELD]) | ||
|
||
with self.assertRaises(InvalidThreadFormatError): | ||
incomplete_thread = [{"role": "geppetto"}] | ||
self.gemini_handler.get_prompt_from_thread( | ||
incomplete_thread, assistant_tag="geppetto", user_tag="slack_user" | ||
) | ||
|
||
def test_llm_generate_content_user_repetition(self): | ||
user_prompt = [ | ||
{"role": "user", "parts": ["Hello"]}, | ||
{"role": "user", "parts": ["How are you?"]}, | ||
{"role": "geppetto", "parts": ["I'm fine."]} | ||
] | ||
|
||
with patch.object(self.gemini_handler.client, "generate_content") as mock_generate_content: | ||
mock_response = Mock() | ||
mock_response.text = "Mocked Gemini response" | ||
mock_generate_content.return_value = mock_response | ||
|
||
self.gemini_handler.llm_generate_content(user_prompt) | ||
|
||
mock_generate_content.assert_called_once_with( | ||
[{"role": "user", "parts": ["Hello", "How are you?"]}, {"role": "geppetto", "parts": ["I'm fine."]}] | ||
) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters