Skip to content

Commit

Permalink
Adds simple RAG example to contrib (#673)
Browse files Browse the repository at this point in the history
This is a basic example to show the basic mechanics of a RAG pipeline.

It uses an in memory vector store with the FAISS for similarity search.
  • Loading branch information
skrawcz authored Feb 2, 2024
1 parent 3ea0068 commit 33c9e36
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 0 deletions.
80 changes: 80 additions & 0 deletions contrib/hamilton/contrib/dagworks/faiss_rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Purpose of this module

This module shows a simple retrieval augmented generation (RAG) example using
Hamilton. It shows you how you might structure your code with Hamilton to
create a simple RAG pipeline.

This example uses [FAISS](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) + and in memory vector store and the OpenAI LLM provider.
The implementation of the FAISS vector store uses the LangChain wrapper around it.
That's because this was the simplest way to get this example up without requiring
someone having to host and manage a proper vector store.

## Example Usage

### Inputs
These are the defined inputs.

- *input_texts*: A list of strings. Each string will be encoded into a vector and stored in the vector store.
- *question*: A string. This is the question you want to ask the LLM, and vector store which will provide context.
- *top_k*: An integer. This is the number of vectors to retrieve from the vector store. Defaults to 5.

### Overrides
With Hamilton you can easily override a function and provide a value for it. For example if you're
iterating you might just want to override these two values before modifying the functions:

- *context*: if you want to skip going to the vector store and provide the context directly, you can do so by providing this override.
- *rag_prompt*: if you want to provide the prompt to pass to the LLM, pass it in as an override.

### Execution
You can ask to get back any result of an intermediate function by providing the function name in the `execute` call.
Here we just ask for the final result, but if you wanted to, you could ask for outputs of any of the functions, which
you can then introspect or log for debugging/evaluation purposes. Note if you want more platform integrations,
you can add adapters that will do this automatically for you, e.g. like we have the `PrintLn` adapter here.
```python
# import the module
from hamilton import driver
from hamilton import lifecycle
dr = (
driver.Builder()
.with_modules(faiss_rag)
.with_config({})
# this prints the inputs and outputs of each step.
.with_adapters(lifecycle.PrintLn(verbosity=2))
.build()
)
result = dr.execute(
["rag_response"],
inputs={
"input_texts": [
"harrison worked at kensho",
"stefan worked at Stitch Fix",
],
"question": "where did stefan work?",
},
)
print(result)
```

# How to extend this module
What you'd most likely want to do is:

1. Change the vector store (and how embeddings are generated).
2. Change the LLM provider.
3. Change the context and prompt.

With (1) you can import any vector store/library that you want. You should draw out
the process you would like, and that should then map to Hamilton functions.
With (2) you can import any LLM provider that you want, just use `@config.when` if you
want to switch between multiple providers.
With (3) you can add more functions that create parts of the prompt.

# Configuration Options
There is no configuration needed for this module.

# Limitations

You need to have the OPENAI_API_KEY in your environment.
It should be accessible from your code by doing `os.environ["OPENAI_API_KEY"]`.

The code does not check the context length, so it may fail if the context passed is too long
for the LLM you send it to.
105 changes: 105 additions & 0 deletions contrib/hamilton/contrib/dagworks/faiss_rag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import logging

logger = logging.getLogger(__name__)

from hamilton import contrib

with contrib.catch_import_errors(__name__, __file__, logger):
import openai

# use langchain implementation of vector store
from langchain_community.vectorstores import FAISS
from langchain_core.vectorstores import VectorStoreRetriever

# use langchain embedding wrapper with vector store
from langchain_openai import OpenAIEmbeddings


def vector_store(input_texts: list[str]) -> VectorStoreRetriever:
"""A Vector store. This function populates and creates one for querying.
This is a cute function encapsulating the creation of a vector store. In real life
you could replace this with a more complex function, or one that returns a
client to an existing vector store.
:param input_texts: the input "text" i.e. documents to be stored.
:return: a vector store that can be queried against.
"""
vectorstore = FAISS.from_texts(input_texts, embedding=OpenAIEmbeddings())
retriever = vectorstore.as_retriever()
return retriever


def context(question: str, vector_store: VectorStoreRetriever, top_k: int = 5) -> str:
"""This function returns the string context to put into a prompt for the RAG model.
:param question: the user question to use to search the vector store against.
:param vector_store: the vector store to search against.
:param top_k: the number of results to return.
:return: a string with all the context.
"""
_results = vector_store.invoke(question, search_kwargs={"k": top_k})
return "\n".join(map(lambda d: d.page_content, _results))


def rag_prompt(context: str, question: str) -> str:
"""Creates a prompt that includes the question and context for the LLM to makse sense of.
:param context: the information context to use.
:param question: the user question the LLM should answer.
:return: the full prompt.
"""
template = (
"Answer the question based only on the following context:\n"
"{context}\n\n"
"Question: {question}"
)

return template.format(context=context, question=question)


def llm_client() -> openai.OpenAI:
"""The LLM client to use for the RAG model."""
return openai.OpenAI()


def rag_response(rag_prompt: str, llm_client: openai.OpenAI) -> str:
"""Creates the RAG response from the LLM model for the given prompt.
:param rag_prompt: the prompt to send to the LLM.
:param llm_client: the LLM client to use.
:return: the response from the LLM.
"""
response = llm_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": rag_prompt}],
)
return response.choices[0].message.content


if __name__ == "__main__":
import __init__ as hamilton_faiss_rag

from hamilton import driver, lifecycle

dr = (
driver.Builder()
.with_modules(hamilton_faiss_rag)
.with_config({})
# this prints the inputs and outputs of each step.
.with_adapters(lifecycle.PrintLn(verbosity=2))
.build()
)
dr.display_all_functions("dag.png")
print(
dr.execute(
["rag_response"],
inputs={
"input_texts": [
"harrison worked at kensho",
"stefan worked at Stitch Fix",
],
"question": "where did stefan work?",
},
)
)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions contrib/hamilton/contrib/dagworks/faiss_rag/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
faiss-cpu
langchain
langchain-community
langchain-openai
7 changes: 7 additions & 0 deletions contrib/hamilton/contrib/dagworks/faiss_rag/tags.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"schema": "1.0",
"use_case_tags": ["LLM", "openai", "RAG", "retrieval augmented generation", "FAISS"],
"secondary_tags": {
"language": "English"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"description": "Default", "name": "default", "config": {}}

0 comments on commit 33c9e36

Please sign in to comment.