Skip to content

Commit

Permalink
DH-5011/updating engine to work with gpt-4-turbo
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Nov 16, 2023
1 parent c70a26c commit 16a4a63
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 30 deletions.
3 changes: 1 addition & 2 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Openai info. All these fields are required for the engine to work.
OPENAI_API_KEY = #This field is required for the engine to work.
ORG_ID =
LLM_MODEL = 'gpt-4' #the openAI llm model that you want to use for evaluation and generating the nl answer. possible values: gpt-4, gpt-3.5-turbo
AGENT_LLM_MODEL = 'gpt-4-32k' # the llm model that you want to use for the agent, it should have a lrage context window. possible values: gpt-4-32k, gpt-3.5-turbo-16k
LLM_MODEL = 'gpt-4-1106-preview' #the openAI llm model that you want to use. possible values: gpt-4-1106-preview.

DH_ENGINE_TIMEOUT = #timeout in seconds for the engine to return a response
UPPER_LIMIT_QUERY_RETURN_ROWS = #The upper limit on number of rows returned from the query engine (equivalent to using LIMIT N in PostgreSQL/MySQL/SQlite). Defauls to 50
Expand Down
5 changes: 1 addition & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,12 @@ cp .env.example .env

Specifically the following 5 fields must be manually set before the engine is started.

LLM_MODEL is employed by evaluators and natural language generators that do not necessitate an extensive context window.

AGENT_LLM_MODEL, on the other hand, is utilized by the NL-to-SQL generator, which relies on a larger context window.
LLM_MODEL is employed by the engine to generate SQL from natural language. You can use the default model (gpt-4-1106-preview) or use your own.

```
#OpenAI credentials and model
OPENAI_API_KEY =
LLM_MODEL =
AGENT_LLM_MODEL =
ORG_ID =
#Encryption key for storing DB connection data in Mongo
Expand Down
25 changes: 4 additions & 21 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import openai
import pandas as pd
import sqlalchemy
import tiktoken
from bson.objectid import ObjectId
from google.api_core.exceptions import GoogleAPIError
from langchain.agents.agent import AgentExecutor
Expand Down Expand Up @@ -41,7 +40,6 @@
DatabaseConnection,
)
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator
from dataherald.sql_generator.adaptive_agent_executor import AdaptiveAgentExecutor
from dataherald.types import Question, Response
from dataherald.utils.agent_prompts import (
AGENT_PREFIX,
Expand All @@ -53,7 +51,6 @@
SUFFIX_WITH_FEW_SHOT_SAMPLES,
SUFFIX_WITHOUT_FEW_SHOT_SAMPLES,
)
from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -585,24 +582,15 @@ def create_sql_agent(
input_variables=input_variables,
)
llm_chain = LLMChain(
llm=self.short_context_llm,
llm=self.llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AdaptiveAgentExecutor.from_agent_and_tools(
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
llm_list={
"short_context_llm": self.short_context_llm,
"long_context_llm": self.long_context_llm,
},
switch_to_larger_model_threshold=OPENAI_CONTEXT_WIDNOW_SIZES[
self.short_context_llm.model_name
]
- 500,
encoding=tiktoken.encoding_for_model(self.short_context_llm.model_name),
callback_manager=callback_manager,
verbose=verbose,
max_iterations=max_iterations,
Expand All @@ -622,15 +610,10 @@ def generate_response(
start_time = time.time()
context_store = self.system.instance(ContextStore)
storage = self.system.instance(DB)
self.short_context_llm = self.model.get_model(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4"),
)
self.long_context_llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
)
repository = TableDescriptionRepository(storage)
db_scan = repository.get_all_tables_by_db(
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/langchain_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def generate_response(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
)
self.database = SQLDatabase.get_sql_engine(database_connection)
tools = SQLDatabaseToolkit(db=self.database, llm=self.llm).get_tools()
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/langchain_sqlchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def generate_response(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
)
self.database = SQLDatabase.get_sql_engine(database_connection)
logger.info(
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def generate_response(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
model_name=os.getenv("LLM_MODEL", "gpt-4-1106-preview"),
)
token_counter = TokenCountingHandler(
tokenizer=tiktoken.encoding_for_model(self.llm.model_name).encode,
Expand Down

0 comments on commit 16a4a63

Please sign in to comment.