Skip to content

Commit

Permalink
DH-5735/add support for multiple schemas for agents
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Apr 25, 2024
1 parent 4eb7a3e commit b24d9c8
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 54 deletions.
1 change: 1 addition & 0 deletions dataherald/api/types/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class PromptRequest(BaseModel):
text: str
db_connection_id: str
schemas: list[str] | None
metadata: dict | None


Expand Down
1 change: 1 addition & 0 deletions dataherald/api/types/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def created_at_as_string(cls, v):
class PromptResponse(BaseResponse):
text: str
db_connection_id: str
schemas: list[str] | None


class SQLGenerationResponse(BaseResponse):
Expand Down
8 changes: 8 additions & 0 deletions dataherald/services/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DatabaseConnectionRepository,
)
from dataherald.repositories.prompts import PromptNotFoundError, PromptRepository
from dataherald.sql_database.services.database_connection import SchemaNotSupportedError
from dataherald.types import Prompt


Expand All @@ -22,9 +23,16 @@ def create(self, prompt_request: PromptRequest) -> Prompt:
f"Database connection {prompt_request.db_connection_id} not found"
)

if not db_connection.schemas and prompt_request.schemas:
raise SchemaNotSupportedError(
"Schema not supported for this db",
description=f"The {db_connection.dialect} dialect doesn't support schemas",
)

prompt = Prompt(
text=prompt_request.text,
db_connection_id=prompt_request.db_connection_id,
schemas=prompt_request.schemas,
metadata=prompt_request.metadata,
)
return self.prompt_repository.insert(prompt)
Expand Down
17 changes: 16 additions & 1 deletion dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, LLMResult
from langchain.schema.messages import BaseMessage
from langchain_community.callbacks import get_openai_callback

from dataherald.config import Component, System
from dataherald.db_scanner.models.types import TableDescription
from dataherald.model.chat_model import ChatModel
from dataherald.repositories.sql_generations import (
SQLGenerationRepository,
Expand Down Expand Up @@ -62,6 +62,21 @@ def remove_markdown(self, query: str) -> str:
return matches[0].strip()
return query

@staticmethod
def get_table_schema(table_name: str, db_scan: List[TableDescription]) -> str:
for table in db_scan:
if table.table_name == table_name:
return table.schema_name
return ""

@staticmethod
def filter_tables_by_schema(
db_scan: List[TableDescription], prompt: Prompt
) -> List[TableDescription]:
if prompt.schemas:
return [table for table in db_scan if table.schema_name in prompt.schemas]
return db_scan

def format_sql_query_intermediate_steps(self, step: str) -> str:
pattern = r"```sql(.*?)```"

Expand Down
66 changes: 48 additions & 18 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,20 @@ def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[s
tables = Parser(example["sql"]).tables
except Exception as e:
logger.error(f"Error parsing SQL: {str(e)}")
most_similar_tables.update(tables)
df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True)
for table in tables:
found_tables = df[df.table_name == table]
for _, row in found_tables.iterrows():
most_similar_tables.add((row["schema_name"], row["table_name"]))
df.drop(
df[
df.table_name.isin([table[1] for table in most_similar_tables])
].index,
inplace=True,
)
return most_similar_tables

@catch_exceptions()
def _run(
def _run( # noqa: PLR0912
self,
user_question: str,
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
Expand All @@ -214,9 +222,12 @@ def _run(
table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}"
else:
table_rep = f"Table {table.table_name} contain columns: [{col_rep}]"
table_representations.append([table.table_name, table_rep])
table_representations.append(
[table.schema_name, table.table_name, table_rep]
)
df = pd.DataFrame(
table_representations, columns=["table_name", "table_representation"]
table_representations,
columns=["schema_name", "table_name", "table_representation"],
)
df["table_embedding"] = self.get_docs_embedding(df.table_representation)
df["similarities"] = df.table_embedding.apply(
Expand All @@ -227,12 +238,20 @@ def _run(
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n'
if row["schema_name"] is not None:
table_name = row["schema_name"] + "." + row["table_name"]
else:
table_name = row["table_name"]
table_relevance += (
f'Table: `{table_name}`, relevance score: {row["similarities"]}\n'
)
if len(most_similar_tables) > 0:
for table in most_similar_tables:
table_relevance += (
f"Table: `{table}`, relevance score: {max(df['similarities'])}\n"
)
if table[0] is not None:
table_name = table[0] + "." + table[1]
else:
table_name = table[1]
table_relevance += f"Table: `{table_name}`, relevance score: {max(df['similarities'])}\n"
return table_relevance

async def _arun(
Expand Down Expand Up @@ -358,27 +377,32 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
db_scan: List[TableDescription]

@catch_exceptions()
def _run(
def _run( # noqa: C901
self,
table_names: str,
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
) -> str:
"""Get the schema for tables in a comma-separated list."""
table_names_list = table_names.split(", ")
table_names_list = [
replace_unprocessable_characters(table_name)
for table_name in table_names_list
]
processed_table_names = []
for table in table_names_list:
formatted_table = replace_unprocessable_characters(table)
if "." in formatted_table:
processed_table_names.append(formatted_table.split(".")[1])
else:
processed_table_names.append(formatted_table)
tables_schema = ""
for table in self.db_scan:
if table.table_name in table_names_list:
if table.table_name in processed_table_names:
tables_schema += "```sql\n"
tables_schema += table.table_schema + "\n"
descriptions = []
if table.description is not None:
descriptions.append(
f"Table `{table.table_name}`: {table.description}\n"
)
if table.schema_name:
table_name = f"{table.schema_name}.{table.table_name}"
else:
table_name = table.table_name
descriptions.append(f"Table `{table_name}`: {table.description}\n")
for column in table.columns:
if column.description is not None:
descriptions.append(
Expand Down Expand Up @@ -555,6 +579,9 @@ def generate_response(
)
if not db_scan:
raise ValueError("No scanned tables found for database")
db_scan = SQLGenerator.filter_tables_by_schema(
db_scan=db_scan, prompt=user_prompt
)
few_shot_examples, instructions = context_store.retrieve_context_for_question(
user_prompt, number_of_samples=5
)
Expand Down Expand Up @@ -658,6 +685,9 @@ def stream_response(
)
if not db_scan:
raise ValueError("No scanned tables found for database")
db_scan = SQLGenerator.filter_tables_by_schema(
db_scan=db_scan, prompt=user_prompt
)
_, instructions = context_store.retrieve_context_for_question(
user_prompt, number_of_samples=1
)
Expand Down
Loading

0 comments on commit b24d9c8

Please sign in to comment.