Skip to content

Commit

Permalink
DH-5080/refactor_engine_with_new_resources
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Dec 15, 2023
1 parent 816a46b commit 019c696
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 257 deletions.
4 changes: 2 additions & 2 deletions dataherald/context_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataherald.config import Component, System
from dataherald.db import DB
from dataherald.types import GoldenRecord, GoldenRecordRequest, Question
from dataherald.types import GoldenRecord, GoldenRecordRequest, Prompt
from dataherald.vector_store import VectorStore


Expand All @@ -24,7 +24,7 @@ def __init__(self, system: System):

@abstractmethod
def retrieve_context_for_question(
self, nl_question: Question, number_of_samples: int = 3
self, prompt: Prompt, number_of_samples: int = 3
) -> Tuple[List[dict] | None, List[dict] | None]:
pass

Expand Down
12 changes: 6 additions & 6 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataherald.context_store import ContextStore
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.types import GoldenRecord, GoldenRecordRequest, Question
from dataherald.types import GoldenRecord, GoldenRecordRequest, Prompt

logger = logging.getLogger(__name__)

Expand All @@ -19,12 +19,12 @@ def __init__(self, system: System):

@override
def retrieve_context_for_question(
self, nl_question: Question, number_of_samples: int = 3
self, prompt: Prompt, number_of_samples: int = 3
) -> Tuple[List[dict] | None, List[dict] | None]:
logger.info(f"Getting context for {nl_question.question}")
logger.info(f"Getting context for {prompt.text}")
closest_questions = self.vector_store.query(
query_texts=[nl_question.question],
db_connection_id=nl_question.db_connection_id,
query_texts=[prompt.text],
db_connection_id=prompt.db_connection_id,
collection=self.golden_record_collection,
num_results=number_of_samples,
)
Expand All @@ -47,7 +47,7 @@ def retrieve_context_for_question(
instruction_repository = InstructionRepository(self.db)
all_instructions = instruction_repository.find_all()
for instruction in all_instructions:
if instruction.db_connection_id == nl_question.db_connection_id:
if instruction.db_connection_id == prompt.db_connection_id:
instructions.append(
{
"instruction": instruction.instruction,
Expand Down
27 changes: 6 additions & 21 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Base class that all sql generation classes inherit from."""
import re
from abc import ABC, abstractmethod
from datetime import date, datetime
from typing import Any, List, Tuple

import sqlparse
Expand All @@ -12,7 +11,7 @@
from dataherald.sql_database.base import SQLDatabase
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_generator.create_sql_query_status import create_sql_query_status
from dataherald.types import Question, Response, SQLQueryResult
from dataherald.types import Prompt, SQLGeneration
from dataherald.utils.strings import contains_line_breaks


Expand Down Expand Up @@ -44,22 +43,9 @@ def remove_markdown(self, query: str) -> str:
return query

def create_sql_query_status(
self,
db: SQLDatabase,
query: str,
response: Response,
top_k: int = None,
generate_csv: bool = False,
database_connection: DatabaseConnection | None = None,
) -> Response:
return create_sql_query_status(
db,
query,
response,
top_k,
generate_csv,
database_connection=database_connection,
)
self, db: SQLDatabase, query: str, sql_generation: SQLGeneration
) -> SQLGeneration:
return create_sql_query_status(db, query, sql_generation)

def format_intermediate_representations(
self, intermediate_representation: List[Tuple[AgentAction, str]]
Expand Down Expand Up @@ -91,10 +77,9 @@ def format_sql_query(self, sql_query: str) -> str:
@abstractmethod
def generate_response(
self,
user_question: Question,
user_prompt: Prompt,
database_connection: DatabaseConnection,
context: List[dict] = None,
generate_csv: bool = False,
) -> Response:
) -> SQLGeneration:
"""Generates a response to a user question."""
pass
105 changes: 16 additions & 89 deletions dataherald/sql_generator/create_sql_query_status.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,43 @@
import csv
import uuid
from datetime import date, datetime
from decimal import Decimal

from sqlalchemy import text

from dataherald.config import Settings
from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.types import Response, SQLQueryResult
from dataherald.utils.s3 import S3
from dataherald.types import SQLGeneration


def format_error_message(response: Response, error_message: str) -> Response:
def format_error_message(
sql_generation: SQLGeneration, error_message: str
) -> SQLGeneration:
# Remove the complete query
if error_message.find("[") > 0 and error_message.find("]") > 0:
error_message = (
error_message[0 : error_message.find("[")]
+ error_message[error_message.rfind("]") + 1 :]
)
response.sql_generation_status = "INVALID"
response.response = ""
response.sql_query_result = None
response.error_message = error_message
return response


def create_csv_file(
generate_csv: bool,
columns: list,
rows: list,
response: Response,
database_connection: DatabaseConnection | None = None,
):
if generate_csv:
file_location = f"tmp/{str(uuid.uuid4())}.csv"
with open(file_location, "w", newline="") as file:
writer = csv.writer(file)

writer.writerow(rows[0].keys())
for row in rows:
writer.writerow(row.values())
if Settings().only_store_csv_files_locally:
response.csv_file_path = file_location
else:
s3 = S3()
response.csv_file_path = s3.upload(
file_location, database_connection.file_storage
)
response.sql_query_result = SQLQueryResult(columns=columns, rows=rows)
sql_generation.status = "INVALID"
sql_generation.error = error_message
return sql_generation


def create_sql_query_status(
db: SQLDatabase,
query: str,
response: Response,
top_k: int = None,
generate_csv: bool = False,
database_connection: DatabaseConnection | None = None,
) -> Response:
sql_generation: SQLGeneration,
) -> SQLGeneration:
"""Find the sql query status and populate the fields sql_query_result, sql_generation_status, and error_message"""
if query == "":
response.sql_generation_status = "INVALID"
response.response = ""
response.sql_query_result = None
response.error_message = None
sql_generation.status = "INVALID"
sql_generation.error = "Empty query"
else:
try:
query = db.parser_to_filter_commands(query)
with db._engine.connect() as connection:
execution = connection.execute(text(query))
columns = execution.keys()
if top_k:
result = execution.fetchmany(top_k)
else:
result = execution.fetchall()
if len(result) == 0:
response.sql_query_result = None
else:
columns = [item for item in columns] # noqa: C416
rows = []
for row in result:
modified_row = {}
for key, value in zip(row.keys(), row, strict=True):
if type(value) in [
date,
datetime,
]: # Check if the value is an instance of datetime.date
modified_row[key] = str(value)
elif (
type(value) is Decimal
): # Check if the value is an instance of decimal.Decimal
modified_row[key] = float(value)
else:
modified_row[key] = value
rows.append(modified_row)

create_csv_file(
generate_csv,
columns,
rows,
response,
database_connection,
)

response.sql_generation_status = "VALID"
response.error_message = None
connection.execute(text(query))
sql_generation.status = "VALID"
sql_generation.error = None
except SQLInjectionError as e:
raise SQLInjectionError(
"Sensitive SQL keyword detected in the query."
) from e
except Exception as e:
response = format_error_message(response, str(e))

return response
sql_generation = format_error_message(sql_generation, str(e))
return sql_generation
43 changes: 20 additions & 23 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
DatabaseConnection,
)
from dataherald.sql_generator import EngineTimeOutORItemLimitError, SQLGenerator
from dataherald.types import Question, Response
from dataherald.types import Prompt, SQLGeneration
from dataherald.utils.agent_prompts import (
FINETUNING_AGENT_PREFIX,
FINETUNING_AGENT_SUFFIX,
Expand Down Expand Up @@ -436,11 +436,10 @@ def create_sql_agent(
@override
def generate_response(
self,
user_question: Question,
user_prompt: Prompt,
database_connection: DatabaseConnection,
context: List[dict] = None, # noqa: ARG002
generate_csv: bool = False,
) -> Response:
) -> SQLGeneration:
"""
generate_response generates a response to a user question using a Finetuning model.
Expand Down Expand Up @@ -470,7 +469,7 @@ def generate_response(
if not db_scan:
raise ValueError("No scanned tables found for database")
_, instructions = context_store.retrieve_context_for_question(
user_question, number_of_samples=1
user_prompt, number_of_samples=1
)

self.database = SQLDatabase.get_sql_engine(database_connection)
Expand All @@ -489,21 +488,21 @@ def generate_response(
agent_executor.handle_parsing_errors = True
with get_openai_callback() as cb:
try:
result = agent_executor({"input": user_question.question})
result = agent_executor({"input": user_prompt.text})
result = self.check_for_time_out_or_tool_limit(result)
except SQLInjectionError as e:
raise SQLInjectionError(e) from e
except EngineTimeOutORItemLimitError as e:
raise EngineTimeOutORItemLimitError(e) from e
except Exception as e:
return Response(
question_id=user_question.id,
total_tokens=cb.total_tokens,
total_cost=cb.total_cost,
sql_query="",
sql_generation_status="INVALID",
sql_query_result=None,
error_message=str(e),
return SQLGeneration(
prompt_id=user_prompt.id,
tokens_used=cb.total_tokens,
model="FineTuning_Agent",
completed_at=datetime.datetime.now(),
sql="",
status="INVALID",
error=str(e),
)
if "```sql" in result["output"]:
sql_query = self.remove_markdown(result["output"])
Expand All @@ -514,17 +513,15 @@ def generate_response(
sql_query = self.remove_markdown(sql_query)
sql_query = self.format_sql_query(action.tool_input)
logger.info(f"cost: {str(cb.total_cost)} tokens: {str(cb.total_tokens)}")
response = Response(
question_id=user_question.id,
total_tokens=cb.total_tokens,
total_cost=cb.total_cost,
sql_query=sql_query,
response = SQLGeneration(
prompt_id=user_prompt.id,
tokens_used=cb.total_tokens,
model="RAG_AGENT",
completed_at=datetime.datetime.now(),
sql=sql_query,
)
return self.create_sql_query_status(
self.database,
response.sql_query,
response.sql,
response,
top_k=TOP_K,
generate_csv=generate_csv,
database_connection=database_connection,
)
Loading

0 comments on commit 019c696

Please sign in to comment.