From 785145605b2ba1e49b6e3c38b6684b2c083ce797 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 24 Oct 2023 17:08:36 -0400 Subject: [PATCH] DH-4902/adding_top_k_as_env_vars --- .env.example | 1 + dataherald/sql_generator/dataherald_sqlagent.py | 2 +- dataherald/sql_generator/generates_nl_answer.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index aa3f8ba7..105ab140 100644 --- a/.env.example +++ b/.env.example @@ -3,6 +3,7 @@ OPENAI_API_KEY = #This field is required for the engine to work. ORG_ID = LLM_MODEL = 'gpt-4' #the openAI llm model that you want to use for evaluation and generating the nl answer. possible values: gpt-4, gpt-3.5-turbo AGENT_LLM_MODEL = 'gpt-4-32k' # the llm model that you want to use for the agent, it should have a lrage context window. possible values: gpt-4-32k, gpt-3.5-turbo-16k +TOP_K = #top k results to be returned to the agent. This values is used to limit the number of rows returned to the LLMs DH_ENGINE_TIMEOUT = #timeout in seconds for the engine to return a response diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 7d6a6cac..9bbdf75c 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) -TOP_K = 50 +TOP_K = os.environ.get("TOP_K", 50) def catch_exceptions(): # noqa: C901 diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index 8fd503a8..2a04b251 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -46,7 +46,10 @@ def execute(self, query_response: Response) -> Response: ) database = SQLDatabase.get_sql_engine(database_connection) query_response = create_sql_query_status( - database, query_response.sql_query, query_response, top_k=50 + database, + query_response.sql_query, + query_response, + top_k=os.environ.get("TOP_K", 50), ) system_message_prompt = SystemMessagePromptTemplate.from_template( SYSTEM_TEMPLATE