Skip to content

Commit

Permalink
DH-4724/adding support for db level instructions (#185)
Browse files Browse the repository at this point in the history
* DH-4724/adding support for db level instructions

* DH-4724/chanign the database tests

* DH-4724/changing the find_all() function

* DH-4724/changing the tests

* DH-4724/adding find_by method to instructionRepo

* DH-4724/reformat

* DH-4724/add paginating to the db

* DH-4724/reformat with black

* DH-4724/updating the endpoints to be RESTfull

* DH-4724/reformat with black

* DH-4724/changing the docs
  • Loading branch information
MohammadrezaPourreza authored Sep 27, 2023
1 parent 680fc68 commit 0e5bce2
Show file tree
Hide file tree
Showing 18 changed files with 570 additions and 28 deletions.
51 changes: 51 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,57 @@ curl -X 'PATCH' \
}'
```

#### adding database level instructions

You can add database level instructions to the context store manually from the `POST /api/v1/instructions` endpoint
These instructions are passed directly to the engine and can be used to steer the engine to generate SQL that is more in line with your business logic.

```
curl -X 'POST' \
'<host>/api/v1/instructions' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"instruction": "This is a database level instruction"
"db_connection_id": "db_connection_id"
}'
```

#### getting database level instructions

You can get database level instructions from the `GET /api/v1/instructions` endpoint

```
curl -X 'GET' \
'<host>/api/v1/instructions?page=1&limit=10&db_connection_id=12312312' \
-H 'accept: application/json'
```

#### deleting database level instructions

You can delete database level instructions from the `DELETE /api/v1/instructions/{instruction_id}` endpoint

```
curl -X 'DELETE' \
'<host>/api/v1/instructions/{instruction_id}' \
-H 'accept: application/json'
```

#### updating database level instructions

You can update database level instructions from the `PUT /api/v1/instructions/{instruction_id}` endpoint
Try different instructions to see how the engine generates SQL

```
curl -X 'PUT' \
'<host>/api/v1/instructions/{instruction_id}' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"instruction": "This is a database level instruction"
}'
```


### Querying the Database in Natural Language
Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the `POST /api/v1/question` endpoint.
Expand Down
25 changes: 25 additions & 0 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
ExecuteTempQueryRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
InstructionRequest,
NLQueryResponse,
QuestionRequest,
ScannerRequest,
TableDescriptionRequest,
UpdateInstruction,
UpdateQueryRequest,
)

Expand Down Expand Up @@ -97,3 +100,25 @@ def delete_golden_record(self, golden_record_id: str) -> dict:
@abstractmethod
def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecord]:
pass

@abstractmethod
def add_instruction(self, instruction_request: InstructionRequest) -> Instruction:
pass

@abstractmethod
def get_instructions(
self, db_connection_id: str = None, page: int = 1, limit: int = 10
) -> List[Instruction]:
pass

@abstractmethod
def delete_instruction(self, instruction_id: str) -> dict:
pass

@abstractmethod
def update_instruction(
self,
instruction_id: str,
instruction_request: UpdateInstruction,
) -> Instruction:
pass
54 changes: 53 additions & 1 deletion dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dataherald.repositories.base import NLQueryResponseRepository
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.repositories.nl_question import NLQuestionRepository
from dataherald.sql_database.base import (
InvalidDBConnectionError,
Expand All @@ -33,11 +34,14 @@
ExecuteTempQueryRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
InstructionRequest,
NLQuery,
NLQueryResponse,
QuestionRequest,
ScannerRequest,
TableDescriptionRequest,
UpdateInstruction,
UpdateQueryRequest,
)

Expand Down Expand Up @@ -133,7 +137,7 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse:
start_generated_answer = time.time()
try:
generated_answer = sql_generation.generate_response(
user_question, database_connection, context
user_question, database_connection, context[0]
)
logger.info("Starts evaluator...")
confidence_score = evaluator.get_confidence_score(
Expand Down Expand Up @@ -353,3 +357,51 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor
start_idx = (page - 1) * limit
end_idx = start_idx + limit
return all_records[start_idx:end_idx]

@override
def add_instruction(self, instruction_request: InstructionRequest) -> Instruction:
instruction_repository = InstructionRepository(self.storage)
instruction = Instruction(
instruction=instruction_request.instruction,
db_connection_id=instruction_request.db_connection_id,
)
return instruction_repository.insert(instruction)

@override
def get_instructions(
self, db_connection_id: str = None, page: int = 1, limit: int = 10
) -> List[Instruction]:
instruction_repository = InstructionRepository(self.storage)
if db_connection_id:
return instruction_repository.find_by(
{"db_connection_id": db_connection_id},
page=page,
limit=limit,
)
return instruction_repository.find_all()

@override
def delete_instruction(self, instruction_id: str) -> dict:
instruction_repository = InstructionRepository(self.storage)
deleted = instruction_repository.delete_by_id(instruction_id)
if deleted == 0:
raise HTTPException(status_code=404, detail="Instruction not found")
return {"status": "success"}

@override
def update_instruction(
self,
instruction_id: str,
instruction_request: UpdateInstruction,
) -> Instruction:
instruction_repository = InstructionRepository(self.storage)
instruction = instruction_repository.find_by_id(instruction_id)
if not instruction:
raise HTTPException(status_code=404, detail="Instruction not found")
updated_instruction = Instruction(
id=instruction_id,
instruction=instruction_request.instruction,
db_connection_id=instruction.db_connection_id,
)
instruction_repository.update(updated_instruction)
return json.loads(json_util.dumps(updated_instruction))
4 changes: 2 additions & 2 deletions dataherald/context_store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Any, List
from typing import List, Tuple

from dataherald.config import Component, System
from dataherald.db import DB
Expand All @@ -25,7 +25,7 @@ def __init__(self, system: System):
@abstractmethod
def retrieve_context_for_question(
self, nl_question: NLQuery, number_of_samples: int = 3
) -> List[dict] | None:
) -> Tuple[List[dict] | None, List[dict] | None]:
pass

@abstractmethod
Expand Down
21 changes: 17 additions & 4 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
from typing import List
from typing import List, Tuple

from overrides import override
from sql_metadata import Parser

from dataherald.config import System
from dataherald.context_store import ContextStore
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.types import GoldenRecord, GoldenRecordRequest, NLQuery

logger = logging.getLogger(__name__)
Expand All @@ -19,7 +20,7 @@ def __init__(self, system: System):
@override
def retrieve_context_for_question(
self, nl_question: NLQuery, number_of_samples: int = 3
) -> List[dict] | None:
) -> Tuple[List[dict] | None, List[dict] | None]:
logger.info(f"Getting context for {nl_question.question}")
closest_questions = self.vector_store.query(
query_texts=[nl_question.question],
Expand All @@ -41,9 +42,21 @@ def retrieve_context_for_question(
}
)
if len(samples) == 0:
return None
samples = None
instructions = []
instruction_repository = InstructionRepository(self.db)
all_instructions = instruction_repository.find_all()
for instruction in all_instructions:
if instruction.db_connection_id == nl_question.db_connection_id:
instructions.append(
{
"instruction": instruction.instruction,
}
)
if len(instructions) == 0:
instructions = None

return samples
return samples, instructions

@override
def add_golden_records(
Expand Down
9 changes: 8 additions & 1 deletion dataherald/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ def find_by_id(self, collection: str, id: str) -> dict:
pass

@abstractmethod
def find(self, collection: str, query: dict, sort: list = None) -> list:
def find(
self,
collection: str,
query: dict,
sort: list = None,
page: int = 0,
limit: int = 0,
) -> list:
pass

@abstractmethod
Expand Down
17 changes: 14 additions & 3 deletions dataherald/db/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,21 @@ def find_by_id(self, collection: str, id: str) -> dict:
return self._data_store[collection].find_one({"_id": ObjectId(id)})

@override
def find(self, collection: str, query: dict, sort: list = None) -> list:
def find(
self,
collection: str,
query: dict,
sort: list = None,
page: int = 0,
limit: int = 0,
) -> list:
skip_count = (page - 1) * limit
cursor = self._data_store[collection].find(query)
if sort:
return self._data_store[collection].find(query).sort(sort)
return self._data_store[collection].find(query)
cursor = cursor.sort(sort)
if page > 0 and limit > 0:
cursor = cursor.skip(skip_count).limit(limit)
return list(cursor)

@override
def find_all(self, collection: str) -> list:
Expand Down
52 changes: 52 additions & 0 deletions dataherald/repositories/instructions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from bson.objectid import ObjectId

from dataherald.types import Instruction

DB_COLLECTION = "instructions"


class InstructionRepository:
def __init__(self, storage):
self.storage = storage

def insert(self, instruction: Instruction) -> Instruction:
instruction.id = str(
self.storage.insert_one(DB_COLLECTION, instruction.dict(exclude={"id"}))
)
return instruction

def find_one(self, query: dict) -> Instruction | None:
row = self.storage.find_one(DB_COLLECTION, query)
if not row:
return None
return Instruction(**row)

def update(self, instruction: Instruction) -> Instruction:
self.storage.update_or_create(
DB_COLLECTION,
{"_id": ObjectId(instruction.id)},
instruction.dict(exclude={"id"}),
)
return instruction

def find_by_id(self, id: str) -> Instruction | None:
row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)})
if not row:
return None
return Instruction(**row)

def find_by(self, query: dict, page: int = 1, limit: int = 10) -> list[Instruction]:
rows = self.storage.find(DB_COLLECTION, query, page=page, limit=limit)
result = []
for row in rows:
obj = Instruction(**row)
obj.id = str(row["_id"])
result.append(obj)
return result

def find_all(self) -> list[Instruction]:
rows = self.storage.find_all(DB_COLLECTION)
return [Instruction(id=str(row["_id"]), **row) for row in rows]

def delete_by_id(self, id: str) -> int:
return self.storage.delete_by_id(DB_COLLECTION, id)
Loading

0 comments on commit 0e5bce2

Please sign in to comment.