Skip to content

Commit

Permalink
remove memory and add keep button
Browse files Browse the repository at this point in the history
  • Loading branch information
jczhong84 committed Aug 16, 2023
1 parent 7a551ec commit 771fd35
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 69 deletions.
34 changes: 2 additions & 32 deletions querybook/server/lib/ai_assistant/base_ai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@

from flask_login import current_user
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
from pydantic.error_wrappers import ValidationError

from app.db import with_session
from app.flask_app import socketio
from const.ai_assistant import AICommandType
from clients.redis_client import with_redis
from lib.logger import get_logger
from lib.query_analysis.lineage import process_query
from logic import admin as admin_logic
from logic import datadoc as datadoc_logic
from logic import metastore as m_logic
from logic import query_execution as qe_logic
from models.metastore import DataTableColumn
Expand All @@ -22,7 +19,6 @@
from .prompts.sql_fix_prompt import SQL_FIX_PROMPT
from .prompts.sql_title_prompt import SQL_TITLE_PROMPT
from .prompts.text2sql_prompt import TEXT2SQL_PROMPT
from .redis_chat_history_storage import RedisChatHistoryStorage
from .streaming_web_socket_callback_handler import (
WebSocketStream,
StreamingWebsocketCallbackHandler,
Expand Down Expand Up @@ -59,26 +55,6 @@ def wrapper(self, *args, **kwargs):
def _get_llm(self, callback_handler: StreamingWebsocketCallbackHandler):
"""return the language model to use"""

@with_redis
def _get_chat_memory(
self,
session_id,
memory_key="chat_history",
input_key="question",
ttl=600,
redis_conn=None,
):
message_history_storage = RedisChatHistoryStorage(
redis_client=redis_conn, ttl=ttl, session_id=session_id
)

return ConversationBufferMemory(
memory_key=memory_key,
chat_memory=message_history_storage,
input_key=input_key,
return_messages=True,
)

def _get_sql_title_prompt(self):
"""Override this method to return specific prompt for your own assistant."""
return SQL_TITLE_PROMPT
Expand Down Expand Up @@ -186,18 +162,15 @@ def handle_ai_command(self, command_type: str, payload: dict = {}):
query = payload["query"]
self.generate_title_from_query(query=query)
elif command_type == AICommandType.TEXT_TO_SQL.value:
data_cell_id = payload["data_cell_id"]
data_cell = datadoc_logic.get_data_cell_by_id(data_cell_id)
query = data_cell.context if data_cell else None
original_query = payload["original_query"]
query_engine_id = payload["query_engine_id"]
tables = payload.get("tables")
question = payload["question"]
self.generate_sql_query(
query_engine_id=query_engine_id,
tables=tables,
question=question,
original_query=query,
memory_session_id=f"{current_user.id}_{data_cell_id}",
original_query=original_query,
)
elif command_type == AICommandType.SQL_FIX.value:
query_execution_id = payload["query_execution_id"]
Expand All @@ -219,7 +192,6 @@ def generate_sql_query(
tables: list[str],
question: str,
original_query: str = None,
memory_session_id=None,
session=None,
):
query_engine = admin_logic.get_query_engine_by_id(
Expand All @@ -230,11 +202,9 @@ def generate_sql_query(
)

prompt = self._get_text2sql_prompt()
memory = self._get_chat_memory(session_id=memory_session_id)
chain = self._get_llm_chain(
command_type=AICommandType.TEXT_TO_SQL.value,
prompt=prompt,
memory=memory,
)
return chain.run(
dialect=query_engine.language,
Expand Down
14 changes: 5 additions & 9 deletions querybook/server/lib/ai_assistant/prompts/text2sql_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
)


system_message_template = (
"You are a SQL expert that can help generating SQL query.\n\n"
system_message_template = "You are a SQL expert that can help generating SQL query."

human_message_template = (
"Please help to generate a new SQL query or modify the original query to answer the following question. Your response should ONLY be based on the given context.\n\n"
"Please always follow the key/value pair format below for your response:\n"
"===Response Format\n"
Expand All @@ -28,24 +29,19 @@
"2. If the provided context is insufficient, please explain what information is missing.\n"
"3. If the original query is provided, please modify the original query to answer the question. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to generate the new query.\n"
"4. Please always honor the table schmeas for the query generation\n\n"
)

human_message_template = (
"===SQL Dialect\n"
"{dialect}\n\n"
"===Tables\n"
"{table_schemas}\n\n"
"===Original Query\n"
"{original_query}\n\n"
"===Question\n"
"{question}\n\n"
)

TEXT2SQL_PROMPT = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(system_message_template),
HumanMessagePromptTemplate.from_template(human_message_template),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template(
"{question}\nPlease remember always start your response with <@query@> or <@explanation@>.\n"
),
]
)
17 changes: 0 additions & 17 deletions querybook/server/lib/ai_assistant/redis_chat_history_storage.py

This file was deleted.

43 changes: 39 additions & 4 deletions querybook/webapp/components/AIAssistant/QueryGenerationModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { Message } from 'ui/Message/Message';
import { Modal } from 'ui/Modal/Modal';
import { ResizableTextArea } from 'ui/ResizableTextArea/ResizableTextArea';
import { StyledText } from 'ui/StyledText/StyledText';
import { Tag } from 'ui/Tag/Tag';

import { TableSelector } from './TableSelector';
import { TextToSQLMode, TextToSQLModeSelector } from './TextToSQLModeSelector';
Expand Down Expand Up @@ -72,6 +73,7 @@ export const QueryGenerationModal = ({
const [textToSQLMode, setTextToSQLMode] = useState(
!!query ? TextToSQLMode.EDIT : TextToSQLMode.GENERATE
);
const [newQuery, setNewQuery] = useState<string>('');

useEffect(() => {
setTables(uniq([...tablesInQuery, ...tables]));
Expand All @@ -83,14 +85,17 @@ export const QueryGenerationModal = ({
query_engine_id: engineId,
tables: tables,
question: question,
data_cell_id:
textToSQLMode === TextToSQLMode.EDIT ? dataCellId : undefined,
original_query: query,
}
);

const { explanation, query: rawNewQuery, data } = streamData;

const newQuery = trimSQLQuery(rawNewQuery);
// const newQuery = trimSQLQuery(rawNewQuery);

useEffect(() => {
setNewQuery(trimSQLQuery(rawNewQuery));
}, [rawNewQuery]);

const onKeyDown = useCallback(
(event: React.KeyboardEvent) => {
Expand Down Expand Up @@ -272,7 +277,37 @@ export const QueryGenerationModal = ({
}
toQuery={newQuery}
fromQueryTitle="Original Query"
toQueryTitle="New Query"
toQueryTitle={
<div className="horizontal-space-between">
{<Tag>New Query</Tag>}
<Button
title="Keep the query"
onClick={() => {
onUpdateQuery(
newQuery,
false
);
setTextToSQLMode(
TextToSQLMode.EDIT
);
setQuestion('');
setNewQuery('');
trackClick({
component:
ComponentType.AI_ASSISTANT,
element:
ElementType.QUERY_GENERATION_KEEP_BUTTON,
aux: {
mode: textToSQLMode,
question,
tables,
},
});
}}
color="confirm"
/>
</div>
}
disableHighlight={
streamStatus === StreamStatus.STREAMING
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
.QueryComparison {
display: flex;
gap: 8px;
.Tag {
margin-bottom: 12px;
}

.diff-side-view {
flex: 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import './QueryComparison.scss';
export const QueryComparison: React.FC<{
fromQuery: string;
toQuery: string;
fromQueryTitle?: string;
toQueryTitle?: string;
fromQueryTitle?: string | React.ReactNode;
toQueryTitle?: string | React.ReactNode;
disableHighlight?: boolean;
hideEmptyQuery?: boolean;
}> = ({
Expand Down Expand Up @@ -63,7 +63,15 @@ export const QueryComparison: React.FC<{
<div className="QueryComparison">
{!(hideEmptyQuery && !fromQuery) && (
<div className="diff-side-view">
{fromQueryTitle && <Tag>{fromQueryTitle}</Tag>}
{fromQueryTitle && (
<div className="mb12">
{typeof fromQueryTitle === 'string' ? (
<Tag>{fromQueryTitle}</Tag>
) : (
fromQueryTitle
)}
</div>
)}
<ThemedCodeHighlightWithMark
highlightRanges={removedRanges}
query={fromQuery}
Expand All @@ -74,7 +82,15 @@ export const QueryComparison: React.FC<{
)}
{!(hideEmptyQuery && !toQuery) && (
<div className="diff-side-view">
{toQueryTitle && <Tag>{toQueryTitle}</Tag>}
{toQueryTitle && (
<div className="mb12">
{typeof toQueryTitle === 'string' ? (
<Tag>{toQueryTitle}</Tag>
) : (
toQueryTitle
)}
</div>
)}
<ThemedCodeHighlightWithMark
highlightRanges={addedRanges}
query={toQuery}
Expand Down
1 change: 1 addition & 0 deletions querybook/webapp/const/analytics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ export enum ElementType {
QUERY_ERROR_AUTO_FIX_APPLY_BUTTON = 'QUERY_ERROR_AUTO_FIX_APPLY_BUTTON',
QUERY_ERROR_AUTO_FIX_APPLY_AND_RUN_BUTTON = 'QUERY_ERROR_AUTO_FIX_APPLY_AND_RUN_BUTTON',
QUERY_GENERATION_BUTTON = 'QUERY_GENERATION_BUTTON',
QUERY_GENERATION_KEEP_BUTTON = 'QUERY_GENERATION_KEEP_BUTTON',
QUERY_GENERATION_APPLY_BUTTON = 'QUERY_GENERATION_APPLY_BUTTON',
QUERY_GENERATION_APPLY_AND_RUN_BUTTON = 'QUERY_GENERATION_APPLY_AND_RUN_BUTTON',
}
Expand Down

0 comments on commit 771fd35

Please sign in to comment.