From 42a9f919c2f5742a1f668df48146722733566bed Mon Sep 17 00:00:00 2001 From: gurdeep330 <[gurdeep330@gmail.com]> Date: Sat, 8 Feb 2025 20:15:46 +0100 Subject: [PATCH 1/6] feat: add-ollama to T2B --- .../talk2biomodels/agents/t2b_agent.py | 8 +++--- .../talk2biomodels/tools/custom_plotter.py | 3 +-- .../talk2biomodels/tools/get_annotation.py | 4 +-- .../talk2biomodels/tools/search_models.py | 2 +- app/frontend/streamlit_app_talk2biomodels.py | 27 +++++++++++-------- app/frontend/utils/streamlit_utils.py | 10 ++++++- 6 files changed, 34 insertions(+), 20 deletions(-) diff --git a/aiagents4pharma/talk2biomodels/agents/t2b_agent.py b/aiagents4pharma/talk2biomodels/agents/t2b_agent.py index 8e9a41b3..dea66aec 100644 --- a/aiagents4pharma/talk2biomodels/agents/t2b_agent.py +++ b/aiagents4pharma/talk2biomodels/agents/t2b_agent.py @@ -8,6 +8,8 @@ from typing import Annotated import hydra from langchain_openai import ChatOpenAI +from langchain_ollama import ChatOllama +from langchain_core.language_models.chat_models import BaseChatModel from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import START, StateGraph from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState @@ -26,7 +28,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def get_app(uniq_id, llm_model='gpt-4o-mini'): +def get_app(uniq_id, llm_model: BaseChatModel = ChatOllama(model='llama3.2:1b', temperature=0)): ''' This function returns the langraph app. ''' @@ -52,7 +54,7 @@ def agent_t2b_node(state: Annotated[dict, InjectedState]): ]) # Define the model - llm = ChatOpenAI(model=llm_model, temperature=0) + # llm = ChatOpenAI(model=llm_model, temperature=0) # Load hydra configuration logger.log(logging.INFO, "Load Hydra configuration for Talk2BioModels agent.") with hydra.initialize(version_base=None, config_path="../../configs"): @@ -62,7 +64,7 @@ def agent_t2b_node(state: Annotated[dict, InjectedState]): logger.log(logging.INFO, "state_modifier: %s", cfg.state_modifier) # Create the agent model = create_react_agent( - llm, + llm_model, tools=tools, state_schema=Talk2Biomodels, state_modifier=cfg.state_modifier, diff --git a/aiagents4pharma/talk2biomodels/tools/custom_plotter.py b/aiagents4pharma/talk2biomodels/tools/custom_plotter.py index b0df439d..07051c3a 100644 --- a/aiagents4pharma/talk2biomodels/tools/custom_plotter.py +++ b/aiagents4pharma/talk2biomodels/tools/custom_plotter.py @@ -8,7 +8,6 @@ from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal from pydantic import BaseModel, Field import pandas as pd -from langchain_openai import ChatOpenAI from langchain_core.tools import BaseTool from langgraph.prebuilt import InjectedState @@ -83,7 +82,7 @@ class CustomHeader(TypedDict): description="""List of species based on user question. If no relevant species are found, it will be None.""") # Create an instance of the LLM model - llm = ChatOpenAI(model=state['llm_model'], temperature=0) + llm = state['llm_model'] llm_with_structured_output = llm.with_structured_output(CustomHeader) results = llm_with_structured_output.invoke(question) extracted_species = [] diff --git a/aiagents4pharma/talk2biomodels/tools/get_annotation.py b/aiagents4pharma/talk2biomodels/tools/get_annotation.py index 52b9bb85..c09611ef 100644 --- a/aiagents4pharma/talk2biomodels/tools/get_annotation.py +++ b/aiagents4pharma/talk2biomodels/tools/get_annotation.py @@ -17,7 +17,7 @@ from langchain_core.tools.base import BaseTool from langchain_core.tools.base import InjectedToolCallId from langchain_core.messages import ToolMessage -from langchain_openai import ChatOpenAI +# from langchain_openai import ChatOpenAI from .load_biomodel import ModelData, load_biomodel from ..api.uniprot import search_uniprot_labels from ..api.ols import search_ols_labels @@ -58,7 +58,7 @@ class CustomHeader(TypedDict): If no relevant species are found, it must be None.""") # Create an instance of the LLM model - llm = ChatOpenAI(model=state['llm_model'], temperature=0) + llm = state['llm_model'] # Get the structured output from the LLM model llm_with_structured_output = llm.with_structured_output(CustomHeader) # Define the question for the LLM model using the prompt diff --git a/aiagents4pharma/talk2biomodels/tools/search_models.py b/aiagents4pharma/talk2biomodels/tools/search_models.py index 1155c2e2..a38fcac0 100644 --- a/aiagents4pharma/talk2biomodels/tools/search_models.py +++ b/aiagents4pharma/talk2biomodels/tools/search_models.py @@ -44,7 +44,7 @@ def _run(self, dict: The answer to the question in the form of a dictionary. """ search_results = biomodels.search_for_model(query) - llm = ChatOpenAI(model=state['llm_model']) + llm = state['llm_model'] # Check if run_manager's metadata has the key 'prompt_content' prompt_content = f''' Convert the input into a table. diff --git a/app/frontend/streamlit_app_talk2biomodels.py b/app/frontend/streamlit_app_talk2biomodels.py index f1a6d25c..b61ace2d 100644 --- a/app/frontend/streamlit_app_talk2biomodels.py +++ b/app/frontend/streamlit_app_talk2biomodels.py @@ -12,6 +12,7 @@ from streamlit_feedback import streamlit_feedback from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.messages import ChatMessage +from langchain_ollama import ChatOllama from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.tracers.context import collect_runs from langchain.callbacks.tracers import LangChainTracer @@ -23,17 +24,17 @@ st.set_page_config(page_title="Talk2Biomodels", page_icon="🤖", layout="wide") -# st.logo( -# image='docs/VPE.png', -# size='large', -# link='https://github.com/VirtualPatientEngine' -# ) +st.logo( + image='docs/VPE.png', + size='large', + link='https://github.com/VirtualPatientEngine' +) # Check if env variable OPENAI_API_KEY exists -if "OPENAI_API_KEY" not in os.environ: - st.error("Please set the OPENAI_API_KEY environment \ - variable in the terminal where you run the app.") - st.stop() +# if "OPENAI_API_KEY" not in os.environ: +# st.error("Please set the OPENAI_API_KEY environment \ +# variable in the terminal where you run the app.") +# st.stop() # Create a chat prompt template prompt = ChatPromptTemplate.from_messages([ @@ -128,7 +129,7 @@ def get_uploaded_files(): unsafe_allow_html=True) # LLM panel (Only at the front-end for now) - llms = ["gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"] + llms = ["gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo", "llama3.2:1b"] st.selectbox( "Pick an LLM to power the agent", llms, @@ -258,9 +259,13 @@ def get_uploaded_files(): config, {"sbml_file_path": [st.session_state.sbml_file_path]} ) + # app.update_state( + # config, + # {"llm_model": st.session_state.llm_model} + # ) app.update_state( config, - {"llm_model": st.session_state.llm_model} + {"llm_model": ChatOllama(model='llama3.2:1b', temperature=0)} ) # print (current_state.values) # current_state = app.get_state(config) diff --git a/app/frontend/utils/streamlit_utils.py b/app/frontend/utils/streamlit_utils.py index 76fc2cb9..2790f2d7 100644 --- a/app/frontend/utils/streamlit_utils.py +++ b/app/frontend/utils/streamlit_utils.py @@ -8,6 +8,8 @@ import pandas as pd import plotly.express as px from langsmith import Client +from langchain_ollama import ChatOllama +from langchain_openai import ChatOpenAI def submit_feedback(user_response): ''' @@ -153,7 +155,13 @@ def update_llm_model(): """ Function to update the LLM model. """ - llm_model = st.session_state.llm_model + # llm_model = st.session_state.llm_model + if st.session_state.llm_model.startswith("llama"): + llm_model = ChatOllama(model=st.session_state.llm_model, + temperature=0) + else: + llm_model = ChatOpenAI(model=st.session_state.llm_model, + temperature=0) st.warning(f"Clicking 'Continue' will reset all agents, \ set the selected LLM to {llm_model}. \ This action will reset the entire app, \ From e3c7c797311b2393b9d8bfdcd82885cc8c3add86 Mon Sep 17 00:00:00 2001 From: gurdeep330 Date: Mon, 10 Feb 2025 20:57:50 +0100 Subject: [PATCH 2/6] fix: pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e2659bd6..8f3d4945 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "aiagents4pharma" -description = "AI Agents for drug discovery, drug development, and other pharmaceutical R&D" +description = "AI Agents for drug discovery, drug development, and other pharmaceutical R&D." readme = "README.md" requires-python = ">=3.12" classifiers = [ From dfd697e142e1c5dec825099ba257807cd6e47651 Mon Sep 17 00:00:00 2001 From: gurdeep330 Date: Tue, 11 Feb 2025 15:28:37 +0100 Subject: [PATCH 3/6] fix: add nvidia backend for langgraph and update requirements --- aiagents4pharma/talk2biomodels/pyproject.toml | 1 + pyproject.toml | 1 + requirements.txt | 1 + 3 files changed, 3 insertions(+) diff --git a/aiagents4pharma/talk2biomodels/pyproject.toml b/aiagents4pharma/talk2biomodels/pyproject.toml index 14578acb..c77cb1da 100644 --- a/aiagents4pharma/talk2biomodels/pyproject.toml +++ b/aiagents4pharma/talk2biomodels/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "langchain-community==0.3.5", "langchain-core==0.3.15", "langchain-experimental==0.3.3", + "langchain-nvidia-ai-endpoints==0.3.9", "langchain-openai==0.2.5", "langgraph==0.2.62", "matplotlib==3.9.2", diff --git a/pyproject.toml b/pyproject.toml index 4d75f962..c56c9131 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "langchain-community==0.3.5", "langchain-core==0.3.31", "langchain-experimental==0.3.3", + "langchain-nvidia-ai-endpoints==0.3.9", "langchain-openai==0.2.5", "langchain_ollama==0.2.2", "langgraph==0.2.66", diff --git a/requirements.txt b/requirements.txt index 5a534400..4eb8e0af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ langchain==0.3.7 langchain-community==0.3.5 langchain-core==0.3.31 langchain-experimental==0.3.3 +langchain-nvidia-ai-endpoints==0.3.9 langchain-openai==0.3.0 langchain_ollama==0.2.2 langgraph==0.2.66 From c04fcd2b5c039f931347aa6e48131257380d6b77 Mon Sep 17 00:00:00 2001 From: gurdeep330 Date: Tue, 11 Feb 2025 15:46:00 +0100 Subject: [PATCH 4/6] fix: add secrets to the yml file --- .github/workflows/tests_talk2biomodels.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests_talk2biomodels.yml b/.github/workflows/tests_talk2biomodels.yml index d3cc30a2..908d8108 100644 --- a/.github/workflows/tests_talk2biomodels.yml +++ b/.github/workflows/tests_talk2biomodels.yml @@ -14,6 +14,7 @@ on: env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + NVIDIA_API_KEY: ${{ secrets.NVIDIA_API_KEY }} # This workflow contains jobs covering linting and code coverage (along with testing). jobs: From 486417be47349470364765f7b1ba91caeb603b1c Mon Sep 17 00:00:00 2001 From: gurdeep330 Date: Tue, 11 Feb 2025 19:05:00 +0100 Subject: [PATCH 5/6] fix: steady_state prompt --- .../configs/tools/ask_question/default.yaml | 22 ++++++------------- .../talk2biomodels/tests/test_steady_state.py | 4 ++-- app/frontend/streamlit_app_talk2biomodels.py | 2 +- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml b/aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml index 28aaeb17..8369c852 100644 --- a/aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +++ b/aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml @@ -10,22 +10,14 @@ steady_state_prompt: > Here are some instructions to help you answer questions: - 1. Before you answer any question, follow the plan and solve - technique. Start by understanding the question, then plan your - approach to solve the question, and finally solve the question - by following the plan. Always give a brief explanation of your - answer to the user. + 1. If the user wants to know the time taken by the model to reach + steady state, you should look at the `steady_state_transition_time` + column of the data for the model species. + + 2. The highest value in the column `steady_state_transition_time` + is the time taken by the model to reach steady state. - 2. If the user wants to know the time taken by the model to reach - steady state, you should look at the steady_state_transition_time - column of the data for the model species. The highest value in - this column is the time taken by the model to reach steady state. - - 3. To get accurate results, trim the data to the relevant columns - before performing any calculations. This will help you avoid - errors in your calculations, and ignore irrelevant data. - - 4. Please use the units provided below to answer the questions. + 3. Please use the units provided below to answer the questions. simulation_prompt: > Following is the information about the data frame: 1. First column is the time column, and the rest of the columns diff --git a/aiagents4pharma/talk2biomodels/tests/test_steady_state.py b/aiagents4pharma/talk2biomodels/tests/test_steady_state.py index eb443aae..af099206 100644 --- a/aiagents4pharma/talk2biomodels/tests/test_steady_state.py +++ b/aiagents4pharma/talk2biomodels/tests/test_steady_state.py @@ -40,8 +40,8 @@ def test_steady_state_tool(): ######################################################### # In this case, we will test if the tool is indeed invoked # successfully - prompt = """Run a steady state analysis of model 64. - Set the initial concentration of `Pyruvate` to 0.2. The + prompt = """Bring model 64 to a steady state. Set the + initial concentration of `Pyruvate` to 0.2. The concentration of `NAD` resets to 100 every 2 time units.""" # Invoke the agent app.invoke( diff --git a/app/frontend/streamlit_app_talk2biomodels.py b/app/frontend/streamlit_app_talk2biomodels.py index 7dc6856c..6c69f9c3 100644 --- a/app/frontend/streamlit_app_talk2biomodels.py +++ b/app/frontend/streamlit_app_talk2biomodels.py @@ -128,7 +128,7 @@ def get_uploaded_files(): unsafe_allow_html=True) # LLM panel - llms = ["meta/llama-3.3-70b-instruct", "gpt-4o-mini"] + llms = ["gpt-4o-mini", "meta/llama-3.3-70b-instruct"] st.selectbox( "Pick an LLM to power the agent", llms, From 8a69413471b8aa5d1f458e4eac9a2cddc97b6bd4 Mon Sep 17 00:00:00 2001 From: gurdeep330 Date: Wed, 12 Feb 2025 14:52:25 +0100 Subject: [PATCH 6/6] fix: address reviews --- README.md | 5 +- app/frontend/streamlit_app_talk2biomodels.py | 43 ++++--- app/frontend/utils/streamlit_utils.py | 128 ++++++++++++++----- docs/index.md | 5 +- 4 files changed, 123 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 872998db..2d885c96 100644 --- a/README.md +++ b/README.md @@ -44,15 +44,14 @@ Check out the tutorials on each agent for detailed instrcutions. ```bash pip install . ``` -3. **Initialize OPENAI_API_KEY and/or NVIDIA_API_KEY** +3. **Initialize OPENAI_API_KEY and NVIDIA_API_KEY** ```bash export OPENAI_API_KEY=.... ``` ```bash export NVIDIA_API_KEY=.... ``` -_Please note to run the agents effectively, atleast one of the OPENAI/NVIDIA -keys must be set. You can create a free account at NVIDIA and apply for their +_You can create a free account at NVIDIA and apply for their free credits [here](https://build.nvidia.com/explore/discover)._ 4. **[Optional] Initialize LANGSMITH_API_KEY** diff --git a/app/frontend/streamlit_app_talk2biomodels.py b/app/frontend/streamlit_app_talk2biomodels.py index 6c69f9c3..bac532e1 100644 --- a/app/frontend/streamlit_app_talk2biomodels.py +++ b/app/frontend/streamlit_app_talk2biomodels.py @@ -8,34 +8,36 @@ import sys import random import streamlit as st -import pandas as pd from streamlit_feedback import streamlit_feedback from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.messages import ChatMessage -from langchain_ollama import ChatOllama from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.tracers.context import collect_runs -from langchain.callbacks.tracers import LangChainTracer from utils import streamlit_utils -sys.path.append('./') -from aiagents4pharma.talk2biomodels.agents.t2b_agent import get_app -# from talk2biomodels.agents.t2b_agent import get_app st.set_page_config(page_title="Talk2Biomodels", page_icon="🤖", layout="wide") - - +# Set the logo st.logo( image='docs/assets/VPE.png', size='large', link='https://github.com/VirtualPatientEngine' ) -# Check if env variable OPENAI_API_KEY exists -# if "OPENAI_API_KEY" not in os.environ: -# st.error("Please set the OPENAI_API_KEY environment \ -# variable in the terminal where you run the app.") -# st.stop() +# Check if env variables OPENAI_API_KEY and/or +# NVIDIA_API_KEY exist +if "OPENAI_API_KEY" not in os.environ or "NVIDIA_API_KEY" not in os.environ: + st.error("Please set the OPENAI_API_KEY and NVIDIA_API_KEY " + "environment variables in the terminal where you run " + "the app. For more information, please refer to our " + "[documentation](https://virtualpatientengine.github.io/AIAgents4Pharma/#option-2-git).") + st.stop() + +# Import the agent +sys.path.append('./') +from aiagents4pharma.talk2biomodels.agents.t2b_agent import get_app +######################################################################################## +# Streamlit app +######################################################################################## # Create a chat prompt template prompt = ChatPromptTemplate.from_messages([ ("system", "Welcome to Talk2Biomodels!"), @@ -68,6 +70,7 @@ if "llm_model" not in st.session_state: st.session_state.app = get_app(st.session_state.unique_id) else: + print (st.session_state.llm_model) st.session_state.app = get_app(st.session_state.unique_id, llm_model=streamlit_utils.get_base_chat_model( st.session_state.llm_model)) @@ -127,8 +130,9 @@ def get_uploaded_files(): """, unsafe_allow_html=True) - # LLM panel - llms = ["gpt-4o-mini", "meta/llama-3.3-70b-instruct"] + # LLM model panel + llms = ["OpenAI/gpt-4o-mini", + "NVIDIA/llama-3.3-70b-instruct"] st.selectbox( "Pick an LLM to power the agent", llms, @@ -139,13 +143,14 @@ def get_uploaded_files(): ) # Text embedding model panel - text_models = ["nvidia/llama-3.2-nv-embedqa-1b-v2", "text-embedding-ada-002"] + text_models = ["NVIDIA/llama-3.2-nv-embedqa-1b-v2", + "OpenAI/text-embedding-ada-002"] st.selectbox( "Pick a text embedding model", text_models, index=0, key="text_embedding_model", - on_change=streamlit_utils.update_embed_model, + on_change=streamlit_utils.update_text_embedding_model, kwargs={"app": app}, help="Used for Retrival Augmented Generation (RAG) and other tasks." ) @@ -236,7 +241,7 @@ def get_uploaded_files(): config, {"llm_model": streamlit_utils.get_base_chat_model( st.session_state.llm_model), - "text_embedding_model": streamlit_utils.get_embedding_model( + "text_embedding_model": streamlit_utils.get_text_embedding_model( st.session_state.text_embedding_model)} ) intro_prompt = "Tell your name and about yourself. Always start with a greeting." diff --git a/app/frontend/utils/streamlit_utils.py b/app/frontend/utils/streamlit_utils.py index d48f5679..226d30bf 100644 --- a/app/frontend/utils/streamlit_utils.py +++ b/app/frontend/utils/streamlit_utils.py @@ -12,6 +12,8 @@ from langchain_openai import ChatOpenAI from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings from langchain_openai.embeddings import OpenAIEmbeddings +from langchain_core.language_models import BaseChatModel +from langchain_core.embeddings import Embeddings from langchain_core.messages import AIMessageChunk, HumanMessage, ChatMessage, AIMessage from langchain_core.tracers.context import collect_runs from langchain.callbacks.tracers import LangChainTracer @@ -19,6 +21,9 @@ def submit_feedback(user_response): ''' Function to submit feedback to the developers. + + Args: + user_response: dict: The user response ''' client = Client() client.create_feedback( @@ -73,6 +78,12 @@ def render_toggle(key: str, save_toggle: bool = False): """ Function to render the toggle button to show/hide the table. + + Args: + key: str: The key for the toggle button + toggle_text: str: The text for the toggle button + toggle_state: bool: The state of the toggle button + save_toggle: bool: Flag to save the toggle button to the chat history """ st.toggle( toggle_text, @@ -93,7 +104,6 @@ def render_toggle(key: str, def render_plotly(df: pd.DataFrame, key: str, title: str, - # tool_name: str, save_chart: bool = False ): """ @@ -101,6 +111,9 @@ def render_plotly(df: pd.DataFrame, Args: df: pd.DataFrame: The input dataframe + key: str: The key for the plotly chart + title: str: The title of the plotly chart + save_chart: bool: Flag to save the chart to the chat history """ # toggle_state = st.session_state[f'toggle_plotly_{tool_name}_{key.split("_")[-1]}']\ toggle_state = st.session_state[f'toggle_plotly_{key.split("plotly_")[1]}'] @@ -132,12 +145,16 @@ def render_plotly(df: pd.DataFrame, }) def render_table(df: pd.DataFrame, - # tool_name: str, key: str, save_table: bool = False ): """ Function to render the table in the chat. + + Args: + df: pd.DataFrame: The input dataframe + key: str: The key for the table + save_table: bool: Flag to save the table to the chat history """ # print (st.session_state['toggle_simulate_model_'+key.split("_")[-1]]) # toggle_state = st.session_state[f'toggle_table_{tool_name}_{key.split("_")[-1]}'] @@ -155,16 +172,10 @@ def render_table(df: pd.DataFrame, # "tool_name": tool_name }) -def stream_response(response): - for chunk in response: - if not isinstance(chunk[0], AIMessageChunk): - # print (chunk) - continue - # print (chunk) - if 'branch:agent:should_continue:tools' not in chunk[1]['langgraph_triggers']: - yield chunk[0].content - def sample_questions(): + """ + Function to get the sample questions. + """ questions = [ "Search for all the BioModels on Crohn's Disease", "Briefly describe biomodel 971 and simulate it for 50 days with an interval of 50.", @@ -173,6 +184,22 @@ def sample_questions(): ] return questions +def stream_response(response): + """ + Function to stream the response from the agent. + + Args: + response: dict: The response from the agent + """ + for chunk in response: + # Stream only the AIMessageChunk + if not isinstance(chunk[0], AIMessageChunk): + continue + # print (chunk) + # Exclude the tool calls that are not part of the conversation + if 'branch:agent:should_continue:tools' not in chunk[1]['langgraph_triggers']: + yield chunk[0].content + def get_response(app, st, prompt): # Create config for the agent config = {"configurable": {"thread_id": st.session_state.unique_id}} @@ -192,12 +219,20 @@ def get_response(app, st, prompt): # Add Langsmith tracer tracer = LangChainTracer(project_name=st.session_state.project_name) # Get response from the agent - response = app.stream( + if current_state.values['llm_model']._llm_type == 'chat-nvidia-ai-playground': + response = app.invoke( {"messages": [HumanMessage(content=prompt)]}, config=config|{"callbacks": [tracer]}, - stream_mode="messages" - ) - st.write_stream(stream_response(response)) + # stream_mode="messages" + ) + st.markdown(response["messages"][-1].content) + else: + response = app.stream( + {"messages": [HumanMessage(content=prompt)]}, + config=config|{"callbacks": [tracer]}, + stream_mode="messages" + ) + st.write_stream(stream_response(response)) # print (cb.traced_runs) # Save the run id and use to save the feedback st.session_state.run_id = cb.traced_runs[-1].id @@ -377,20 +412,46 @@ def get_response(app, st, prompt): "tool_name": msg.name }) -def get_embedding_model(model_name): - if model_name.startswith("nvidia"): - return NVIDIAEmbeddings(model=model_name) - return OpenAIEmbeddings(model=model_name) - -def get_base_chat_model(model_name): - if model_name.startswith("llama"): - return ChatOllama(model=model_name, - temperature=0) - elif model_name.startswith("meta"): - return ChatNVIDIA(model=model_name, - temperature=0) - return ChatOpenAI(model=model_name, - temperature=0) +def get_text_embedding_model(model_name) -> Embeddings: + ''' + Function to get the text embedding model. + + Args: + model_name: str: The name of the model + + Returns: + Embeddings: The text embedding model + ''' + dic_text_embedding_models = { + "NVIDIA/llama-3.2-nv-embedqa-1b-v2": "nvidia/llama-3.2-nv-embedqa-1b-v2", + "OpenAI/text-embedding-ada-002": "text-embedding-ada-002" + } + if model_name.startswith("NVIDIA"): + return NVIDIAEmbeddings(model=dic_text_embedding_models[model_name]) + return OpenAIEmbeddings(model=dic_text_embedding_models[model_name]) + +def get_base_chat_model(model_name) -> BaseChatModel: + ''' + Function to get the base chat model. + + Args: + model_name: str: The name of the model + + Returns: + BaseChatModel: The base chat model + ''' + dic_llm_models = { + "NVIDIA/llama-3.3-70b-instruct": "meta/llama-3.3-70b-instruct", + "OpenAI/gpt-4o-mini": "gpt-4o-mini" + } + if model_name.startswith("Llama"): + return ChatOllama(model=dic_llm_models[model_name], + temperature=0) + elif model_name.startswith("NVIDIA"): + return ChatNVIDIA(model=dic_llm_models[model_name], + temperature=0) + return ChatOpenAI(model=dic_llm_models[model_name], + temperature=0) @st.dialog("Warning ⚠️") def update_llm_model(): @@ -413,18 +474,19 @@ def update_llm_model(): del st.session_state[key] st.rerun() -# @st.dialog("Warning ⚠️") -def update_embed_model(app): +def update_text_embedding_model(app): """ Function to update the text embedding model. + + Args: + app: The LangGraph app """ - # text_embed_model = st.session_state.text_embedding_model config = {"configurable": {"thread_id": st.session_state.unique_id} } app.update_state( config, - {"text_embedding_model": get_embedding_model( + {"text_embedding_model": get_text_embedding_model( st.session_state.text_embedding_model)} ) diff --git a/docs/index.md b/docs/index.md index 00e86770..a3510bf9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -44,15 +44,14 @@ Check out the tutorials on each agent for detailed instrcutions. ```bash pip install . ``` -3. **Initialize OPENAI_API_KEY and/or NVIDIA_API_KEY** +3. **Initialize OPENAI_API_KEY and NVIDIA_API_KEY** ```bash export OPENAI_API_KEY=.... ``` ```bash export NVIDIA_API_KEY=.... ``` -_Please note to run the agents effectively, atleast one of the OPENAI/NVIDIA -keys must be set. You can create a free account at NVIDIA and apply for their +_You can create a free account at NVIDIA and apply for their free credits [here](https://build.nvidia.com/explore/discover)._ 4. **[Optional] Initialize LANGSMITH_API_KEY**