From 38a0a31958f1289833bc87f0a2025e8d3077ed4d Mon Sep 17 00:00:00 2001 From: Mohammadreza Pourreza <71866535+MohammadrezaPourreza@users.noreply.github.com> Date: Tue, 26 Mar 2024 12:26:17 -0400 Subject: [PATCH] DATA-2068/modify schema linking to use few-shot samples (#442) * DATA-2068/modify schema linking to use few-shot samples * DATA-2068/changing the logger method --- .../dataherald_finetuning_agent.py | 31 +++++++++++++++++-- .../sql_generator/dataherald_sqlagent.py | 21 +++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index b3f00160..e0ff4367 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -25,6 +25,7 @@ from openai import OpenAI from overrides import override from pydantic import BaseModel, Field +from sql_metadata import Parser from sqlalchemy.exc import SQLAlchemyError from dataherald.context_store import ContextStore @@ -163,6 +164,7 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): """ db_scan: List[TableDescription] embedding: OpenAIEmbeddings + few_shot_examples: List[dict] | None = Field(exclude=True, default=None) def get_embedding( self, @@ -180,6 +182,18 @@ def get_docs_embedding( def cosine_similarity(self, a: List[float], b: List[float]) -> float: return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4) + def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[str]: + most_similar_tables = set() + if self.few_shot_examples is not None: + for example in self.few_shot_examples: + try: + tables = Parser(example["sql"]).tables + except Exception as e: + logger.error(f"Error parsing SQL: {str(e)}") + most_similar_tables.update(tables) + df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True) + return most_similar_tables + @catch_exceptions() def _run( self, @@ -210,11 +224,17 @@ def _run( ) df = df.sort_values(by="similarities", ascending=True) df = df.tail(TOP_TABLES) + most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): table_relevance += ( f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n' ) + if len(most_similar_tables) > 0: + for table in most_similar_tables: + table_relevance += ( + f"Table: {table}, relevance score: {max(df['similarities'])}\n" + ) return table_relevance async def _arun( @@ -389,6 +409,7 @@ class SQLDatabaseToolkit(BaseToolkit): model_name: str = Field(exclude=True) openai_fine_tuning: OpenAIFineTuning = Field(exclude=True) embedding: OpenAIEmbeddings = Field(exclude=True) + few_shot_examples: List[dict] | None = Field(exclude=True, default=None) @property def dialect(self) -> str: @@ -408,7 +429,10 @@ def get_tools(self) -> List[BaseTool]: tools.append(SchemaSQLDatabaseTool(db=self.db, db_scan=self.db_scan)) tools.append( TablesSQLDatabaseTool( - db=self.db, db_scan=self.db_scan, embedding=self.embedding + db=self.db, + db_scan=self.db_scan, + embedding=self.embedding, + few_shot_examples=self.few_shot_examples, ) ) tools.append(QuerySQLDataBaseTool(db=self.db)) @@ -529,8 +553,8 @@ def generate_response( ) if not db_scan: raise ValueError("No scanned tables found for database") - _, instructions = context_store.retrieve_context_for_question( - user_prompt, number_of_samples=1 + few_shot_examples, instructions = context_store.retrieve_context_for_question( + user_prompt, number_of_samples=5 ) finetunings_repository = FinetuningsRepository(storage) finetuning = finetunings_repository.find_by_id(self.finetuning_id) @@ -545,6 +569,7 @@ def generate_response( toolkit = SQLDatabaseToolkit( db=self.database, instructions=instructions, + few_shot_examples=few_shot_examples, db_scan=db_scan, api_key=database_connection.decrypt_api_key(), finetuning_model_id=finetuning.model_id, diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 0bf499c4..7d634138 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -26,6 +26,7 @@ from langchain_openai import OpenAIEmbeddings from overrides import override from pydantic import BaseModel, Field +from sql_metadata import Parser from sqlalchemy import MetaData from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.sql import func @@ -224,6 +225,7 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): """ db_scan: List[TableDescription] embedding: OpenAIEmbeddings + few_shot_examples: List[dict] | None = Field(exclude=True, default=None) def get_embedding( self, @@ -241,6 +243,18 @@ def get_docs_embedding( def cosine_similarity(self, a: List[float], b: List[float]) -> float: return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4) + def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[str]: + most_similar_tables = set() + if self.few_shot_examples is not None: + for example in self.few_shot_examples: + try: + tables = Parser(example["sql"]).tables + except Exception as e: + logger.error(f"Error parsing SQL: {str(e)}") + most_similar_tables.update(tables) + df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True) + return most_similar_tables + @catch_exceptions() def _run( self, @@ -271,11 +285,17 @@ def _run( ) df = df.sort_values(by="similarities", ascending=True) df = df.tail(TOP_TABLES) + most_similar_tables = self.similart_tables_based_on_few_shot_examples(df) table_relevance = "" for _, row in df.iterrows(): table_relevance += ( f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n' ) + if len(most_similar_tables) > 0: + for table in most_similar_tables: + table_relevance += ( + f"Table: {table}, relevance score: {max(df['similarities'])}\n" + ) return table_relevance async def _arun( @@ -547,6 +567,7 @@ def get_tools(self) -> List[BaseTool]: context=self.context, db_scan=self.db_scan, embedding=self.embedding, + few_shot_examples=self.few_shot_examples, ) tools.append(tables_sql_db_tool) schema_sql_db_tool = SchemaSQLDatabaseTool(