-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cc33d4c
commit 55f71fe
Showing
1 changed file
with
171 additions
and
0 deletions.
There are no files selected for viewing
171 changes: 171 additions & 0 deletions
171
packages/ragbits-document-search/examples/documents_chat.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
# /// script | ||
# requires-python = ">=3.10" | ||
# dependencies = [ | ||
# "gradio", | ||
# "ragbits-document-search", | ||
# "ragbits-core[chromadb, litellm]", | ||
# ] | ||
# /// | ||
from pathlib import Path | ||
from typing import AsyncIterator | ||
|
||
import chromadb | ||
import gradio as gr | ||
from pydantic import BaseModel | ||
|
||
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings | ||
from ragbits.core.llms.litellm import LiteLLM | ||
from ragbits.core.prompt import Prompt | ||
from ragbits.core.vector_store.chromadb_store import ChromaDBStore | ||
from ragbits.document_search import DocumentSearch | ||
from ragbits.document_search.documents.document import DocumentMeta | ||
from ragbits.document_search.documents.element import TextElement | ||
|
||
|
||
class QueryWithContext(BaseModel): | ||
""" | ||
Input format for the QueryWithContext. | ||
""" | ||
|
||
query: str | ||
context: list[str] | ||
|
||
|
||
class ChatAnswer(BaseModel): | ||
""" | ||
Output format for the ChatAnswer. | ||
""" | ||
|
||
answer: str | ||
|
||
|
||
class RAGPrompt(Prompt[QueryWithContext, ChatAnswer]): | ||
""" | ||
A simple prompt for RAG system. | ||
""" | ||
|
||
system_prompt = """ | ||
You are a helpful assistant. Answer the QUESTION that will be provided using CONTEXT. | ||
If in the given CONTEXT there is not enough information refuse to answer. | ||
""" | ||
|
||
user_prompt = """ | ||
QUESTION: | ||
{{ query }} | ||
CONTEXT: | ||
{% for item in context %} | ||
{{ item }} | ||
{% endfor %} | ||
""" | ||
|
||
|
||
class RAGSystemWithUI: | ||
""" | ||
Simple RAG application | ||
""" | ||
|
||
DATABASE_CREATED_MESSAGE = "Database created and saved at: " | ||
DATABASE_LOADED_MESSAGE = "Database loaded" | ||
NO_DOCUMENTS_INGESTED_MESSAGE = ( | ||
"Before making queries you need to select documents and create database or " | ||
"provide a path to an existing database" | ||
) | ||
DOCUMENT_PICKER_LABEL = "Documents" | ||
DATABASE_TEXT_BOX_LABEL = "Database path" | ||
DATABASE_CREATE_BUTTON_LABEL = "Create Database" | ||
DATABASE_LOAD_BUTTON_LABEL = "Load Database" | ||
DATABASE_CREATION_STATUS_LABEL = "Database creation status" | ||
DATABASE_CREATION_STATUS_PLACEHOLDER = "Upload files and click 'Create Database' to start..." | ||
DATABASE_LOADING_STATUS_LABEL = "Database loading status" | ||
DATABASE_LOADING_STATUS_PLACEHOLDER = "Click 'Load Database' to start..." | ||
|
||
def __init__( | ||
self, | ||
database_path: str = "chroma", | ||
index_name: str = "documents", | ||
model_name: str = "gpt-4o-2024-08-06", | ||
columns_ratios: tuple[int, int] = (1, 5), | ||
chatbot_height_vh: int = 90, | ||
) -> None: | ||
self._database_path = database_path | ||
self._index_name = index_name | ||
self._columns_ratios = columns_ratios | ||
self._chatbot_height_vh = chatbot_height_vh | ||
self._documents_ingested = False | ||
self._prepare_document_search(self._database_path, self._index_name) | ||
self._llm = LiteLLM(model_name, use_structured_output=True) | ||
|
||
def _prepare_document_search(self, database_path: str, index_name: str) -> None: | ||
chroma_client = chromadb.PersistentClient(path=database_path) | ||
embedding_client = LiteLLMEmbeddings() | ||
|
||
vector_store = ChromaDBStore( | ||
index_name=index_name, | ||
chroma_client=chroma_client, | ||
embedding_function=embedding_client, | ||
) | ||
self.document_search = DocumentSearch(embedder=vector_store.embedding_function, vector_store=vector_store) | ||
|
||
async def _create_database(self, document_paths: list[str]) -> str: | ||
for path in document_paths: | ||
await self.document_search.ingest_document(DocumentMeta.from_local_path(Path(path))) | ||
self._documents_ingested = True | ||
return self.DATABASE_CREATED_MESSAGE + self._database_path | ||
|
||
def _load_database(self, database_path: str) -> str: | ||
self._prepare_document_search(database_path, self._index_name) | ||
self._documents_ingested = True | ||
return self.DATABASE_LOADED_MESSAGE | ||
|
||
async def _handle_message( | ||
self, message: str, history: list[dict] # pylint: disable=unused-argument | ||
) -> AsyncIterator[str]: | ||
if not self._documents_ingested: | ||
yield self.NO_DOCUMENTS_INGESTED_MESSAGE | ||
results = await self.document_search.search(message[-1]) | ||
prompt = RAGPrompt( | ||
QueryWithContext(query=message, context=[i.content for i in results if isinstance(i, TextElement)]) | ||
) | ||
response = await self._llm.generate(prompt) | ||
yield response.answer | ||
|
||
def prepare_layout(self) -> gr.Blocks: | ||
""" | ||
Crates gradio layout as gr.Blocks and connects component events with proper handlers | ||
Returns: | ||
gradio layout | ||
""" | ||
with gr.Blocks(fill_height=True, fill_width=True) as app: | ||
with gr.Row(): | ||
with gr.Column(scale=self._columns_ratios[0]): | ||
with gr.Group(): | ||
documents_picker = gr.File(file_count="multiple", label=self.DOCUMENT_PICKER_LABEL) | ||
create_btn = gr.Button(self.DATABASE_CREATE_BUTTON_LABEL) | ||
creating_status_display = gr.Textbox( | ||
label=self.DATABASE_CREATION_STATUS_LABEL, | ||
interactive=False, | ||
placeholder=self.DATABASE_CREATION_STATUS_PLACEHOLDER, | ||
) | ||
|
||
with gr.Group(): | ||
database_path = gr.Textbox(label=self.DATABASE_TEXT_BOX_LABEL) | ||
load_btn = gr.Button(self.DATABASE_LOAD_BUTTON_LABEL) | ||
loading_status_display = gr.Textbox( | ||
label=self.DATABASE_LOADING_STATUS_LABEL, | ||
interactive=False, | ||
placeholder=self.DATABASE_LOADING_STATUS_PLACEHOLDER, | ||
) | ||
load_btn.click(fn=self._load_database, inputs=database_path, outputs=loading_status_display) | ||
create_btn.click(fn=self._create_database, inputs=documents_picker, outputs=creating_status_display) | ||
|
||
with gr.Column(scale=self._columns_ratios[1]): | ||
chat_interface = gr.ChatInterface(self._handle_message, type="messages") | ||
chat_interface.chatbot.height = f"{self._chatbot_height_vh}vh" | ||
return app | ||
|
||
|
||
if __name__ == "__main__": | ||
rag_system = RAGSystemWithUI() | ||
rag_system.prepare_layout().launch() |