Skip to content

Commit

Permalink
Merge pull request #58 from Deeptechia/main
Browse files Browse the repository at this point in the history
v0.2.3
  • Loading branch information
kelyacf authored Apr 26, 2024
2 parents 16c4b01 + 79b9e7e commit bed9d7e
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 19 deletions.
12 changes: 8 additions & 4 deletions geppetto/gemini_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ def convert_gemini_to_slack(text):
formatted_text = formatted_text.replace("- ", "• ")
formatted_text = re.sub(r"\[(.*?)\]\((.*?)\)", r"<\2|\1>", formatted_text)

formatted_text += f"\n\n_(Geppetto v0.2.1 Source: Gemini Model {GEMINI_MODEL})_"
formatted_text += f"\n\n_(Geppetto v0.2.3 Source: Gemini Model {GEMINI_MODEL})_"

return formatted_text


class GeminiHandler(LLMHandler):

def __init__(
Expand Down Expand Up @@ -72,7 +71,12 @@ def llm_generate_content(self, user_prompt, status_callback=None, *status_callba
user_prompt = [merged_prompt] + user_prompt[2:]
response= self.client.generate_content(user_prompt)
markdown_response = convert_gemini_to_slack(response.text)
return markdown_response
if len(markdown_response) > 4000:
# Split the message if it's too long
response_parts = self.split_message(markdown_response)
return response_parts
else:
return markdown_response

def get_prompt_from_thread(self, thread: List[Dict], assistant_tag: str, user_tag: str):
prompt = super().get_prompt_from_thread(thread, assistant_tag, user_tag)
Expand All @@ -81,4 +85,4 @@ def get_prompt_from_thread(self, thread: List[Dict], assistant_tag: str, user_ta
msg[MSG_FIELD] = [msg.pop(MSG_INPUT_FIELD)]
else:
raise InvalidThreadFormatError("The input thread doesn't have the field %s" % MSG_INPUT_FIELD)
return prompt
return prompt
13 changes: 13 additions & 0 deletions geppetto/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,16 @@ def get_prompt_from_thread(self, thread: List[Dict], assistant_tag: str, user_ta
else:
raise InvalidThreadFormatError("The input thread doesn't have the field %s" % ROLE_FIELD)
return prompt

def split_message(self, message):
"""
Split a message into parts if it exceeds 4000 characters.
Args:
message (str): The message to split.
Returns:
List[str]: A list of message parts.
"""
max_length = 4000
return [message[i:i+max_length] for i in range(0, len(message), max_length)]
11 changes: 8 additions & 3 deletions geppetto/openai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def convert_openai_markdown_to_slack(text):
formatted_text = formatted_text.replace("__", "_")
formatted_text = formatted_text.replace("- ", "• ")
formatted_text = re.sub(r"\[(.*?)\]\((.*?)\)", r"<\2|\1>", formatted_text)
formatted_text += f"\n\n_(Geppetto v0.2.1 Source: OpenAI Model {CHATGPT_MODEL})_"
formatted_text += f"\n\n_(Geppetto v0.2.3 Source: OpenAI Model {CHATGPT_MODEL})_"

# Code blocks and italics remain unchanged but can be explicitly formatted if necessary
return formatted_text
Expand Down Expand Up @@ -157,12 +157,17 @@ def llm_generate_content(self, user_prompt, status_callback=None, *status_callba
function_args = json.loads(tool_call.function.arguments)
function = available_functions[function_name]
if function_name == OPENAI_IMG_FUNCTION and status_callback:
status_callback(*status_callback_args, ":geppetto: I'm preparing the image, please be patient "
status_callback(*status_callback_args, "I'm preparing the image, please be patient "
":lower_left_paintbrush: ...")
response = function(**function_args)
return response
else:
response = response.choices[0].message.content
markdown_response = convert_openai_markdown_to_slack(response)
return markdown_response
if len(markdown_response) > 4000:
# Split the message if it's too long
response_parts = self.split_message(markdown_response)
return response_parts
else:
return markdown_response

27 changes: 19 additions & 8 deletions geppetto/slack_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def handle_message(self, msg, channel_id, thread_id):
response = self.send_message(
channel_id,
thread_id,
":geppetto: :thought_balloon: ..."
":thought_balloon:"
)

if response["ok"]:
Expand Down Expand Up @@ -96,12 +96,23 @@ def handle_message(self, msg, channel_id, thread_id):
logging.info(
"response from %s: %s" % (self.name, response_from_llm_api)
)
self.app.client.chat_update(
channel=channel_id,
text=response_from_llm_api,
thread_ts=thread_id,
ts=timestamp,
)
# If there are multiple parts, send each part separately
if isinstance(response_from_llm_api, list):
for part in response_from_llm_api:
self.app.client.chat_postMessage(
channel=channel_id,
text=part,
thread_ts=thread_id,
mrkdwn=True
)
else:
# If it is a single message, send it normally
self.app.client.chat_update(
channel=channel_id,
text=response_from_llm_api,
thread_ts=thread_id,
ts=timestamp,
)
except Exception as e:
logging.error("Error posting message: %s", e)

Expand Down Expand Up @@ -152,4 +163,4 @@ def select_llm_from_msg(self, message, last_llm=''):
return last_llm
else:
# default first LLM
return controlled_llms[0]
return controlled_llms[0]
4 changes: 2 additions & 2 deletions tests/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_llm_generate_content(self, mock_to_markdown):
mock_response = Mock()
mock_response.text = "Mocked Gemini response"
self.gemini_handler.client.generate_content.return_value = mock_response
mock_to_markdown.return_value = Mock(data="Mocked Markdown data")
mock_to_markdown.return_value = "Mocked Markdown data"

response = self.gemini_handler.llm_generate_content(user_prompt)

self.assertEqual(response.data, "Mocked Markdown data")
self.assertEqual(response, "Mocked Markdown data")
mock_to_markdown.assert_called_once_with("Mocked Gemini response")

def test_get_prompt_from_thread(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_random_user_allowed_with_wildcard_permission(self):

self.MockApp().client.chat_postMessage.assert_called_with(
channel="test_channel",
text=":geppetto: :thought_balloon: ...",
text=":thought_balloon:",
thread_ts="1",
mrkdwn=True
)
Expand All @@ -110,7 +110,7 @@ def test_handle_message(self):

self.MockApp().client.chat_postMessage.assert_called_with(
channel=channel_id,
text=":geppetto: :thought_balloon: ...",
text=":thought_balloon:",
thread_ts=thread_id,
mrkdwn=True
)
Expand Down

0 comments on commit bed9d7e

Please sign in to comment.