Skip to content

Commit

Permalink
merge gemini handler with slack handler
Browse files Browse the repository at this point in the history
  • Loading branch information
ignaciopenia committed Apr 16, 2024
2 parents d2fa52d + ea80253 commit 18e4d98
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 35 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ Before running the application, copy the `.configuration/.env.example` file into
- `SIGNING_SECRET`: Your Signing secret to verify Slack requests (from your Slack App Credentials).
- `DALLE_MODEL`: The OpenAI DALL-E-3 model.
- `CHATGPT_MODEL`: The OpenAI ChatGPT-4 model.
- `GOOGLE_API_KEY`: The Google Gemini API Key.
- `GEMINI_MODEL`: The Gemini model.

## Deployment

Expand Down
2 changes: 2 additions & 0 deletions config/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ OPENAI_API_KEY = "YOUR_TOKEN"
CHATGPT_MODEL = "gpt-4"
DALLE_MODEL = "dall-e-3"
SIGNING_SECRET = "YOUR_SECRET"
GOOGLE_API_KEY = "YOUR_TOKEN"
GEMINI_MODEL = "gemini-pro"
61 changes: 61 additions & 0 deletions geppetto/gemini_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from urllib.request import urlopen
import logging

from .exceptions import InvalidThreadFormatError
from .llm_api_handler import LLMHandler
from dotenv import load_dotenv
from typing import List, Dict
import os
import textwrap
import google.generativeai as genai
from IPython.display import display
from IPython.display import Markdown

load_dotenv(os.path.join("config", ".env"))

GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
GEMINI_MODEL=os.getenv("GEMINI_MODEL", "gemini-pro")
MSG_FIELD = "parts"
MSG_INPUT_FIELD = "content"

def to_markdown(text):
text = text.replace('•', ' *')
return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))

class GeminiHandler(LLMHandler):

def __init__(
self,
personality,
):
super().__init__(
'Gemini',
GEMINI_MODEL,
genai.GenerativeModel(GEMINI_MODEL),
)
self.personality = personality
self.system_role = "system"
self.assistant_role = "model"
self.user_role = "user"
genai.configure(api_key=GOOGLE_API_KEY)

def llm_generate_content(self, user_prompt, status_callback=None, *status_callback_args):
logging.info("Sending msg to gemini: %s" % user_prompt)
if len(user_prompt) >= 2 and user_prompt[0].get('role') == 'user' and user_prompt[1].get('role') == 'user':
merged_prompt = {
'role': 'user',
'parts': [msg['parts'][0] for msg in user_prompt[:2]]
}
user_prompt = [merged_prompt] + user_prompt[2:]
response= self.client.generate_content(user_prompt)
markdown_response = to_markdown(response.text)
return str(markdown_response.data)

def get_prompt_from_thread(self, thread: List[Dict], assistant_tag: str, user_tag: str):
prompt = super().get_prompt_from_thread(thread, assistant_tag, user_tag)
for msg in prompt:
if MSG_INPUT_FIELD in msg:
msg[MSG_FIELD] = [msg.pop(MSG_INPUT_FIELD)]
else:
raise InvalidThreadFormatError("The input thread doesn't have the field %s" % MSG_INPUT_FIELD)
return prompt
16 changes: 13 additions & 3 deletions geppetto/llm_api_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Dict
from typing import List, Dict, Callable
from .exceptions import InvalidThreadFormatError

ROLE_FIELD = "role"

class LLMHandler(ABC):
def __init__(self, name, model, client):
Expand All @@ -15,6 +17,14 @@ def get_info(self):
def llm_generate_content(self, prompt: str, callback: Callable, *callback_args):
pass

@abstractmethod
def get_prompt_from_thread(self, thread: List[Dict], assistant_tag: str, user_tag: str):
pass
prompt = []
for msg in thread:
formatted_msg = dict(msg)
if ROLE_FIELD in formatted_msg:
formatted_msg[ROLE_FIELD] = formatted_msg[ROLE_FIELD].replace(assistant_tag, self.assistant_role)
formatted_msg[ROLE_FIELD] = formatted_msg[ROLE_FIELD].replace(user_tag, self.user_role)
prompt.append(formatted_msg)
else:
raise InvalidThreadFormatError("The input thread doesn't have the field %s" % ROLE_FIELD)
return prompt
23 changes: 17 additions & 6 deletions geppetto/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .llm_controller import LLMController
from .slack_handler import SlackHandler
from .openai_handler import OpenAIHandler
from .gemini_handler import GeminiHandler
from slack_bolt.adapter.socket_mode import SocketModeHandler
from .utils import load_json

Expand All @@ -24,18 +25,28 @@

def initialized_llm_controller():
controller = LLMController(
[{
"name": "OpenAI",
"handler": OpenAIHandler,
"handler_args": {
"personality": DEFAULT_RESPONSES["features"]["personality"]
[
{
"name": "OpenAI",
"handler": OpenAIHandler,
"handler_args": {
"personality": DEFAULT_RESPONSES["features"]["personality"]
}
},
{
"name": "Gemini",
"handler": GeminiHandler,
"handler_args": {
"personality": DEFAULT_RESPONSES["features"]["personality"]
}
}
}]
]
)
controller.init_controller()
return controller



def main():
Slack_Handler = SlackHandler(
load_json("allowed-slack-ids.json"),
Expand Down
12 changes: 0 additions & 12 deletions geppetto/openai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,3 @@ def llm_generate_content(self, user_prompt, status_callback=None, *status_callba
return response
else:
return response.choices[0].message.content

def get_prompt_from_thread(self, thread: List[Dict], assistant_tag: str, user_tag: str):
prompt = []
for msg in thread:
formatted_msg = dict(msg)
if ROLE_FIELD in formatted_msg:
formatted_msg[ROLE_FIELD] = formatted_msg[ROLE_FIELD].replace(assistant_tag, self.assistant_role)
formatted_msg[ROLE_FIELD] = formatted_msg[ROLE_FIELD].replace(user_tag, self.user_role)
prompt.append(formatted_msg)
else:
raise InvalidThreadFormatError("The input thread doesn't have the field %s" % ROLE_FIELD)
return prompt
4 changes: 2 additions & 2 deletions geppetto/slack_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import re

from geppetto.utils import is_image_data, lower_string_list

# Set SSL certificate for secure requests
os.environ["SSL_CERT_FILE"] = certifi.where()

# UI roles
USER = "slack_user"
ASSISTANT = "geppetto"


class SlackHandler:

def __init__(
Expand Down Expand Up @@ -140,7 +140,7 @@ def send_message(self, channel_id, thread_id, message, tag="general"):
)

def select_llm_from_msg(self, message, last_llm=''):
mentions = re.findall(r'\#[^\ ]*', message)
mentions = re.findall(r'(?<=\bllm_)\w+', message)
clean_mentions = [re.sub(r'[\#\!\?\,\;\.]', "", mention) for mention in mentions]
hashtags = lower_string_list(clean_mentions)
controlled_llms = self.llm_ctrl.list_llms()
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ classifiers = [
]

[tool.poetry.dependencies]
python = "^3.8"
python = "^3.9"
certifi = "^2023.11.17"
openai = "^1.4.0"
python-dotenv = "^1.0.0"
slack-bolt = "^1.18.1"
slack-sdk = "^3.26.1"
Pillow = "^10.1.0"
google-generativeai = "^0.5.0"
IPython = "^8.0.0"

[tool.poetry.scripts]
geppetto = "geppetto.main:main"
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ openai>=1.4.0
python-dotenv==1.0.0
slack-bolt>=1.18.1
slack-sdk>=3.26.1
pillow>=10.1.0
pillow>=10.1.0
google-generativeai>=0.5.0
IPython >=8.0.0
98 changes: 98 additions & 0 deletions tests/test_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import sys
import unittest
from unittest.mock import Mock, patch

script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(script_dir)
sys.path.append(parent_dir)

from geppetto.exceptions import InvalidThreadFormatError
from geppetto.gemini_handler import GeminiHandler

def OF(**kw):
class OF:
pass
instance = OF()
for k, v in kw.items():
setattr(instance, k, v)
return instance

class TestGemini(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.patcher = patch("geppetto.gemini_handler.genai")
cls.mock_genai = cls.patcher.start()
cls.gemini_handler = GeminiHandler(personality="Your AI personality")

@classmethod
def tearDownClass(cls):
cls.patcher.stop()

def test_personality(self):
self.assertEqual(self.gemini_handler.personality, "Your AI personality")

@patch("geppetto.gemini_handler.to_markdown")
def test_llm_generate_content(self, mock_to_markdown):
user_prompt = [
{"role": "user", "parts": ["Hello"]},
{"role": "user", "parts": ["How are you?"]}
]
mock_response = Mock()
mock_response.text = "Mocked Gemini response"
self.gemini_handler.client.generate_content.return_value = mock_response
mock_to_markdown.return_value.data = "Mocked Markdown data"

response = self.gemini_handler.llm_generate_content(user_prompt)

self.assertEqual(response, "Mocked Markdown data")
mock_to_markdown.assert_called_once_with("Mocked Gemini response")

def test_get_prompt_from_thread(self):
thread = [
{"role": "slack_user", "content": "Message 1"},
{"role": "geppetto", "content": "Message 2"}
]

ROLE_FIELD = "role"
MSG_FIELD = "parts"

prompt = self.gemini_handler.get_prompt_from_thread(
thread, assistant_tag="geppetto", user_tag="slack_user"
)

self.assertIsInstance(prompt, list)

for msg in prompt:
self.assertIsInstance(msg, dict)
self.assertIn(ROLE_FIELD, msg)
self.assertIn(MSG_FIELD, msg)
self.assertIsInstance(msg[MSG_FIELD], list)
self.assertTrue(msg[MSG_FIELD])

with self.assertRaises(InvalidThreadFormatError):
incomplete_thread = [{"role": "geppetto"}]
self.gemini_handler.get_prompt_from_thread(
incomplete_thread, assistant_tag="geppetto", user_tag="slack_user"
)

def test_llm_generate_content_user_repetition(self):
user_prompt = [
{"role": "user", "parts": ["Hello"]},
{"role": "user", "parts": ["How are you?"]},
{"role": "geppetto", "parts": ["I'm fine."]}
]

with patch.object(self.gemini_handler.client, "generate_content") as mock_generate_content:
mock_response = Mock()
mock_response.text = "Mocked Gemini response"
mock_generate_content.return_value = mock_response

self.gemini_handler.llm_generate_content(user_prompt)

mock_generate_content.assert_called_once_with(
[{"role": "user", "parts": ["Hello", "How are you?"]}, {"role": "geppetto", "parts": ["I'm fine."]}]
)

if __name__ == "__main__":
unittest.main()
20 changes: 10 additions & 10 deletions tests/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_handle_message_switch_simple(self):

# Case B: LLM B
self.MockLLMHandlerB().llm_generate_content.return_value = MOCK_GENERIC_LLM_RESPONSE_B
message_b = "Test message #llmb"
message_b = "Test message llm_llmb"
self.slack_handler.handle_message(message_b, channel_id, thread_id)
self.assertIn(
{"role": "slack_user", "content": message_b},
Expand All @@ -157,7 +157,7 @@ def test_handle_message_switch_same_thread_continue_non_default(self):

# Case C: LLM C
self.MockLLMHandlerC().llm_generate_content.return_value = MOCK_GENERIC_LLM_RESPONSE_C
message_c = "Test message #llmc"
message_c = "Test message llm_llmc"
self.slack_handler.handle_message(message_c, channel_id, thread_id)
self.assertIn(
{"role": "slack_user", "content": message_c},
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_handle_message_switch_same_thread_reset_on_switch(self):

# Case A: LLM A
self.MockLLMHandlerA().llm_generate_content.return_value = MOCK_GENERIC_LLM_RESPONSE
message_a = "Test message #llma"
message_a = "Test message llm_llma"
self.slack_handler.handle_message(message_a, channel_id, thread_id)
user_msg_a = {"role": "slack_user", "content": message_a}
geppetto_msg_a = {"role": "geppetto", "content": MOCK_GENERIC_LLM_RESPONSE}
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_handle_message_switch_same_thread_reset_on_switch(self):

# SWITCH TO LLM C in an ongoing conversation
self.MockLLMHandlerC().llm_generate_content.return_value = MOCK_GENERIC_LLM_RESPONSE_C
message_c = "Test message #llmc"
message_c = "Test message llm_llmc"
user_msg_c = {"role": "slack_user", "content": message_c}
geppetto_msg_c = {"role": "geppetto", "content": MOCK_GENERIC_LLM_RESPONSE_C}
self.slack_handler.handle_message(message_c, channel_id, thread_id)
Expand Down Expand Up @@ -250,14 +250,14 @@ def test_handle_message_switch_different_thread(self):

# --- LLM B on thread I ---
self.MockLLMHandlerB().llm_generate_content.return_value = MOCK_GENERIC_LLM_RESPONSE
message_b = "Test message #llmb"
message_b = "Test message llm_llmb"
self.slack_handler.handle_message(message_b, channel_id, thread_id_i)
user_msg_b = {"role": "slack_user", "content": message_b}
geppetto_msg_b = {"role": "geppetto", "content": MOCK_GENERIC_LLM_RESPONSE}

# --- LLM C on thread II ---
self.MockLLMHandlerC().llm_generate_content.return_value = MOCK_GENERIC_LLM_RESPONSE
message_c = "Test message #llmc"
message_c = "Test message llm_llmc"
self.slack_handler.handle_message(message_c, channel_id, thread_id_ii)
user_msg_c = {"role": "slack_user", "content": message_c}
geppetto_msg_c = {"role": "geppetto", "content": MOCK_GENERIC_LLM_RESPONSE}
Expand Down Expand Up @@ -345,11 +345,11 @@ def test_handle_image(self):
)

def test_select_llm_from_msg(self):
message_a = "#llma Test message"
message_b = "Test #llmb# message"
message_c = "Test message #llmc?"
message_a = "llm_llma Test message"
message_b = "Test llm_llmb message"
message_c = "Test message llm_llmc?"
message_default_empty = "Test message"
message_default_many = "#llmc Test #llmb message #llma"
message_default_many = "llm_llmc Test llm_llmb message llm_llma"
message_default_wrong = "Test message #zeta"

self.assertEqual(self.slack_handler.select_llm_from_msg(
Expand Down

0 comments on commit 18e4d98

Please sign in to comment.