Skip to content

Commit

Permalink
DH-5465/fixing the table names (#402)
Browse files Browse the repository at this point in the history
* DH-5465/fixing the table names

* DH-5465/refromat with black
  • Loading branch information
MohammadrezaPourreza authored Feb 20, 2024
1 parent b7e9900 commit cb4bc30
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 11 deletions.
2 changes: 1 addition & 1 deletion dataherald/finetuning/openai_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def format_dataset(
db_scan: List[TableDescription],
prompt: str,
token_limit: int,
correct_tables: [str] = None,
correct_tables: [str] = None, # type: ignore
) -> str:
schema_of_database = ""
indexes_to_remove = []
Expand Down
76 changes: 66 additions & 10 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from functools import wraps
from typing import Any, Callable, Dict, List, Type

import numpy as np
import openai
import pandas as pd
from google.api_core.exceptions import GoogleAPIError
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.base import BaseToolkit
Expand All @@ -18,6 +20,7 @@
from langchain.schema import AgentAction
from langchain.tools.base import BaseTool
from langchain_community.callbacks import get_openai_callback
from langchain_openai import OpenAIEmbeddings
from openai import OpenAI
from overrides import override
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -48,6 +51,8 @@


TOP_K = SQLGenerator.get_upper_bound_limit()
EMBEDDING_MODEL = "text-embedding-3-large"
TOP_TABLES = 10


class FinetuningNotAvailableError(Exception):
Expand Down Expand Up @@ -144,29 +149,71 @@ async def _arun(
class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool which takes in the given question and returns a list of tables with their relevance score to the question"""

name = "GetDbTableNames"
name = "DbTablesWithRelevanceScores"
description = """
Input: None.
Output: List of tables in the database.
Use this tool to get the list of tables in the database.
Input: Given question.
Output: Comma-separated list of tables with their relevance scores, indicating their relevance to the question.
Use this tool to identify the relevant tables for the given question.
"""
db_scan: List[TableDescription]
embedding: OpenAIEmbeddings

def get_embedding(
self,
text: str,
) -> List[float]:
text = text.replace("\n", " ")
return self.embedding.embed_query(text)

def get_docs_embedding(
self,
docs: List[str],
) -> List[List[float]]:
return self.embedding.embed_documents(docs)

def cosine_similarity(self, a: List[float], b: List[float]) -> float:
return round(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)), 4)

@catch_exceptions()
def _run(
self,
input: str, # noqa: ARG002
user_question: str,
run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002
) -> str:
"""Use the concatenation of table name, columns names, and the description of the table as the table representation"""
tables = []
question_embedding = self.get_embedding(user_question)
table_representations = []
for table in self.db_scan:
tables.append(table.table_name)
return f"Tables in the database: {','.join(tables)}"
col_rep = ""
for column in table.columns:
if column.description is not None:
col_rep += f"{column.name}: {column.description}, "
else:
col_rep += f"{column.name}, "
if table.description is not None:
table_rep = f"Table {table.table_name} contain columns: [{col_rep}], this tables has: {table.description}"
else:
table_rep = f"Table {table.table_name} contain columns: [{col_rep}]"
table_representations.append([table.table_name, table_rep])
df = pd.DataFrame(
table_representations, columns=["table_name", "table_representation"]
)
df["table_embedding"] = self.get_docs_embedding(df.table_representation)
df["similarities"] = df.table_embedding.apply(
lambda x: self.cosine_similarity(x, question_embedding)
)
df = df.sort_values(by="similarities", ascending=True)
df = df.tail(TOP_TABLES)
table_relevance = ""
for _, row in df.iterrows():
table_relevance += (
f'Table: {row["table_name"]}, relevance score: {row["similarities"]}\n'
)
return table_relevance

async def _arun(
self,
input: str = "",
user_question: str = "",
run_manager: AsyncCallbackManagerForToolRun | None = None,
) -> str:
raise NotImplementedError("TablesSQLDatabaseTool does not support async")
Expand Down Expand Up @@ -308,6 +355,7 @@ class SQLDatabaseToolkit(BaseToolkit):
use_finetuned_model_only: bool = Field(exclude=True, default=None)
model_name: str = Field(exclude=True)
openai_fine_tuning: OpenAIFineTuning = Field(exclude=True)
embedding: OpenAIEmbeddings = Field(exclude=True)

@property
def dialect(self) -> str:
Expand All @@ -325,7 +373,11 @@ def get_tools(self) -> List[BaseTool]:
if not self.use_finetuned_model_only:
tools.append(SystemTime(db=self.db))
tools.append(SchemaSQLDatabaseTool(db=self.db, db_scan=self.db_scan))
tools.append(TablesSQLDatabaseTool(db=self.db, db_scan=self.db_scan))
tools.append(
TablesSQLDatabaseTool(
db=self.db, db_scan=self.db_scan, embedding=self.embedding
)
)
tools.append(QuerySQLDataBaseTool(db=self.db))
tools.append(
GenerateSQL(
Expand Down Expand Up @@ -464,6 +516,10 @@ def generate_response(
use_finetuned_model_only=self.use_fintuned_model_only,
model_name=finetuning.base_llm.model_name,
openai_fine_tuning=openai_fine_tuning,
embedding=OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
)
agent_executor = self.create_sql_agent(
toolkit=toolkit,
Expand Down

0 comments on commit cb4bc30

Please sign in to comment.