-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Contrib/conversational rag example (#674)
Adds conversational RAG example to contrib This shows how you can incorporate chat history into a RAG setting. The primary purpose is to help contextualize the ask from the new question appropriately.
- Loading branch information
Showing
7 changed files
with
351 additions
and
0 deletions.
There are no files selected for viewing
108 changes: 108 additions & 0 deletions
108
contrib/hamilton/contrib/dagworks/conversational_rag/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Purpose of this module | ||
|
||
This module shows a conversational retrieval augmented generation (RAG) example using | ||
Hamilton. It shows you how you might structure your code with Hamilton to | ||
create a RAG pipeline that takes into account conversation. | ||
|
||
This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + and in memory vector store and the OpenAI LLM provider. | ||
The implementation of the FAISS vector store uses the LangChain wrapper around it. | ||
That's because this was the simplest way to get this example up without requiring | ||
someone having to host and manage a proper vector store. | ||
|
||
The "smarts" in the is pipeline are that it will take a conversation, and then a question, | ||
and then rewrite the question based on the conversation to be "standalone". That way | ||
the standalone question can be used for the vector store query, as well as a more | ||
specific question for the LLM given the found context. | ||
|
||
## Example Usage | ||
|
||
### Inputs | ||
These are the defined inputs you can provide. | ||
|
||
- *input_texts*: A list of strings. Each string will be encoded into a vector and stored in the vector store. | ||
- *question*: A string. This is the question you want to ask the LLM, and vector store which will provide context. | ||
- *chat_history*: A list of strings. Each string is a line of conversation. They need to be prefixed with "Human" or "AI" to indicate who said it. They should be alternating. | ||
- *top_k*: An integer. This is the number of vectors to retrieve from the vector store. Defaults to 5. | ||
|
||
### Overrides | ||
With Hamilton you can easily override a function and provide a value for it. For example if you're | ||
iterating you might just want to override these two values before modifying the functions: | ||
|
||
- *context*: if you want to skip going to the vector store and provide the context directly, you can do so by providing this override. | ||
- *standalone_question*: if you want to skip the rewording of the question, you can provide the standalone question directly. | ||
- *answer_prompt*: if you want to provide the prompt to pass to the LLM, pass it in as an override. | ||
|
||
### Execution | ||
You can ask to get back any result of an intermediate function by providing the function name in the `execute` call. | ||
Here we just ask for the final result, but if you wanted to, you could ask for outputs of any of the functions, which | ||
you can then introspect or log for debugging/evaluation purposes. Note if you want more platform integrations, | ||
you can add adapters that will do this automatically for you, e.g. like we have the `PrintLn` adapter here. | ||
|
||
```python | ||
# import the module | ||
from hamilton import driver | ||
from hamilton import lifecycle | ||
dr = ( | ||
driver.Builder() | ||
.with_modules(conversational_rag) | ||
.with_config({}) | ||
# this prints the inputs and outputs of each step. | ||
.with_adapters(lifecycle.PrintLn(verbosity=2)) | ||
.build() | ||
) | ||
# no chat history -- nothing to rewrite | ||
result = dr.execute( | ||
["conversational_rag_response"], | ||
inputs={ | ||
"input_texts": [ | ||
"harrison worked at kensho", | ||
"stefan worked at Stitch Fix", | ||
], | ||
"question": "where did stefan work?", | ||
"chat_history": [] | ||
}, | ||
) | ||
print(result) | ||
|
||
# this will now reword the question to then be | ||
# used to query the vector store and the final LLM call. | ||
result = dr.execute( | ||
["conversational_rag_response"], | ||
inputs={ | ||
"input_texts": [ | ||
"harrison worked at kensho", | ||
"stefan worked at Stitch Fix", | ||
], | ||
"question": "where did he work?", | ||
"chat_history": [ | ||
"Human: Who wrote this example?", | ||
"AI: Stefan" | ||
] | ||
}, | ||
) | ||
print(result) | ||
``` | ||
|
||
# How to extend this module | ||
What you'd most likely want to do is: | ||
|
||
1. Change the vector store (and how embeddings are generated). | ||
2. Change the LLM provider. | ||
3. Change the context and prompt. | ||
|
||
With (1) you can import any vector store/library that you want. You should draw out | ||
the process you would like, and that should then map to Hamilton functions. | ||
With (2) you can import any LLM provider that you want, just use `@config.when` if you | ||
want to switch between multiple providers. | ||
With (3) you can add more functions that create parts of the prompt. | ||
|
||
# Configuration Options | ||
There is no configuration needed for this module. | ||
|
||
# Limitations | ||
|
||
You need to have the OPENAI_API_KEY in your environment. | ||
It should be accessible from your code by doing `os.environ["OPENAI_API_KEY"]`. | ||
|
||
The code does not check the context length, so it may fail if the context passed is too long | ||
for the LLM you send it to. |
161 changes: 161 additions & 0 deletions
161
contrib/hamilton/contrib/dagworks/conversational_rag/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
from hamilton import contrib | ||
|
||
with contrib.catch_import_errors(__name__, __file__, logger): | ||
import openai | ||
|
||
# use langchain implementation of vector store | ||
from langchain_community.vectorstores import FAISS | ||
from langchain_core.vectorstores import VectorStoreRetriever | ||
|
||
# use langchain embedding wrapper with vector store | ||
from langchain_openai import OpenAIEmbeddings | ||
|
||
|
||
def standalone_question_prompt(chat_history: list[str], question: str) -> str: | ||
"""Prompt for getting a standalone question given the chat history. | ||
This is then used to query the vector store with. | ||
:param chat_history: the history of the conversation. | ||
:param question: the current user question. | ||
:return: prompt to use. | ||
""" | ||
chat_history_str = "\n".join(chat_history) | ||
return ( | ||
"Given the following conversation and a follow up question, " | ||
"rephrase the follow up question to be a standalone question, " | ||
"in its original language.\n\n" | ||
"Chat History:\n" | ||
"{chat_history}\n" | ||
"Follow Up Input: {question}\n" | ||
"Standalone question:" | ||
).format(chat_history=chat_history_str, question=question) | ||
|
||
|
||
def standalone_question(standalone_question_prompt: str, llm_client: openai.OpenAI) -> str: | ||
"""Asks the LLM to create a standalone question from the prompt. | ||
:param standalone_question_prompt: the prompt with context. | ||
:param llm_client: the llm client to use. | ||
:return: the standalone question. | ||
""" | ||
response = llm_client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
messages=[{"role": "user", "content": standalone_question_prompt}], | ||
) | ||
return response.choices[0].message.content | ||
|
||
|
||
def vector_store(input_texts: list[str]) -> VectorStoreRetriever: | ||
"""A Vector store. This function populates and creates one for querying. | ||
This is a cute function encapsulating the creation of a vector store. In real life | ||
you could replace this with a more complex function, or one that returns a | ||
client to an existing vector store. | ||
:param input_texts: the input "text" i.e. documents to be stored. | ||
:return: a vector store that can be queried against. | ||
""" | ||
vectorstore = FAISS.from_texts(input_texts, embedding=OpenAIEmbeddings()) | ||
retriever = vectorstore.as_retriever() | ||
return retriever | ||
|
||
|
||
def context(standalone_question: str, vector_store: VectorStoreRetriever, top_k: int = 5) -> str: | ||
"""This function returns the string context to put into a prompt for the RAG model. | ||
It queries the provided vector store for information. | ||
:param standalone_question: the question to use to search the vector store against. | ||
:param vector_store: the vector store to search against. | ||
:param top_k: the number of results to return. | ||
:return: a string with all the context. | ||
""" | ||
_results = vector_store.invoke(standalone_question, search_kwargs={"k": top_k}) | ||
return "\n\n".join(map(lambda d: d.page_content, _results)) | ||
|
||
|
||
def answer_prompt(context: str, standalone_question: str) -> str: | ||
"""Creates a prompt that includes the question and context for the LLM to make sense of. | ||
:param context: the information context to use. | ||
:param standalone_question: the user question the LLM should answer. | ||
:return: the full prompt. | ||
""" | ||
template = ( | ||
"Answer the question based only on the following context:\n" | ||
"{context}\n\n" | ||
"Question: {question}" | ||
) | ||
|
||
return template.format(context=context, question=standalone_question) | ||
|
||
|
||
def llm_client() -> openai.OpenAI: | ||
"""The LLM client to use for the RAG model.""" | ||
return openai.OpenAI() | ||
|
||
|
||
def conversational_rag_response(answer_prompt: str, llm_client: openai.OpenAI) -> str: | ||
"""Creates the RAG response from the LLM model for the given prompt. | ||
:param answer_prompt: the prompt to send to the LLM. | ||
:param llm_client: the LLM client to use. | ||
:return: the response from the LLM. | ||
""" | ||
response = llm_client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
messages=[{"role": "user", "content": answer_prompt}], | ||
) | ||
return response.choices[0].message.content | ||
|
||
|
||
if __name__ == "__main__": | ||
import __init__ as conversational_rag | ||
|
||
from hamilton import driver, lifecycle | ||
|
||
dr = ( | ||
driver.Builder() | ||
.with_modules(conversational_rag) | ||
.with_config({}) | ||
# this prints the inputs and outputs of each step. | ||
.with_adapters(lifecycle.PrintLn(verbosity=2)) | ||
.build() | ||
) | ||
dr.display_all_functions("dag.png") | ||
|
||
# shows no question is reworded | ||
print( | ||
dr.execute( | ||
["conversational_rag_response"], | ||
inputs={ | ||
"input_texts": [ | ||
"harrison worked at kensho", | ||
"stefan worked at Stitch Fix", | ||
], | ||
"question": "where did stefan work?", | ||
"chat_history": [], | ||
}, | ||
) | ||
) | ||
|
||
# this will now reword the question to then be | ||
# used to query the vector store. | ||
print( | ||
dr.execute( | ||
["conversational_rag_response"], | ||
inputs={ | ||
"input_texts": [ | ||
"harrison worked at kensho", | ||
"stefan worked at Stitch Fix", | ||
], | ||
"question": "where did he work?", | ||
"chat_history": ["Human: Who wrote this example?", "AI: Stefan"], | ||
}, | ||
) | ||
) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 70 additions & 0 deletions
70
contrib/hamilton/contrib/dagworks/conversational_rag/langchain.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from operator import itemgetter | ||
|
||
from langchain.prompts.prompt import PromptTemplate | ||
from langchain.schema import format_document | ||
from langchain_community.vectorstores import FAISS | ||
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_core.runnables import RunnableParallel, RunnablePassthrough | ||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | ||
|
||
vectorstore = FAISS.from_texts(["harrison worked at kensho"], embedding=OpenAIEmbeddings()) | ||
retriever = vectorstore.as_retriever() | ||
|
||
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. | ||
Chat History: | ||
{chat_history} | ||
Follow Up Input: {question} | ||
Standalone question:""" | ||
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) | ||
|
||
template = """Answer the question based only on the following context: | ||
{context} | ||
Question: {question} | ||
""" | ||
ANSWER_PROMPT = ChatPromptTemplate.from_template(template) | ||
|
||
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") | ||
|
||
|
||
def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"): | ||
doc_strings = [format_document(doc, document_prompt) for doc in docs] | ||
return document_separator.join(doc_strings) | ||
|
||
|
||
_inputs = RunnableParallel( | ||
standalone_question=RunnablePassthrough.assign( | ||
chat_history=lambda x: get_buffer_string(x["chat_history"]) | ||
) | ||
| CONDENSE_QUESTION_PROMPT | ||
| ChatOpenAI(temperature=0) | ||
| StrOutputParser(), | ||
) | ||
_context = { | ||
"context": itemgetter("standalone_question") | retriever | _combine_documents, | ||
"question": lambda x: x["standalone_question"], | ||
} | ||
conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | ChatOpenAI() | ||
|
||
print( | ||
conversational_qa_chain.invoke( | ||
{ | ||
"question": "where did harrison work?", | ||
"chat_history": [], | ||
} | ||
) | ||
) | ||
print( | ||
conversational_qa_chain.invoke( | ||
{ | ||
"question": "where did he work?", | ||
"chat_history": [ | ||
HumanMessage(content="Who wrote this notebook?"), | ||
AIMessage(content="Harrison"), | ||
], | ||
} | ||
) | ||
) |
4 changes: 4 additions & 0 deletions
4
contrib/hamilton/contrib/dagworks/conversational_rag/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
faiss-cpu | ||
langchain | ||
langchain-community | ||
langchain-openai |
7 changes: 7 additions & 0 deletions
7
contrib/hamilton/contrib/dagworks/conversational_rag/tags.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"schema": "1.0", | ||
"use_case_tags": ["LLM", "openai", "RAG", "retrieval augmented generation", "FAISS"], | ||
"secondary_tags": { | ||
"language": "English" | ||
} | ||
} |
1 change: 1 addition & 0 deletions
1
contrib/hamilton/contrib/dagworks/conversational_rag/valid_configs.jsonl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"description": "Default", "name": "default", "config": {}} |