From 5a55c5dc100210f741bccd32ea046011bbace073 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 26 Sep 2023 15:02:47 -0400 Subject: [PATCH 01/11] DH-4724/adding support for db level instructions --- README.md | 50 ++++++++++++++ dataherald/api/__init__.py | 27 ++++++++ dataherald/api/fastapi.py | 55 +++++++++++++++- dataherald/context_store/__init__.py | 4 +- dataherald/context_store/default.py | 20 ++++-- dataherald/repositories/instructions.py | 43 ++++++++++++ dataherald/server/fastapi/__init__.py | 66 +++++++++++++++++++ .../sql_generator/dataherald_sqlagent.py | 61 +++++++++++++---- dataherald/types.py | 10 +++ docs/api.add_instructions.rst | 53 +++++++++++++++ docs/api.delete_instructions.rst | 40 +++++++++++ docs/api.list_instructions.rst | 43 ++++++++++++ docs/api.rst | 25 ++++++- docs/api.update_instructions.rst | 55 ++++++++++++++++ docs/context_store.rst | 6 +- 15 files changed, 535 insertions(+), 23 deletions(-) create mode 100644 dataherald/repositories/instructions.py create mode 100644 docs/api.add_instructions.rst create mode 100644 docs/api.delete_instructions.rst create mode 100644 docs/api.list_instructions.rst create mode 100644 docs/api.update_instructions.rst diff --git a/README.md b/README.md index 48710dc1..02d7d231 100644 --- a/README.md +++ b/README.md @@ -317,6 +317,56 @@ curl -X 'PATCH' \ }' ``` +#### adding database level instructions + +You can add database level instructions to the context store manually from the `POST /api/v1/{db_connection_id}/instructions` endpoint +These instructions are passed directly to the engine and can be used to steer the engine to generate SQL that is more in line with your business logic. + +``` +curl -X 'POST' \ + '/api/v1/{db_connection_id}/instructions' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "instruction": "This is a database level instruction" +}' +``` + +#### getting database level instructions + +You can get database level instructions from the `GET /api/v1/{db_connection_id}/instructions` endpoint + +``` +curl -X 'GET' \ + '/api/v1/{db_connection_id}/instructions?page=1&limit=10' \ + -H 'accept: application/json' +``` + +#### deleting database level instructions + +You can delete database level instructions from the `DELETE /api/v1/{db_connection_id}/instructions/{instruction_id}` endpoint + +``` +curl -X 'DELETE' \ + '/api/v1/{db_connection_id}/instructions/{instruction_id}' \ + -H 'accept: application/json' +``` + +#### updating database level instructions + +You can update database level instructions from the `PATCH /api/v1/{db_connection_id}/instructions/{instruction_id}` endpoint +Try different instructions to see how the engine generates SQL + +``` +curl -X 'PATCH' \ + '/api/v1/{db_connection_id}/instructions/{instruction_id}' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "instruction": "This is a database level instruction" +}' +``` + ### Querying the Database in Natural Language Once you have connected the engine to your data warehouse (and preferably added some context to the store), you can query your data warehouse using the `POST /api/v1/question` endpoint. diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index b6c10f7b..48743585 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -12,6 +12,8 @@ ExecuteTempQueryRequest, GoldenRecord, GoldenRecordRequest, + Instruction, + InstructionRequest, NLQueryResponse, QuestionRequest, ScannerRequest, @@ -97,3 +99,28 @@ def delete_golden_record(self, golden_record_id: str) -> dict: @abstractmethod def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecord]: pass + + @abstractmethod + def add_instruction( + self, db_connection_id: str, instruction_request: InstructionRequest + ) -> Instruction: + pass + + @abstractmethod + def get_instructions( + self, db_connection_id: str, page: int = 1, limit: int = 10 + ) -> List[Instruction]: + pass + + @abstractmethod + def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict: + pass + + @abstractmethod + def update_instruction( + self, + db_connection_id: str, + instruction_id: str, + instruction_request: InstructionRequest, + ) -> Instruction: + pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 21bfb84e..aeec48af 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -19,6 +19,7 @@ from dataherald.repositories.base import NLQueryResponseRepository from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.repositories.golden_records import GoldenRecordRepository +from dataherald.repositories.instructions import InstructionRepository from dataherald.repositories.nl_question import NLQuestionRepository from dataherald.sql_database.base import ( InvalidDBConnectionError, @@ -33,6 +34,8 @@ ExecuteTempQueryRequest, GoldenRecord, GoldenRecordRequest, + Instruction, + InstructionRequest, NLQuery, NLQueryResponse, QuestionRequest, @@ -117,7 +120,7 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse: start_generated_answer = time.time() try: generated_answer = sql_generation.generate_response( - user_question, database_connection, context + user_question, database_connection, context[0] ) logger.info("Starts evaluator...") confidence_score = evaluator.get_confidence_score( @@ -312,3 +315,53 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor start_idx = (page - 1) * limit end_idx = start_idx + limit return all_records[start_idx:end_idx] + + @override + def add_instruction( + self, db_connection_id: str, instruction_request: InstructionRequest + ) -> Instruction: + instruction_repository = InstructionRepository(self.storage) + instruction = Instruction( + instruction=instruction_request.instruction, + db_connection_id=db_connection_id, + ) + return instruction_repository.insert(instruction) + + @override + def get_instructions( + self, db_connection_id: str, page: int = 1, limit: int = 10 + ) -> List[Instruction]: + instruction_repository = InstructionRepository(self.storage) + instructions = instruction_repository.find_all() + filtered_instructions = [] + for instruction in instructions: + if instruction.db_connection_id == db_connection_id: + filtered_instructions.append(instruction) + start_idx = (page - 1) * limit + end_idx = start_idx + limit + return filtered_instructions[start_idx:end_idx] + + @override + def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict: + instruction_repository = InstructionRepository(self.storage) + instruction = instruction_repository.find_by_id(instruction_id) + if instruction.db_connection_id != db_connection_id: + raise HTTPException(status_code=404, detail="Instruction not found") + instruction_repository.delete_by_id(instruction_id) + return {"status": "success"} + + @override + def update_instruction( + self, + db_connection_id: str, + instruction_id: str, + instruction_request: InstructionRequest, + ) -> Instruction: + instruction_repository = InstructionRepository(self.storage) + instruction = Instruction( + id=instruction_id, + instruction=instruction_request.instruction, + db_connection_id=db_connection_id, + ) + instruction_repository.update(instruction) + return json.loads(json_util.dumps(instruction)) diff --git a/dataherald/context_store/__init__.py b/dataherald/context_store/__init__.py index eb327cd7..693b5d6e 100644 --- a/dataherald/context_store/__init__.py +++ b/dataherald/context_store/__init__.py @@ -1,6 +1,6 @@ import os from abc import ABC, abstractmethod -from typing import Any, List +from typing import List, Tuple from dataherald.config import Component, System from dataherald.db import DB @@ -25,7 +25,7 @@ def __init__(self, system: System): @abstractmethod def retrieve_context_for_question( self, nl_question: NLQuery, number_of_samples: int = 3 - ) -> List[dict] | None: + ) -> Tuple[List[dict] | None, List[dict] | None]: pass @abstractmethod diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index b73de25f..42e7a40a 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Tuple from overrides import override from sql_metadata import Parser @@ -7,6 +7,7 @@ from dataherald.config import System from dataherald.context_store import ContextStore from dataherald.repositories.golden_records import GoldenRecordRepository +from dataherald.repositories.instructions import InstructionRepository from dataherald.types import GoldenRecord, GoldenRecordRequest, NLQuery logger = logging.getLogger(__name__) @@ -19,7 +20,7 @@ def __init__(self, system: System): @override def retrieve_context_for_question( self, nl_question: NLQuery, number_of_samples: int = 3 - ) -> List[dict] | None: + ) -> Tuple[List[dict] | None, List[dict] | None]: logger.info(f"Getting context for {nl_question.question}") closest_questions = self.vector_store.query( query_texts=[nl_question.question], @@ -41,9 +42,20 @@ def retrieve_context_for_question( } ) if len(samples) == 0: - return None + samples = None + instructions = [] + instruction_repository = InstructionRepository(self.db) + for instruction in instruction_repository.find_all(): + if instruction.db_connection_id == nl_question.db_connection_id: + instructions.append( + { + "instruction": instruction.instruction, + } + ) + if len(instructions) == 0: + instructions = None - return samples + return samples, instructions @override def add_golden_records( diff --git a/dataherald/repositories/instructions.py b/dataherald/repositories/instructions.py new file mode 100644 index 00000000..036fe936 --- /dev/null +++ b/dataherald/repositories/instructions.py @@ -0,0 +1,43 @@ +from bson.objectid import ObjectId + +from dataherald.types import Instruction + +DB_COLLECTION = "instructions" + + +class InstructionRepository: + def __init__(self, storage): + self.storage = storage + + def insert(self, instruction: Instruction) -> Instruction: + instruction.id = str( + self.storage.insert_one(DB_COLLECTION, instruction.dict(exclude={"id"})) + ) + return instruction + + def find_one(self, query: dict) -> Instruction | None: + row = self.storage.find_one(DB_COLLECTION, query) + if not row: + return None + return Instruction(**row) + + def update(self, instruction: Instruction) -> Instruction: + self.storage.update_or_create( + DB_COLLECTION, + {"_id": ObjectId(instruction.id)}, + instruction.dict(exclude={"id"}), + ) + return instruction + + def find_by_id(self, id: str) -> Instruction | None: + row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) + if not row: + return None + return Instruction(**row) + + def find_all(self) -> list[Instruction]: + rows = self.storage.find_all(DB_COLLECTION) + return [Instruction(id=str(row["_id"]), **row) for row in rows] + + def delete_by_id(self, id: str) -> int: + return self.storage.delete_by_id(DB_COLLECTION, id) diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 9e67afd6..3298bfb2 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -16,6 +16,8 @@ ExecuteTempQueryRequest, GoldenRecord, GoldenRecordRequest, + Instruction, + InstructionRequest, NLQueryResponse, QuestionRequest, ScannerRequest, @@ -134,6 +136,34 @@ def __init__(self, settings: Settings): tags=["SQL queries"], ) + self.router.add_api_route( + "/api/v1/{db_connection_id}/instructions", + self.add_instruction, + methods=["POST"], + tags=["Instructions"], + ) + + self.router.add_api_route( + "/api/v1/{db_connection_id}/instructions", + self.get_instructions, + methods=["GET"], + tags=["Instructions"], + ) + + self.router.add_api_route( + "/api/v1/{db_connection_id}/instructions/{instruction_id}", + self.delete_instruction, + methods=["DELETE"], + tags=["Instructions"], + ) + + self.router.add_api_route( + "/api/v1/{db_connection_id}/instructions/{instruction_id}", + self.update_instruction, + methods=["PATCH"], + tags=["Instructions"], + ) + self.router.add_api_route( "/api/v1/heartbeat", self.heartbeat, methods=["GET"], tags=["System"] ) @@ -229,3 +259,39 @@ def add_golden_records( def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecord]: """Gets golden records""" return self._api.get_golden_records(page, limit) + + def add_instruction( + self, db_connection_id: str, instruction_request: InstructionRequest + ) -> Instruction: + """Adds an instruction""" + created_records = self._api.add_instruction( + db_connection_id, instruction_request + ) + + # Return a JSONResponse with status code 201 and the location header. + instruction_as_dict = created_records.dict() + + return JSONResponse( + content=instruction_as_dict, status_code=status.HTTP_201_CREATED + ) + + def get_instructions( + self, db_connection_id: str, page: int = 1, limit: int = 10 + ) -> List[Instruction]: + """Gets instructions""" + return self._api.get_instructions(db_connection_id, page, limit) + + def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict: + """Deletes an instruction""" + return self._api.delete_instruction(db_connection_id, instruction_id) + + def update_instruction( + self, + db_connection_id: str, + instruction_id: str, + instruction_request: InstructionRequest, + ) -> Instruction: + """Updates an instruction""" + return self._api.update_instruction( + db_connection_id, instruction_id, instruction_request + ) diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 8009e31a..f4ab91d3 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -55,7 +55,8 @@ 4) Use the db_relevant_columns_info tool to gather more information about the possibly relevant columns, filtering them to find the relevant ones. 5) [Optional based on the question] Use the get_current_datetime tool if the question has any mentions of time or dates. 6) [Optional based on the question] Always use the db_column_entity_checker tool to make sure that relevant columns have the cell-values. -7) Write a {dialect} query and use sql_db_query tool the Execute the SQL query on the database to obtain the results. +7) Use the get_admin_instructions tool to retrieve the DB admin instructions before generating the SQL query. +8) Write a {dialect} query and use sql_db_query tool the Execute the SQL query on the database to obtain the results. # Some tips to always keep in mind: tip1) For complex questions that has many relevant columns and tables request for more examples of Question/SQL pairs. @@ -136,7 +137,7 @@ class Config(BaseTool.Config): class GetCurrentTimeTool(BaseSQLDatabaseTool, BaseTool): - """Tool for querying a SQL database.""" + """Tool for finding the current data and time.""" name = "get_current_datetime" description = """ @@ -187,7 +188,37 @@ async def _arun( query: str, run_manager: AsyncCallbackManagerForToolRun | None = None, ) -> str: - raise NotImplementedError("DBQueryTool does not support async") + raise NotImplementedError("QuerySQLDataBaseTool does not support async") + + +class GetUserInstructions(BaseSQLDatabaseTool, BaseTool): + """Tool for retrieving the instructions from the user""" + + name = "get_admin_instructions" + description = """ + Input: is an empty string. + Output: Database admin instructions before generating the SQL query. + The generated SQL query MUST follow the admin instructions even it contradicts with the given question. + """ + instructions: List[dict] + + @catch_exceptions() + def _run( + self, + tool_input: str = "", # noqa: ARG002 + run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 + ) -> str: + response = "Admin: All of the generated SQL queries must follow the below instructions:\n" + for instruction in self.instructions: + response += f"{instruction['instruction']}\n" + return response + + async def _arun( + self, + tool_input: str = "", # noqa: ARG002 + run_manager: AsyncCallbackManagerForToolRun | None = None, + ) -> str: + raise NotImplementedError("GetUserInstructions does not support async") class TablesSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): @@ -248,13 +279,11 @@ async def _arun( user_question: str = "", run_manager: AsyncCallbackManagerForToolRun | None = None, ) -> str: - raise NotImplementedError( - "TablesWithRelevanceScoresTool does not support async" - ) + raise NotImplementedError("TablesSQLDatabaseTool does not support async") class ColumnEntityChecker(BaseSQLDatabaseTool, BaseTool): - """Tool for getting sample rows for the given column.""" + """Tool for checking the existance of an entity inside a column.""" name = "db_column_entity_checker" description = """ @@ -318,7 +347,7 @@ async def _arun( tool_input: str, run_manager: AsyncCallbackManagerForToolRun | None = None, ) -> str: - raise NotImplementedError("ColumnsSampleRowsTool does not support async") + raise NotImplementedError("ColumnEntityChecker does not support async") class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): @@ -357,7 +386,7 @@ async def _arun( table_name: str, run_manager: AsyncCallbackManagerForToolRun | None = None, ) -> str: - raise NotImplementedError("DBRelevantTablesSchemaTool does not support async") + raise NotImplementedError("SchemaSQLDatabaseTool does not support async") class InfoRelevantColumns(BaseSQLDatabaseTool, BaseTool): @@ -379,7 +408,7 @@ def _run( column_names: str, run_manager: CallbackManagerForToolRun | None = None, # noqa: ARG002 ) -> str: - """Get the schema for tables in a comma-separated list.""" + """Get the column level information.""" items_list = column_names.split(", ") column_full_info = "" for item in items_list: @@ -460,6 +489,7 @@ class SQLDatabaseToolkit(BaseToolkit): db: SQLDatabase = Field(exclude=True) context: List[dict] | None = Field(exclude=True, default=None) few_shot_examples: List[dict] | None = Field(exclude=True, default=None) + instructions: List[dict] | None = Field(exclude=True, default=None) db_scan: List[TableSchemaDetail] = Field(exclude=True) @property @@ -477,6 +507,12 @@ def get_tools(self) -> List[BaseTool]: tools = [] query_sql_db_tool = QuerySQLDataBaseTool(db=self.db, context=self.context) tools.append(query_sql_db_tool) + if self.instructions is not None: + tools.append( + GetUserInstructions( + db=self.db, context=self.context, instructions=self.instructions + ) + ) get_current_datetime = GetCurrentTimeTool(db=self.db, context=self.context) tools.append(get_current_datetime) tables_sql_db_tool = TablesSQLDatabaseTool( @@ -528,7 +564,7 @@ def create_sql_agent( input_variables: List[str] | None = None, max_examples: int = 20, top_k: int = 13, - max_iterations: int | None = 10, + max_iterations: int | None = 15, max_execution_time: float | None = None, early_stopping_method: str = "force", verbose: bool = False, @@ -585,7 +621,7 @@ def generate_response( ) if not db_scan: raise ValueError("No scanned tables found for database") - few_shot_examples = context_store.retrieve_context_for_question( + few_shot_examples, instructions = context_store.retrieve_context_for_question( user_question, number_of_samples=self.max_number_of_examples ) if few_shot_examples is not None: @@ -600,6 +636,7 @@ def generate_response( db=self.database, context=context, few_shot_examples=new_fewshot_examples, + instructions=instructions, db_scan=db_scan, ) agent_executor = self.create_sql_agent( diff --git a/dataherald/types.py b/dataherald/types.py index 602f39dc..ef8a6f03 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -40,6 +40,16 @@ class NLQuery(BaseModel): db_connection_id: str +class InstructionRequest(BaseModel): + instruction: str + + +class Instruction(BaseModel): + id: Any + instruction: str + db_connection_id: str + + class GoldenRecordRequest(DBConnectionValidation): question: str sql_query: str diff --git a/docs/api.add_instructions.rst b/docs/api.add_instructions.rst new file mode 100644 index 00000000..0330308c --- /dev/null +++ b/docs/api.add_instructions.rst @@ -0,0 +1,53 @@ +.. _api.add_instructions: + +Set Instructions +======================= + +To return an accurate response based on our your business rules, you can set some constraints on the SQL generation process. + +Request this ``POST`` endpoint:: + + /api/v1/{db_connection_id}/instructions + +** Parameters ** + +.. csv-table:: + :header: "Name", "Type", "Description" + :widths: 20, 20, 60 + + "db_connection_id", "string", "Database connection we want to add instructions, ``Required``" + +**Request body** + +.. code-block:: rst + + { + "instruction": "string" + } + +**Responses** + +HTTP 201 code response + +.. code-block:: rst + + { + "id": "instruction_id", + "instruction": "Instructions", + "db_connection_id": "database_connection_id" + } + +**Example** + +Only set a instruction for a database connection + +.. code-block:: rst + + curl -X 'POST' \ + '/api/v1/{db_connection_id}/instructions' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "instruction": "string", + }' + diff --git a/docs/api.delete_instructions.rst b/docs/api.delete_instructions.rst new file mode 100644 index 00000000..f868b522 --- /dev/null +++ b/docs/api.delete_instructions.rst @@ -0,0 +1,40 @@ +.. _api.delete_instructions: + +Delete Instructions +======================= + +If your business logic requires to delete a instruction, you can use this endpoint. + +Request this ``DELETE`` endpoint:: + + /api/v1/{db_connection_id}/instructions/{instruction_id} + +** Parameters ** + +.. csv-table:: + :header: "Name", "Type", "Description" + :widths: 20, 20, 60 + + "db_connection_id", "string", "Database connection we want to delete the instructions, ``Required``" + "instruction_id", "string", "Instruction id we want to delete, ``Required``" + +**Responses** + +HTTP 201 code response + +.. code-block:: rst + + { + "status": bool, + } + +**Example** + +Only set a instruction for a database connection + +.. code-block:: rst + + curl -X 'DELETE' \ + 'http://localhost/api/v1/{db_connection_id}/instructions/{instruction_id}' \ + -H 'accept: application/json' + diff --git a/docs/api.list_instructions.rst b/docs/api.list_instructions.rst new file mode 100644 index 00000000..23a67c5b --- /dev/null +++ b/docs/api.list_instructions.rst @@ -0,0 +1,43 @@ +.. _api.list_instructions: + +List instructions +======================= + +You can use this endpoint to retrieve a list of instructions for a database connection. + +Request this ``GET`` endpoint:: + + GET /api/v1/{db_connection_id}/instructions + +** Parameters ** + +.. csv-table:: + :header: "Name", "Type", "Description" + :widths: 20, 20, 60 + + "db_connection_id", "string", "Database connection we want to get instructions, ``Required``" + "page", "integer", "Page number, ``Optional``" + "limit", "integer", "Limit number of instructions, ``Optional``" + +**Responses** + +HTTP 201 code response + +.. code-block:: rst + + [ + { + "id": "string", + "instruction": "string", + "db_connection_id": "string" + } + ] + +**Example** + +.. code-block:: rst + + curl -X 'GET' \ + '/api/v1/{db_connection_id}/instructions?page=1&limit=10' \ + -H 'accept: application/json + diff --git a/docs/api.rst b/docs/api.rst index fb94ecc0..0239a804 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,7 +4,7 @@ API The Dataherald Engine exposes RESTful APIs that can be used to: * 🔌 Connect to and manage connections to databases -* 🔑 Add context to the engine through scanning the databases, adding descriptions to tables and columns and adding golden records +* 🔑 Add context to the engine through scanning the databases, adding database level instructions, adding descriptions to tables and columns and adding golden records * 🙋‍♀️ Ask natural language questions from the relational data Our APIs have resource-oriented URL built around standard HTTP response codes and verbs. The core resources are described below. @@ -92,6 +92,24 @@ Related endpoints are: "table_schema": "string" } +Database Instructions +--------------------- +The ``database-instructions`` object is used to set constraints on the SQL that is generated by the LLM. +These are then used to help the LLM build valid SQL to answer natural language questions based on your business rules. + +Related endpoints are: + +* :doc:`Add database instructions ` -- ``POST api/v1/{db_connection_id}/instructions`` +* :doc:`List database instructions ` -- ``GET api/v1/{db_connection_id}/instructions`` +* :doc:`Update database instructions ` -- ``PUT api/v1/{db_connection_id}/instructions/{instruction_id}`` +* :doc:`Delete database instructions ` -- ``DELETE api/v1/{db_connection_id}/instructions/{instruction_id}`` + +.. code-block:: json + + { + "db_connection_id": "string", + "instruction": "string", + } .. toctree:: @@ -105,6 +123,11 @@ Related endpoints are: api.add_descriptions api.list_table_description + api.add_instructions + api.list_instructions + api.update_instructions + api.delete_instructions + api.golden_record api.question diff --git a/docs/api.update_instructions.rst b/docs/api.update_instructions.rst new file mode 100644 index 00000000..b0351902 --- /dev/null +++ b/docs/api.update_instructions.rst @@ -0,0 +1,55 @@ +.. _api.update_instructions: + +Update Instructions +======================= + +In order to get the best performance from the engine you should try using different instructions for each database connection. +You can update the instructions for a database connection using the ``PATCH`` method. + +Request this ``PATCH`` endpoint:: + + /api/v1/{db_connection_id}/instructions/{instruction_id} + +** Parameters ** + +.. csv-table:: + :header: "Name", "Type", "Description" + :widths: 20, 20, 60 + + "db_connection_id", "string", "Database connection we want to update the instructions, ``Required``" + "instruction_id", "string", "Instruction id we want to update, ``Required``" + +**Request body** + +.. code-block:: rst + + { + "instruction": "string" + } + +**Responses** + +HTTP 201 code response + +.. code-block:: rst + + { + "id": "string", + "instruction": "string", + "db_connection_id": "string" + } + +**Example** + +Only set a instruction for a database connection + +.. code-block:: rst + + curl -X 'PATCH' \ + '/api/v1/{db_connection_id}/instructions/{instruction_id}' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "instruction": "string" + }' + diff --git a/docs/context_store.rst b/docs/context_store.rst index 5977c872..666b70b9 100644 --- a/docs/context_store.rst +++ b/docs/context_store.rst @@ -33,7 +33,7 @@ This abstract class provides a consistent interface for working with context ret :param system: The system object. :type system: System -.. py:method:: retrieve_context_for_question(self, nl_question: str, number_of_samples: int = 3) -> List[dict] | None +.. py:method:: retrieve_context_for_question(self, nl_question: str, number_of_samples: int = 3) -> Tuple[List[dict] | None, List[dict] | None] :noindex: Given a natural language question, this method retrieves a single string containing information about relevant data stores, tables, and columns necessary for building the SQL query. This information includes example questions, corresponding SQL queries, and metadata about the tables (e.g., categorical columns). The retrieved string is then passed to the text-to-SQL generator. @@ -42,8 +42,8 @@ This abstract class provides a consistent interface for working with context ret :type nl_question: str :param number_of_samples: The number of context samples to retrieve. :type number_of_samples: int - :return: A list of dictionaries containing context information for generating SQL. - :rtype: List[dict] | None + :return: A list of dictionaries containing context information for generating SQL (contain few-shot samples and instructions). + :rtype: Tuple[List[dict] | None, List[dict] | None] .. py:method:: add_golden_records(self, golden_records: List[GoldenRecordRequest]) -> List[GoldenRecord] :noindex: From 08f5244ba3222dde3aea04c5fc725c2615df3e0c Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 26 Sep 2023 15:17:26 -0400 Subject: [PATCH 02/11] DH-4724/chanign the database tests --- dataherald/tests/db/test_db.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/dataherald/tests/db/test_db.py b/dataherald/tests/db/test_db.py index 2c7b39c3..29b579a2 100644 --- a/dataherald/tests/db/test_db.py +++ b/dataherald/tests/db/test_db.py @@ -19,6 +19,13 @@ def __init__(self, system: System): "ssh_settings": None, } ] + self.memory["instructions"] = [ + { + "_id": "64dfa0e103f5134086f7090c", + "instructions": "foo", + "db_connection_id": "64dfa0e103f5134086f7090c", + } + ] @override def insert_one(self, collection: str, obj: dict) -> int: From 08fbb3c83d9f0a3960f36cf038bdbd3a5b6e0916 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 26 Sep 2023 15:23:44 -0400 Subject: [PATCH 03/11] DH-4724/changing the find_all() function --- dataherald/context_store/default.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index 42e7a40a..453d34b8 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -45,7 +45,8 @@ def retrieve_context_for_question( samples = None instructions = [] instruction_repository = InstructionRepository(self.db) - for instruction in instruction_repository.find_all(): + all_instructions = instruction_repository.find_all() + for instruction in all_instructions: if instruction.db_connection_id == nl_question.db_connection_id: instructions.append( { From 659aae9690499d00998fafb20ca578a7730c6b48 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Tue, 26 Sep 2023 15:29:59 -0400 Subject: [PATCH 04/11] DH-4724/changing the tests --- dataherald/tests/db/test_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataherald/tests/db/test_db.py b/dataherald/tests/db/test_db.py index 29b579a2..53d8d7fd 100644 --- a/dataherald/tests/db/test_db.py +++ b/dataherald/tests/db/test_db.py @@ -22,7 +22,7 @@ def __init__(self, system: System): self.memory["instructions"] = [ { "_id": "64dfa0e103f5134086f7090c", - "instructions": "foo", + "instruction": "foo", "db_connection_id": "64dfa0e103f5134086f7090c", } ] From e8e445ab10b1a4e9b382e825b3b55659884478e4 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Wed, 27 Sep 2023 13:59:59 -0400 Subject: [PATCH 05/11] DH-4724/adding find_by method to instructionRepo --- dataherald/api/fastapi.py | 8 ++------ dataherald/repositories/instructions.py | 9 +++++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index aeec48af..0150dfcf 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -332,14 +332,10 @@ def get_instructions( self, db_connection_id: str, page: int = 1, limit: int = 10 ) -> List[Instruction]: instruction_repository = InstructionRepository(self.storage) - instructions = instruction_repository.find_all() - filtered_instructions = [] - for instruction in instructions: - if instruction.db_connection_id == db_connection_id: - filtered_instructions.append(instruction) + instructions = instruction_repository.find_by({"db_connection_id": db_connection_id}) start_idx = (page - 1) * limit end_idx = start_idx + limit - return filtered_instructions[start_idx:end_idx] + return instructions[start_idx:end_idx] @override def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict: diff --git a/dataherald/repositories/instructions.py b/dataherald/repositories/instructions.py index 036fe936..f3dfdfe7 100644 --- a/dataherald/repositories/instructions.py +++ b/dataherald/repositories/instructions.py @@ -35,6 +35,15 @@ def find_by_id(self, id: str) -> Instruction | None: return None return Instruction(**row) + def find_by(self, query: dict) -> list[Instruction]: + rows = self.storage.find(DB_COLLECTION, query) + result = [] + for row in rows: + obj = Instruction(**row) + obj.id = str(row["_id"]) + result.append(obj) + return result + def find_all(self) -> list[Instruction]: rows = self.storage.find_all(DB_COLLECTION) return [Instruction(id=str(row["_id"]), **row) for row in rows] From e4a9b3a144eeec32960b8ea71c7b8a89e77abbde Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Wed, 27 Sep 2023 14:01:21 -0400 Subject: [PATCH 06/11] DH-4724/reformat --- dataherald/api/fastapi.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 0150dfcf..feacb949 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -332,7 +332,9 @@ def get_instructions( self, db_connection_id: str, page: int = 1, limit: int = 10 ) -> List[Instruction]: instruction_repository = InstructionRepository(self.storage) - instructions = instruction_repository.find_by({"db_connection_id": db_connection_id}) + instructions = instruction_repository.find_by( + {"db_connection_id": db_connection_id} + ) start_idx = (page - 1) * limit end_idx = start_idx + limit return instructions[start_idx:end_idx] From 14ed35e1937229882b899cf28fb820bd58dedb11 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Wed, 27 Sep 2023 14:14:19 -0400 Subject: [PATCH 07/11] DH-4724/add paginating to the db --- dataherald/api/fastapi.py | 9 ++++----- dataherald/db/__init__.py | 2 +- dataherald/db/mongo.py | 10 +++++++--- dataherald/repositories/instructions.py | 4 ++-- dataherald/tests/db/test_db.py | 2 +- 5 files changed, 15 insertions(+), 12 deletions(-) diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index feacb949..aca19c31 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -332,12 +332,11 @@ def get_instructions( self, db_connection_id: str, page: int = 1, limit: int = 10 ) -> List[Instruction]: instruction_repository = InstructionRepository(self.storage) - instructions = instruction_repository.find_by( - {"db_connection_id": db_connection_id} + return instruction_repository.find_by( + {"db_connection_id": db_connection_id}, + page=page, + limit=limit, ) - start_idx = (page - 1) * limit - end_idx = start_idx + limit - return instructions[start_idx:end_idx] @override def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict: diff --git a/dataherald/db/__init__.py b/dataherald/db/__init__.py index 6a650ab1..deb82185 100644 --- a/dataherald/db/__init__.py +++ b/dataherald/db/__init__.py @@ -29,7 +29,7 @@ def find_by_id(self, collection: str, id: str) -> dict: pass @abstractmethod - def find(self, collection: str, query: dict, sort: list = None) -> list: + def find(self, collection: str, query: dict, sort: list = None, page: int = 0, limit: int = 0) -> list: pass @abstractmethod diff --git a/dataherald/db/mongo.py b/dataherald/db/mongo.py index 7879d5e0..5508b5e8 100644 --- a/dataherald/db/mongo.py +++ b/dataherald/db/mongo.py @@ -40,10 +40,14 @@ def find_by_id(self, collection: str, id: str) -> dict: return self._data_store[collection].find_one({"_id": ObjectId(id)}) @override - def find(self, collection: str, query: dict, sort: list = None) -> list: + def find(self, collection: str, query: dict, sort: list = None, page: int = 0, limit: int = 0) -> list: + skip_count = (page - 1) * limit + cursor = self._data_store[collection].find(query) if sort: - return self._data_store[collection].find(query).sort(sort) - return self._data_store[collection].find(query) + cursor = cursor.sort(sort) + if page > 0 and limit > 0: + cursor = cursor.skip(skip_count).limit(limit) + return list(cursor) @override def find_all(self, collection: str) -> list: diff --git a/dataherald/repositories/instructions.py b/dataherald/repositories/instructions.py index f3dfdfe7..0c81e688 100644 --- a/dataherald/repositories/instructions.py +++ b/dataherald/repositories/instructions.py @@ -35,8 +35,8 @@ def find_by_id(self, id: str) -> Instruction | None: return None return Instruction(**row) - def find_by(self, query: dict) -> list[Instruction]: - rows = self.storage.find(DB_COLLECTION, query) + def find_by(self, query: dict, page: int = 1, limit: int = 10) -> list[Instruction]: + rows = self.storage.find(DB_COLLECTION, query, page=page, limit=limit) result = [] for row in rows: obj = Instruction(**row) diff --git a/dataherald/tests/db/test_db.py b/dataherald/tests/db/test_db.py index 53d8d7fd..c69e0aa3 100644 --- a/dataherald/tests/db/test_db.py +++ b/dataherald/tests/db/test_db.py @@ -55,7 +55,7 @@ def find_by_id(self, collection: str, id: str) -> dict: return None @override - def find(self, collection: str, query: dict, sort: list = None) -> list: + def find(self, collection: str, query: dict, sort: list = None, page: int = 0, limit: int = 0) -> list: return [] @override From a7b86e7ed6536054e0dcdeca42435932dd9948e3 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Wed, 27 Sep 2023 14:14:48 -0400 Subject: [PATCH 08/11] DH-4724/reformat with black --- dataherald/db/__init__.py | 9 ++++++++- dataherald/db/mongo.py | 9 ++++++++- dataherald/tests/db/test_db.py | 9 ++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/dataherald/db/__init__.py b/dataherald/db/__init__.py index deb82185..4488ce4d 100644 --- a/dataherald/db/__init__.py +++ b/dataherald/db/__init__.py @@ -29,7 +29,14 @@ def find_by_id(self, collection: str, id: str) -> dict: pass @abstractmethod - def find(self, collection: str, query: dict, sort: list = None, page: int = 0, limit: int = 0) -> list: + def find( + self, + collection: str, + query: dict, + sort: list = None, + page: int = 0, + limit: int = 0, + ) -> list: pass @abstractmethod diff --git a/dataherald/db/mongo.py b/dataherald/db/mongo.py index 5508b5e8..b2a7a41a 100644 --- a/dataherald/db/mongo.py +++ b/dataherald/db/mongo.py @@ -40,7 +40,14 @@ def find_by_id(self, collection: str, id: str) -> dict: return self._data_store[collection].find_one({"_id": ObjectId(id)}) @override - def find(self, collection: str, query: dict, sort: list = None, page: int = 0, limit: int = 0) -> list: + def find( + self, + collection: str, + query: dict, + sort: list = None, + page: int = 0, + limit: int = 0, + ) -> list: skip_count = (page - 1) * limit cursor = self._data_store[collection].find(query) if sort: diff --git a/dataherald/tests/db/test_db.py b/dataherald/tests/db/test_db.py index c69e0aa3..a8056e20 100644 --- a/dataherald/tests/db/test_db.py +++ b/dataherald/tests/db/test_db.py @@ -55,7 +55,14 @@ def find_by_id(self, collection: str, id: str) -> dict: return None @override - def find(self, collection: str, query: dict, sort: list = None, page: int = 0, limit: int = 0) -> list: + def find( + self, + collection: str, + query: dict, + sort: list = None, + page: int = 0, + limit: int = 0, + ) -> list: return [] @override From c830f972b6ee2912da1382226a7ea17143543ba4 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Wed, 27 Sep 2023 16:57:54 -0400 Subject: [PATCH 09/11] DH-4724/updating the endpoints to be RESTfull --- README.md | 19 +++++++------ dataherald/api/__init__.py | 10 +++---- dataherald/api/fastapi.py | 40 +++++++++++++++------------ dataherald/server/fastapi/__init__.py | 26 ++++++++--------- dataherald/types.py | 4 ++- docs/api.add_instructions.rst | 14 +++------- docs/api.delete_instructions.rst | 7 ++--- docs/api.list_instructions.rst | 6 ++-- docs/api.update_instructions.rst | 11 ++++---- 9 files changed, 68 insertions(+), 69 deletions(-) diff --git a/README.md b/README.md index 02d7d231..3b44fca3 100644 --- a/README.md +++ b/README.md @@ -319,47 +319,48 @@ curl -X 'PATCH' \ #### adding database level instructions -You can add database level instructions to the context store manually from the `POST /api/v1/{db_connection_id}/instructions` endpoint +You can add database level instructions to the context store manually from the `POST /api/v1/instructions` endpoint These instructions are passed directly to the engine and can be used to steer the engine to generate SQL that is more in line with your business logic. ``` curl -X 'POST' \ - '/api/v1/{db_connection_id}/instructions' \ + '/api/v1/instructions' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "instruction": "This is a database level instruction" + "db_connection_id": "db_connection_id" }' ``` #### getting database level instructions -You can get database level instructions from the `GET /api/v1/{db_connection_id}/instructions` endpoint +You can get database level instructions from the `GET /api/v1/instructions` endpoint ``` curl -X 'GET' \ - '/api/v1/{db_connection_id}/instructions?page=1&limit=10' \ + '/api/v1/instructions?page=1&limit=10' \ -H 'accept: application/json' ``` #### deleting database level instructions -You can delete database level instructions from the `DELETE /api/v1/{db_connection_id}/instructions/{instruction_id}` endpoint +You can delete database level instructions from the `DELETE /api/v1/instructions/{instruction_id}` endpoint ``` curl -X 'DELETE' \ - '/api/v1/{db_connection_id}/instructions/{instruction_id}' \ + '/api/v1/instructions/{instruction_id}' \ -H 'accept: application/json' ``` #### updating database level instructions -You can update database level instructions from the `PATCH /api/v1/{db_connection_id}/instructions/{instruction_id}` endpoint +You can update database level instructions from the `PUT /api/v1/instructions/{instruction_id}` endpoint Try different instructions to see how the engine generates SQL ``` -curl -X 'PATCH' \ - '/api/v1/{db_connection_id}/instructions/{instruction_id}' \ +curl -X 'PUT' \ + '/api/v1/instructions/{instruction_id}' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index 48743585..b9da5068 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -18,6 +18,7 @@ QuestionRequest, ScannerRequest, TableDescriptionRequest, + UpdateInstruction, UpdateQueryRequest, ) @@ -102,25 +103,24 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor @abstractmethod def add_instruction( - self, db_connection_id: str, instruction_request: InstructionRequest + self, instruction_request: InstructionRequest ) -> Instruction: pass @abstractmethod def get_instructions( - self, db_connection_id: str, page: int = 1, limit: int = 10 + self, db_connection_id: str = None, page: int = 1, limit: int = 10 ) -> List[Instruction]: pass @abstractmethod - def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict: + def delete_instruction(self, instruction_id: str) -> dict: pass @abstractmethod def update_instruction( self, - db_connection_id: str, instruction_id: str, - instruction_request: InstructionRequest, + instruction_request: UpdateInstruction, ) -> Instruction: pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index aca19c31..3838e00f 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -41,6 +41,7 @@ QuestionRequest, ScannerRequest, TableDescriptionRequest, + UpdateInstruction, UpdateQueryRequest, ) @@ -318,47 +319,50 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor @override def add_instruction( - self, db_connection_id: str, instruction_request: InstructionRequest + self, instruction_request: InstructionRequest ) -> Instruction: instruction_repository = InstructionRepository(self.storage) instruction = Instruction( instruction=instruction_request.instruction, - db_connection_id=db_connection_id, + db_connection_id=instruction_request.db_connection_id, ) return instruction_repository.insert(instruction) @override def get_instructions( - self, db_connection_id: str, page: int = 1, limit: int = 10 + self, db_connection_id: str = None, page: int = 1, limit: int = 10 ) -> List[Instruction]: instruction_repository = InstructionRepository(self.storage) - return instruction_repository.find_by( - {"db_connection_id": db_connection_id}, - page=page, - limit=limit, - ) + if db_connection_id: + return instruction_repository.find_by( + {"db_connection_id": db_connection_id}, + page=page, + limit=limit, + ) + return instruction_repository.find_all() @override - def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict: + def delete_instruction(self,instruction_id: str) -> dict: instruction_repository = InstructionRepository(self.storage) - instruction = instruction_repository.find_by_id(instruction_id) - if instruction.db_connection_id != db_connection_id: + deleted = instruction_repository.delete_by_id(instruction_id) + if deleted == 0: raise HTTPException(status_code=404, detail="Instruction not found") - instruction_repository.delete_by_id(instruction_id) return {"status": "success"} @override def update_instruction( self, - db_connection_id: str, instruction_id: str, - instruction_request: InstructionRequest, + instruction_request: UpdateInstruction, ) -> Instruction: instruction_repository = InstructionRepository(self.storage) - instruction = Instruction( + instruction = instruction_repository.find_by_id(instruction_id) + if not instruction: + raise HTTPException(status_code=404, detail="Instruction not found") + updated_instruction = Instruction( id=instruction_id, instruction=instruction_request.instruction, - db_connection_id=db_connection_id, + db_connection_id=instruction.db_connection_id, ) - instruction_repository.update(instruction) - return json.loads(json_util.dumps(instruction)) + instruction_repository.update(updated_instruction) + return json.loads(json_util.dumps(updated_instruction)) diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 3298bfb2..55dffd12 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -22,6 +22,7 @@ QuestionRequest, ScannerRequest, TableDescriptionRequest, + UpdateInstruction, UpdateQueryRequest, ) @@ -137,30 +138,30 @@ def __init__(self, settings: Settings): ) self.router.add_api_route( - "/api/v1/{db_connection_id}/instructions", + "/api/v1/instructions", self.add_instruction, methods=["POST"], tags=["Instructions"], ) self.router.add_api_route( - "/api/v1/{db_connection_id}/instructions", + "/api/v1/instructions", self.get_instructions, methods=["GET"], tags=["Instructions"], ) self.router.add_api_route( - "/api/v1/{db_connection_id}/instructions/{instruction_id}", + "/api/v1/instructions/{instruction_id}", self.delete_instruction, methods=["DELETE"], tags=["Instructions"], ) self.router.add_api_route( - "/api/v1/{db_connection_id}/instructions/{instruction_id}", + "/api/v1/instructions/{instruction_id}", self.update_instruction, - methods=["PATCH"], + methods=["PUT"], tags=["Instructions"], ) @@ -261,11 +262,11 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor return self._api.get_golden_records(page, limit) def add_instruction( - self, db_connection_id: str, instruction_request: InstructionRequest + self, instruction_request: InstructionRequest ) -> Instruction: """Adds an instruction""" created_records = self._api.add_instruction( - db_connection_id, instruction_request + instruction_request ) # Return a JSONResponse with status code 201 and the location header. @@ -276,22 +277,21 @@ def add_instruction( ) def get_instructions( - self, db_connection_id: str, page: int = 1, limit: int = 10 + self, db_connection_id: str = "", page: int = 1, limit: int = 10 ) -> List[Instruction]: """Gets instructions""" return self._api.get_instructions(db_connection_id, page, limit) - def delete_instruction(self, db_connection_id: str, instruction_id: str) -> dict: + def delete_instruction(self, instruction_id: str) -> dict: """Deletes an instruction""" - return self._api.delete_instruction(db_connection_id, instruction_id) + return self._api.delete_instruction(instruction_id) def update_instruction( self, - db_connection_id: str, instruction_id: str, - instruction_request: InstructionRequest, + instruction_request: UpdateInstruction, ) -> Instruction: """Updates an instruction""" return self._api.update_instruction( - db_connection_id, instruction_id, instruction_request + instruction_id, instruction_request ) diff --git a/dataherald/types.py b/dataherald/types.py index ef8a6f03..d8bebb52 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -39,8 +39,10 @@ class NLQuery(BaseModel): question: str db_connection_id: str +class UpdateInstruction(BaseModel): + instruction: str -class InstructionRequest(BaseModel): +class InstructionRequest(DBConnectionValidation): instruction: str diff --git a/docs/api.add_instructions.rst b/docs/api.add_instructions.rst index 0330308c..93fece81 100644 --- a/docs/api.add_instructions.rst +++ b/docs/api.add_instructions.rst @@ -7,15 +7,7 @@ To return an accurate response based on our your business rules, you can set som Request this ``POST`` endpoint:: - /api/v1/{db_connection_id}/instructions - -** Parameters ** - -.. csv-table:: - :header: "Name", "Type", "Description" - :widths: 20, 20, 60 - - "db_connection_id", "string", "Database connection we want to add instructions, ``Required``" + /api/v1/instructions **Request body** @@ -23,6 +15,7 @@ Request this ``POST`` endpoint:: { "instruction": "string" + "db_connection_id": "database_connection_id" } **Responses** @@ -44,10 +37,11 @@ Only set a instruction for a database connection .. code-block:: rst curl -X 'POST' \ - '/api/v1/{db_connection_id}/instructions' \ + '/api/v1/instructions' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "instruction": "string", + "db_connection_id": "database_connection_id" }' diff --git a/docs/api.delete_instructions.rst b/docs/api.delete_instructions.rst index f868b522..36f01265 100644 --- a/docs/api.delete_instructions.rst +++ b/docs/api.delete_instructions.rst @@ -7,7 +7,7 @@ If your business logic requires to delete a instruction, you can use this endpoi Request this ``DELETE`` endpoint:: - /api/v1/{db_connection_id}/instructions/{instruction_id} + /api/v1/instructions/{instruction_id} ** Parameters ** @@ -15,8 +15,7 @@ Request this ``DELETE`` endpoint:: :header: "Name", "Type", "Description" :widths: 20, 20, 60 - "db_connection_id", "string", "Database connection we want to delete the instructions, ``Required``" - "instruction_id", "string", "Instruction id we want to delete, ``Required``" + "instruction_id", "string", "Instruction id we want to delete, ``Required``" **Responses** @@ -35,6 +34,6 @@ Only set a instruction for a database connection .. code-block:: rst curl -X 'DELETE' \ - 'http://localhost/api/v1/{db_connection_id}/instructions/{instruction_id}' \ + '/api/v1/instructions/{instruction_id}' \ -H 'accept: application/json' diff --git a/docs/api.list_instructions.rst b/docs/api.list_instructions.rst index 23a67c5b..db2c6d9a 100644 --- a/docs/api.list_instructions.rst +++ b/docs/api.list_instructions.rst @@ -7,7 +7,7 @@ You can use this endpoint to retrieve a list of instructions for a database conn Request this ``GET`` endpoint:: - GET /api/v1/{db_connection_id}/instructions + GET /api/v1/instructions ** Parameters ** @@ -15,7 +15,7 @@ Request this ``GET`` endpoint:: :header: "Name", "Type", "Description" :widths: 20, 20, 60 - "db_connection_id", "string", "Database connection we want to get instructions, ``Required``" + "db_connection_id", "string", "Database connection we want to get instructions, ``Optional``" "page", "integer", "Page number, ``Optional``" "limit", "integer", "Limit number of instructions, ``Optional``" @@ -38,6 +38,6 @@ HTTP 201 code response .. code-block:: rst curl -X 'GET' \ - '/api/v1/{db_connection_id}/instructions?page=1&limit=10' \ + '/api/v1/instructions?page=1&limit=10' \ -H 'accept: application/json diff --git a/docs/api.update_instructions.rst b/docs/api.update_instructions.rst index b0351902..2340adca 100644 --- a/docs/api.update_instructions.rst +++ b/docs/api.update_instructions.rst @@ -4,11 +4,11 @@ Update Instructions ======================= In order to get the best performance from the engine you should try using different instructions for each database connection. -You can update the instructions for a database connection using the ``PATCH`` method. +You can update the instructions for a database connection using the ``PUT`` method. -Request this ``PATCH`` endpoint:: +Request this ``PUT`` endpoint:: - /api/v1/{db_connection_id}/instructions/{instruction_id} + /api/v1/instructions/{instruction_id} ** Parameters ** @@ -16,7 +16,6 @@ Request this ``PATCH`` endpoint:: :header: "Name", "Type", "Description" :widths: 20, 20, 60 - "db_connection_id", "string", "Database connection we want to update the instructions, ``Required``" "instruction_id", "string", "Instruction id we want to update, ``Required``" **Request body** @@ -45,8 +44,8 @@ Only set a instruction for a database connection .. code-block:: rst - curl -X 'PATCH' \ - '/api/v1/{db_connection_id}/instructions/{instruction_id}' \ + curl -X 'PUT' \ + '/api/v1/instructions/{instruction_id}' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ From 463fe71d7765a46eca6cc496201efb5903f92828 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Wed, 27 Sep 2023 16:58:42 -0400 Subject: [PATCH 10/11] DH-4724/reformat with black --- dataherald/api/__init__.py | 4 +--- dataherald/api/fastapi.py | 6 ++---- dataherald/server/fastapi/__init__.py | 12 +++--------- dataherald/types.py | 2 ++ 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index b9da5068..374928c6 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -102,9 +102,7 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor pass @abstractmethod - def add_instruction( - self, instruction_request: InstructionRequest - ) -> Instruction: + def add_instruction(self, instruction_request: InstructionRequest) -> Instruction: pass @abstractmethod diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 3838e00f..15755528 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -318,9 +318,7 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor return all_records[start_idx:end_idx] @override - def add_instruction( - self, instruction_request: InstructionRequest - ) -> Instruction: + def add_instruction(self, instruction_request: InstructionRequest) -> Instruction: instruction_repository = InstructionRepository(self.storage) instruction = Instruction( instruction=instruction_request.instruction, @@ -342,7 +340,7 @@ def get_instructions( return instruction_repository.find_all() @override - def delete_instruction(self,instruction_id: str) -> dict: + def delete_instruction(self, instruction_id: str) -> dict: instruction_repository = InstructionRepository(self.storage) deleted = instruction_repository.delete_by_id(instruction_id) if deleted == 0: diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 55dffd12..2bf76eb1 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -261,13 +261,9 @@ def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecor """Gets golden records""" return self._api.get_golden_records(page, limit) - def add_instruction( - self, instruction_request: InstructionRequest - ) -> Instruction: + def add_instruction(self, instruction_request: InstructionRequest) -> Instruction: """Adds an instruction""" - created_records = self._api.add_instruction( - instruction_request - ) + created_records = self._api.add_instruction(instruction_request) # Return a JSONResponse with status code 201 and the location header. instruction_as_dict = created_records.dict() @@ -292,6 +288,4 @@ def update_instruction( instruction_request: UpdateInstruction, ) -> Instruction: """Updates an instruction""" - return self._api.update_instruction( - instruction_id, instruction_request - ) + return self._api.update_instruction(instruction_id, instruction_request) diff --git a/dataherald/types.py b/dataherald/types.py index d8bebb52..1e46a0ad 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -39,9 +39,11 @@ class NLQuery(BaseModel): question: str db_connection_id: str + class UpdateInstruction(BaseModel): instruction: str + class InstructionRequest(DBConnectionValidation): instruction: str From 520b633989a9a3dd0f52e11e5562dd142187e186 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Wed, 27 Sep 2023 17:12:52 -0400 Subject: [PATCH 11/11] DH-4724/changing the docs --- README.md | 2 +- docs/api.list_instructions.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 3b44fca3..fbb4aae0 100644 --- a/README.md +++ b/README.md @@ -339,7 +339,7 @@ You can get database level instructions from the `GET /api/v1/instructions` endp ``` curl -X 'GET' \ - '/api/v1/instructions?page=1&limit=10' \ + '/api/v1/instructions?page=1&limit=10&db_connection_id=12312312' \ -H 'accept: application/json' ``` diff --git a/docs/api.list_instructions.rst b/docs/api.list_instructions.rst index db2c6d9a..f6e4a06e 100644 --- a/docs/api.list_instructions.rst +++ b/docs/api.list_instructions.rst @@ -38,6 +38,6 @@ HTTP 201 code response .. code-block:: rst curl -X 'GET' \ - '/api/v1/instructions?page=1&limit=10' \ + '/api/v1/instructions?page=1&limit=10&db_connection_id=12312312' \ -H 'accept: application/json