Skip to content

Commit

Permalink
fix: add model validator for the tool model_description
Browse files Browse the repository at this point in the history
  • Loading branch information
gurdeep330 committed Nov 8, 2024
1 parent b2dc343 commit df529b6
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions agents/talk2biomodels/tools/model_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Type, Optional
from dataclasses import dataclass
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
import streamlit as st
from langchain_core.tools import BaseTool
from langchain_core.output_parsers import StrOutputParser
Expand All @@ -24,6 +24,22 @@ class ModelData:
sbml_file_path: Optional[str] = None
model_object: Optional[BasicoModel] = None

# Check if model_object is an instance of BasicoModel
# otherwise make it None. This is important because
# sometimes the LLM may invoke the tool with an
# inappropriate model_object.
@model_validator(mode="before")
@classmethod
def check_model_object(cls, data):
"""
Check if model_object is an instance of BasicoModel.
"""
if 'model_object' in data:
if not isinstance(data['model_object'], BasicoModel):
data['model_object'] = None
return data


class ModelDescriptionInput(BaseModel):
"""
Input schema for the ModelDescription tool.
Expand Down Expand Up @@ -61,7 +77,7 @@ def _run(self,
Returns:
str: The answer to the question.
"""
# Check if sys_bio_model is provided
# Check if sys_bio_model is provided in the input schema
if sys_bio_model.modelid or sys_bio_model.sbml_file_path \
or sys_bio_model.model_object not in [None, "", {}]:
if sys_bio_model.modelid:
Expand All @@ -73,17 +89,15 @@ def _run(self,
model_object = sys_bio_model.model_object
if st_session_key:
st.session_state[st_session_key] = model_object
# Check if sys_bio_model is provided in the Streamlit session state
elif st_session_key:
if st_session_key not in st.session_state:
return f"Session key {st_session_key} " \
"not found in Streamlit session state."
model_object = st.session_state[st_session_key]
else:
# If the model_object is not provided,
# get it from the Streamlit session state
if st_session_key:
if st_session_key not in st.session_state:
return f"Session key {st_session_key} " \
"not found in Streamlit session state."
model_object = st.session_state[st_session_key]
else:
return "Please provide a valid model object or Streamlit "\
"session key that contains the model object."
return "Please provide a valid model object or Streamlit "\
"session key that contains the model object."
# check if model_object is None
if model_object is None:
return "Please provide a BioModels ID or an SBML file path for the model."
Expand Down

0 comments on commit df529b6

Please sign in to comment.