Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DH-5688/fixing the observations code blocks #454

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def stream_agent_steps( # noqa: C901
queue.put(message.content + "\n")
elif "steps" in chunk:
for step in chunk["steps"]:
queue.put(f"**Observation:**\n `{step.observation}`\n")
queue.put(f"\n**Observation:**\n {step.observation}\n")
elif "output" in chunk:
queue.put(f'**Final Answer:**\n {chunk["output"]}')
queue.put(f'\n**Final Answer:**\n {chunk["output"]}')
if "```sql" in chunk["output"]:
response.sql = replace_unprocessable_characters(
self.remove_markdown(chunk["output"])
Expand Down
14 changes: 8 additions & 6 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,11 @@ def _run(
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += (
f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n'
)
table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n'
if len(most_similar_tables) > 0:
for table in most_similar_tables:
table_relevance += (
f"Table: {table}, relevance score: {max(df['similarities'])}\n"
f"Table: `{table}`, relevance score: {max(df['similarities'])}\n"
)
return table_relevance

Expand All @@ -250,9 +248,10 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

name = "SqlDbQuery"
description = """
Input: SQL query.
Input: A SQL query between ```sql and ``` tags.
Output: Result from the database or an error message if the query is incorrect.
Use this tool to execute the SQL query on the database, and return the results.
Add newline after both ```sql and ``` tags.
"""
args_schema: Type[BaseModel] = SQLInput

Expand Down Expand Up @@ -335,7 +334,8 @@ def _run(
{"role": "user", "content": user_prompt},
],
)
return response.choices[0].message.content
returned_sql = response.choices[0].message.content
return f"```sql\n{returned_sql}```"

async def _arun(
self,
Expand Down Expand Up @@ -372,6 +372,7 @@ def _run(
tables_schema = ""
for table in self.db_scan:
if table.table_name in table_names_list:
tables_schema += "```sql\n"
tables_schema += table.table_schema + "\n"
descriptions = []
if table.description is not None:
Expand All @@ -385,6 +386,7 @@ def _run(
)
if len(descriptions) > 0:
tables_schema += f"/*\n{''.join(descriptions)}*/\n"
tables_schema += "```\n"
if tables_schema == "":
tables_schema += "Tables not found in the database"
return tables_schema
Expand Down
21 changes: 10 additions & 11 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,11 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):

name = "SqlDbQuery"
description = """
Input: SQL query.
Input: A SQL query between ```sql and ``` tags.
Output: Result from the database or an error message if the query is incorrect.
If an error occurs, rewrite the query and retry.
Use this tool to execute SQL queries.
Add newline after both ```sql and ``` tags.
"""

@catch_exceptions()
Expand Down Expand Up @@ -204,8 +205,8 @@ def _run(
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
) -> str:
response = "Admin: All of the generated SQL queries must follow the below instructions:\n"
for instruction in self.instructions:
response += f"{instruction['instruction']}\n"
for index, instruction in enumerate(self.instructions):
response += f"{index + 1}) {instruction['instruction']}\n"
return response

async def _arun(
Expand Down Expand Up @@ -290,13 +291,11 @@ def _run(
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += (
f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n'
)
table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n'
if len(most_similar_tables) > 0:
for table in most_similar_tables:
table_relevance += (
f"Table: {table}, relevance score: {max(df['similarities'])}\n"
f"Table: `{table}`, relevance score: {max(df['similarities'])}\n"
)
return table_relevance

Expand Down Expand Up @@ -404,7 +403,7 @@ def _run(
replace_unprocessable_characters(table_name)
for table_name in table_names_list
]
tables_schema = ""
tables_schema = "```sql\n"
for table in self.db_scan:
if table.table_name in table_names_list:
tables_schema += table.table_schema + "\n"
Expand All @@ -420,6 +419,7 @@ def _run(
)
if len(descriptions) > 0:
tables_schema += f"/*\n{''.join(descriptions)}*/\n"
tables_schema += "```\n"
if tables_schema == "":
tables_schema += "Tables not found in the database"
return tables_schema
Expand Down Expand Up @@ -516,9 +516,8 @@ def _run(
return "Action input for the fewshot_examples_retriever tool should be an integer"
returned_output = ""
for example in self.few_shot_examples[:number_of_samples]:
returned_output += (
f"Question: {example['prompt_text']} -> SQL: {example['sql']}\n"
)
returned_output += f"Question: {example['prompt_text']} \n"
returned_output += f"```sql\n{example['sql']}\n```\n"
if returned_output == "":
returned_output = "No previously asked Question/SQL pairs are available"
return returned_output
Expand Down
Loading