Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add nvidia nim #100

Merged
merged 8 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests_talk2biomodels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ on:

env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
NVIDIA_API_KEY: ${{ secrets.NVIDIA_API_KEY }}

# This workflow contains jobs covering linting and code coverage (along with testing).
jobs:
Expand Down
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,17 @@ Check out the tutorials on each agent for detailed instrcutions.
```bash
pip install .
```
3. **Initialize OPENAI_API_KEY**
3. **Initialize OPENAI_API_KEY and/or NVIDIA_API_KEY**
```bash
export OPENAI_API_KEY=....
```
```bash
export NVIDIA_API_KEY=....
```
_Please note to run the agents effectively, atleast one of the OPENAI/NVIDIA
keys must be set. You can create a free account at NVIDIA and apply for their
free credits [here](https://build.nvidia.com/explore/discover)._

4. **[Optional] Initialize LANGSMITH_API_KEY**
```bash
export LANGCHAIN_TRACING_V2=true
Expand Down
8 changes: 4 additions & 4 deletions aiagents4pharma/talk2biomodels/agents/t2b_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Annotated
import hydra
from langchain_openai import ChatOpenAI
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import create_react_agent, ToolNode, InjectedState
Expand All @@ -26,7 +27,8 @@
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def get_app(uniq_id, llm_model='gpt-4o-mini'):
def get_app(uniq_id,
llm_model: BaseChatModel = ChatOpenAI(model='gpt-4o-mini', temperature=0)):
'''
This function returns the langraph app.
'''
Expand All @@ -51,8 +53,6 @@ def agent_t2b_node(state: Annotated[dict, InjectedState]):
QueryArticle()
])

# Define the model
llm = ChatOpenAI(model=llm_model, temperature=0)
# Load hydra configuration
logger.log(logging.INFO, "Load Hydra configuration for Talk2BioModels agent.")
with hydra.initialize(version_base=None, config_path="../configs"):
Expand All @@ -62,7 +62,7 @@ def agent_t2b_node(state: Annotated[dict, InjectedState]):
logger.log(logging.INFO, "state_modifier: %s", cfg.state_modifier)
# Create the agent
model = create_react_agent(
llm,
llm_model,
tools=tools,
state_schema=Talk2Biomodels,
state_modifier=cfg.state_modifier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,14 @@ steady_state_prompt: >

Here are some instructions to help you answer questions:

1. Before you answer any question, follow the plan and solve
technique. Start by understanding the question, then plan your
approach to solve the question, and finally solve the question
by following the plan. Always give a brief explanation of your
answer to the user.
1. If the user wants to know the time taken by the model to reach
steady state, you should look at the `steady_state_transition_time`
column of the data for the model species.

2. The highest value in the column `steady_state_transition_time`
is the time taken by the model to reach steady state.

2. If the user wants to know the time taken by the model to reach
steady state, you should look at the steady_state_transition_time
column of the data for the model species. The highest value in
this column is the time taken by the model to reach steady state.

3. To get accurate results, trim the data to the relevant columns
before performing any calculations. This will help you avoid
errors in your calculations, and ignore irrelevant data.

4. Please use the units provided below to answer the questions.
3. Please use the units provided below to answer the questions.
simulation_prompt: >
Following is the information about the data frame:
1. First column is the time column, and the rest of the columns
Expand Down
1 change: 1 addition & 0 deletions aiagents4pharma/talk2biomodels/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"langchain-community==0.3.5",
"langchain-core==0.3.15",
"langchain-experimental==0.3.3",
"langchain-nvidia-ai-endpoints==0.3.9",
"langchain-openai==0.2.5",
"langgraph==0.2.62",
"matplotlib==3.9.2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Annotated
import operator
from langgraph.prebuilt.chat_agent_executor import AgentState
from langchain_core.language_models import BaseChatModel
from langchain_core.embeddings import Embeddings

def add_data(data1: dict, data2: dict) -> dict:
"""
Expand All @@ -26,7 +28,8 @@ class Talk2Biomodels(AgentState):
"""
The state for the Talk2BioModels agent.
"""
llm_model: str
llm_model: BaseChatModel
text_embedding_model: Embeddings
pdf_file_name: str
# A StateGraph may receive a concurrent updates
# which is not supported by the StateGraph. Hence,
Expand Down
6 changes: 4 additions & 2 deletions aiagents4pharma/talk2biomodels/tests/test_ask_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
'''

from langchain_core.messages import HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
from ..agents.t2b_agent import get_app

def test_ask_question_tool():
'''
Test the ask_question tool without the simulation results.
'''
unique_id = 12345
app = get_app(unique_id, llm_model='gpt-4o-mini')
app = get_app(unique_id)
config = {"configurable": {"thread_id": unique_id}}

##########################################
Expand All @@ -20,7 +21,8 @@ def test_ask_question_tool():
# case, the tool should return an error
##########################################
# Update state
app.update_state(config, {"llm_model": "gpt-4o-mini"})
app.update_state(config,
{"llm_model": ChatOpenAI(model='gpt-4o-mini', temperature=0)})
# Define the prompt
prompt = "Call the ask_question tool to answer the "
prompt += "question: What is the concentration of CRP "
Expand Down
6 changes: 4 additions & 2 deletions aiagents4pharma/talk2biomodels/tests/test_get_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import pytest
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
from ..agents.t2b_agent import get_app
from ..tools.get_annotation import prepare_content_msg

Expand All @@ -16,7 +17,9 @@ def make_graph_fixture():
unique_id = random.randint(1000, 9999)
graph = get_app(unique_id)
config = {"configurable": {"thread_id": unique_id}}
graph.update_state(config, {"llm_model": "gpt-4o-mini"})
graph.update_state(config, {"llm_model": ChatOpenAI(model='gpt-4o-mini',
temperature=0)
})
return graph, config

def test_no_model_provided(make_graph):
Expand Down Expand Up @@ -85,7 +88,6 @@ def test_invalid_species_provided(make_graph):
# (likely due to an invalid species).
test_condition = True
break
# assert test_condition
assert test_condition

def test_invalid_and_valid_species_provided(make_graph):
Expand Down
64 changes: 34 additions & 30 deletions aiagents4pharma/talk2biomodels/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@

import pandas as pd
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
from ..agents.t2b_agent import get_app

LLM_MODEL = ChatOpenAI(model='gpt-4o-mini', temperature=0)

def test_integration():
'''
Test the integration of the tools.
'''
unique_id = 1234567
app = get_app(unique_id)
config = {"configurable": {"thread_id": unique_id}}
app.update_state(config, {"llm_model": "gpt-4o-mini"})
app.update_state(config, {"llm_model": LLM_MODEL})
# ##########################################
# ## Test simulate_model tool
# ##########################################
Expand All @@ -34,7 +37,7 @@ def test_integration():
# results are available
##########################################
# Update state
app.update_state(config, {"llm_model": "gpt-4o-mini"})
app.update_state(config, {"llm_model": LLM_MODEL})
prompt = """What is the concentration of CRP in serum after 100 hours?
Round off the value to 2 decimal places."""
# Test the tool get_modelinfo
Expand All @@ -49,12 +52,15 @@ def test_integration():

##########################################
# Test custom_plotter tool when the
# simulation results are available
# simulation results are available but
# the species is not available
##########################################
prompt = "Plot only CRP related species."

prompt = """Call the custom_plotter tool to make a plot
showing only species `TP53` and `Pyruvate`. Let me
know if these species were not found. Do not
invoke any other tool."""
# Update state
app.update_state(config, {"llm_model": "gpt-4o-mini"}
app.update_state(config, {"llm_model": LLM_MODEL}
)
# Test the tool get_modelinfo
response = app.invoke(
Expand All @@ -66,11 +72,8 @@ def test_integration():
# Get the messages from the current state
# and reverse the order
reversed_messages = current_state.values["messages"][::-1]
# Loop through the reversed messages
# until a ToolMessage is found.
expected_header = ['Time', 'CRP{serum}', 'CRPExtracellular']
expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
expected_header += ['CRP{liver}']
# Loop through the reversed messages until a
# ToolMessage is found.
predicted_artifact = []
for msg in reversed_messages:
if isinstance(msg, ToolMessage):
Expand All @@ -80,38 +83,33 @@ def test_integration():
if msg.name == "custom_plotter":
predicted_artifact = msg.artifact
break
# Convert the artifact into a pandas dataframe
# for easy comparison
df = pd.DataFrame(predicted_artifact)
# Extract the headers from the dataframe
predicted_header = df.columns.tolist()
# Check if the header is in the expected_header
# assert expected_header in predicted_artifact
assert set(expected_header).issubset(set(predicted_header))
# Check if the the predicted artifact is `None`
assert predicted_artifact is None

##########################################
# Test custom_plotter tool when the
# simulation results are available but
# the species is not available
# simulation results are available
##########################################
prompt = """Make a custom plot showing the
concentration of the species `TP53` over
time. Do not show any other species."""
prompt = "Plot only CRP related species."

# Update state
app.update_state(config, {"llm_model": "gpt-4o-mini"}
app.update_state(config, {"llm_model": LLM_MODEL}
)
# Test the tool get_modelinfo
response = app.invoke(
{"messages": [HumanMessage(content=prompt)]},
config=config
)
assistant_msg = response["messages"][-1].content
# print (response["messages"])
current_state = app.get_state(config)
# Get the messages from the current state
# and reverse the order
reversed_messages = current_state.values["messages"][::-1]
# Loop through the reversed messages until a
# ToolMessage is found.
# Loop through the reversed messages
# until a ToolMessage is found.
expected_header = ['Time', 'CRP{serum}', 'CRPExtracellular']
expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
expected_header += ['CRP{liver}']
predicted_artifact = []
for msg in reversed_messages:
if isinstance(msg, ToolMessage):
Expand All @@ -121,5 +119,11 @@ def test_integration():
if msg.name == "custom_plotter":
predicted_artifact = msg.artifact
break
# Check if the the predicted artifact is `None`
assert predicted_artifact is None
# Convert the artifact into a pandas dataframe
# for easy comparison
df = pd.DataFrame(predicted_artifact)
# Extract the headers from the dataframe
predicted_header = df.columns.tolist()
# Check if the header is in the expected_header
# assert expected_header in predicted_artifact
assert set(expected_header).issubset(set(predicted_header))
8 changes: 7 additions & 1 deletion aiagents4pharma/talk2biomodels/tests/test_query_article.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, Field
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from ..agents.t2b_agent import get_app

class Article(BaseModel):
Expand All @@ -21,8 +22,10 @@ def test_query_article_with_an_article():
app = get_app(unique_id)
config = {"configurable": {"thread_id": unique_id}}
# Update state by providing the pdf file name
# and the text embedding model
app.update_state(config,
{"pdf_file_name": "aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf"})
{"pdf_file_name": "aiagents4pharma/talk2biomodels/tests/article_on_model_537.pdf",
"text_embedding_model": NVIDIAEmbeddings(model='nvidia/llama-3.2-nv-embedqa-1b-v2')})
prompt = "What is the title of the article?"
# Test the tool query_article
response = app.invoke(
Expand Down Expand Up @@ -55,6 +58,9 @@ def test_query_article_without_an_article():
app = get_app(unique_id)
config = {"configurable": {"thread_id": unique_id}}
prompt = "What is the title of the uploaded article?"
# Update state by providing the text embedding model
app.update_state(config,
{"text_embedding_model": NVIDIAEmbeddings(model='nvidia/llama-3.2-nv-embedqa-1b-v2')})
# Test the tool query_article
app.invoke(
{"messages": [HumanMessage(content=prompt)]},
Expand Down
4 changes: 3 additions & 1 deletion aiagents4pharma/talk2biomodels/tests/test_search_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''

from langchain_core.messages import HumanMessage
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from ..agents.t2b_agent import get_app

def test_search_models_tool():
Expand All @@ -13,7 +14,8 @@ def test_search_models_tool():
app = get_app(unique_id)
config = {"configurable": {"thread_id": unique_id}}
# Update state
app.update_state(config, {"llm_model": "gpt-4o-mini"})
app.update_state(config,
{"llm_model": ChatNVIDIA(model="meta/llama-3.3-70b-instruct")})
prompt = "Search for models on Crohn's disease."
# Test the tool get_modelinfo
response = app.invoke(
Expand Down
9 changes: 6 additions & 3 deletions aiagents4pharma/talk2biomodels/tests/test_steady_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
'''

from langchain_core.messages import HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
from ..agents.t2b_agent import get_app

LLM_MODEL = ChatOpenAI(model='gpt-4o-mini', temperature=0)

def test_steady_state_tool():
'''
Test the steady_state tool.
'''
unique_id = 123
app = get_app(unique_id)
config = {"configurable": {"thread_id": unique_id}}
app.update_state(config, {"llm_model": "gpt-4o-mini"})
app.update_state(config, {"llm_model": LLM_MODEL})
#########################################################
# In this case, we will test if the tool returns an error
# when the model does not achieve a steady state. The tool
Expand All @@ -37,8 +40,8 @@ def test_steady_state_tool():
#########################################################
# In this case, we will test if the tool is indeed invoked
# successfully
prompt = """Run a steady state analysis of model 64.
Set the initial concentration of `Pyruvate` to 0.2. The
prompt = """Bring model 64 to a steady state. Set the
initial concentration of `Pyruvate` to 0.2. The
concentration of `NAD` resets to 100 every 2 time units."""
# Invoke the agent
app.invoke(
Expand Down
3 changes: 1 addition & 2 deletions aiagents4pharma/talk2biomodels/tools/ask_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pydantic import BaseModel, Field
from langchain_core.tools.base import BaseTool
from langchain_experimental.agents import create_pandas_dataframe_agent
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import InjectedState

# Initialize logger
Expand Down Expand Up @@ -101,7 +100,7 @@ def _run(self,
prompt_content += f"{basico.model_info.get_model_units()}\n\n"
# Create a pandas dataframe agent
df_agent = create_pandas_dataframe_agent(
ChatOpenAI(model=state['llm_model']),
state['llm_model'],
allow_dangerous_code=True,
agent_type='tool-calling',
df=df,
Expand Down
Loading