Skip to content

Commit

Permalink
DH-4784 Store all responses and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Oct 5, 2023
1 parent 9b56161 commit 4919c9c
Show file tree
Hide file tree
Showing 45 changed files with 506 additions and 475 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,11 @@ curl -X 'PUT' \


### Querying the Database in Natural Language
Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the `POST /api/v1/question` endpoint.
Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the `POST /api/v1/questions` endpoint.

```
curl -X 'POST' \
'<host>/api/v1/question' \
'<host>/api/v1/questions' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
Expand Down
34 changes: 17 additions & 17 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@

from dataherald.api.types import Query
from dataherald.config import Component
from dataherald.db_scanner.models.types import TableSchemaDetail
from dataherald.db_scanner.models.types import TableDescription
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
from dataherald.types import (
CreateResponseRequest,
DatabaseConnectionRequest,
ExecuteTempQueryRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
InstructionRequest,
NLQueryResponse,
Question,
QuestionRequest,
Response,
ScannerRequest,
TableDescriptionRequest,
UpdateInstruction,
UpdateQueryRequest,
)


Expand All @@ -36,7 +36,11 @@ def scan_db(
pass

@abstractmethod
def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse:
def answer_question(self, question_request: QuestionRequest) -> Response:
pass

@abstractmethod
def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
pass

@abstractmethod
Expand All @@ -62,13 +66,17 @@ def update_table_description(
self,
table_description_id: str,
table_description_request: TableDescriptionRequest,
) -> TableSchemaDetail:
) -> TableDescription:
pass

@abstractmethod
def list_table_descriptions(
self, db_connection_id: str | None = None, table_name: str | None = None
) -> list[TableSchemaDetail]:
self, db_connection_id: str, table_name: str | None = None
) -> list[TableDescription]:
pass

@abstractmethod
def get_responses(self, question_id: str | None = None) -> list[Response]:
pass

@abstractmethod
Expand All @@ -82,15 +90,7 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]:
pass

@abstractmethod
def update_nl_query_response(
self, query_id: str, query: UpdateQueryRequest
) -> NLQueryResponse:
pass

@abstractmethod
def get_nl_query_response(
self, query_request: ExecuteTempQueryRequest
) -> NLQueryResponse:
def create_response(self, query_request: CreateResponseRequest) -> Response:
pass

@abstractmethod
Expand Down
118 changes: 51 additions & 67 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

from bson import json_util
from bson.objectid import ObjectId
from fastapi import BackgroundTasks, HTTPException
from overrides import override

Expand All @@ -13,14 +14,14 @@
from dataherald.context_store import ContextStore
from dataherald.db import DB
from dataherald.db_scanner import Scanner
from dataherald.db_scanner.models.types import TableDescriptionStatus, TableSchemaDetail
from dataherald.db_scanner.repository.base import DBScannerRepository
from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus
from dataherald.db_scanner.repository.base import TableDescriptionRepository
from dataherald.eval import Evaluator
from dataherald.repositories.base import NLQueryResponseRepository
from dataherald.repositories.base import ResponseRepository
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.repositories.nl_question import NLQuestionRepository
from dataherald.repositories.question import QuestionRepository
from dataherald.sql_database.base import (
InvalidDBConnectionError,
SQLDatabase,
Expand All @@ -30,19 +31,18 @@
from dataherald.sql_generator import SQLGenerator
from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer
from dataherald.types import (
CreateResponseRequest,
DatabaseConnectionRequest,
ExecuteTempQueryRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
InstructionRequest,
NLQuery,
NLQueryResponse,
Question,
QuestionRequest,
Response,
ScannerRequest,
TableDescriptionRequest,
UpdateInstruction,
UpdateQueryRequest,
)

logger = logging.getLogger(__name__)
Expand All @@ -53,7 +53,7 @@ def async_scanning(scanner, database, scanner_request, storage):
database,
scanner_request.db_connection_id,
scanner_request.table_names,
DBScannerRepository(storage),
TableDescriptionRepository(storage),
)


Expand Down Expand Up @@ -103,7 +103,7 @@ def scan_db(
scanner.synchronizing(
scanner_request.table_names,
scanner_request.db_connection_id,
DBScannerRepository(self.storage),
TableDescriptionRepository(self.storage),
)

background_tasks.add_task(
Expand All @@ -112,20 +112,20 @@ def scan_db(
return True

@override
def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse:
def answer_question(self, question_request: QuestionRequest) -> Response:
"""Takes in an English question and answers it based on content from the registered databases"""
logger.info(f"Answer question: {question_request.question}")
sql_generation = self.system.instance(SQLGenerator)
evaluator = self.system.instance(Evaluator)
context_store = self.system.instance(ContextStore)

user_question = NLQuery(
user_question = Question(
question=question_request.question,
db_connection_id=question_request.db_connection_id,
)

nl_question_repository = NLQuestionRepository(self.storage)
user_question = nl_question_repository.insert(user_question)
question_repository = QuestionRepository(self.storage)
user_question = question_repository.insert(user_question)

db_connection_repository = DatabaseConnectionRepository(self.storage)
database_connection = db_connection_repository.find_by_id(
Expand All @@ -149,9 +149,8 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse:
raise HTTPException(status_code=404, detail=str(e)) from e
generated_answer.confidence_score = confidence_score
generated_answer.exec_time = time.time() - start_generated_answer
nl_query_response_repository = NLQueryResponseRepository(self.storage)
nl_query_response = nl_query_response_repository.insert(generated_answer)
return json.loads(json_util.dumps(nl_query_response))
response_repository = ResponseRepository(self.storage)
return response_repository.insert(generated_answer)

@override
def create_database_connection(
Expand Down Expand Up @@ -217,8 +216,8 @@ def update_table_description(
self,
table_description_id: str,
table_description_request: TableDescriptionRequest,
) -> TableSchemaDetail:
scanner_repository = DBScannerRepository(self.storage)
) -> TableDescription:
scanner_repository = TableDescriptionRepository(self.storage)
table = scanner_repository.find_by_id(table_description_id)

if not table:
Expand All @@ -238,11 +237,11 @@ def update_table_description(

@override
def list_table_descriptions(
self, db_connection_id: str | None = None, table_name: str | None = None
) -> list[TableSchemaDetail]:
scanner_repository = DBScannerRepository(self.storage)
self, db_connection_id: str, table_name: str | None = None
) -> list[TableDescription]:
scanner_repository = TableDescriptionRepository(self.storage)
table_descriptions = scanner_repository.find_by(
{"db_connection_id": db_connection_id, "table_name": table_name}
{"db_connection_id": ObjectId(db_connection_id), "table_name": table_name}
)

if db_connection_id:
Expand All @@ -260,7 +259,7 @@ def list_table_descriptions(
all_tables.remove(table_description.table_name)
for table in all_tables:
table_descriptions.append(
TableSchemaDetail(
TableDescription(
table_name=table,
status=TableDescriptionStatus.NOT_SYNCHRONIZED.value,
db_connection_id=db_connection_id,
Expand All @@ -270,6 +269,22 @@ def list_table_descriptions(

return table_descriptions

@override
def get_responses(self, question_id: str | None = None) -> list[Response]:
response_repository = ResponseRepository(self.storage)
query = {}
if question_id:
query = {"question_id": ObjectId(question_id)}
return response_repository.find_by(query)

@override
def get_questions(self, db_connection_id: str | None = None) -> list[Question]:
question_repository = QuestionRepository(self.storage)
query = {}
if db_connection_id:
query = {"db_connection_id": ObjectId(db_connection_id)}
return question_repository.find_by(query)

@override
def add_golden_records(
self, golden_records: List[GoldenRecordRequest]
Expand All @@ -295,53 +310,22 @@ def execute_sql_query(self, query: Query) -> tuple[str, dict]:
return result

@override
def update_nl_query_response(
self, query_id: str, query: UpdateQueryRequest # noqa: ARG002
) -> NLQueryResponse:
nl_query_response_repository = NLQueryResponseRepository(self.storage)
nl_question_repository = NLQuestionRepository(self.storage)
nl_query_response = nl_query_response_repository.find_by_id(query_id)
nl_question = nl_question_repository.find_by_id(
nl_query_response.nl_question_id
def create_response(
self, query_request: CreateResponseRequest # noqa: ARG002
) -> Response:
response = Response(
question_id=query_request.question_id, sql_query=query_request.sql_query
)
if nl_query_response.sql_query.strip() != query.sql_query.strip():
nl_query_response.sql_query = query.sql_query
evaluator = self.system.instance(Evaluator)
db_connection_repository = DatabaseConnectionRepository(self.storage)
database_connection = db_connection_repository.find_by_id(
nl_question.db_connection_id
)
if not database_connection:
raise HTTPException(
status_code=404, detail="Database connection not found"
)
try:
confidence_score = evaluator.get_confidence_score(
nl_question, nl_query_response, database_connection
)
nl_query_response.confidence_score = confidence_score
generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
nl_query_response = generates_nl_answer.execute(nl_query_response)
except SQLInjectionError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
nl_query_response_repository.update(nl_query_response)
return json.loads(json_util.dumps(nl_query_response))
response_repository = ResponseRepository(self.storage)
response_repository.insert(response)

@override
def get_nl_query_response(
self, query_request: ExecuteTempQueryRequest # noqa: ARG002
) -> NLQueryResponse:
nl_query_response_repository = NLQueryResponseRepository(self.storage)
nl_query_response = nl_query_response_repository.find_by_id(
query_request.query_id
)
nl_query_response.sql_query = query_request.sql_query
try:
generates_nl_answer = GeneratesNlAnswer(self.system, self.storage)
nl_query_response = generates_nl_answer.execute(nl_query_response)
response = generates_nl_answer.execute(response)
response_repository.update(response)
except SQLInjectionError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
return json.loads(json_util.dumps(nl_query_response))
return response

@override
def delete_golden_record(self, golden_record_id: str) -> dict:
Expand All @@ -356,7 +340,7 @@ def get_golden_records(
golden_records_repository = GoldenRecordRepository(self.storage)
if db_connection_id:
return golden_records_repository.find_by(
{"db_connection_id": db_connection_id},
{"db_connection_id": ObjectId(db_connection_id)},
page=page,
limit=limit,
)
Expand All @@ -378,7 +362,7 @@ def get_instructions(
instruction_repository = InstructionRepository(self.storage)
if db_connection_id:
return instruction_repository.find_by(
{"db_connection_id": db_connection_id},
{"db_connection_id": ObjectId(db_connection_id)},
page=page,
limit=limit,
)
Expand Down
4 changes: 2 additions & 2 deletions dataherald/context_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataherald.config import Component, System
from dataherald.db import DB
from dataherald.types import GoldenRecord, GoldenRecordRequest, NLQuery
from dataherald.types import GoldenRecord, GoldenRecordRequest, Question
from dataherald.vector_store import VectorStore


Expand All @@ -24,7 +24,7 @@ def __init__(self, system: System):

@abstractmethod
def retrieve_context_for_question(
self, nl_question: NLQuery, number_of_samples: int = 3
self, nl_question: Question, number_of_samples: int = 3
) -> Tuple[List[dict] | None, List[dict] | None]:
pass

Expand Down
4 changes: 2 additions & 2 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataherald.context_store import ContextStore
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.types import GoldenRecord, GoldenRecordRequest, NLQuery
from dataherald.types import GoldenRecord, GoldenRecordRequest, Question

logger = logging.getLogger(__name__)

Expand All @@ -19,7 +19,7 @@ def __init__(self, system: System):

@override
def retrieve_context_for_question(
self, nl_question: NLQuery, number_of_samples: int = 3
self, nl_question: Question, number_of_samples: int = 3
) -> Tuple[List[dict] | None, List[dict] | None]:
logger.info(f"Getting context for {nl_question.question}")
closest_questions = self.vector_store.query(
Expand Down
6 changes: 6 additions & 0 deletions dataherald/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def insert_one(self, collection: str, obj: dict) -> int:
def rename(self, old_collection_name: str, new_collection_name) -> None:
pass

@abstractmethod
def rename_field(
self, collection_name: str, old_field_name: str, new_field_name: str
) -> None:
pass

@abstractmethod
def update_or_create(self, collection: str, query: dict, obj: dict) -> int:
pass
Expand Down
Loading

0 comments on commit 4919c9c

Please sign in to comment.