Skip to content

Commit

Permalink
feat: use websocket for ai assistant (pinterest#1311)
Browse files Browse the repository at this point in the history
* feat: use websocket for ai assistant

* fix node test

* comments

* remove memory and add keep button

* fix linter
  • Loading branch information
jczhong84 committed Aug 30, 2023
1 parent f770e80 commit 2fb0242
Show file tree
Hide file tree
Showing 24 changed files with 507 additions and 442 deletions.
14 changes: 14 additions & 0 deletions querybook/server/const/ai_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from enum import Enum


# KEEP IT CONSISTENT AS webapp/const/aiAssistant.ts
class AICommandType(Enum):
SQL_FIX = "SQL_FIX"
SQL_TITLE = "SQL_TITLE"
TEXT_TO_SQL = "TEXT_TO_SQL"
RESET_MEMORY = "RESET_MEMORY"


AI_ASSISTANT_NAMESPACE = "/ai_assistant"
AI_ASSISTANT_REQUEST_EVENT = "ai_assistant_request"
AI_ASSISTANT_RESPONSE_EVENT = "ai_assistant_response"
2 changes: 0 additions & 2 deletions querybook/server/datasources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from . import event_log
from . import data_element
from . import comment
from . import ai_assistant

# Keep this at the end of imports to make sure the plugin APIs override the default ones
try:
Expand All @@ -43,5 +42,4 @@
event_log
data_element
comment
ai_assistant
api_plugin
48 changes: 0 additions & 48 deletions querybook/server/datasources/ai_assistant.py

This file was deleted.

2 changes: 2 additions & 0 deletions querybook/server/datasources_socketio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from . import query_execution
from . import datadoc
from . import connect
from . import ai_assistant

connect
query_execution
datadoc
ai_assistant
13 changes: 13 additions & 0 deletions querybook/server/datasources_socketio/ai_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from const.ai_assistant import (
AI_ASSISTANT_NAMESPACE,
AI_ASSISTANT_REQUEST_EVENT,
)

from .helper import register_socket


@register_socket(AI_ASSISTANT_REQUEST_EVENT, namespace=AI_ASSISTANT_NAMESPACE)
def ai_assistant_request(command_type: str, payload={}):
from lib.ai_assistant import ai_assistant

ai_assistant.handle_ai_command(command_type, payload)
6 changes: 6 additions & 0 deletions querybook/server/datasources_socketio/connect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flask_login import current_user
from flask_socketio import ConnectionRefusedError

from const.ai_assistant import AI_ASSISTANT_NAMESPACE
from const.data_doc import DATA_DOC_NAMESPACE
from const.query_execution import QUERY_EXECUTION_NAMESPACE

Expand All @@ -20,3 +21,8 @@ def connect_query_execution(auth):
@register_socket("connect", namespace=DATA_DOC_NAMESPACE)
def connect_datadoc(auth):
on_connect()


@register_socket("connect", namespace=AI_ASSISTANT_NAMESPACE)
def connect_ai_assistant(auth):
on_connect()
11 changes: 7 additions & 4 deletions querybook/server/lib/ai_assistant/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from env import QuerybookSettings

from .all_ai_assistants import get_ai_assistant_class


if QuerybookSettings.AI_ASSISTANT_PROVIDER:
from .ai_assistant import AIAssistant
ai_assistant = get_ai_assistant_class(QuerybookSettings.AI_ASSISTANT_PROVIDER)
ai_assistant.set_config(QuerybookSettings.AI_ASSISTANT_CONFIG)

ai_assistant = AIAssistant(
QuerybookSettings.AI_ASSISTANT_PROVIDER, QuerybookSettings.AI_ASSISTANT_CONFIG
)
else:
ai_assistant = None


__all__ = ["ai_assistant"]
52 changes: 0 additions & 52 deletions querybook/server/lib/ai_assistant/ai_assistant.py

This file was deleted.

176 changes: 3 additions & 173 deletions querybook/server/lib/ai_assistant/assistants/openai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@

from langchain.chat_models import ChatOpenAI
from langchain.callbacks.manager import CallbackManager
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessage,
HumanMessagePromptTemplate,
)
import openai


Expand All @@ -29,174 +24,9 @@ def _get_error_msg(self, error) -> str:

return super()._get_error_msg(error)

@property
def title_generation_prompt_template(self) -> ChatPromptTemplate:
system_message_prompt = SystemMessage(
content="You are a helpful assistant that can summerize SQL queries."
)
human_template = (
"Generate a brief 10-word-maximum title for the SQL query below. "
"===Query\n"
"{query}\n\n"
"===Response Guidelines\n"
"1. Only respond with the title without any explanation\n"
"2. Dont use double quotes to enclose the title\n"
"3. Dont add a final period to the title\n\n"
"===Example response\n"
"This is a title\n"
)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
return ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)

@property
def query_auto_fix_prompt_template(self) -> ChatPromptTemplate:
system_message_prompt = SystemMessage(
content=(
"You are a SQL expert that can help fix SQL query errors.\n\n"
"Please follow the format below for your response:\n"
"<@key-1@>\n"
"value-1\n\n"
"<@key-2@>\n"
"value-2\n\n"
)
)
human_template = (
"Please help fix the query below based on the given error message and table schemas. \n\n"
"===SQL dialect\n"
"{dialect}\n\n"
"===Query\n"
"{query}\n\n"
"===Error\n"
"{error}\n\n"
"===Table Schemas\n"
"{table_schemas}\n\n"
"===Response Format\n"
"<@key-1@>\n"
"value-1\n\n"
"<@key-2@>\n"
"value-2\n\n"
"===Example response:\n"
"<@explanation@>\n"
"This is an explanation about the error\n\n"
"<@fix_suggestion@>\n"
"This is a recommended fix for the error\n\n"
"<@fixed_query@>\n"
"The fixed SQL query\n\n"
"===Response Guidelines\n"
"1. For the <@fixed_query@> section, it can only be a valid SQL query without any explanation.\n"
"2. If there is insufficient context to address the query error, you may leave the fixed_query section blank and provide a general suggestion instead.\n"
"3. Maintain the original query format and case in the fixed_query section, including comments, except when correcting the erroneous part.\n"
)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
return ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)

@property
def generate_sql_query_prompt_template(self) -> ChatPromptTemplate:
system_message_prompt = SystemMessage(
content=(
"You are a SQL expert that can help generating SQL query.\n\n"
"Please follow the key/value pair format below for your response:\n"
"<@key-1@>\n"
"value-1\n\n"
"<@key-2@>\n"
"value-2\n\n"
)
)
human_template = (
"Please help to generate a new SQL query or modify the original query to answer the following question. Your response should ONLY be based on the given context.\n\n"
"===SQL Dialect\n"
"{dialect}\n\n"
"===Tables\n"
"{table_schemas}\n\n"
"===Original Query\n"
"{original_query}\n\n"
"===Question\n"
"{question}\n\n"
"===Response Format\n"
"<@key-1@>\n"
"value-1\n\n"
"<@key-2@>\n"
"value-2\n\n"
"===Example Response:\n"
"Example 1: Sufficient Context\n"
"<@query@>\n"
"A generated SQL query based on the provided context with the asked question at the beginning is provided here.\n\n"
"Example 2: Insufficient Context\n"
"<@explanation@>\n"
"An explanation of the missing context is provided here.\n\n"
"===Response Guidelines\n"
"1. If the provided context is sufficient, please respond only with a valid SQL query without any explanations in the <@query@> section. The query should start with a comment containing the question being asked.\n"
"2. If the provided context is insufficient, please explain what information is missing.\n"
"3. If the original query is provided, please modify the original query to answer the question. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to generate the new query.\n"
"4. The <@key_name@> in the response can only be <@explanation@> or <@query@>.\n\n"
)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
return ChatPromptTemplate.from_messages(
[system_message_prompt, human_message_prompt]
)

def _generate_title_from_query(
self, query, stream=True, callback_handler=None, user_id=None
):
"""Generate title from SQL query using OpenAI's chat model."""
messages = self.title_generation_prompt_template.format_prompt(
query=query
).to_messages()
chat = ChatOpenAI(
**self._config,
streaming=stream,
callback_manager=CallbackManager([callback_handler]),
)
ai_message = chat(messages)
return ai_message.content

def _query_auto_fix(
self,
language,
query,
error,
table_schemas,
stream,
callback_handler,
user_id=None,
):
"""Query auto fix using OpenAI's chat model."""
messages = self.query_auto_fix_prompt_template.format_prompt(
dialect=language, query=query, error=error, table_schemas=table_schemas
).to_messages()
chat = ChatOpenAI(
**self._config,
streaming=stream,
callback_manager=CallbackManager([callback_handler]),
)
ai_message = chat(messages)
return ai_message.content

def _generate_sql_query(
self,
language: str,
table_schemas: str,
question: str,
original_query: str,
stream,
callback_handler,
user_id=None,
):
"""Generate SQL query using OpenAI's chat model."""
messages = self.generate_sql_query_prompt_template.format_prompt(
dialect=language,
question=question,
table_schemas=table_schemas,
original_query=original_query,
).to_messages()
chat = ChatOpenAI(
def _get_llm(self, callback_handler):
return ChatOpenAI(
**self._config,
streaming=stream,
streaming=True,
callback_manager=CallbackManager([callback_handler]),
)
ai_message = chat(messages)
return ai_message.content
Loading

0 comments on commit 2fb0242

Please sign in to comment.