diff --git a/dataherald/api/types/requests.py b/dataherald/api/types/requests.py index fcb90348..108c775e 100644 --- a/dataherald/api/types/requests.py +++ b/dataherald/api/types/requests.py @@ -6,6 +6,7 @@ class PromptRequest(BaseModel): text: str db_connection_id: str + schemas: list[str] | None metadata: dict | None diff --git a/dataherald/api/types/responses.py b/dataherald/api/types/responses.py index 7edd7c28..ebacf601 100644 --- a/dataherald/api/types/responses.py +++ b/dataherald/api/types/responses.py @@ -25,6 +25,7 @@ def created_at_as_string(cls, v): class PromptResponse(BaseResponse): text: str db_connection_id: str + schemas: list[str] | None class SQLGenerationResponse(BaseResponse): diff --git a/dataherald/services/prompts.py b/dataherald/services/prompts.py index f42c272b..b7392b4a 100644 --- a/dataherald/services/prompts.py +++ b/dataherald/services/prompts.py @@ -4,6 +4,7 @@ DatabaseConnectionRepository, ) from dataherald.repositories.prompts import PromptNotFoundError, PromptRepository +from dataherald.sql_database.services.database_connection import SchemaNotSupportedError from dataherald.types import Prompt @@ -22,9 +23,16 @@ def create(self, prompt_request: PromptRequest) -> Prompt: f"Database connection {prompt_request.db_connection_id} not found" ) + if not db_connection.schemas and prompt_request.schemas: + raise SchemaNotSupportedError( + "Schema not supported for this db", + description=f"The {db_connection.dialect} dialect doesn't support schemas", + ) + prompt = Prompt( text=prompt_request.text, db_connection_id=prompt_request.db_connection_id, + schemas=prompt_request.schemas, metadata=prompt_request.metadata, ) return self.prompt_repository.insert(prompt) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 16fc94d4..e997920e 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -11,10 +11,10 @@ from langchain.agents.agent import AgentExecutor from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, LLMResult -from langchain.schema.messages import BaseMessage from langchain_community.callbacks import get_openai_callback from dataherald.config import Component, System +from dataherald.db_scanner.models.types import TableDescription from dataherald.model.chat_model import ChatModel from dataherald.repositories.sql_generations import ( SQLGenerationRepository, @@ -62,6 +62,21 @@ def remove_markdown(self, query: str) -> str: return matches[0].strip() return query + @staticmethod + def get_table_schema(table_name: str, db_scan: List[TableDescription]) -> str: + for table in db_scan: + if table.table_name == table_name: + return table.schema_name + return "" + + @staticmethod + def filter_tables_by_schema( + db_scan: List[TableDescription], prompt: Prompt + ) -> List[TableDescription]: + if prompt.schemas: + return [table for table in db_scan if table.schema_name in prompt.schemas] + return db_scan + def format_sql_query_intermediate_steps(self, step: str) -> str: pattern = r"```sql(.*?)```" diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index 5fe0ed4d..5e2b423c 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -190,12 +190,20 @@ def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[s tables = Parser(example["sql"]).tables except Exception as e: logger.error(f"Error parsing SQL: {str(e)}") - most_similar_tables.update(tables) - df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True) + for table in tables: + found_tables = df[df.table_name == table] + for _, row in found_tables.iterrows(): + most_similar_tables.add((row["schema_name"], row["table_name"])) + df.drop( + df[ + df.table_name.isin([table[1] for table in most_similar_tables]) + ].index, + inplace=True, + ) return most_similar_tables @catch_exceptions() - def _run( + def _run( # noqa: PLR0912 self, user_question: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 @@ -214,9 +222,12 @@ def _run( table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}" else: table_rep = f"Table {table.table_name} contain columns: [{col_rep}]" - table_representations.append([table.table_name, table_rep]) + table_representations.append( + [table.schema_name, table.table_name, table_rep] + ) df = pd.DataFrame( - table_representations, columns=["table_name", "table_representation"] + table_representations, + columns=["schema_name", "table_name", "table_representation"], ) df["table_embedding"] = self.get_docs_embedding(df.table_representation) df["similarities"] = df.table_embedding.apply( @@ -227,12 +238,20 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' + if row["schema_name"] is not None: + table_name = row["schema_name"] + "." + row["table_name"] + else: + table_name = row["table_name"] + table_relevance += ( + f'Table: `{table_name}`, relevance score: {row["similarities"]}\n' + ) if len(most_similar_tables) > 0: for table in most_similar_tables: - table_relevance += ( - f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" - ) + if table[0] is not None: + table_name = table[0] + "." + table[1] + else: + table_name = table[1] + table_relevance += f"Table: `{table_name}`, relevance score: {max(df['similarities'])}\n" return table_relevance async def _arun( @@ -358,27 +377,32 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): db_scan: List[TableDescription] @catch_exceptions() - def _run( + def _run( # noqa: C901 self, table_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: """Get the schema for tables in a comma-separated list.""" table_names_list = table_names.split(", ") - table_names_list = [ - replace_unprocessable_characters(table_name) - for table_name in table_names_list - ] + processed_table_names = [] + for table in table_names_list: + formatted_table = replace_unprocessable_characters(table) + if "." in formatted_table: + processed_table_names.append(formatted_table.split(".")[1]) + else: + processed_table_names.append(formatted_table) tables_schema = "" for table in self.db_scan: - if table.table_name in table_names_list: + if table.table_name in processed_table_names: tables_schema += "```sql\n" tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: - descriptions.append( - f"Table `{table.table_name}`: {table.description}\n" - ) + if table.schema_name: + table_name = f"{table.schema_name}.{table.table_name}" + else: + table_name = table.table_name + descriptions.append(f"Table `{table_name}`: {table.description}\n") for column in table.columns: if column.description is not None: descriptions.append( @@ -555,6 +579,9 @@ def generate_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) few_shot_examples, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=5 ) @@ -658,6 +685,9 @@ def stream_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) _, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=1 ) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 899e6e87..b686bf2f 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -10,7 +10,6 @@ import numpy as np import openai import pandas as pd -import sqlalchemy from google.api_core.exceptions import GoogleAPIError from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.base import BaseToolkit @@ -27,9 +26,7 @@ from overrides import override from pydantic import BaseModel, Field from sql_metadata import Parser -from sqlalchemy import MetaData from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.sql import func from dataherald.context_store import ContextStore from dataherald.db import DB @@ -254,12 +251,20 @@ def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[s tables = Parser(example["sql"]).tables except Exception as e: logger.error(f"Error parsing SQL: {str(e)}") - most_similar_tables.update(tables) - df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True) + for table in tables: + found_tables = df[df.table_name == table] + for _, row in found_tables.iterrows(): + most_similar_tables.add((row["schema_name"], row["table_name"])) + df.drop( + df[ + df.table_name.isin([table[1] for table in most_similar_tables]) + ].index, + inplace=True, + ) return most_similar_tables @catch_exceptions() - def _run( + def _run( # noqa: PLR0912 self, user_question: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 @@ -278,9 +283,12 @@ def _run( table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}" else: table_rep = f"Table {table.table_name} contain columns: [{col_rep}]" - table_representations.append([table.table_name, table_rep]) + table_representations.append( + [table.schema_name, table.table_name, table_rep] + ) df = pd.DataFrame( - table_representations, columns=["table_name", "table_representation"] + table_representations, + columns=["schema_name", "table_name", "table_representation"], ) df["table_embedding"] = self.get_docs_embedding(df.table_representation) df["similarities"] = df.table_embedding.apply( @@ -291,12 +299,20 @@ def _run( most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): - table_relevance += f'Table: `{row["table_name"]}`, relevance score: {row["similarities"]}\n' + if row["schema_name"] is not None: + table_name = row["schema_name"] + "." + row["table_name"] + else: + table_name = row["table_name"] + table_relevance += ( + f'Table: `{table_name}`, relevance score: {row["similarities"]}\n' + ) if len(most_similar_tables) > 0: for table in most_similar_tables: - table_relevance += ( - f"Table: `{table}`, relevance score: {max(df['similarities'])}\n" - ) + if table[0] is not None: + table_name = table[0] + "." + table[1] + else: + table_name = table[1] + table_relevance += f"Table: `{table_name}`, relevance score: {max(df['similarities'])}\n" return table_relevance async def _arun( @@ -318,6 +334,8 @@ class ColumnEntityChecker(BaseSQLDatabaseTool, BaseTool): Example Input: table1 -> column2, entity """ + db_scan: List[TableDescription] + is_multiple_schema: bool def find_similar_strings( self, input_list: List[tuple], target_string: str, threshold=0.4 @@ -341,21 +359,25 @@ def _run( try: schema, entity = tool_input.split(",") table_name, column_name = schema.split("->") + table_name = replace_unprocessable_characters(table_name) + column_name = replace_unprocessable_characters(column_name).strip() + if "." not in table_name and self.is_multiple_schema: + raise Exception( + "Table name should be in the format schema_name.table_name" + ) except ValueError: return "Invalid input format, use following format: table_name -> column_name, entity (entity should be a string without ',')" search_pattern = f"%{entity.strip().lower()}%" - meta = MetaData(bind=self.db.engine) - table = sqlalchemy.Table(table_name.strip(), meta, autoload=True) + search_query = f"SELECT DISTINCT {column_name} FROM {table_name} WHERE {column_name} ILIKE :search_pattern" # noqa: S608 try: - search_query = sqlalchemy.select( - [func.distinct(table.c[column_name.strip()])] - ).where(func.lower(table.c[column_name.strip()]).like(search_pattern)) - search_results = self.db.engine.execute(search_query).fetchall() + search_results = self.db.engine.execute( + search_query, {"search_pattern": search_pattern} + ).fetchall() search_results = search_results[:25] except SQLAlchemyError: search_results = [] - distinct_query = sqlalchemy.select( - [func.distinct(table.c[column_name.strip()])] + distinct_query = ( + f"SELECT DISTINCT {column_name} FROM {table_name}" # noqa: S608 ) results = self.db.engine.execute(distinct_query).fetchall() results = self.find_similar_strings(results, entity) @@ -392,26 +414,31 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): db_scan: List[TableDescription] @catch_exceptions() - def _run( + def _run( # noqa: C901 self, table_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: """Get the schema for tables in a comma-separated list.""" table_names_list = table_names.split(", ") - table_names_list = [ - replace_unprocessable_characters(table_name) - for table_name in table_names_list - ] + processed_table_names = [] + for table in table_names_list: + formatted_table = replace_unprocessable_characters(table) + if "." in formatted_table: + processed_table_names.append(formatted_table.split(".")[1]) + else: + processed_table_names.append(formatted_table) tables_schema = "```sql\n" for table in self.db_scan: - if table.table_name in table_names_list: + if table.table_name in processed_table_names: tables_schema += table.table_schema + "\n" descriptions = [] if table.description is not None: - descriptions.append( - f"Table `{table.table_name}`: {table.description}\n" - ) + if table.schema_name: + table_name = f"{table.schema_name}.{table.table_name}" + else: + table_name = table.table_name + descriptions.append(f"Table `{table_name}`: {table.description}\n") for column in table.columns: if column.description is not None: descriptions.append( @@ -446,7 +473,7 @@ class InfoRelevantColumns(BaseSQLDatabaseTool, BaseTool): db_scan: List[TableDescription] @catch_exceptions() - def _run( # noqa: C901 + def _run( # noqa: C901, PLR0912 self, column_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 @@ -457,6 +484,8 @@ def _run( # noqa: C901 for item in items_list: if " -> " in item: table_name, column_name = item.split(" -> ") + if "." in table_name: + table_name = table_name.split(".")[1] table_name = replace_unprocessable_characters(table_name) column_name = replace_unprocessable_characters(column_name) found = False @@ -474,7 +503,11 @@ def _run( # noqa: C901 for row in table.examples: col_info += row[column_name] + ", " col_info = col_info[:-2] - column_full_info += f"Table: {table_name}, column: {column_name}, additional info: {col_info}\n" + if table.schema_name: + schema_table = f"{table.schema_name}.{table.table_name}" + else: + schema_table = table.table_name + column_full_info += f"Table: {schema_table}, column: {column_name}, additional info: {col_info}\n" else: return "Malformed input, input should be in the following format Example Input: table1 -> column1, table1 -> column2, table2 -> column1" # noqa: E501 if not found: @@ -539,6 +572,7 @@ class SQLDatabaseToolkit(BaseToolkit): instructions: List[dict] | None = Field(exclude=True, default=None) db_scan: List[TableDescription] = Field(exclude=True) embedding: OpenAIEmbeddings = Field(exclude=True) + is_multiple_schema: bool = False @property def dialect(self) -> str: @@ -579,7 +613,12 @@ def get_tools(self) -> List[BaseTool]: db=self.db, context=self.context, db_scan=self.db_scan ) tools.append(info_relevant_tool) - column_sample_tool = ColumnEntityChecker(db=self.db, context=self.context) + column_sample_tool = ColumnEntityChecker( + db=self.db, + context=self.context, + db_scan=self.db_scan, + is_multiple_schema=self.is_multiple_schema, + ) tools.append(column_sample_tool) if self.few_shot_examples is not None: get_fewshot_examples_tool = GetFewShotExamples( @@ -700,6 +739,9 @@ def generate_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) few_shot_examples, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=self.max_number_of_examples ) @@ -716,6 +758,7 @@ def generate_response( context=context, few_shot_examples=new_fewshot_examples, instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, db_scan=db_scan, embedding=OpenAIEmbeddings( openai_api_key=database_connection.decrypt_api_key(), @@ -802,6 +845,9 @@ def stream_response( ) if not db_scan: raise ValueError("No scanned tables found for database") + db_scan = SQLGenerator.filter_tables_by_schema( + db_scan=db_scan, prompt=user_prompt + ) few_shot_examples, instructions = context_store.retrieve_context_for_question( user_prompt, number_of_samples=self.max_number_of_examples ) @@ -818,6 +864,7 @@ def stream_response( context=[{}], few_shot_examples=new_fewshot_examples, instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, db_scan=db_scan, embedding=OpenAIEmbeddings( openai_api_key=database_connection.decrypt_api_key(), diff --git a/dataherald/types.py b/dataherald/types.py index 9f74a166..8977fbca 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -178,6 +178,7 @@ class Prompt(BaseModel): id: str | None = None text: str db_connection_id: str + schemas: list[str] | None created_at: datetime = Field(default_factory=datetime.now) metadata: dict | None diff --git a/docs/api.create_prompt.rst b/docs/api.create_prompt.rst index e0bf7982..71867776 100644 --- a/docs/api.create_prompt.rst +++ b/docs/api.create_prompt.rst @@ -17,6 +17,9 @@ Request this ``POST`` endpoint to create a finetuning job:: { "text": "string", "db_connection_id": "string", + "schemas": [ + "string" + ], "metadata": {} } @@ -31,7 +34,10 @@ HTTP 201 code response "metadata": {}, "created_at": "string", "text": "string", - "db_connection_id": "string" + "db_connection_id": "string", + "schemas": [ + "string" + ] } **Request example** diff --git a/docs/api.create_prompt_sql_generation.rst b/docs/api.create_prompt_sql_generation.rst index 3226e7d9..5b486e17 100644 --- a/docs/api.create_prompt_sql_generation.rst +++ b/docs/api.create_prompt_sql_generation.rst @@ -45,6 +45,9 @@ Request this ``POST`` endpoint to create a SQL query and a NL response for a giv "prompt": { "text": "string", "db_connection_id": "string", + "schemas": [ + "string" + ], "metadata": {} } } diff --git a/docs/api.create_prompt_sql_generation_nl_generation.rst b/docs/api.create_prompt_sql_generation_nl_generation.rst index 8b2f62af..7851b9c7 100644 --- a/docs/api.create_prompt_sql_generation_nl_generation.rst +++ b/docs/api.create_prompt_sql_generation_nl_generation.rst @@ -61,6 +61,9 @@ Request this ``POST`` endpoint to create a SQL query and a NL response for a giv "prompt": { "text": "string", "db_connection_id": "string", + "schemas": [ + "string" + ], "metadata": {} } } diff --git a/docs/api.get_prompt.rst b/docs/api.get_prompt.rst index fa02f0be..f86a9043 100644 --- a/docs/api.get_prompt.rst +++ b/docs/api.get_prompt.rst @@ -16,7 +16,10 @@ HTTP 200 code response "metadata": {}, "created_at": "string", "text": "string", - "db_connection_id": "string" + "db_connection_id": "string", + "schemas": [ + "string" + ] } **Request example** @@ -37,4 +40,7 @@ HTTP 200 code response "created_at": "2024-01-03 15:40:27.624000+00:00", "text": "What is the expected total full year sales in 2023?", "db_connection_id": "659435a5453d359345834" + "schemas": [ + "public" + ] } \ No newline at end of file diff --git a/docs/api.list_prompts.rst b/docs/api.list_prompts.rst index 9f434a21..2122f117 100644 --- a/docs/api.list_prompts.rst +++ b/docs/api.list_prompts.rst @@ -29,7 +29,10 @@ HTTP 201 code response "metadata": {}, "created_at": "string", "text": "string", - "db_connection_id": "string" + "db_connection_id": "string", + "schemas": [ + "string" + ] } ]