Skip to content

Commit

Permalink
DH-5080/refactor_engine_with_new_resources (#277)
Browse files Browse the repository at this point in the history
* DH-5080/refactor_engine_with_new_resources

* DH-5080/update the tests

* DH-5080/update the testst

* DH-5080/updating the evaluatord
  • Loading branch information
MohammadrezaPourreza authored Dec 18, 2023
1 parent 6eb2116 commit 3418756
Show file tree
Hide file tree
Showing 15 changed files with 240 additions and 297 deletions.
4 changes: 2 additions & 2 deletions dataherald/context_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataherald.config import Component, System
from dataherald.db import DB
from dataherald.types import GoldenRecord, GoldenRecordRequest, Question
from dataherald.types import GoldenRecord, GoldenRecordRequest, Prompt
from dataherald.vector_store import VectorStore


Expand All @@ -24,7 +24,7 @@ def __init__(self, system: System):

@abstractmethod
def retrieve_context_for_question(
self, nl_question: Question, number_of_samples: int = 3
self, prompt: Prompt, number_of_samples: int = 3
) -> Tuple[List[dict] | None, List[dict] | None]:
pass

Expand Down
12 changes: 6 additions & 6 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataherald.context_store import ContextStore
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.types import GoldenRecord, GoldenRecordRequest, Question
from dataherald.types import GoldenRecord, GoldenRecordRequest, Prompt

logger = logging.getLogger(__name__)

Expand All @@ -19,12 +19,12 @@ def __init__(self, system: System):

@override
def retrieve_context_for_question(
self, nl_question: Question, number_of_samples: int = 3
self, prompt: Prompt, number_of_samples: int = 3
) -> Tuple[List[dict] | None, List[dict] | None]:
logger.info(f"Getting context for {nl_question.question}")
logger.info(f"Getting context for {prompt.text}")
closest_questions = self.vector_store.query(
query_texts=[nl_question.question],
db_connection_id=nl_question.db_connection_id,
query_texts=[prompt.text],
db_connection_id=prompt.db_connection_id,
collection=self.golden_record_collection,
num_results=number_of_samples,
)
Expand All @@ -47,7 +47,7 @@ def retrieve_context_for_question(
instruction_repository = InstructionRepository(self.db)
all_instructions = instruction_repository.find_all()
for instruction in all_instructions:
if instruction.db_connection_id == nl_question.db_connection_id:
if instruction.db_connection_id == prompt.db_connection_id:
instructions.append(
{
"instruction": instruction.instruction,
Expand Down
14 changes: 7 additions & 7 deletions dataherald/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataherald.model.chat_model import ChatModel
from dataherald.sql_database.base import SQLDatabase
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.types import Question, Response
from dataherald.types import Prompt, SQLGeneration


class Evaluation(BaseModel):
Expand All @@ -27,23 +27,23 @@ def __init__(self, system: System):

def get_confidence_score(
self,
question: Question,
generated_answer: Response,
user_prompt: Prompt,
sql_generation: SQLGeneration,
database_connection: DatabaseConnection,
) -> confloat:
"""Determines if a generated response from the engine is acceptable considering the ACCEPTANCE_THRESHOLD"""
evaluation = self.evaluate(
question=question,
generated_answer=generated_answer,
user_prompt=user_prompt,
sql_generation=sql_generation,
database_connection=database_connection,
)
return evaluation.score

@abstractmethod
def evaluate(
self,
question: Question,
generated_answer: Response,
user_prompt: Prompt,
sql_generation: SQLGeneration,
database_connection: DatabaseConnection,
) -> Evaluation:
"""Evaluates a question with an SQL pair."""
14 changes: 7 additions & 7 deletions dataherald/eval/eval_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dataherald.eval import Evaluation, Evaluator
from dataherald.sql_database.base import SQLDatabase
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.types import Question, Response
from dataherald.types import Prompt, SQLGeneration

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -240,22 +240,22 @@ def create_evaluation_agent(
@override
def evaluate(
self,
question: Question,
generated_answer: Response,
user_prompt: Prompt,
sql_generation: SQLGeneration,
database_connection: DatabaseConnection,
) -> Evaluation:
start_time = time.time()
logger.info(
f"Generating score for the question/sql pair: {str(question.question)}/ {str(generated_answer.sql_query)}"
f"Generating score for the question/sql pair: {str(user_prompt.text)}/ {str(sql_generation.sql)}"
)
self.llm = self.model.get_model(
database_connection=database_connection,
temperature=0,
model_name=os.getenv("LLM_MODEL", "gpt-4"),
)
database = SQLDatabase.get_sql_engine(database_connection)
user_question = question.question
sql = generated_answer.sql_query
user_question = user_prompt.text
sql = sql_generation.sql
database._sample_rows_in_table_info = self.sample_rows
toolkit = SQLEvaluationToolkit(db=database)
agent_executor = self.create_evaluation_agent(
Expand All @@ -269,5 +269,5 @@ def evaluate(
end_time = time.time()
logger.info(f"Evaluation time elapsed: {str(end_time - start_time)}")
return Evaluation(
question_id=question.id, answer_id=generated_answer.id, score=score
question_id=user_prompt.id, answer_id=sql_generation.id, score=score
)
52 changes: 41 additions & 11 deletions dataherald/eval/simple_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import re
import time
from datetime import date, datetime
from decimal import Decimal
from typing import Any

from bson.objectid import ObjectId
Expand All @@ -13,15 +15,16 @@
)
from overrides import override
from sql_metadata import Parser
from sqlalchemy import text

from dataherald.config import System
from dataherald.db import DB
from dataherald.db_scanner.models.types import TableDescriptionStatus
from dataherald.db_scanner.repository.base import TableDescriptionRepository
from dataherald.eval import Evaluation, Evaluator
from dataherald.sql_database.base import SQLDatabase
from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.types import Question, Response
from dataherald.types import Prompt, SQLGeneration

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,6 +58,7 @@
SQL Query Result: {SQL_result}
give me a one or two lines explanation and the score after 'Score: '.
"""
TOP_K = 100


class SimpleEvaluator(Evaluator):
Expand Down Expand Up @@ -85,13 +89,13 @@ def answer_parser(self, answer: str) -> int:
@override
def evaluate(
self,
question: Question,
generated_answer: Response,
user_prompt: Prompt,
sql_generation: SQLGeneration,
database_connection: DatabaseConnection,
) -> Evaluation:
database = SQLDatabase.get_sql_engine(database_connection)
logger.info(
f"(Simple evaluator) Generating score for the question/sql pair: {str(question.question)}/ {str(generated_answer.sql_query)}"
f"(Simple evaluator) Generating score for the question/sql pair: {str(user_prompt.text)}/ {str(sql_generation.sql)}"
)
storage = self.system.instance(DB)
repository = TableDescriptionRepository(storage)
Expand All @@ -114,28 +118,54 @@ def evaluate(
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)
user_question = question.question
sql = generated_answer.sql_query
user_question = user_prompt.text
sql = sql_generation.sql
dialect = database.dialect
tables = Parser(sql).tables
schema = ""
for scanned_table in db_scan:
if scanned_table.table_name in tables:
schema += f"Table: {scanned_table.table_schema}\n"
if generated_answer.sql_query_result is None:
if sql_generation.status == "INVALID":
logger.info(
f"(Simple evaluator) SQL query: {sql} is not valid. Returning score 0"
)
return Evaluation(
question_id=question.id, answer_id=generated_answer.id, score=0
question_id=user_prompt.id, answer_id=sql_generation.id, score=0
)
chain = LLMChain(llm=self.llm, prompt=chat_prompt)
try:
query = database.parser_to_filter_commands(sql_generation.sql)
with database._engine.connect() as connection:
execution = connection.execute(text(query))
result = execution.fetchmany(TOP_K)
rows = []
for row in result:
modified_row = {}
for key, value in zip(row.keys(), row, strict=True):
if type(value) in [
date,
datetime,
]: # Check if the value is an instance of datetime.date
modified_row[key] = str(value)
elif (
type(value) is Decimal
): # Check if the value is an instance of decimal.Decimal
modified_row[key] = float(value)
else:
modified_row[key] = value
rows.append(modified_row)

except SQLInjectionError as e:
raise SQLInjectionError(
"Sensitive SQL keyword detected in the query."
) from e
answer = chain.run(
{
"dialect": dialect,
"question": user_question,
"SQL": sql,
"SQL_result": str(generated_answer.sql_query_result.json()),
"SQL_result": "\n".join([str(row) for row in rows]),
"schema": schema,
}
)
Expand All @@ -145,5 +175,5 @@ def evaluate(
end_time = time.time()
logger.info(f"Evaluation time elapsed: {str(end_time - start_time)}")
return Evaluation(
question_id=question.id, answer_id=generated_answer.id, score=score
question_id=user_prompt.id, answer_id=sql_generation.id, score=score
)
27 changes: 6 additions & 21 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Base class that all sql generation classes inherit from."""
import re
from abc import ABC, abstractmethod
from datetime import date, datetime
from typing import Any, List, Tuple

import sqlparse
Expand All @@ -12,7 +11,7 @@
from dataherald.sql_database.base import SQLDatabase
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_generator.create_sql_query_status import create_sql_query_status
from dataherald.types import Question, Response, SQLQueryResult
from dataherald.types import Prompt, SQLGeneration
from dataherald.utils.strings import contains_line_breaks


Expand Down Expand Up @@ -44,22 +43,9 @@ def remove_markdown(self, query: str) -> str:
return query

def create_sql_query_status(
self,
db: SQLDatabase,
query: str,
response: Response,
top_k: int = None,
generate_csv: bool = False,
database_connection: DatabaseConnection | None = None,
) -> Response:
return create_sql_query_status(
db,
query,
response,
top_k,
generate_csv,
database_connection=database_connection,
)
self, db: SQLDatabase, query: str, sql_generation: SQLGeneration
) -> SQLGeneration:
return create_sql_query_status(db, query, sql_generation)

def format_intermediate_representations(
self, intermediate_representation: List[Tuple[AgentAction, str]]
Expand Down Expand Up @@ -91,10 +77,9 @@ def format_sql_query(self, sql_query: str) -> str:
@abstractmethod
def generate_response(
self,
user_question: Question,
user_prompt: Prompt,
database_connection: DatabaseConnection,
context: List[dict] = None,
generate_csv: bool = False,
) -> Response:
) -> SQLGeneration:
"""Generates a response to a user question."""
pass
Loading

0 comments on commit 3418756

Please sign in to comment.