diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 298fc3a3..76ca9e97 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -175,6 +175,8 @@ def answer_question_with_timeout( question=question_request.question, db_connection_id=question_request.db_connection_id, ) + question_repository = QuestionRepository(self.storage) + user_question = question_repository.insert(user_question) stop_event = threading.Event() def run_and_catch_exceptions(): diff --git a/dataherald/model/chat_model.py b/dataherald/model/chat_model.py index 5d637f68..dd78fe86 100644 --- a/dataherald/model/chat_model.py +++ b/dataherald/model/chat_model.py @@ -1,7 +1,7 @@ import os from typing import Any -from langchain.chat_models import ChatLiteLLM +from langchain.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm, ChatOpenAI from overrides import override from dataherald.model import LLMModel @@ -34,7 +34,12 @@ def get_model( os.environ["GOOGLE_API_KEY"] = api_key elif model_family == "cohere": os.environ["COHERE_API_KEY"] = api_key - try: - return ChatLiteLLM(model_name=self.model_name, **kwargs) - except Exception as e: - raise ValueError("No valid API key environment variable found") from e + if os.environ.get("OPENAI_API_KEY") is not None: + return ChatOpenAI(model_name=self.model_name, **kwargs) + if os.environ.get("ANTHROPIC_API_KEY") is not None: + return ChatAnthropic(model_name=self.model_name, **kwargs) + if os.environ.get("GOOGLE_API_KEY") is not None: + return ChatGooglePalm(model_name=self.model_name, **kwargs) + if os.environ.get("COHERE_API_KEY") is not None: + return ChatCohere(model_name=self.model_name, **kwargs) + raise ValueError("No valid API key environment variable found") diff --git a/requirements.txt b/requirements.txt index 55472b57..efc79b52 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,4 +34,3 @@ sphinx-book-theme==1.0.1 boto3==1.28.38 botocore==1.31.38 PyAthena==3.0.6 -litellm==0.7.3