diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index 5e14ff18..34693771 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -12,6 +12,7 @@ PromptSQLGenerationNLGenerationRequest, PromptSQLGenerationRequest, SQLGenerationRequest, + StreamPromptSQLGenerationRequest, UpdateMetadataRequest, ) from dataherald.api.types.responses import ( @@ -265,3 +266,10 @@ def update_nl_generation( self, nl_generation_id: str, update_metadata_request: UpdateMetadataRequest ) -> NLGenerationResponse: pass + + @abstractmethod + async def stream_create_prompt_and_sql_generation( + self, + request: StreamPromptSQLGenerationRequest, + ): + pass diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 6eaf2c3b..747b55d3 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -1,8 +1,11 @@ +import asyncio import datetime import io +import json import logging import os import time +from queue import Queue from typing import List from bson.objectid import InvalidId, ObjectId @@ -18,6 +21,7 @@ PromptSQLGenerationNLGenerationRequest, PromptSQLGenerationRequest, SQLGenerationRequest, + StreamPromptSQLGenerationRequest, UpdateMetadataRequest, ) from dataherald.api.types.responses import ( @@ -83,7 +87,7 @@ UpdateInstruction, ) from dataherald.utils.encrypt import FernetEncrypt -from dataherald.utils.error_codes import error_response +from dataherald.utils.error_codes import error_response, stream_error_response logger = logging.getLogger(__name__) @@ -884,3 +888,26 @@ def get_nl_generation(self, nl_generation_id: str) -> NLGenerationResponse: detail=f"NL Generation {nl_generation_id} not found", ) return NLGenerationResponse(**nl_generations[0].dict()) + + @override + async def stream_create_prompt_and_sql_generation( + self, + request: StreamPromptSQLGenerationRequest, + ): + try: + queue = Queue() + prompt_service = PromptService(self.storage) + prompt = prompt_service.create(request.prompt) + sql_generation_service = SQLGenerationService(self.system, self.storage) + sql_generation_service.start_streaming(prompt.id, request, queue) + while True: + value = queue.get() + if value is None: + break + yield value + queue.task_done() + await asyncio.sleep(0.001) + except Exception as e: + yield json.dumps( + stream_error_response(e, request.dict(), "nl_generation_not_created") + ) diff --git a/dataherald/api/types/requests.py b/dataherald/api/types/requests.py index 4f6a2e75..fcb90348 100644 --- a/dataherald/api/types/requests.py +++ b/dataherald/api/types/requests.py @@ -18,10 +18,21 @@ class SQLGenerationRequest(BaseModel): metadata: dict | None +class StreamSQLGenerationRequest(BaseModel): + finetuning_id: str | None + low_latency_mode: bool = False + llm_config: LLMConfig | None + metadata: dict | None + + class PromptSQLGenerationRequest(SQLGenerationRequest): prompt: PromptRequest +class StreamPromptSQLGenerationRequest(StreamSQLGenerationRequest): + prompt: PromptRequest + + class NLGenerationRequest(BaseModel): llm_config: LLMConfig | None max_rows: int = 100 diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 049d73d0..d686f990 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -16,6 +16,7 @@ PromptSQLGenerationNLGenerationRequest, PromptSQLGenerationRequest, SQLGenerationRequest, + StreamPromptSQLGenerationRequest, UpdateMetadataRequest, ) from dataherald.api.types.responses import ( @@ -356,6 +357,13 @@ def __init__(self, settings: Settings): tags=["Finetunings"], ) + self.router.add_api_route( + "/api/v1/stream-sql-generation", + self.stream_sql_generation, + methods=["POST"], + tags=["Stream SQL Generation"], + ) + self.router.add_api_route( "/api/v1/heartbeat", self.heartbeat, methods=["GET"], tags=["System"] ) @@ -601,3 +609,11 @@ def update_finetuning_job( ) -> Finetuning: """Gets fine tuning jobs""" return self._api.update_finetuning_job(finetuning_id, update_metadata_request) + + async def stream_sql_generation( + self, request: StreamPromptSQLGenerationRequest + ) -> StreamingResponse: + return StreamingResponse( + self._api.stream_create_prompt_and_sql_generation(request), + media_type="text/event-stream", + ) diff --git a/dataherald/services/sql_generations.py b/dataherald/services/sql_generations.py index fec36aac..fd690519 100644 --- a/dataherald/services/sql_generations.py +++ b/dataherald/services/sql_generations.py @@ -1,6 +1,7 @@ import os from concurrent.futures import ThreadPoolExecutor, TimeoutError from datetime import datetime +from queue import Queue import pandas as pd @@ -159,6 +160,68 @@ def create( initial_sql_generation.error = sql_generation.error return self.sql_generation_repository.update(initial_sql_generation) + def start_streaming( + self, prompt_id: str, sql_generation_request: SQLGenerationRequest, queue: Queue + ): + initial_sql_generation = SQLGeneration( + prompt_id=prompt_id, + created_at=datetime.now(), + llm_config=sql_generation_request.llm_config + if sql_generation_request.llm_config + else LLMConfig(), + metadata=sql_generation_request.metadata, + ) + self.sql_generation_repository.insert(initial_sql_generation) + prompt_repository = PromptRepository(self.storage) + prompt = prompt_repository.find_by_id(prompt_id) + if not prompt: + self.update_error(initial_sql_generation, f"Prompt {prompt_id} not found") + raise PromptNotFoundError( + f"Prompt {prompt_id} not found", initial_sql_generation.id + ) + db_connection_repository = DatabaseConnectionRepository(self.storage) + db_connection = db_connection_repository.find_by_id(prompt.db_connection_id) + if ( + sql_generation_request.finetuning_id is None + or sql_generation_request.finetuning_id == "" + ): + if sql_generation_request.low_latency_mode: + raise SQLGenerationError( + "Low latency mode is not supported for our old agent with no finetuning. Please specify a finetuning id.", + initial_sql_generation.id, + ) + sql_generator = DataheraldSQLAgent( + self.system, + sql_generation_request.llm_config + if sql_generation_request.llm_config + else LLMConfig(), + ) + else: + sql_generator = DataheraldFinetuningAgent( + self.system, + sql_generation_request.llm_config + if sql_generation_request.llm_config + else LLMConfig(), + ) + sql_generator.finetuning_id = sql_generation_request.finetuning_id + sql_generator.use_fintuned_model_only = ( + sql_generation_request.low_latency_mode + ) + initial_sql_generation.finetuning_id = sql_generation_request.finetuning_id + initial_sql_generation.low_latency_mode = ( + sql_generation_request.low_latency_mode + ) + try: + sql_generator.stream_response( + user_prompt=prompt, + database_connection=db_connection, + response=initial_sql_generation, + queue=queue, + ) + except Exception as e: + self.update_error(initial_sql_generation, str(e)) + raise SQLGenerationError(str(e), initial_sql_generation.id) from e + def get(self, query) -> list[SQLGeneration]: return self.sql_generation_repository.find_by(query) diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index 3f7a295a..461fdd00 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -1,15 +1,25 @@ """Base class that all sql generation classes inherit from.""" +import datetime +import logging import os import re from abc import ABC, abstractmethod -from typing import Any, List, Tuple +from queue import Queue +from typing import Any, Dict, List, Tuple import sqlparse -from langchain.schema import AgentAction +from langchain.agents.agent import AgentExecutor +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, LLMResult +from langchain.schema.messages import BaseMessage +from langchain_community.callbacks import get_openai_callback from dataherald.config import Component, System from dataherald.model.chat_model import ChatModel -from dataherald.sql_database.base import SQLDatabase +from dataherald.repositories.sql_generations import ( + SQLGenerationRepository, +) +from dataherald.sql_database.base import SQLDatabase, SQLInjectionError from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator.create_sql_query_status import create_sql_query_status from dataherald.types import LLMConfig, Prompt, SQLGeneration @@ -20,6 +30,12 @@ class EngineTimeOutORItemLimitError(Exception): pass +def replace_unprocessable_characters(text: str) -> str: + """Replace unprocessable characters with a space.""" + text = text.strip() + return text.replace(r"\_", "_") + + class SQLGenerator(Component, ABC): metadata: Any llm: ChatModel | None = None @@ -114,3 +130,63 @@ def generate_response( ) -> SQLGeneration: """Generates a response to a user question.""" pass + + def stream_agent_steps( # noqa: C901 + self, + question: str, + agent_executor: AgentExecutor, + response: SQLGeneration, + sql_generation_repository: SQLGenerationRepository, + queue: Queue, + ): + try: + with get_openai_callback() as cb: + for chunk in agent_executor.stream({"input": question}): + if "actions" in chunk: + for message in chunk["messages"]: + queue.put(message.content + "\n") + elif "steps" in chunk: + for step in chunk["steps"]: + queue.put(f"Observation: `{step.observation}`\n") + elif "output" in chunk: + queue.put(f'Final Answer: {chunk["output"]}') + if "```sql" in chunk["output"]: + response.sql = replace_unprocessable_characters( + self.remove_markdown(chunk["output"]) + ) + else: + raise ValueError() + except SQLInjectionError as e: + raise SQLInjectionError(e) from e + except EngineTimeOutORItemLimitError as e: + raise EngineTimeOutORItemLimitError(e) from e + except Exception as e: + response.sql = ("",) + response.status = ("INVALID",) + response.error = (str(e),) + finally: + queue.put(None) + response.tokens_used = cb.total_tokens + response.completed_at = datetime.datetime.now() + if not response.error: + if response.sql: + response = self.create_sql_query_status( + self.database, + response.sql, + response, + ) + else: + response.status = "INVALID" + response.error = "No SQL query generated" + sql_generation_repository.update(response) + + @abstractmethod + def stream_response( + self, + user_prompt: Prompt, + database_connection: DatabaseConnection, + response: SQLGeneration, + queue: Queue, + ): + """Streams a response to a user question.""" + pass diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index 1b74cac0..eb177268 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -2,6 +2,8 @@ import logging import os from functools import wraps +from queue import Queue +from threading import Thread from typing import Any, Callable, Dict, List, Type import numpy as np @@ -31,6 +33,9 @@ from dataherald.db_scanner.repository.base import TableDescriptionRepository from dataherald.finetuning.openai_finetuning import OpenAIFineTuning from dataherald.repositories.finetunings import FinetuningsRepository +from dataherald.repositories.sql_generations import ( + SQLGenerationRepository, +) from dataherald.sql_database.base import SQLDatabase, SQLInjectionError from dataherald.sql_database.models.types import ( DatabaseConnection, @@ -248,7 +253,7 @@ def _run( args=(query,), kwargs={"top_k": TOP_K}, timeout_duration=int(os.getenv("SQL_EXECUTION_TIMEOUT", "60")), - ) + )[0] except TimeoutError: return "SQL query execution time exceeded, proceed without query execution" @@ -591,3 +596,76 @@ def generate_response( response.sql, response, ) + + @override + def stream_response( + self, + user_prompt: Prompt, + database_connection: DatabaseConnection, + response: SQLGeneration, + queue: Queue, + ): + context_store = self.system.instance(ContextStore) + storage = self.system.instance(DB) + sql_generation_repository = SQLGenerationRepository(storage) + self.llm = self.model.get_model( + database_connection=database_connection, + temperature=0, + model_name=self.llm_config.llm_name, + api_base=self.llm_config.api_base, + streaming=True, + ) + repository = TableDescriptionRepository(storage) + db_scan = repository.get_all_tables_by_db( + { + "db_connection_id": str(database_connection.id), + "status": TableDescriptionStatus.SCANNED.value, + } + ) + if not db_scan: + raise ValueError("No scanned tables found for database") + _, instructions = context_store.retrieve_context_for_question( + user_prompt, number_of_samples=1 + ) + finetunings_repository = FinetuningsRepository(storage) + finetuning = finetunings_repository.find_by_id(self.finetuning_id) + openai_fine_tuning = OpenAIFineTuning(storage, finetuning) + finetuning = openai_fine_tuning.retrieve_finetuning_job() + if finetuning.status != FineTuningStatus.SUCCEEDED.value: + raise FinetuningNotAvailableError( + f"Finetuning({self.finetuning_id}) has the status {finetuning.status}." + f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries." + ) + self.database = SQLDatabase.get_sql_engine(database_connection) + toolkit = SQLDatabaseToolkit( + db=self.database, + instructions=instructions, + db_scan=db_scan, + api_key=database_connection.decrypt_api_key(), + finetuning_model_id=finetuning.model_id, + use_finetuned_model_only=self.use_fintuned_model_only, + model_name=finetuning.base_llm.model_name, + openai_fine_tuning=openai_fine_tuning, + embedding=OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ), + ) + agent_executor = self.create_sql_agent( + toolkit=toolkit, + verbose=True, + max_execution_time=int(os.environ.get("DH_ENGINE_TIMEOUT", 150)), + ) + agent_executor.return_intermediate_steps = True + agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE + thread = Thread( + target=self.stream_agent_steps, + args=( + user_prompt.text, + agent_executor, + response, + sql_generation_repository, + queue, + ), + ) + thread.start() diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 206d1275..4af5b3e9 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -3,6 +3,8 @@ import logging import os from functools import wraps +from queue import Queue +from threading import Thread from typing import Any, Callable, Dict, List import numpy as np @@ -32,6 +34,9 @@ from dataherald.db import DB from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus from dataherald.db_scanner.repository.base import TableDescriptionRepository +from dataherald.repositories.sql_generations import ( + SQLGenerationRepository, +) from dataherald.sql_database.base import SQLDatabase, SQLInjectionError from dataherald.sql_database.models.types import ( DatabaseConnection, @@ -166,7 +171,7 @@ def _run( args=(query,), kwargs={"top_k": top_k}, timeout_duration=int(os.getenv("SQL_EXECUTION_TIMEOUT", "60")), - ) + )[0] except TimeoutError: return "SQL query execution time exceeded, proceed without query execution" @@ -736,3 +741,73 @@ def generate_response( response.sql, response, ) + + @override + def stream_response( + self, + user_prompt: Prompt, + database_connection: DatabaseConnection, + response: SQLGeneration, + queue: Queue, + ): + context_store = self.system.instance(ContextStore) + storage = self.system.instance(DB) + sql_generation_repository = SQLGenerationRepository(storage) + self.llm = self.model.get_model( + database_connection=database_connection, + temperature=0, + model_name=self.llm_config.llm_name, + api_base=self.llm_config.api_base, + streaming=True, + ) + repository = TableDescriptionRepository(storage) + db_scan = repository.get_all_tables_by_db( + { + "db_connection_id": str(database_connection.id), + "status": TableDescriptionStatus.SCANNED.value, + } + ) + if not db_scan: + raise ValueError("No scanned tables found for database") + few_shot_examples, instructions = context_store.retrieve_context_for_question( + user_prompt, number_of_samples=self.max_number_of_examples + ) + if few_shot_examples is not None: + new_fewshot_examples = self.remove_duplicate_examples(few_shot_examples) + number_of_samples = len(new_fewshot_examples) + else: + new_fewshot_examples = None + number_of_samples = 0 + self.database = SQLDatabase.get_sql_engine(database_connection) + toolkit = SQLDatabaseToolkit( + queuer=queue, + db=self.database, + context=[{}], + few_shot_examples=new_fewshot_examples, + instructions=instructions, + db_scan=db_scan, + embedding=OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ), + ) + agent_executor = self.create_sql_agent( + toolkit=toolkit, + verbose=True, + max_examples=number_of_samples, + number_of_instructions=len(instructions) if instructions is not None else 0, + max_execution_time=int(os.environ.get("DH_ENGINE_TIMEOUT", 150)), + ) + agent_executor.return_intermediate_steps = True + agent_executor.handle_parsing_errors = ERROR_PARSING_MESSAGE + thread = Thread( + target=self.stream_agent_steps, + args=( + user_prompt.text, + agent_executor, + response, + sql_generation_repository, + queue, + ), + ) + thread.start() diff --git a/dataherald/utils/error_codes.py b/dataherald/utils/error_codes.py index f4c4cb7e..6680feb5 100644 --- a/dataherald/utils/error_codes.py +++ b/dataherald/utils/error_codes.py @@ -46,3 +46,20 @@ def error_response(error, detail: dict, default_error_code=""): "detail": detail, }, ) + + +def stream_error_response(error, detail: dict, default_error_code=""): + error_code = ERROR_MAPPING.get(error.__class__.__name__, default_error_code) + description = getattr(error, "description", None) + logger.error( + f"Error code: {error_code}, message: {error}, description: {description}, detail: {detail}" + ) + + detail.pop("metadata", None) + + return { + "error_code": error_code, + "message": str(error), + "description": description, + "detail": detail, + }