diff --git a/README.md b/README.md index 48710dc1..fbb4aae0 100644 --- a/README.md +++ b/README.md @@ -317,6 +317,57 @@ curl -X 'PATCH' \ }' ``` +#### adding database level instructions + +You can add database level instructions to the context store manually from the `POST /api/v1/instructions` endpoint +These instructions are passed directly to the engine and can be used to steer the engine to generate SQL that is more in line with your business logic. + +``` +curl -X 'POST' \ + '/api/v1/instructions' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "instruction": "This is a database level instruction" + "db_connection_id": "db_connection_id" +}' +``` + +#### getting database level instructions + +You can get database level instructions from the `GET /api/v1/instructions` endpoint + +``` +curl -X 'GET' \ + '/api/v1/instructions?page=1&limit=10&db_connection_id=12312312' \ + -H 'accept: application/json' +``` + +#### deleting database level instructions + +You can delete database level instructions from the `DELETE /api/v1/instructions/{instruction_id}` endpoint + +``` +curl -X 'DELETE' \ + '/api/v1/instructions/{instruction_id}' \ + -H 'accept: application/json' +``` + +#### updating database level instructions + +You can update database level instructions from the `PUT /api/v1/instructions/{instruction_id}` endpoint +Try different instructions to see how the engine generates SQL + +``` +curl -X 'PUT' \ + '/api/v1/instructions/{instruction_id}' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "instruction": "This is a database level instruction" +}' +``` + ### Querying the Database in Natural Language Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the `POST /api/v1/question` endpoint. diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index b6c10f7b..374928c6 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -12,10 +12,13 @@ ExecuteTempQueryRequest, GoldenRecord, GoldenRecordRequest, + Instruction, + InstructionRequest, NLQueryResponse, QuestionRequest, ScannerRequest, TableDescriptionRequest, + UpdateInstruction, UpdateQueryRequest, ) @@ -97,3 +100,25 @@ def delete_golden_record(self, golden_record_id: str) -> dict: @abstractmethod def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecord]: pass + + @abstractmethod + def add_instruction(self, instruction_request: InstructionRequest) -> Instruction: + pass + + @abstractmethod + def get_instructions( + self, db_connection_id: str = None, page: int = 1, limit: int = 10 + ) -> List[Instruction]: + pass + + @abstractmethod + def delete_instruction(self, instruction_id: str) -> dict: + pass + + @abstractmethod + def update_instruction( + self, + instruction_id: str, + instruction_request: UpdateInstruction, + ) -> Instruction: + pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 21bfb84e..15755528 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -19,6 +19,7 @@ from dataherald.repositories.base import NLQueryResponseRepository from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.repositories.golden_records import GoldenRecordRepository +from dataherald.repositories.instructions import InstructionRepository from dataherald.repositories.nl_question import NLQuestionRepository from dataherald.sql_database.base import ( InvalidDBConnectionError, @@ -33,11 +34,14 @@ ExecuteTempQueryRequest, GoldenRecord, GoldenRecordRequest, + Instruction, + InstructionRequest, NLQuery, NLQueryResponse, QuestionRequest, ScannerRequest, TableDescriptionRequest, + UpdateInstruction, UpdateQueryRequest, ) @@ -117,7 +121,7 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse: start_generated_answer = time.time() try: generated_answer = sql_generation.generate_response( - user_question, database_connection, context + user_question, database_connection, context[0] ) logger.info("Starts evaluator...") confidence_score = evaluator.get_confidence_score( @@ -312,3 +316,51 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor start_idx = (page - 1) * limit end_idx = start_idx + limit return all_records[start_idx:end_idx] + + @override + def add_instruction(self, instruction_request: InstructionRequest) -> Instruction: + instruction_repository = InstructionRepository(self.storage) + instruction = Instruction( + instruction=instruction_request.instruction, + db_connection_id=instruction_request.db_connection_id, + ) + return instruction_repository.insert(instruction) + + @override + def get_instructions( + self, db_connection_id: str = None, page: int = 1, limit: int = 10 + ) -> List[Instruction]: + instruction_repository = InstructionRepository(self.storage) + if db_connection_id: + return instruction_repository.find_by( + {"db_connection_id": db_connection_id}, + page=page, + limit=limit, + ) + return instruction_repository.find_all() + + @override + def delete_instruction(self, instruction_id: str) -> dict: + instruction_repository = InstructionRepository(self.storage) + deleted = instruction_repository.delete_by_id(instruction_id) + if deleted == 0: + raise HTTPException(status_code=404, detail="Instruction not found") + return {"status": "success"} + + @override + def update_instruction( + self, + instruction_id: str, + instruction_request: UpdateInstruction, + ) -> Instruction: + instruction_repository = InstructionRepository(self.storage) + instruction = instruction_repository.find_by_id(instruction_id) + if not instruction: + raise HTTPException(status_code=404, detail="Instruction not found") + updated_instruction = Instruction( + id=instruction_id, + instruction=instruction_request.instruction, + db_connection_id=instruction.db_connection_id, + ) + instruction_repository.update(updated_instruction) + return json.loads(json_util.dumps(updated_instruction)) diff --git a/dataherald/context_store/__init__.py b/dataherald/context_store/__init__.py index eb327cd7..693b5d6e 100644 --- a/dataherald/context_store/__init__.py +++ b/dataherald/context_store/__init__.py @@ -1,6 +1,6 @@ import os from abc import ABC, abstractmethod -from typing import Any, List +from typing import List, Tuple from dataherald.config import Component, System from dataherald.db import DB @@ -25,7 +25,7 @@ def __init__(self, system: System): @abstractmethod def retrieve_context_for_question( self, nl_question: NLQuery, number_of_samples: int = 3 - ) -> List[dict] | None: + ) -> Tuple[List[dict] | None, List[dict] | None]: pass @abstractmethod diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index b73de25f..453d34b8 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Tuple from overrides import override from sql_metadata import Parser @@ -7,6 +7,7 @@ from dataherald.config import System from dataherald.context_store import ContextStore from dataherald.repositories.golden_records import GoldenRecordRepository +from dataherald.repositories.instructions import InstructionRepository from dataherald.types import GoldenRecord, GoldenRecordRequest, NLQuery logger = logging.getLogger(__name__) @@ -19,7 +20,7 @@ def __init__(self, system: System): @override def retrieve_context_for_question( self, nl_question: NLQuery, number_of_samples: int = 3 - ) -> List[dict] | None: + ) -> Tuple[List[dict] | None, List[dict] | None]: logger.info(f"Getting context for {nl_question.question}") closest_questions = self.vector_store.query( query_texts=[nl_question.question], @@ -41,9 +42,21 @@ def retrieve_context_for_question( } ) if len(samples) == 0: - return None + samples = None + instructions = [] + instruction_repository = InstructionRepository(self.db) + all_instructions = instruction_repository.find_all() + for instruction in all_instructions: + if instruction.db_connection_id == nl_question.db_connection_id: + instructions.append( + { + "instruction": instruction.instruction, + } + ) + if len(instructions) == 0: + instructions = None - return samples + return samples, instructions @override def add_golden_records( diff --git a/dataherald/db/__init__.py b/dataherald/db/__init__.py index 6a650ab1..4488ce4d 100644 --- a/dataherald/db/__init__.py +++ b/dataherald/db/__init__.py @@ -29,7 +29,14 @@ def find_by_id(self, collection: str, id: str) -> dict: pass @abstractmethod - def find(self, collection: str, query: dict, sort: list = None) -> list: + def find( + self, + collection: str, + query: dict, + sort: list = None, + page: int = 0, + limit: int = 0, + ) -> list: pass @abstractmethod diff --git a/dataherald/db/mongo.py b/dataherald/db/mongo.py index 7879d5e0..b2a7a41a 100644 --- a/dataherald/db/mongo.py +++ b/dataherald/db/mongo.py @@ -40,10 +40,21 @@ def find_by_id(self, collection: str, id: str) -> dict: return self._data_store[collection].find_one({"_id": ObjectId(id)}) @override - def find(self, collection: str, query: dict, sort: list = None) -> list: + def find( + self, + collection: str, + query: dict, + sort: list = None, + page: int = 0, + limit: int = 0, + ) -> list: + skip_count = (page - 1) * limit + cursor = self._data_store[collection].find(query) if sort: - return self._data_store[collection].find(query).sort(sort) - return self._data_store[collection].find(query) + cursor = cursor.sort(sort) + if page > 0 and limit > 0: + cursor = cursor.skip(skip_count).limit(limit) + return list(cursor) @override def find_all(self, collection: str) -> list: diff --git a/dataherald/repositories/instructions.py b/dataherald/repositories/instructions.py new file mode 100644 index 00000000..0c81e688 --- /dev/null +++ b/dataherald/repositories/instructions.py @@ -0,0 +1,52 @@ +from bson.objectid import ObjectId + +from dataherald.types import Instruction + +DB_COLLECTION = "instructions" + + +class InstructionRepository: + def __init__(self, storage): + self.storage = storage + + def insert(self, instruction: Instruction) -> Instruction: + instruction.id = str( + self.storage.insert_one(DB_COLLECTION, instruction.dict(exclude={"id"})) + ) + return instruction + + def find_one(self, query: dict) -> Instruction | None: + row = self.storage.find_one(DB_COLLECTION, query) + if not row: + return None + return Instruction(**row) + + def update(self, instruction: Instruction) -> Instruction: + self.storage.update_or_create( + DB_COLLECTION, + {"_id": ObjectId(instruction.id)}, + instruction.dict(exclude={"id"}), + ) + return instruction + + def find_by_id(self, id: str) -> Instruction | None: + row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) + if not row: + return None + return Instruction(**row) + + def find_by(self, query: dict, page: int = 1, limit: int = 10) -> list[Instruction]: + rows = self.storage.find(DB_COLLECTION, query, page=page, limit=limit) + result = [] + for row in rows: + obj = Instruction(**row) + obj.id = str(row["_id"]) + result.append(obj) + return result + + def find_all(self) -> list[Instruction]: + rows = self.storage.find_all(DB_COLLECTION) + return [Instruction(id=str(row["_id"]), **row) for row in rows] + + def delete_by_id(self, id: str) -> int: + return self.storage.delete_by_id(DB_COLLECTION, id) diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 9e67afd6..2bf76eb1 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -16,10 +16,13 @@ ExecuteTempQueryRequest, GoldenRecord, GoldenRecordRequest, + Instruction, + InstructionRequest, NLQueryResponse, QuestionRequest, ScannerRequest, TableDescriptionRequest, + UpdateInstruction, UpdateQueryRequest, ) @@ -134,6 +137,34 @@ def __init__(self, settings: Settings): tags=["SQL queries"], ) + self.router.add_api_route( + "/api/v1/instructions", + self.add_instruction, + methods=["POST"], + tags=["Instructions"], + ) + + self.router.add_api_route( + "/api/v1/instructions", + self.get_instructions, + methods=["GET"], + tags=["Instructions"], + ) + + self.router.add_api_route( + "/api/v1/instructions/{instruction_id}", + self.delete_instruction, + methods=["DELETE"], + tags=["Instructions"], + ) + + self.router.add_api_route( + "/api/v1/instructions/{instruction_id}", + self.update_instruction, + methods=["PUT"], + tags=["Instructions"], + ) + self.router.add_api_route( "/api/v1/heartbeat", self.heartbeat, methods=["GET"], tags=["System"] ) @@ -229,3 +260,32 @@ def add_golden_records( def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecord]: """Gets golden records""" return self._api.get_golden_records(page, limit) + + def add_instruction(self, instruction_request: InstructionRequest) -> Instruction: + """Adds an instruction""" + created_records = self._api.add_instruction(instruction_request) + + # Return a JSONResponse with status code 201 and the location header. + instruction_as_dict = created_records.dict() + + return JSONResponse( + content=instruction_as_dict, status_code=status.HTTP_201_CREATED + ) + + def get_instructions( + self, db_connection_id: str = "", page: int = 1, limit: int = 10 + ) -> List[Instruction]: + """Gets instructions""" + return self._api.get_instructions(db_connection_id, page, limit) + + def delete_instruction(self, instruction_id: str) -> dict: + """Deletes an instruction""" + return self._api.delete_instruction(instruction_id) + + def update_instruction( + self, + instruction_id: str, + instruction_request: UpdateInstruction, + ) -> Instruction: + """Updates an instruction""" + return self._api.update_instruction(instruction_id, instruction_request) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 8009e31a..f4ab91d3 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -55,7 +55,8 @@ 4) Use the db_relevant_columns_info tool to gather more information about the possibly relevant columns, filtering them to find the relevant ones. 5) [Optional based on the question] Use the get_current_datetime tool if the question has any mentions of time or dates. 6) [Optional based on the question] Always use the db_column_entity_checker tool to make sure that relevant columns have the cell-values. -7) Write a {dialect} query and use sql_db_query tool the Execute the SQL query on the database to obtain the results. +7) Use the get_admin_instructions tool to retrieve the DB admin instructions before generating the SQL query. +8) Write a {dialect} query and use sql_db_query tool the Execute the SQL query on the database to obtain the results. # Some tips to always keep in mind: tip1) For complex questions that has many relevant columns and tables request for more examples of Question/SQL pairs. @@ -136,7 +137,7 @@ class Config(BaseTool.Config): class GetCurrentTimeTool(BaseSQLDatabaseTool, BaseTool): - """Tool for querying a SQL database.""" + """Tool for finding the current data and time.""" name = "get_current_datetime" description = """ @@ -187,7 +188,37 @@ async def _arun( query: str, run_manager: AsyncCallbackManagerForToolRun | None = None, ) -> str: - raise NotImplementedError("DBQueryTool does not support async") + raise NotImplementedError("QuerySQLDataBaseTool does not support async") + + +class GetUserInstructions(BaseSQLDatabaseTool, BaseTool): + """Tool for retrieving the instructions from the user""" + + name = "get_admin_instructions" + description = """ + Input: is an empty string. + Output: Database admin instructions before generating the SQL query. + The generated SQL query MUST follow the admin instructions even it contradicts with the given question. + """ + instructions: List[dict] + + @catch_exceptions() + def _run( + self, + tool_input: str = "", # noqa: ARG002 + run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 + ) -> str: + response = "Admin: All of the generated SQL queries must follow the below instructions:\n" + for instruction in self.instructions: + response += f"{instruction['instruction']}\n" + return response + + async def _arun( + self, + tool_input: str = "", # noqa: ARG002 + run_manager: AsyncCallbackManagerForToolRun | None = None, + ) -> str: + raise NotImplementedError("GetUserInstructions does not support async") class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): @@ -248,13 +279,11 @@ async def _arun( user_question: str = "", run_manager: AsyncCallbackManagerForToolRun | None = None, ) -> str: - raise NotImplementedError( - "TablesWithRelevanceScoresTool does not support async" - ) + raise NotImplementedError("TablesSQLDatabaseTool does not support async") class ColumnEntityChecker(BaseSQLDatabaseTool, BaseTool): - """Tool for getting sample rows for the given column.""" + """Tool for checking the existance of an entity inside a column.""" name = "db_column_entity_checker" description = """ @@ -318,7 +347,7 @@ async def _arun( tool_input: str, run_manager: AsyncCallbackManagerForToolRun | None = None, ) -> str: - raise NotImplementedError("ColumnsSampleRowsTool does not support async") + raise NotImplementedError("ColumnEntityChecker does not support async") class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): @@ -357,7 +386,7 @@ async def _arun( table_name: str, run_manager: AsyncCallbackManagerForToolRun | None = None, ) -> str: - raise NotImplementedError("DBRelevantTablesSchemaTool does not support async") + raise NotImplementedError("SchemaSQLDatabaseTool does not support async") class InfoRelevantColumns(BaseSQLDatabaseTool, BaseTool): @@ -379,7 +408,7 @@ def _run( column_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: - """Get the schema for tables in a comma-separated list.""" + """Get the column level information.""" items_list = column_names.split(", ") column_full_info = "" for item in items_list: @@ -460,6 +489,7 @@ class SQLDatabaseToolkit(BaseToolkit): db: SQLDatabase = Field(exclude=True) context: List[dict] | None = Field(exclude=True, default=None) few_shot_examples: List[dict] | None = Field(exclude=True, default=None) + instructions: List[dict] | None = Field(exclude=True, default=None) db_scan: List[TableSchemaDetail] = Field(exclude=True) @property @@ -477,6 +507,12 @@ def get_tools(self) -> List[BaseTool]: tools = [] query_sql_db_tool = QuerySQLDataBaseTool(db=self.db, context=self.context) tools.append(query_sql_db_tool) + if self.instructions is not None: + tools.append( + GetUserInstructions( + db=self.db, context=self.context, instructions=self.instructions + ) + ) get_current_datetime = GetCurrentTimeTool(db=self.db, context=self.context) tools.append(get_current_datetime) tables_sql_db_tool = TablesSQLDatabaseTool( @@ -528,7 +564,7 @@ def create_sql_agent( input_variables: List[str] | None = None, max_examples: int = 20, top_k: int = 13, - max_iterations: int | None = 10, + max_iterations: int | None = 15, max_execution_time: float | None = None, early_stopping_method: str = "force", verbose: bool = False, @@ -585,7 +621,7 @@ def generate_response( ) if not db_scan: raise ValueError("No scanned tables found for database") - few_shot_examples = context_store.retrieve_context_for_question( + few_shot_examples, instructions = context_store.retrieve_context_for_question( user_question, number_of_samples=self.max_number_of_examples ) if few_shot_examples is not None: @@ -600,6 +636,7 @@ def generate_response( db=self.database, context=context, few_shot_examples=new_fewshot_examples, + instructions=instructions, db_scan=db_scan, ) agent_executor = self.create_sql_agent( diff --git a/dataherald/tests/db/test_db.py b/dataherald/tests/db/test_db.py index 2c7b39c3..a8056e20 100644 --- a/dataherald/tests/db/test_db.py +++ b/dataherald/tests/db/test_db.py @@ -19,6 +19,13 @@ def __init__(self, system: System): "ssh_settings": None, } ] + self.memory["instructions"] = [ + { + "_id": "64dfa0e103f5134086f7090c", + "instruction": "foo", + "db_connection_id": "64dfa0e103f5134086f7090c", + } + ] @override def insert_one(self, collection: str, obj: dict) -> int: @@ -48,7 +55,14 @@ def find_by_id(self, collection: str, id: str) -> dict: return None @override - def find(self, collection: str, query: dict, sort: list = None) -> list: + def find( + self, + collection: str, + query: dict, + sort: list = None, + page: int = 0, + limit: int = 0, + ) -> list: return [] @override diff --git a/dataherald/types.py b/dataherald/types.py index 602f39dc..1e46a0ad 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -40,6 +40,20 @@ class NLQuery(BaseModel): db_connection_id: str +class UpdateInstruction(BaseModel): + instruction: str + + +class InstructionRequest(DBConnectionValidation): + instruction: str + + +class Instruction(BaseModel): + id: Any + instruction: str + db_connection_id: str + + class GoldenRecordRequest(DBConnectionValidation): question: str sql_query: str diff --git a/docs/api.add_instructions.rst b/docs/api.add_instructions.rst new file mode 100644 index 00000000..93fece81 --- /dev/null +++ b/docs/api.add_instructions.rst @@ -0,0 +1,47 @@ +.. _api.add_instructions: + +Set Instructions +======================= + +To return an accurate response based on our your business rules, you can set some constraints on the SQL generation process. + +Request this ``POST`` endpoint:: + + /api/v1/instructions + +**Request body** + +.. code-block:: rst + + { + "instruction": "string" + "db_connection_id": "database_connection_id" + } + +**Responses** + +HTTP 201 code response + +.. code-block:: rst + + { + "id": "instruction_id", + "instruction": "Instructions", + "db_connection_id": "database_connection_id" + } + +**Example** + +Only set a instruction for a database connection + +.. code-block:: rst + + curl -X 'POST' \ + '/api/v1/instructions' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "instruction": "string", + "db_connection_id": "database_connection_id" + }' + diff --git a/docs/api.delete_instructions.rst b/docs/api.delete_instructions.rst new file mode 100644 index 00000000..36f01265 --- /dev/null +++ b/docs/api.delete_instructions.rst @@ -0,0 +1,39 @@ +.. _api.delete_instructions: + +Delete Instructions +======================= + +If your business logic requires to delete a instruction, you can use this endpoint. + +Request this ``DELETE`` endpoint:: + + /api/v1/instructions/{instruction_id} + +** Parameters ** + +.. csv-table:: + :header: "Name", "Type", "Description" + :widths: 20, 20, 60 + + "instruction_id", "string", "Instruction id we want to delete, ``Required``" + +**Responses** + +HTTP 201 code response + +.. code-block:: rst + + { + "status": bool, + } + +**Example** + +Only set a instruction for a database connection + +.. code-block:: rst + + curl -X 'DELETE' \ + '/api/v1/instructions/{instruction_id}' \ + -H 'accept: application/json' + diff --git a/docs/api.list_instructions.rst b/docs/api.list_instructions.rst new file mode 100644 index 00000000..f6e4a06e --- /dev/null +++ b/docs/api.list_instructions.rst @@ -0,0 +1,43 @@ +.. _api.list_instructions: + +List instructions +======================= + +You can use this endpoint to retrieve a list of instructions for a database connection. + +Request this ``GET`` endpoint:: + + GET /api/v1/instructions + +** Parameters ** + +.. csv-table:: + :header: "Name", "Type", "Description" + :widths: 20, 20, 60 + + "db_connection_id", "string", "Database connection we want to get instructions, ``Optional``" + "page", "integer", "Page number, ``Optional``" + "limit", "integer", "Limit number of instructions, ``Optional``" + +**Responses** + +HTTP 201 code response + +.. code-block:: rst + + [ + { + "id": "string", + "instruction": "string", + "db_connection_id": "string" + } + ] + +**Example** + +.. code-block:: rst + + curl -X 'GET' \ + '/api/v1/instructions?page=1&limit=10&db_connection_id=12312312' \ + -H 'accept: application/json + diff --git a/docs/api.rst b/docs/api.rst index fb94ecc0..0239a804 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,7 +4,7 @@ API The Dataherald Engine exposes RESTful APIs that can be used to: * 🔌 Connect to and manage connections to databases -* 🔑 Add context to the engine through scanning the databases, adding descriptions to tables and columns and adding golden records +* 🔑 Add context to the engine through scanning the databases, adding database level instructions, adding descriptions to tables and columns and adding golden records * 🙋‍♀️ Ask natural language questions from the relational data Our APIs have resource-oriented URL built around standard HTTP response codes and verbs. The core resources are described below. @@ -92,6 +92,24 @@ Related endpoints are: "table_schema": "string" } +Database Instructions +--------------------- +The ``database-instructions`` object is used to set constraints on the SQL that is generated by the LLM. +These are then used to help the LLM build valid SQL to answer natural language questions based on your business rules. + +Related endpoints are: + +* :doc:`Add database instructions ` -- ``POST api/v1/{db_connection_id}/instructions`` +* :doc:`List database instructions ` -- ``GET api/v1/{db_connection_id}/instructions`` +* :doc:`Update database instructions ` -- ``PUT api/v1/{db_connection_id}/instructions/{instruction_id}`` +* :doc:`Delete database instructions ` -- ``DELETE api/v1/{db_connection_id}/instructions/{instruction_id}`` + +.. code-block:: json + + { + "db_connection_id": "string", + "instruction": "string", + } .. toctree:: @@ -105,6 +123,11 @@ Related endpoints are: api.add_descriptions api.list_table_description + api.add_instructions + api.list_instructions + api.update_instructions + api.delete_instructions + api.golden_record api.question diff --git a/docs/api.update_instructions.rst b/docs/api.update_instructions.rst new file mode 100644 index 00000000..2340adca --- /dev/null +++ b/docs/api.update_instructions.rst @@ -0,0 +1,54 @@ +.. _api.update_instructions: + +Update Instructions +======================= + +In order to get the best performance from the engine you should try using different instructions for each database connection. +You can update the instructions for a database connection using the ``PUT`` method. + +Request this ``PUT`` endpoint:: + + /api/v1/instructions/{instruction_id} + +** Parameters ** + +.. csv-table:: + :header: "Name", "Type", "Description" + :widths: 20, 20, 60 + + "instruction_id", "string", "Instruction id we want to update, ``Required``" + +**Request body** + +.. code-block:: rst + + { + "instruction": "string" + } + +**Responses** + +HTTP 201 code response + +.. code-block:: rst + + { + "id": "string", + "instruction": "string", + "db_connection_id": "string" + } + +**Example** + +Only set a instruction for a database connection + +.. code-block:: rst + + curl -X 'PUT' \ + '/api/v1/instructions/{instruction_id}' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "instruction": "string" + }' + diff --git a/docs/context_store.rst b/docs/context_store.rst index 5977c872..666b70b9 100644 --- a/docs/context_store.rst +++ b/docs/context_store.rst @@ -33,7 +33,7 @@ This abstract class provides a consistent interface for working with context ret :param system: The system object. :type system: System -.. py:method:: retrieve_context_for_question(self, nl_question: str, number_of_samples: int = 3) -> List[dict] | None +.. py:method:: retrieve_context_for_question(self, nl_question: str, number_of_samples: int = 3) -> Tuple[List[dict] | None, List[dict] | None] :noindex: Given a natural language question, this method retrieves a single string containing information about relevant data stores, tables, and columns necessary for building the SQL query. This information includes example questions, corresponding SQL queries, and metadata about the tables (e.g., categorical columns). The retrieved string is then passed to the text-to-SQL generator. @@ -42,8 +42,8 @@ This abstract class provides a consistent interface for working with context ret :type nl_question: str :param number_of_samples: The number of context samples to retrieve. :type number_of_samples: int - :return: A list of dictionaries containing context information for generating SQL. - :rtype: List[dict] | None + :return: A list of dictionaries containing context information for generating SQL (contain few-shot samples and instructions). + :rtype: Tuple[List[dict] | None, List[dict] | None] .. py:method:: add_golden_records(self, golden_records: List[GoldenRecordRequest]) -> List[GoldenRecord] :noindex: