Skip to content

Commit

Permalink
[CAI-257] Chatbot/post queries with conversation (#1245)
Browse files Browse the repository at this point in the history
* feat(chatbot): conversational chat

* fix: dict

* fix(chatbot): messages
  • Loading branch information
batdevis authored Nov 20, 2024
1 parent 3197206 commit ba37f09
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 34 deletions.
17 changes: 12 additions & 5 deletions apps/chatbot/config/prompts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,16 @@ refine_prompt_str: |
Answer with either the original answer or with a refined answer to better answer the user query according to the `Chatbot Policy` listed above.
Answer:
condense_prompt_str: |
Given a conversation (between Human and Assistant) and a follow up message from Human, rewrite the message to be a standalone question that captures all relevant context from the conversation.
The standalone question must be in Italian.
<Chat History>
{chat_history}
<Follow Up Message>
{question}
<Standalone question>
# Reply to the user following these two steps:
# Step 1:
# Pay great attention in detail on the query's language and determine if it is formulated in Italian, English, Spanish, French, German, Greek, Croatian, or Slovenian ('yes' or 'no').
# Step 2:
# If Step 1 returns 'yes': reply according to the `Chatbot Policy` listed above. If `no`, reply you cannot speak that language and ask for a new query written in an accepted language.
15 changes: 12 additions & 3 deletions apps/chatbot/src/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import uuid
import boto3
import datetime
import time
import jwt
from typing import Annotated
from typing import Annotated, List
from boto3.dynamodb.conditions import Key
from botocore.exceptions import BotoCoreError, ClientError
from fastapi import FastAPI, HTTPException, Header
Expand All @@ -24,9 +23,15 @@
chatbot = Chatbot(params, prompts)


class QueryFromThePast(BaseModel):
id: str | None = None
question: str
answer: str | None = None

class Query(BaseModel):
question: str
queriedAt: str | None = None
history: List[QueryFromThePast] | None = None

class QueryFeedback(BaseModel):
badAnswer: bool
Expand Down Expand Up @@ -75,7 +80,11 @@ async def query_creation (
now = datetime.datetime.now(datetime.UTC)
userId = current_user_id(authorization)
session = find_or_create_session(userId, now=now)
answer = chatbot.generate(query.question)
answer = chatbot.chat_generate(
query_str = query.question,
messages = [item.dict() for item in query.history] if query.history else None
)


if query.queriedAt is None:
queriedAt = now.isoformat()
Expand Down
32 changes: 21 additions & 11 deletions apps/chatbot/src/modules/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dotenv import load_dotenv

load_dotenv()
nest_asyncio.apply()
#nest_asyncio.apply()


USE_PRESIDIO = True if (os.getenv("CHB_USE_PRESIDIO", "True")).lower() == "true" else False
Expand Down Expand Up @@ -62,12 +62,13 @@ def __init__(
chunk_sizes=params["vector_index"]["chunk_sizes"],
chunk_overlap=params["vector_index"]["chunk_overlap"]
)
self.qa_prompt_tmpl, self.ref_prompt_tmpl = self._get_prompt_templates()
self.qa_prompt_tmpl, self.ref_prompt_tmpl, self.condense_prompt_tmpl = self._get_prompt_templates()
self.engine = get_automerging_engine(
self.index,
llm=self.model,
text_qa_template=self.qa_prompt_tmpl,
refine_template=self.ref_prompt_tmpl,
condense_template=self.condense_prompt_tmpl,
verbose=self.params["engine"]["verbose"]
)

Expand All @@ -83,16 +84,25 @@ def _get_prompt_templates(self) -> Tuple[PromptTemplate, PromptTemplate]:
}
)

ref_prompt_tmpl = PromptTemplate(
self.prompts["refine_prompt_str"],
prompt_type="refine",
template_var_mappings = {
"existing_answer": "existing_answer",
"query_str": "query_str"
ref_prompt_tmpl = None
# PromptTemplate(
# self.prompts["refine_prompt_str"],
# prompt_type="refine",
# template_var_mappings = {
# "existing_answer": "existing_answer",
# "query_str": "query_str"
# }
# )

condense_prompt_tmpl = PromptTemplate(
self.prompts["condense_prompt_str"],
template_var_mappings={
"chat_history": "chat_history",
"question": "question"
}
)

return qa_prompt_tmpl, ref_prompt_tmpl
return qa_prompt_tmpl, ref_prompt_tmpl, condense_prompt_tmpl


def _get_response_str(self, engine_response: RESPONSE_TYPE) -> str:
Expand Down Expand Up @@ -169,11 +179,11 @@ def _messages_to_chathistory(self, messages: Optional[List[dict]] = None) -> Lis
chat_history += [
ChatMessage(
role = MessageRole.USER,
content = message["query"]
content = message["question"]
),
ChatMessage(
role = MessageRole.ASSISTANT,
content = message["response"]
content = message["answer"]
)
]

Expand Down
4 changes: 3 additions & 1 deletion apps/chatbot/src/modules/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_automerging_engine(
response_mode: str = "compact",
text_qa_template: PromptTemplate | None = None,
refine_template: PromptTemplate | None = None,
condense_template: PromptTemplate | None = None,
verbose: bool = True,
use_chat_engine: bool | None = None
) -> (RetrieverQueryEngine | CondenseQuestionChatEngine):
Expand Down Expand Up @@ -57,7 +58,8 @@ def get_automerging_engine(

if use_chat_engine:
automerging_engine = CondenseQuestionChatEngine.from_defaults(
query_engine = automerging_engine
query_engine = automerging_engine,
condense_question_prompt = condense_template
)

return automerging_engine
6 changes: 3 additions & 3 deletions apps/chatbot/src/modules/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def test_messages_to_chathistory():
###########################3

messages = [
{"query": "aaaa", "response": "bbbb"},
{"query": "cccc", "response": "dddd"},
{"query": "eeee", "response": "ffff"},
{"question": "aaaa", "answer": "bbbb"},
{"question": "cccc", "answer": "dddd"},
{"question": "eeee", "answer": "ffff"},
]
chat_history = CHATBOT._messages_to_chathistory(messages)

Expand Down
25 changes: 14 additions & 11 deletions apps/chatbot/src/modules/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import time
import json
import tqdm
import logging
import hashlib
import html2text
import pytz
from logging import getLogger
from datetime import datetime
from bs4 import BeautifulSoup
from selenium import webdriver
Expand Down Expand Up @@ -35,7 +35,10 @@

from dotenv import load_dotenv


load_dotenv()
logger = getLogger(__name__)


PROVIDER = os.getenv("CHB_PROVIDER")
assert PROVIDER in ["google", "aws"]
Expand Down Expand Up @@ -154,7 +157,7 @@ def create_documentation(
driver_options.add_argument('--disable-gpu')
driver_options.add_argument('--no-sandbox')
driver_options.add_argument('--disable-dev-shm-usage')

for file in tqdm.tqdm(html_files, total=len(html_files), desc="Extracting HTML"):

if file in dynamic_htmls or "/webinars/" in file or "/api/" in file:
Expand Down Expand Up @@ -198,8 +201,8 @@ def create_documentation(
}
))

logging.info(f"[vector_database.py - create_documentation] Number of documents with content: {len(documents)}")
logging.info(f"[vector_database.py - create_documentation] Number of empty pages in the documentation: {len(empty_pages)}. These are left out.")
logger.info(f"Number of documents with content: {len(documents)}")
logger.info(f"Number of empty pages in the documentation: {len(empty_pages)}. These are left out.")
with open("empty_htmls.json", "w") as f:
json.dump(empty_pages, f, indent=4)

Expand All @@ -214,7 +217,7 @@ def build_automerging_index_redis(
chunk_overlap: int
) -> VectorStoreIndex:

logging.info("[vector_database.py - build_automerging_index_redis] Storing vector index and hash table on Redis..")
logger.info("Storing vector index and hash table on Redis..")

Settings.llm = llm
Settings.embed_model = embed_model
Expand All @@ -230,9 +233,9 @@ def build_automerging_index_redis(
key=key,
val=value
)
logging.info(f"[vector_database.py - build_automerging_index_redis] hash_table_{NEW_INDEX_ID} is now on Redis.")
logger.info(f"[vector_database.py - build_automerging_index_redis] hash_table_{NEW_INDEX_ID} is now on Redis.")

logging.info(f"[vector_database.py - build_automerging_index_redis] Creating index {NEW_INDEX_ID} ...")
logger.info(f"Creating index {NEW_INDEX_ID} ...")
nodes = Settings.node_parser.get_nodes_from_documents(documents)
leaf_nodes = get_leaf_nodes(nodes)

Expand Down Expand Up @@ -267,7 +270,7 @@ def build_automerging_index_redis(
)
automerging_index.set_index_id(NEW_INDEX_ID)
put_ssm_parameter(os.getenv("CHB_LLAMAINDEX_INDEX_ID"), NEW_INDEX_ID)
logging.info("[vector_database.py - build_automerging_index_redis] Created vector index successfully and stored on Redis.")
logger.info("Created vector index successfully and stored on Redis.")

delete_old_index()

Expand Down Expand Up @@ -296,7 +299,7 @@ def load_automerging_index_redis(
schema=REDIS_SCHEMA
)

logging.info("[vector_database.py - load_automerging_index_redis] Loading vector index from Redis...")
logger.info("Loading vector index from Redis...")
storage_context = StorageContext.from_defaults(
vector_store=redis_vector_store,
docstore=REDIS_DOCSTORE,
Expand All @@ -310,7 +313,7 @@ def load_automerging_index_redis(

return automerging_index
else:
logging.error("[vector_database.py - load_automerging_index_redis] No index_id provided.")
logger.error("No index_id provided.")


def delete_old_index():
Expand All @@ -320,4 +323,4 @@ def delete_old_index():
if f"{INDEX_ID}/vector" in str(key) or f"hash_table_{INDEX_ID}" == str(key):
REDIS_CLIENT.delete(key)

logging.info(f"[vector_database.py - delete_old_index] Deleted index with ID: {INDEX_ID} and its hash table from Redis.")
logger.info(f"Deleted index with ID: {INDEX_ID} and its hash table from Redis.")

0 comments on commit ba37f09

Please sign in to comment.