Skip to content

Commit

Permalink
DATA-2068/modify schema linking to use few-shot samples (#442)
Browse files Browse the repository at this point in the history
* DATA-2068/modify schema linking to use few-shot samples

* DATA-2068/changing the logger method
  • Loading branch information
MohammadrezaPourreza authored Mar 26, 2024
1 parent cf88a1b commit 38a0a31
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
31 changes: 28 additions & 3 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from openai import OpenAI
from overrides import override
from pydantic import BaseModel, Field
from sql_metadata import Parser
from sqlalchemy.exc import SQLAlchemyError

from dataherald.context_store import ContextStore
Expand Down Expand Up @@ -163,6 +164,7 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""
db_scan: List[TableDescription]
embedding: OpenAIEmbeddings
few_shot_examples: List[dict] | None = Field(exclude=True, default=None)

def get_embedding(
self,
Expand All @@ -180,6 +182,18 @@ def get_docs_embedding(
def cosine_similarity(self, a: List[float], b: List[float]) -> float:
return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4)

def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[str]:
most_similar_tables = set()
if self.few_shot_examples is not None:
for example in self.few_shot_examples:
try:
tables = Parser(example["sql"]).tables
except Exception as e:
logger.error(f"Error parsing SQL: {str(e)}")
most_similar_tables.update(tables)
df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True)
return most_similar_tables

@catch_exceptions()
def _run(
self,
Expand Down Expand Up @@ -210,11 +224,17 @@ def _run(
)
df = df.sort_values(by="similarities", ascending=True)
df = df.tail(TOP_TABLES)
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += (
f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n'
)
if len(most_similar_tables) > 0:
for table in most_similar_tables:
table_relevance += (
f"Table: {table}, relevance score: {max(df['similarities'])}\n"
)
return table_relevance

async def _arun(
Expand Down Expand Up @@ -389,6 +409,7 @@ class SQLDatabaseToolkit(BaseToolkit):
model_name: str = Field(exclude=True)
openai_fine_tuning: OpenAIFineTuning = Field(exclude=True)
embedding: OpenAIEmbeddings = Field(exclude=True)
few_shot_examples: List[dict] | None = Field(exclude=True, default=None)

@property
def dialect(self) -> str:
Expand All @@ -408,7 +429,10 @@ def get_tools(self) -> List[BaseTool]:
tools.append(SchemaSQLDatabaseTool(db=self.db, db_scan=self.db_scan))
tools.append(
TablesSQLDatabaseTool(
db=self.db, db_scan=self.db_scan, embedding=self.embedding
db=self.db,
db_scan=self.db_scan,
embedding=self.embedding,
few_shot_examples=self.few_shot_examples,
)
)
tools.append(QuerySQLDataBaseTool(db=self.db))
Expand Down Expand Up @@ -529,8 +553,8 @@ def generate_response(
)
if not db_scan:
raise ValueError("No scanned tables found for database")
_, instructions = context_store.retrieve_context_for_question(
user_prompt, number_of_samples=1
few_shot_examples, instructions = context_store.retrieve_context_for_question(
user_prompt, number_of_samples=5
)
finetunings_repository = FinetuningsRepository(storage)
finetuning = finetunings_repository.find_by_id(self.finetuning_id)
Expand All @@ -545,6 +569,7 @@ def generate_response(
toolkit = SQLDatabaseToolkit(
db=self.database,
instructions=instructions,
few_shot_examples=few_shot_examples,
db_scan=db_scan,
api_key=database_connection.decrypt_api_key(),
finetuning_model_id=finetuning.model_id,
Expand Down
21 changes: 21 additions & 0 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from langchain_openai import OpenAIEmbeddings
from overrides import override
from pydantic import BaseModel, Field
from sql_metadata import Parser
from sqlalchemy import MetaData
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql import func
Expand Down Expand Up @@ -224,6 +225,7 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""
db_scan: List[TableDescription]
embedding: OpenAIEmbeddings
few_shot_examples: List[dict] | None = Field(exclude=True, default=None)

def get_embedding(
self,
Expand All @@ -241,6 +243,18 @@ def get_docs_embedding(
def cosine_similarity(self, a: List[float], b: List[float]) -> float:
return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4)

def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[str]:
most_similar_tables = set()
if self.few_shot_examples is not None:
for example in self.few_shot_examples:
try:
tables = Parser(example["sql"]).tables
except Exception as e:
logger.error(f"Error parsing SQL: {str(e)}")
most_similar_tables.update(tables)
df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True)
return most_similar_tables

@catch_exceptions()
def _run(
self,
Expand Down Expand Up @@ -271,11 +285,17 @@ def _run(
)
df = df.sort_values(by="similarities", ascending=True)
df = df.tail(TOP_TABLES)
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += (
f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n'
)
if len(most_similar_tables) > 0:
for table in most_similar_tables:
table_relevance += (
f"Table: {table}, relevance score: {max(df['similarities'])}\n"
)
return table_relevance

async def _arun(
Expand Down Expand Up @@ -547,6 +567,7 @@ def get_tools(self) -> List[BaseTool]:
context=self.context,
db_scan=self.db_scan,
embedding=self.embedding,
few_shot_examples=self.few_shot_examples,
)
tools.append(tables_sql_db_tool)
schema_sql_db_tool = SchemaSQLDatabaseTool(
Expand Down

0 comments on commit 38a0a31

Please sign in to comment.