Skip to content

Commit

Permalink
Implement conversation chain
Browse files Browse the repository at this point in the history
  • Loading branch information
Vidminas committed Nov 18, 2023
1 parent be39d05 commit 4832b35
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 29 deletions.
15 changes: 6 additions & 9 deletions chatdocs/chains.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
from typing import Any, Dict, Optional
from typing import Any, Dict

from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import RetrievalQA
from langchain.chains import ConversationalRetrievalChain

from .llms import get_llm
from .vectorstores import get_vectorstore


def get_retrieval_qa(
def make_conversation_chain(
config: Dict[str, Any],
*,
selected_llm_index: int = 0,
callbacks: Optional[list[BaseCallbackHandler]] = None,
) -> RetrievalQA:
) -> ConversationalRetrievalChain:
db = get_vectorstore(config)
retriever = db.as_retriever(**config["retriever"])
llm = get_llm(config, selected_llm_index=selected_llm_index, callbacks=callbacks)
return RetrievalQA.from_chain_type(
llm = get_llm(config, selected_llm_index=selected_llm_index)
return ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
)
20 changes: 14 additions & 6 deletions chatdocs/chat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Dict, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.messages import AIMessage, HumanMessage
from rich import print
from rich.markup import escape
from rich.panel import Panel

from .chains import get_retrieval_qa
from .chains import make_conversation_chain


def print_answer(text: str) -> None:
Expand All @@ -18,7 +19,8 @@ def on_llm_new_token(self, token: str, **kwargs) -> None:


def chat(config: Dict[str, Any], query: Optional[str] = None) -> None:
qa = get_retrieval_qa(config, callbacks=[PrintCallback])
llm = make_conversation_chain(config)
messages = []

interactive = not query
print()
Expand All @@ -37,9 +39,10 @@ def chat(config: Dict[str, Any], query: Optional[str] = None) -> None:
break
print("[bold]A:", end="", flush=True)

res = qa(query)
if config["llm"] != "ctransformers":
print_answer(res["result"])
res = llm(
{ "question": query, "chat_history": messages },
callbacks=[PrintCallback],
)

print()
for doc in res["source_documents"]:
Expand All @@ -49,7 +52,12 @@ def chat(config: Dict[str, Any], query: Optional[str] = None) -> None:
f"[bright_blue]{escape(source)}[/bright_blue]\n\n{escape(content)}"
)
)

print()

print_answer(res["answer"])

if not interactive:
break

messages.append(HumanMessage(content=query))
messages.append(AIMessage(content=res["answer"]))
11 changes: 5 additions & 6 deletions chatdocs/llms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Optional
from typing import Any

from langchain.callbacks.base import BaseCallbackHandler
from langchain.llms import CTransformers, HuggingFacePipeline, OpenAI
from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
Expand All @@ -12,7 +11,6 @@ def get_llm(
config: dict[str, Any],
*,
selected_llm_index: int = 0,
callbacks: Optional[list[BaseCallbackHandler]] = None,
) -> LLM:
local_files_only = not config["download"]

Expand All @@ -22,7 +20,7 @@ def get_llm(

if model_framework == "ctransformers":
config = merge(config, {"config": {"local_files_only": local_files_only}})
llm = CTransformers(callbacks=callbacks, **config)
llm = CTransformers(**config)
elif model_framework == "openai":
llm = OpenAI(**config)
elif model_framework == "huggingface":
Expand All @@ -33,8 +31,9 @@ def get_llm(
if not tokenizer.pad_token_id:
tokenizer.pad_token_id = model.config.eos_token_id
pipe = pipeline(
"text-generation", model=model, tokenizer=tokenizer,
callbacks=callbacks,
"text-generation",
model=model,
tokenizer=tokenizer,
**config["pipeline_kwargs"],
)
llm = HuggingFacePipeline(pipeline=pipe)
Expand Down
18 changes: 10 additions & 8 deletions chatdocs/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.callbacks import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.schema.output import LLMResult

import streamlit as st
Expand All @@ -18,7 +19,7 @@

__package__ = Path(__file__).parent.name

from .chains import get_retrieval_qa
from .chains import make_conversation_chain
from .st_utils import load_config


Expand All @@ -37,7 +38,7 @@ def on_llm_start(
**kwargs: Any,
) -> None:
# Workaround to prevent showing the rephrased question as output
if prompts[0].startswith("Human"):
if prompts[0][:20] == CONDENSE_QUESTION_PROMPT.template[:20]:
self.run_id_ignore_token = run_id
return
self.status = self.container.status(
Expand Down Expand Up @@ -98,8 +99,8 @@ def print_state_messages(msgs: StreamlitChatMessageHistory):


@st.cache_resource
def load_qa_chain(config, selected_llm):
return get_retrieval_qa(config, selected_llm_index=selected_llm)
def load_llm(config, selected_llm):
return make_conversation_chain(config, selected_llm_index=selected_llm)


def main():
Expand All @@ -124,7 +125,7 @@ def main():

config = load_config()
selected_llm = st.sidebar.radio("LLM", range(len(config["llms"])), format_func=lambda idx: config["llms"][idx]["model"])
qa = load_qa_chain(config, selected_llm)
llm = load_llm(config, selected_llm)

if prompt := st.chat_input("Enter a query"):
with st.chat_message("user"):
Expand All @@ -134,10 +135,11 @@ def main():
retrieve_callback = PrintRetrievalHandler(st.container())
print_callback = StreamHandler(st.empty())
stdout_callback = StreamingStdOutCallbackHandler()
response = qa(
prompt, callbacks=[retrieve_callback, print_callback, stdout_callback]
response = llm(
{ "question": prompt, "chat_history": msgs.messages },
callbacks=[retrieve_callback, print_callback, stdout_callback],
)
msgs.add_ai_message(response["result"])
msgs.add_ai_message(response["answer"])


if __name__ == "__main__":
Expand Down

0 comments on commit 4832b35

Please sign in to comment.