Skip to content

Commit

Permalink
DH-4879/using different models for agent and other calls (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza authored Oct 18, 2023
1 parent 838f0a2 commit 553a0f7
Show file tree
Hide file tree
Showing 12 changed files with 44 additions and 18 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Openai info. All these fields are required for the engine to work.
OPENAI_API_KEY = #This field is required for the engine to work.
ORG_ID =
LLM_MODEL = 'gpt-4-32k' #the openAI llm model that you want to use. possible values: gpt-4-32k, gpt-4, gpt-3.5-turbo, gpt-3.5-turbo-16k
LLM_MODEL = 'gpt-4' #the openAI llm model that you want to use for evaluation and generating the nl answer. possible values: gpt-4, gpt-3.5-turbo
AGENT_LLM_MODEL = 'gpt-4-32k' # the llm model that you want to use for the agent, it should have a lrage context window. possible values: gpt-4-32k, gpt-3.5-turbo-16k

DH_ENGINE_TIMEOUT = #timeout in seconds for the engine to return a response

Expand Down
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ You can also self-host the engine locally using Docker. By default the engine us
cp .env.example .env
```

Specifically the following 4 fields must be manually set before the engine is started.
Specifically the following 5 fields must be manually set before the engine is started.

LLM_MODEL is employed by evaluators and natural language generators that do not necessitate an extensive context window.

AGENT_LLM_MODEL, on the other hand, is utilized by the NL-to-SQL generator, which relies on a larger context window.

```
#OpenAI credentials and model
OPENAI_API_KEY =
LLM_MODEL =
LLM_MODEL =
AGENT_LLM_MODEL =
ORG_ID =
#Encryption key for storing DB connection data in Mongo
Expand Down
5 changes: 4 additions & 1 deletion dataherald/eval/eval_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import re
import time
from difflib import SequenceMatcher
Expand Down Expand Up @@ -248,7 +249,9 @@ def evaluate(
f"Generating score for the question/sql pair: {str(question.question)}/ {str(generated_answer.sql_query)}"
)
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4"),
)
database = SQLDatabase.get_sql_engine(database_connection)
user_question = question.question
Expand Down
5 changes: 4 additions & 1 deletion dataherald/eval/simple_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import re
import time
from typing import Any
Expand Down Expand Up @@ -101,7 +102,9 @@ def evaluate(
}
)
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4"),
)
start_time = time.time()
system_message_prompt = SystemMessagePromptTemplate.from_template(
Expand Down
1 change: 1 addition & 0 deletions dataherald/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def get_model(
self,
database_connection: DatabaseConnection,
model_family="openai",
model_name="gpt-4",
**kwargs: Any
) -> Any:
pass
10 changes: 5 additions & 5 deletions dataherald/model/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
class BaseModel(LLMModel):
def __init__(self, system):
super().__init__(system)
self.model_name = os.environ.get("LLM_MODEL", "text-davinci-003")
self.openai_api_key = os.environ.get("OPENAI_API_KEY")
self.aleph_alpha_api_key = os.environ.get("ALEPH_ALPHA_API_KEY")
self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
Expand All @@ -23,6 +22,7 @@ def get_model(
self,
database_connection: DatabaseConnection,
model_family="openai",
model_name="davinci-003",
**kwargs: Any
) -> Any:
if database_connection.llm_credentials is not None:
Expand All @@ -37,13 +37,13 @@ def get_model(
elif model_family == "google":
self.google_api_key = api_key
if self.openai_api_key:
self.model = OpenAI(model_name=self.model_name, **kwargs)
self.model = OpenAI(model_name=model_name, **kwargs)
elif self.aleph_alpha_api_key:
self.model = AlephAlpha(model=self.model_name, **kwargs)
self.model = AlephAlpha(model=model_name, **kwargs)
elif self.anthropic_api_key:
self.model = Anthropic(model=self.model, **kwargs)
self.model = Anthropic(model=model_name, **kwargs)
elif self.cohere_api_key:
self.model = Cohere(model=self.model, **kwargs)
self.model = Cohere(model=model_name, **kwargs)
else:
raise ValueError("No valid API key environment variable found")
return self.model
10 changes: 5 additions & 5 deletions dataherald/model/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
class ChatModel(LLMModel):
def __init__(self, system):
super().__init__(system)
self.model_name = os.environ.get("LLM_MODEL", "gpt-4-32k")

@override
def get_model(
self,
database_connection: DatabaseConnection,
model_family="openai",
model_name="gpt-4-32k",
**kwargs: Any
) -> Any:
if database_connection.llm_credentials is not None:
Expand All @@ -35,11 +35,11 @@ def get_model(
elif model_family == "cohere":
os.environ["COHERE_API_KEY"] = api_key
if os.environ.get("OPENAI_API_KEY") is not None:
return ChatOpenAI(model_name=self.model_name, **kwargs)
return ChatOpenAI(model_name=model_name, **kwargs)
if os.environ.get("ANTHROPIC_API_KEY") is not None:
return ChatAnthropic(model_name=self.model_name, **kwargs)
return ChatAnthropic(model_name=model_name, **kwargs)
if os.environ.get("GOOGLE_API_KEY") is not None:
return ChatGooglePalm(model_name=self.model_name, **kwargs)
return ChatGooglePalm(model_name=model_name, **kwargs)
if os.environ.get("COHERE_API_KEY") is not None:
return ChatCohere(model_name=self.model_name, **kwargs)
return ChatCohere(model_name=model_name, **kwargs)
raise ValueError("No valid API key environment variable found")
1 change: 1 addition & 0 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def generate_response(
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
)
repository = TableDescriptionRepository(storage)
db_scan = repository.get_all_tables_by_db(
Expand Down
3 changes: 3 additions & 0 deletions dataherald/sql_generator/generates_nl_answer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from langchain.chains import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
Expand Down Expand Up @@ -40,6 +42,7 @@ def execute(self, query_response: Response) -> Response:
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4"),
)
database = SQLDatabase.get_sql_engine(database_connection)
query_response = create_sql_query_status(
Expand Down
5 changes: 4 additions & 1 deletion dataherald/sql_generator/langchain_sqlagent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A wrapper for the SQL generation functions in langchain"""

import logging
import os
import time
from typing import Any, List

Expand Down Expand Up @@ -31,7 +32,9 @@ def generate_response(
) -> Response: # type: ignore
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
)
self.database = SQLDatabase.get_sql_engine(database_connection)
tools = SQLDatabaseToolkit(db=self.database, llm=self.llm).get_tools()
Expand Down
5 changes: 4 additions & 1 deletion dataherald/sql_generator/langchain_sqlchain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A wrapper for the SQL generation functions in langchain"""

import logging
import os
import time
from typing import Any, List

Expand Down Expand Up @@ -49,7 +50,9 @@ def generate_response(
) -> Response:
start_time = time.time()
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
)
self.database = SQLDatabase.get_sql_engine(database_connection)
logger.info(
Expand Down
5 changes: 4 additions & 1 deletion dataherald/sql_generator/llamaindex.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A wrapper for the SQL generation functions in langchain"""

import logging
import os
import time
from typing import Any, List

Expand Down Expand Up @@ -38,7 +39,9 @@ def generate_response(
start_time = time.time()
logger.info(f"Generating SQL response to question: {str(user_question.dict())}")
self.llm = self.model.get_model(
database_connection=database_connection, temperature=0
database_connection=database_connection,
temperature=0,
model_name=os.getenv("AGENT_LLM_MODEL", "gpt-4-32k"),
)
token_counter = TokenCountingHandler(
tokenizer=tiktoken.encoding_for_model(self.llm.model_name).encode,
Expand Down

0 comments on commit 553a0f7

Please sign in to comment.