Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DATA-2068/modify schema linking to use few-shot samples #442

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from openai import OpenAI
from overrides import override
from pydantic import BaseModel, Field
from sql_metadata import Parser
from sqlalchemy.exc import SQLAlchemyError

from dataherald.context_store import ContextStore
Expand Down Expand Up @@ -163,6 +164,7 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""
db_scan: List[TableDescription]
embedding: OpenAIEmbeddings
few_shot_examples: List[dict] | None = Field(exclude=True, default=None)

def get_embedding(
self,
Expand All @@ -180,6 +182,18 @@ def get_docs_embedding(
def cosine_similarity(self, a: List[float], b: List[float]) -> float:
return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4)

def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[str]:
most_similar_tables = set()
if self.few_shot_examples is not None:
for example in self.few_shot_examples:
try:
tables = Parser(example["sql"]).tables
except Exception as e:
logger.error(f"Error parsing SQL: {str(e)}")
most_similar_tables.update(tables)
df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True)
return most_similar_tables

@catch_exceptions()
def _run(
self,
Expand Down Expand Up @@ -210,11 +224,17 @@ def _run(
)
df = df.sort_values(by="similarities", ascending=True)
df = df.tail(TOP_TABLES)
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += (
f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n'
)
if len(most_similar_tables) > 0:
for table in most_similar_tables:
table_relevance += (
f"Table: {table}, relevance score: {max(df['similarities'])}\n"
)
return table_relevance

async def _arun(
Expand Down Expand Up @@ -389,6 +409,7 @@ class SQLDatabaseToolkit(BaseToolkit):
model_name: str = Field(exclude=True)
openai_fine_tuning: OpenAIFineTuning = Field(exclude=True)
embedding: OpenAIEmbeddings = Field(exclude=True)
few_shot_examples: List[dict] | None = Field(exclude=True, default=None)

@property
def dialect(self) -> str:
Expand All @@ -408,7 +429,10 @@ def get_tools(self) -> List[BaseTool]:
tools.append(SchemaSQLDatabaseTool(db=self.db, db_scan=self.db_scan))
tools.append(
TablesSQLDatabaseTool(
db=self.db, db_scan=self.db_scan, embedding=self.embedding
db=self.db,
db_scan=self.db_scan,
embedding=self.embedding,
few_shot_examples=self.few_shot_examples,
)
)
tools.append(QuerySQLDataBaseTool(db=self.db))
Expand Down Expand Up @@ -529,8 +553,8 @@ def generate_response(
)
if not db_scan:
raise ValueError("No scanned tables found for database")
_, instructions = context_store.retrieve_context_for_question(
user_prompt, number_of_samples=1
few_shot_examples, instructions = context_store.retrieve_context_for_question(
user_prompt, number_of_samples=5
)
finetunings_repository = FinetuningsRepository(storage)
finetuning = finetunings_repository.find_by_id(self.finetuning_id)
Expand All @@ -545,6 +569,7 @@ def generate_response(
toolkit = SQLDatabaseToolkit(
db=self.database,
instructions=instructions,
few_shot_examples=few_shot_examples,
db_scan=db_scan,
api_key=database_connection.decrypt_api_key(),
finetuning_model_id=finetuning.model_id,
Expand Down
21 changes: 21 additions & 0 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from langchain_openai import OpenAIEmbeddings
from overrides import override
from pydantic import BaseModel, Field
from sql_metadata import Parser
from sqlalchemy import MetaData
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.sql import func
Expand Down Expand Up @@ -224,6 +225,7 @@ class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""
db_scan: List[TableDescription]
embedding: OpenAIEmbeddings
few_shot_examples: List[dict] | None = Field(exclude=True, default=None)

def get_embedding(
self,
Expand All @@ -241,6 +243,18 @@ def get_docs_embedding(
def cosine_similarity(self, a: List[float], b: List[float]) -> float:
return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4)

def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[str]:
most_similar_tables = set()
if self.few_shot_examples is not None:
for example in self.few_shot_examples:
try:
tables = Parser(example["sql"]).tables
except Exception as e:
logger.error(f"Error parsing SQL: {str(e)}")
most_similar_tables.update(tables)
df.drop(df[df.table_name.isin(most_similar_tables)].index, inplace=True)
return most_similar_tables

@catch_exceptions()
def _run(
self,
Expand Down Expand Up @@ -271,11 +285,17 @@ def _run(
)
df = df.sort_values(by="similarities", ascending=True)
df = df.tail(TOP_TABLES)
most_similar_tables = self.similart_tables_based_on_few_shot_examples(df)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += (
f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n'
)
if len(most_similar_tables) > 0:
for table in most_similar_tables:
table_relevance += (
f"Table: {table}, relevance score: {max(df['similarities'])}\n"
)
return table_relevance

async def _arun(
Expand Down Expand Up @@ -547,6 +567,7 @@ def get_tools(self) -> List[BaseTool]:
context=self.context,
db_scan=self.db_scan,
embedding=self.embedding,
few_shot_examples=self.few_shot_examples,
)
tools.append(tables_sql_db_tool)
schema_sql_db_tool = SchemaSQLDatabaseTool(
Expand Down
Loading