From e27441aef34e6e66efc7542296b645016f76bf18 Mon Sep 17 00:00:00 2001 From: kaarthik108 Date: Fri, 1 Mar 2024 23:30:35 +1300 Subject: [PATCH] use mistral from Groq --- .gitignore | 8 +++++- chain.py | 60 ++++++++++++++------------------------------ main.py | 36 ++++++++++++++++++++++---- ui/styles.md | 10 ++++++-- utils/snowchat_ui.py | 53 ++++++++++++++++++++++++-------------- 5 files changed, 99 insertions(+), 68 deletions(-) diff --git a/.gitignore b/.gitignore index 33b021f..e08e5cf 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,10 @@ archived_logs/ build/ snowchat.egg-info/ -chroma_db \ No newline at end of file +chroma_db + +pplx.py + +test.json +test.* +app.py \ No newline at end of file diff --git a/chain.py b/chain.py index 7505f36..edd9b8b 100644 --- a/chain.py +++ b/chain.py @@ -1,8 +1,7 @@ from typing import Any, Callable, Dict, Optional -import boto3 import streamlit as st -from langchain.chat_models import BedrockChat, ChatOpenAI +from langchain.chat_models import ChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.llms import OpenAI from langchain.vectorstores import SupabaseVectorStore @@ -34,7 +33,7 @@ class ModelConfig(BaseModel): @validator("model_type", pre=True, always=True) def validate_model_type(cls, v): - if v not in ["gpt", "codellama", "mistral"]: + if v not in ["gpt", "mistral", "gemini"]: raise ValueError(f"Unsupported model type: {v}") return v @@ -53,8 +52,8 @@ def __init__(self, config: ModelConfig): def setup(self): if self.model_type == "gpt": self.setup_gpt() - elif self.model_type == "codellama": - self.setup_codellama() + elif self.model_type == "gemini": + self.setup_gemini() elif self.model_type == "mistral": self.setup_mixtral() @@ -63,7 +62,7 @@ def setup_gpt(self): model_name="gpt-3.5-turbo-0125", temperature=0.2, api_key=self.secrets["OPENAI_API_KEY"], - max_tokens=500, + max_tokens=1000, callbacks=[self.callback_handler], streaming=True, base_url=self.gateway_url, @@ -71,51 +70,30 @@ def setup_gpt(self): def setup_mixtral(self): self.llm = ChatOpenAI( - model_name="mistralai/mistral-medium", + model_name="mixtral-8x7b-32768", temperature=0.2, - api_key=self.secrets["OPENROUTER_API_KEY"], - max_tokens=500, + api_key=self.secrets["GROQ_API_KEY"], + max_tokens=3000, callbacks=[self.callback_handler], streaming=True, - base_url="https://openrouter.ai/api/v1", + base_url="https://api.groq.com/openai/v1", ) - def setup_codellama(self): + def setup_gemini(self): self.llm = ChatOpenAI( - model_name="codellama/codellama-70b-instruct", + model_name="google/gemini-pro", temperature=0.2, api_key=self.secrets["OPENROUTER_API_KEY"], - max_tokens=500, + max_tokens=1200, callbacks=[self.callback_handler], streaming=True, base_url="https://openrouter.ai/api/v1", + default_headers={ + "HTTP-Referer": "https://snowchat.streamlit.app/", + "X-Title": "Snowchat", + }, ) - # def setup_claude(self): - # bedrock_runtime = boto3.client( - # service_name="bedrock-runtime", - # aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"], - # aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"], - # region_name="us-east-1", - # ) - # parameters = { - # "max_tokens_to_sample": 1000, - # "stop_sequences": [], - # "temperature": 0, - # "top_p": 0.9, - # } - # self.q_llm = BedrockChat( - # model_id="anthropic.claude-instant-v1", client=bedrock_runtime - # ) - - # self.llm = BedrockChat( - # model_id="anthropic.claude-instant-v1", - # client=bedrock_runtime, - # callbacks=[self.callback_handler], - # streaming=True, - # model_kwargs=parameters, - # ) - def get_chain(self, vectorstore): def _combine_documents( docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n" @@ -153,12 +131,12 @@ def load_chain(model_name="GPT-3.5", callback_handler=None): query_name="v_match_documents", ) - if "codellama" in model_name.lower(): - model_type = "codellama" - elif "GPT-3.5" in model_name: + if "GPT-3.5" in model_name: model_type = "gpt" elif "mistral" in model_name.lower(): model_type = "mistral" + elif "gemini" in model_name.lower(): + model_type = "gemini" else: raise ValueError(f"Unsupported model name: {model_name}") diff --git a/main.py b/main.py index 157e63a..894db25 100644 --- a/main.py +++ b/main.py @@ -14,16 +14,43 @@ chat_history = [] snow_ddl = Snowddl() -st.title("snowChat") +gradient_text_html = """ + +
snowChat
+""" + +st.markdown(gradient_text_html, unsafe_allow_html=True) + st.caption("Talk your way through data") model = st.radio( "", - options=["✨ GPT-3.5", "♾️ codellama", "πŸ‘‘ Mistral"], + options=["GPT-3.5 - OpenAI", "Gemini 1.5 - Openrouter", "Mistral 8x7B - Groq"], index=0, horizontal=True, ) st.session_state["model"] = model +if "toast_shown" not in st.session_state: + st.session_state["toast_shown"] = False + +# Show the toast only if it hasn't been shown before +if not st.session_state["toast_shown"]: + st.toast("The snowflake data retrieval is disabled for now.", icon="πŸ‘‹") + st.session_state["toast_shown"] = True + +if st.session_state["model"] == "πŸ‘‘ Mistral 8x7B - Groq": + st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️") + INITIAL_MESSAGE = [ {"role": "user", "content": "Hi!"}, { @@ -38,10 +65,8 @@ with open("ui/styles.md", "r") as styles_file: styles_content = styles_file.read() -# Display the DDL for the selected table st.sidebar.markdown(sidebar_content) -# Create a sidebar with a dropdown menu selected_table = st.sidebar.selectbox( "Select a table:", options=list(snow_ddl.ddl_dict.keys()) ) @@ -81,9 +106,10 @@ message["content"], True if message["role"] == "user" else False, True if message["role"] == "data" else False, + model, ) -callback_handler = StreamlitUICallbackHandler() +callback_handler = StreamlitUICallbackHandler(model) chain = load_chain(st.session_state["model"], callback_handler) diff --git a/ui/styles.md b/ui/styles.md index de84444..79238e4 100644 --- a/ui/styles.md +++ b/ui/styles.md @@ -11,8 +11,14 @@ background-color: white; z-index: 100; } - h1 { - font-family: 'Roboto Slab', serif; + h1, h2 { + font-weight: bold; + background: -webkit-linear-gradient(left, red, orange); + background: linear-gradient(to right, red, orange); + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + display: inline; + font-size: 3em; } .user-avatar { float: right; diff --git a/utils/snowchat_ui.py b/utils/snowchat_ui.py index d630bb4..30e3446 100644 --- a/utils/snowchat_ui.py +++ b/utils/snowchat_ui.py @@ -4,6 +4,24 @@ import streamlit as st from langchain.callbacks.base import BaseCallbackHandler +image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/" +gemini_url = image_url + "google-gemini-icon.png?t=2024-03-01T07%3A25%3A59.637Z" +mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png" +openai_url = ( + image_url + + "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-03-01T07%3A41%3A47.586Z" +) + + +def get_model_url(model_name): + if "gpt" in model_name.lower(): + return openai_url + elif "gemini" in model_name.lower(): + return gemini_url + elif "mistral" in model_name.lower(): + return mistral_url + return mistral_url + def format_message(text): """ @@ -26,7 +44,7 @@ def format_message(text): return formatted_text -def message_func(text, is_user=False, is_df=False): +def message_func(text, is_user=False, is_df=False, model="gpt"): """ This function is used to display the messages in the chatbot UI. @@ -35,6 +53,9 @@ def message_func(text, is_user=False, is_df=False): is_user (bool): Whether the message is from the user or not. is_df (bool): Whether the message is a dataframe or not. """ + model_url = get_model_url(model) + + avatar_url = model_url if is_user: avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=ShortHairShortFlat&accessoriesType=Prescription01&hairColor=Auburn&facialHairType=BeardLight&facialHairColor=Black&clotheType=Hoodie&clotheColor=PastelBlue&eyeType=Squint&eyebrowType=DefaultNatural&mouthType=Smile&skinColor=Tanned" message_alignment = "flex-end" @@ -45,13 +66,12 @@ def message_func(text, is_user=False, is_df=False):
{text} \n
- avatar + avatar
""", unsafe_allow_html=True, ) else: - avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light" message_alignment = "flex-start" message_bg_color = "#71797E" avatar_class = "bot-avatar" @@ -60,7 +80,7 @@ def message_func(text, is_user=False, is_df=False): st.write( f"""
- avatar + avatar
""", unsafe_allow_html=True, @@ -73,8 +93,8 @@ def message_func(text, is_user=False, is_df=False): st.write( f"""
- avatar -
+ avatar +
{text} \n
""", @@ -83,11 +103,13 @@ def message_func(text, is_user=False, is_df=False): class StreamlitUICallbackHandler(BaseCallbackHandler): - def __init__(self): + def __init__(self, model): self.token_buffer = [] self.placeholder = st.empty() self.has_streaming_ended = False self.has_streaming_started = False + self.model = model + self.avatar_url = get_model_url(model) def start_loading_message(self): loading_message_content = self._get_bot_message_container("Thinking...") @@ -109,17 +131,11 @@ def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs): def _get_bot_message_container(self, text): """Generate the bot's message container style for the given text.""" - avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light" - message_alignment = "flex-start" - message_bg_color = "#71797E" - avatar_class = "bot-avatar" - formatted_text = format_message( - text - ) # Ensure this handles "Thinking..." appropriately. + formatted_text = format_message(text) container_content = f""" -
- avatar -
+
+ avatar +
{formatted_text} \n
""" @@ -129,14 +145,13 @@ def display_dataframe(self, df): """ Display the dataframe in Streamlit UI within the chat container. """ - avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light" message_alignment = "flex-start" avatar_class = "bot-avatar" st.write( f"""
- avatar + avatar
""", unsafe_allow_html=True,