Skip to content

Commit

Permalink
DH-5599/adding_straming_endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Mar 14, 2024
1 parent 61a92c9 commit 68c1509
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 6 deletions.
8 changes: 8 additions & 0 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
PromptSQLGenerationNLGenerationRequest,
PromptSQLGenerationRequest,
SQLGenerationRequest,
StreamPromptSQLGenerationRequest,
UpdateMetadataRequest,
)
from dataherald.api.types.responses import (
Expand Down Expand Up @@ -265,3 +266,10 @@ def update_nl_generation(
self, nl_generation_id: str, update_metadata_request: UpdateMetadataRequest
) -> NLGenerationResponse:
pass

@abstractmethod
async def stream_create_prompt_and_sql_generation(
self,
request: StreamPromptSQLGenerationRequest,
):
pass
29 changes: 28 additions & 1 deletion dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
import datetime
import io
import json
import logging
import os
import time
from queue import Queue
from typing import List

from bson.objectid import InvalidId, ObjectId
Expand All @@ -18,6 +21,7 @@
PromptSQLGenerationNLGenerationRequest,
PromptSQLGenerationRequest,
SQLGenerationRequest,
StreamPromptSQLGenerationRequest,
UpdateMetadataRequest,
)
from dataherald.api.types.responses import (
Expand Down Expand Up @@ -83,7 +87,7 @@
UpdateInstruction,
)
from dataherald.utils.encrypt import FernetEncrypt
from dataherald.utils.error_codes import error_response
from dataherald.utils.error_codes import error_response, stream_error_response

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -884,3 +888,26 @@ def get_nl_generation(self, nl_generation_id: str) -> NLGenerationResponse:
detail=f"NL Generation {nl_generation_id} not found",
)
return NLGenerationResponse(**nl_generations[0].dict())

@override
async def stream_create_prompt_and_sql_generation(
self,
request: StreamPromptSQLGenerationRequest,
):
try:
queue = Queue()
prompt_service = PromptService(self.storage)
prompt = prompt_service.create(request.prompt)
sql_generation_service = SQLGenerationService(self.system, self.storage)
sql_generation_service.start_streaming(prompt.id, request, queue)
while True:
value = queue.get()
if value is None:
break
yield value
queue.task_done()
await asyncio.sleep(0.001)
except Exception as e:
yield json.dumps(
stream_error_response(e, request.dict(), "nl_generation_not_created")
)
11 changes: 11 additions & 0 deletions dataherald/api/types/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,21 @@ class SQLGenerationRequest(BaseModel):
metadata: dict | None


class StreamSQLGenerationRequest(BaseModel):
finetuning_id: str | None
low_latency_mode: bool = False
llm_config: LLMConfig | None
metadata: dict | None


class PromptSQLGenerationRequest(SQLGenerationRequest):
prompt: PromptRequest


class StreamPromptSQLGenerationRequest(StreamSQLGenerationRequest):
prompt: PromptRequest


class NLGenerationRequest(BaseModel):
llm_config: LLMConfig | None
max_rows: int = 100
Expand Down
16 changes: 16 additions & 0 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PromptSQLGenerationNLGenerationRequest,
PromptSQLGenerationRequest,
SQLGenerationRequest,
StreamPromptSQLGenerationRequest,
UpdateMetadataRequest,
)
from dataherald.api.types.responses import (
Expand Down Expand Up @@ -356,6 +357,13 @@ def __init__(self, settings: Settings):
tags=["Finetunings"],
)

self.router.add_api_route(
"/api/v1/stream-sql-generation",
self.stream_sql_generation,
methods=["POST"],
tags=["Stream SQL Generation"],
)

self.router.add_api_route(
"/api/v1/heartbeat", self.heartbeat, methods=["GET"], tags=["System"]
)
Expand Down Expand Up @@ -601,3 +609,11 @@ def update_finetuning_job(
) -> Finetuning:
"""Gets fine tuning jobs"""
return self._api.update_finetuning_job(finetuning_id, update_metadata_request)

async def stream_sql_generation(
self, request: StreamPromptSQLGenerationRequest
) -> StreamingResponse:
return StreamingResponse(
self._api.stream_create_prompt_and_sql_generation(request),
media_type="text/event-stream",
)
63 changes: 63 additions & 0 deletions dataherald/services/sql_generations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from datetime import datetime
from queue import Queue

import pandas as pd

Expand Down Expand Up @@ -159,6 +160,68 @@ def create(
initial_sql_generation.error = sql_generation.error
return self.sql_generation_repository.update(initial_sql_generation)

def start_streaming(
self, prompt_id: str, sql_generation_request: SQLGenerationRequest, queue: Queue
):
initial_sql_generation = SQLGeneration(
prompt_id=prompt_id,
created_at=datetime.now(),
llm_config=sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
metadata=sql_generation_request.metadata,
)
self.sql_generation_repository.insert(initial_sql_generation)
prompt_repository = PromptRepository(self.storage)
prompt = prompt_repository.find_by_id(prompt_id)
if not prompt:
self.update_error(initial_sql_generation, f"Prompt {prompt_id} not found")
raise PromptNotFoundError(
f"Prompt {prompt_id} not found", initial_sql_generation.id
)
db_connection_repository = DatabaseConnectionRepository(self.storage)
db_connection = db_connection_repository.find_by_id(prompt.db_connection_id)
if (
sql_generation_request.finetuning_id is None
or sql_generation_request.finetuning_id == ""
):
if sql_generation_request.low_latency_mode:
raise SQLGenerationError(
"Low latency mode is not supported for our old agent with no finetuning. Please specify a finetuning id.",
initial_sql_generation.id,
)
sql_generator = DataheraldSQLAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
)
else:
sql_generator = DataheraldFinetuningAgent(
self.system,
sql_generation_request.llm_config
if sql_generation_request.llm_config
else LLMConfig(),
)
sql_generator.finetuning_id = sql_generation_request.finetuning_id
sql_generator.use_fintuned_model_only = (
sql_generation_request.low_latency_mode
)
initial_sql_generation.finetuning_id = sql_generation_request.finetuning_id
initial_sql_generation.low_latency_mode = (
sql_generation_request.low_latency_mode
)
try:
sql_generator.stream_response(
user_prompt=prompt,
database_connection=db_connection,
response=initial_sql_generation,
queue=queue,
)
except Exception as e:
self.update_error(initial_sql_generation, str(e))
raise SQLGenerationError(str(e), initial_sql_generation.id) from e

def get(self, query) -> list[SQLGeneration]:
return self.sql_generation_repository.find_by(query)

Expand Down
82 changes: 79 additions & 3 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
"""Base class that all sql generation classes inherit from."""
import datetime
import logging
import os
import re
from abc import ABC, abstractmethod
from typing import Any, List, Tuple
from queue import Queue
from typing import Any, Dict, List, Tuple

import sqlparse
from langchain.schema import AgentAction
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, LLMResult
from langchain.schema.messages import BaseMessage
from langchain_community.callbacks import get_openai_callback

from dataherald.config import Component, System
from dataherald.model.chat_model import ChatModel
from dataherald.sql_database.base import SQLDatabase
from dataherald.repositories.sql_generations import (
SQLGenerationRepository,
)
from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_generator.create_sql_query_status import create_sql_query_status
from dataherald.types import LLMConfig, Prompt, SQLGeneration
Expand All @@ -20,6 +30,12 @@ class EngineTimeOutORItemLimitError(Exception):
pass


def replace_unprocessable_characters(text: str) -> str:
"""Replace unprocessable characters with a space."""
text = text.strip()
return text.replace(r"\_", "_")


class SQLGenerator(Component, ABC):
metadata: Any
llm: ChatModel | None = None
Expand Down Expand Up @@ -114,3 +130,63 @@ def generate_response(
) -> SQLGeneration:
"""Generates a response to a user question."""
pass

def stream_agent_steps( # noqa: C901
self,
question: str,
agent_executor: AgentExecutor,
response: SQLGeneration,
sql_generation_repository: SQLGenerationRepository,
queue: Queue,
):
try:
with get_openai_callback() as cb:
for chunk in agent_executor.stream({"input": question}):
if "actions" in chunk:
for message in chunk["messages"]:
queue.put(message.content + "\n")
elif "steps" in chunk:
for step in chunk["steps"]:
queue.put(f"Observation: `{step.observation}`\n")
elif "output" in chunk:
queue.put(f'Final Answer: {chunk["output"]}')
if "```sql" in chunk["output"]:
response.sql = replace_unprocessable_characters(
self.remove_markdown(chunk["output"])
)
else:
raise ValueError()
except SQLInjectionError as e:
raise SQLInjectionError(e) from e
except EngineTimeOutORItemLimitError as e:
raise EngineTimeOutORItemLimitError(e) from e
except Exception as e:
response.sql = ("",)
response.status = ("INVALID",)
response.error = (str(e),)
finally:
queue.put(None)
response.tokens_used = cb.total_tokens
response.completed_at = datetime.datetime.now()
if not response.error:
if response.sql:
response = self.create_sql_query_status(
self.database,
response.sql,
response,
)
else:
response.status = "INVALID"
response.error = "No SQL query generated"
sql_generation_repository.update(response)

@abstractmethod
def stream_response(
self,
user_prompt: Prompt,
database_connection: DatabaseConnection,
response: SQLGeneration,
queue: Queue,
):
"""Streams a response to a user question."""
pass
Loading

0 comments on commit 68c1509

Please sign in to comment.