Skip to content

Commit

Permalink
DH-4851/using gpt-3.5--turbo for nl answer generation
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Oct 16, 2023
1 parent d906876 commit 09b5642
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 14 deletions.
5 changes: 4 additions & 1 deletion dataherald/eval/eval_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import re
import time
from difflib import SequenceMatcher
Expand Down Expand Up @@ -248,7 +249,9 @@ def evaluate(
f"Generating score for the question/sql pair: {str(question.question)}/ {str(generated_answer.sql_query)}"
)
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"),
temperature=0
)
database = SQLDatabase.get_sql_engine(database_connection)
user_question = question.question
Expand Down
5 changes: 4 additions & 1 deletion dataherald/eval/simple_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import re
import time
from typing import Any
Expand Down Expand Up @@ -101,7 +102,9 @@ def evaluate(
}
)
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"),
temperature=0
)
start_time = time.time()
system_message_prompt = SystemMessagePromptTemplate.from_template(
Expand Down
3 changes: 2 additions & 1 deletion dataherald/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(self, system: System):
def get_model(
self,
database_connection: DatabaseConnection,
model_family="openai",
model_family: str ="openai",
model_name: str = "gpt-4-32k",
**kwargs: Any
) -> Any:
pass
12 changes: 6 additions & 6 deletions dataherald/model/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
class BaseModel(LLMModel):
def __init__(self, system):
super().__init__(system)
self.model_name = os.environ.get("LLM_MODEL", "text-davinci-003")
self.openai_api_key = os.environ.get("OPENAI_API_KEY")
self.aleph_alpha_api_key = os.environ.get("ALEPH_ALPHA_API_KEY")
self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
Expand All @@ -22,7 +21,8 @@ def __init__(self, system):
def get_model(
self,
database_connection: DatabaseConnection,
model_family="openai",
model_family: str = "openai",
model_name: str = "davinchi-003",
**kwargs: Any
) -> Any:
if database_connection.llm_credentials is not None:
Expand All @@ -37,13 +37,13 @@ def get_model(
elif model_family == "google":
self.google_api_key = api_key
if self.openai_api_key:
self.model = OpenAI(model_name=self.model_name, **kwargs)
self.model = OpenAI(model_name=model_name, **kwargs)
elif self.aleph_alpha_api_key:
self.model = AlephAlpha(model=self.model_name, **kwargs)
self.model = AlephAlpha(model=model_name **kwargs)
elif self.anthropic_api_key:
self.model = Anthropic(model=self.model, **kwargs)
self.model = Anthropic(model=model_name, **kwargs)
elif self.cohere_api_key:
self.model = Cohere(model=self.model, **kwargs)
self.model = Cohere(model=model_name, **kwargs)
else:
raise ValueError("No valid API key environment variable found")
return self.model
4 changes: 2 additions & 2 deletions dataherald/model/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
class ChatModel(LLMModel):
def __init__(self, system):
super().__init__(system)
self.model_name = os.environ.get("LLM_MODEL", "gpt-4-32k")

@override
def get_model(
self,
database_connection: DatabaseConnection,
model_family="openai",
model_name: str = "gpt-4-32k",
**kwargs: Any
) -> Any:
if database_connection.llm_credentials is not None:
Expand All @@ -35,6 +35,6 @@ def get_model(
elif model_family == "cohere":
os.environ["COHERE_API_KEY"] = api_key
try:
return ChatLiteLLM(model_name=self.model_name, **kwargs)
return ChatLiteLLM(model_name=model_name, **kwargs)
except Exception as e:
raise ValueError("No valid API key environment variable found") from e
1 change: 1 addition & 0 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ def generate_response(
storage = self.system.instance(DB)
self.llm = self.model.get_model(
database_connection=database_connection,
model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"),
temperature=0,
)
repository = TableDescriptionRepository(storage)
Expand Down
3 changes: 3 additions & 0 deletions dataherald/sql_generator/generates_nl_answer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from langchain.chains import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
Expand Down Expand Up @@ -39,6 +41,7 @@ def execute(self, query_response: Response) -> Response:
)
self.llm = self.model.get_model(
database_connection=database_connection,
model_name = os.environ.get("RESPONSE_GENERATOR_LLM_MODEL", "gpt-3.5-turbo-16k"),
temperature=0,
)
database = SQLDatabase.get_sql_engine(database_connection)
Expand Down
5 changes: 4 additions & 1 deletion dataherald/sql_generator/langchain_sqlagent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A wrapper for the SQL generation functions in langchain"""

import logging
import os
import time
from typing import Any, List

Expand Down Expand Up @@ -31,7 +32,9 @@ def generate_response(
) -> Response: # type: ignore
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"),
temperature=0
)
self.database = SQLDatabase.get_sql_engine(database_connection)
tools = SQLDatabaseToolkit(db=self.database, llm=self.llm).get_tools()
Expand Down
5 changes: 4 additions & 1 deletion dataherald/sql_generator/langchain_sqlchain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A wrapper for the SQL generation functions in langchain"""

import logging
import os
import time
from typing import Any, List

Expand Down Expand Up @@ -49,7 +50,9 @@ def generate_response(
) -> Response:
start_time = time.time()
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"),
temperature=0
)
self.database = SQLDatabase.get_sql_engine(database_connection)
logger.info(
Expand Down
5 changes: 4 additions & 1 deletion dataherald/sql_generator/llamaindex.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A wrapper for the SQL generation functions in langchain"""

import logging
import os
import time
from typing import Any, List

Expand Down Expand Up @@ -38,7 +39,9 @@ def generate_response(
start_time = time.time()
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
model_name = os.environ.get("LLM_MODEL", "gpt-4-32k"),
temperature=0
)
token_counter = TokenCountingHandler(
tokenizer=tiktoken.encoding_for_model(self.llm.model_name).encode,
Expand Down

0 comments on commit 09b5642

Please sign in to comment.