From 55f71fe2b65d365104ce6d44398521698bf3e079 Mon Sep 17 00:00:00 2001 From: konrad-czarnota-ds Date: Tue, 15 Oct 2024 14:36:39 +0200 Subject: [PATCH 1/4] example: create e2e rag application --- .../examples/documents_chat.py | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 packages/ragbits-document-search/examples/documents_chat.py diff --git a/packages/ragbits-document-search/examples/documents_chat.py b/packages/ragbits-document-search/examples/documents_chat.py new file mode 100644 index 00000000..b5d4121d --- /dev/null +++ b/packages/ragbits-document-search/examples/documents_chat.py @@ -0,0 +1,171 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "gradio", +# "ragbits-document-search", +# "ragbits-core[chromadb, litellm]", +# ] +# /// +from pathlib import Path +from typing import AsyncIterator + +import chromadb +import gradio as gr +from pydantic import BaseModel + +from ragbits.core.embeddings.litellm import LiteLLMEmbeddings +from ragbits.core.llms.litellm import LiteLLM +from ragbits.core.prompt import Prompt +from ragbits.core.vector_store.chromadb_store import ChromaDBStore +from ragbits.document_search import DocumentSearch +from ragbits.document_search.documents.document import DocumentMeta +from ragbits.document_search.documents.element import TextElement + + +class QueryWithContext(BaseModel): + """ + Input format for the QueryWithContext. + """ + + query: str + context: list[str] + + +class ChatAnswer(BaseModel): + """ + Output format for the ChatAnswer. + """ + + answer: str + + +class RAGPrompt(Prompt[QueryWithContext, ChatAnswer]): + """ + A simple prompt for RAG system. + """ + + system_prompt = """ + You are a helpful assistant. Answer the QUESTION that will be provided using CONTEXT. + If in the given CONTEXT there is not enough information refuse to answer. + """ + + user_prompt = """ + QUESTION: + {{ query }} + + CONTEXT: + {% for item in context %} + {{ item }} + {% endfor %} + """ + + +class RAGSystemWithUI: + """ + Simple RAG application + """ + + DATABASE_CREATED_MESSAGE = "Database created and saved at: " + DATABASE_LOADED_MESSAGE = "Database loaded" + NO_DOCUMENTS_INGESTED_MESSAGE = ( + "Before making queries you need to select documents and create database or " + "provide a path to an existing database" + ) + DOCUMENT_PICKER_LABEL = "Documents" + DATABASE_TEXT_BOX_LABEL = "Database path" + DATABASE_CREATE_BUTTON_LABEL = "Create Database" + DATABASE_LOAD_BUTTON_LABEL = "Load Database" + DATABASE_CREATION_STATUS_LABEL = "Database creation status" + DATABASE_CREATION_STATUS_PLACEHOLDER = "Upload files and click 'Create Database' to start..." + DATABASE_LOADING_STATUS_LABEL = "Database loading status" + DATABASE_LOADING_STATUS_PLACEHOLDER = "Click 'Load Database' to start..." + + def __init__( + self, + database_path: str = "chroma", + index_name: str = "documents", + model_name: str = "gpt-4o-2024-08-06", + columns_ratios: tuple[int, int] = (1, 5), + chatbot_height_vh: int = 90, + ) -> None: + self._database_path = database_path + self._index_name = index_name + self._columns_ratios = columns_ratios + self._chatbot_height_vh = chatbot_height_vh + self._documents_ingested = False + self._prepare_document_search(self._database_path, self._index_name) + self._llm = LiteLLM(model_name, use_structured_output=True) + + def _prepare_document_search(self, database_path: str, index_name: str) -> None: + chroma_client = chromadb.PersistentClient(path=database_path) + embedding_client = LiteLLMEmbeddings() + + vector_store = ChromaDBStore( + index_name=index_name, + chroma_client=chroma_client, + embedding_function=embedding_client, + ) + self.document_search = DocumentSearch(embedder=vector_store.embedding_function, vector_store=vector_store) + + async def _create_database(self, document_paths: list[str]) -> str: + for path in document_paths: + await self.document_search.ingest_document(DocumentMeta.from_local_path(Path(path))) + self._documents_ingested = True + return self.DATABASE_CREATED_MESSAGE + self._database_path + + def _load_database(self, database_path: str) -> str: + self._prepare_document_search(database_path, self._index_name) + self._documents_ingested = True + return self.DATABASE_LOADED_MESSAGE + + async def _handle_message( + self, message: str, history: list[dict] # pylint: disable=unused-argument + ) -> AsyncIterator[str]: + if not self._documents_ingested: + yield self.NO_DOCUMENTS_INGESTED_MESSAGE + results = await self.document_search.search(message[-1]) + prompt = RAGPrompt( + QueryWithContext(query=message, context=[i.content for i in results if isinstance(i, TextElement)]) + ) + response = await self._llm.generate(prompt) + yield response.answer + + def prepare_layout(self) -> gr.Blocks: + """ + Crates gradio layout as gr.Blocks and connects component events with proper handlers + + Returns: + gradio layout + """ + with gr.Blocks(fill_height=True, fill_width=True) as app: + with gr.Row(): + with gr.Column(scale=self._columns_ratios[0]): + with gr.Group(): + documents_picker = gr.File(file_count="multiple", label=self.DOCUMENT_PICKER_LABEL) + create_btn = gr.Button(self.DATABASE_CREATE_BUTTON_LABEL) + creating_status_display = gr.Textbox( + label=self.DATABASE_CREATION_STATUS_LABEL, + interactive=False, + placeholder=self.DATABASE_CREATION_STATUS_PLACEHOLDER, + ) + + with gr.Group(): + database_path = gr.Textbox(label=self.DATABASE_TEXT_BOX_LABEL) + load_btn = gr.Button(self.DATABASE_LOAD_BUTTON_LABEL) + loading_status_display = gr.Textbox( + label=self.DATABASE_LOADING_STATUS_LABEL, + interactive=False, + placeholder=self.DATABASE_LOADING_STATUS_PLACEHOLDER, + ) + load_btn.click(fn=self._load_database, inputs=database_path, outputs=loading_status_display) + create_btn.click(fn=self._create_database, inputs=documents_picker, outputs=creating_status_display) + + with gr.Column(scale=self._columns_ratios[1]): + chat_interface = gr.ChatInterface(self._handle_message, type="messages") + chat_interface.chatbot.height = f"{self._chatbot_height_vh}vh" + return app + + +if __name__ == "__main__": + rag_system = RAGSystemWithUI() + rag_system.prepare_layout().launch() From 02b800c78514c387d1414e38b25ec65553bb73b3 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Tue, 15 Oct 2024 15:26:18 +0200 Subject: [PATCH 2/4] feat(llms): option to set a default LLM factory (#101) --- .../ragbits-core/src/ragbits/core/config.py | 3 + .../src/ragbits/core/llms/factory.py | 60 +++++++++++++++++++ .../src/ragbits/core/prompt/lab/app.py | 35 +++++------ .../tests/unit/llms/factory/__init__.py | 5 ++ .../unit/llms/factory/test_get_default_llm.py | 15 +++++ .../llms/factory/test_get_llm_from_factory.py | 22 +++++++ .../unit/llms/factory/test_has_default_llm.py | 20 +++++++ uv.lock | 14 +---- 8 files changed, 145 insertions(+), 29 deletions(-) create mode 100644 packages/ragbits-core/src/ragbits/core/llms/factory.py create mode 100644 packages/ragbits-core/tests/unit/llms/factory/__init__.py create mode 100644 packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py create mode 100644 packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py create mode 100644 packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py diff --git a/packages/ragbits-core/src/ragbits/core/config.py b/packages/ragbits-core/src/ragbits/core/config.py index 49c0b12b..830dfc9f 100644 --- a/packages/ragbits-core/src/ragbits/core/config.py +++ b/packages/ragbits-core/src/ragbits/core/config.py @@ -11,5 +11,8 @@ class CoreConfig(BaseModel): # Pattern used to search for prompt files prompt_path_pattern: str = "**/prompt_*.py" + # Path to a function that returns an LLM object, e.g. "my_project.llms.get_llm" + default_llm_factory: str | None = None + core_config = get_config_instance(CoreConfig, subproject="core") diff --git a/packages/ragbits-core/src/ragbits/core/llms/factory.py b/packages/ragbits-core/src/ragbits/core/llms/factory.py new file mode 100644 index 00000000..02bd4704 --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/llms/factory.py @@ -0,0 +1,60 @@ +import importlib + +from ragbits.core.config import core_config +from ragbits.core.llms.base import LLM +from ragbits.core.llms.litellm import LiteLLM + + +def get_llm_from_factory(factory_path: str) -> LLM: + """ + Get an instance of an LLM using a factory function specified by the user. + + Args: + factory_path (str): The path to the factory function. + + Returns: + LLM: An instance of the LLM. + """ + module_name, function_name = factory_path.rsplit(".", 1) + module = importlib.import_module(module_name) + function = getattr(module, function_name) + return function() + + +def has_default_llm() -> bool: + """ + Check if the default LLM factory is set in the configuration. + + Returns: + bool: Whether the default LLM factory is set. + """ + return core_config.default_llm_factory is not None + + +def get_default_llm() -> LLM: + """ + Get an instance of the default LLM using the factory function + specified in the configuration. + + Returns: + LLM: An instance of the default LLM. + + Raises: + ValueError: If the default LLM factory is not set. + """ + factory = core_config.default_llm_factory + if factory is None: + raise ValueError("Default LLM factory is not set") + + return get_llm_from_factory(factory) + + +def simple_litellm_factory() -> LLM: + """ + A basic LLM factory that creates an LiteLLM instance with the default model, + default options, and assumes that the API key is set in the environment. + + Returns: + LLM: An instance of the LiteLLM. + """ + return LiteLLM() diff --git a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py index 9b6a7dc4..05648b01 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py @@ -14,8 +14,8 @@ from rich.console import Console from ragbits.core.config import core_config -from ragbits.core.llms import LiteLLM -from ragbits.core.llms.clients import LiteLLMOptions +from ragbits.core.llms import LLM +from ragbits.core.llms.factory import get_llm_from_factory from ragbits.core.prompt import Prompt from ragbits.core.prompt.discovery import PromptDiscovery @@ -30,14 +30,12 @@ class PromptState: Attributes: prompts (list): A list containing discovered prompts. rendered_prompt (Prompt): The most recently rendered Prompt instance. - llm_model_name (str): The name of the selected LLM model. - llm_api_key (str | None): The API key for the chosen LLM model. + llm (LLM): The LLM instance to be used for generating responses. """ prompts: list = field(default_factory=list) rendered_prompt: Prompt | None = None - llm_model_name: str | None = None - llm_api_key: str | None = None + llm: LLM | None = None def render_prompt(index: int, system_prompt: str, user_prompt: str, state: PromptState, *args: Any) -> PromptState: @@ -99,15 +97,17 @@ def send_prompt_to_llm(state: PromptState) -> str: Returns: str: The response generated by the LLM. - """ - assert state.llm_model_name is not None, "LLM model name is not set." - llm_client = LiteLLM(model_name=state.llm_model_name, api_key=state.llm_api_key) + Raises: + ValueError: If the LLM model is not configured. + """ assert state.rendered_prompt is not None, "Prompt has not been rendered yet." + + if state.llm is None: + raise ValueError("LLM model is not configured.") + try: - response = asyncio.run( - llm_client.client.call(conversation=state.rendered_prompt.chat, options=LiteLLMOptions()) - ) + response = asyncio.run(state.llm.generate_raw(prompt=state.rendered_prompt)) except Exception as e: # pylint: disable=broad-except response = str(e) @@ -136,7 +136,8 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]: def lab_app( # pylint: disable=missing-param-doc - file_pattern: str = core_config.prompt_path_pattern, llm_model: str | None = None, llm_api_key: str | None = None + file_pattern: str = core_config.prompt_path_pattern, + llm_factory: str | None = core_config.default_llm_factory, ) -> None: """ Launches the interactive application for listing, rendering, and testing prompts @@ -163,9 +164,8 @@ def lab_app( # pylint: disable=missing-param-doc with gr.Blocks() as gr_app: prompts_state = gr.State( PromptState( - llm_model_name=llm_model, - llm_api_key=llm_api_key, prompts=list(prompts), + llm=get_llm_from_factory(llm_factory) if llm_factory else None, ) ) @@ -220,14 +220,15 @@ def show_split(index: int, state: gr.State) -> None: ) gr.Textbox(label="Rendered User Prompt", value=rendered_user_prompt, interactive=False) - llm_enabled = state.llm_model_name is not None + llm_enabled = state.llm is not None prompt_ready = state.rendered_prompt is not None llm_request_button = gr.Button( value="Send to LLM", interactive=llm_enabled and prompt_ready, ) gr.Markdown( - "To enable this button, select an LLM model when starting the app in CLI.", visible=not llm_enabled + "To enable this button set an LLM factory function in CLI options or your pyproject.toml", + visible=not llm_enabled, ) gr.Markdown("To enable this button, render a prompt first.", visible=llm_enabled and not prompt_ready) llm_prompt_response = gr.Textbox(lines=10, label="LLM response") diff --git a/packages/ragbits-core/tests/unit/llms/factory/__init__.py b/packages/ragbits-core/tests/unit/llms/factory/__init__.py new file mode 100644 index 00000000..a3559f0c --- /dev/null +++ b/packages/ragbits-core/tests/unit/llms/factory/__init__.py @@ -0,0 +1,5 @@ +import sys +from pathlib import Path + +# Add "llms" to sys.path +sys.path.append(str(Path(__file__).parent.parent)) diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py b/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py new file mode 100644 index 00000000..c07272fb --- /dev/null +++ b/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py @@ -0,0 +1,15 @@ +from ragbits.core.config import core_config +from ragbits.core.llms.factory import get_default_llm +from ragbits.core.llms.litellm import LiteLLM + + +def test_get_default_llm(monkeypatch): + """ + Test the get_llm_from_factory function. + """ + + monkeypatch.setattr(core_config, "default_llm_factory", "factory.test_get_llm_from_factory.mock_llm_factory") + + llm = get_default_llm() + assert isinstance(llm, LiteLLM) + assert llm.model_name == "mock_model" diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py b/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py new file mode 100644 index 00000000..8d2a948c --- /dev/null +++ b/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py @@ -0,0 +1,22 @@ +from ragbits.core.llms.factory import get_llm_from_factory +from ragbits.core.llms.litellm import LiteLLM + + +def mock_llm_factory() -> LiteLLM: + """ + A mock LLM factory that creates a LiteLLM instance with a mock model name. + + Returns: + LiteLLM: An instance of the LiteLLM. + """ + return LiteLLM(model_name="mock_model") + + +def test_get_llm_from_factory(): + """ + Test the get_llm_from_factory function. + """ + llm = get_llm_from_factory("factory.test_get_llm_from_factory.mock_llm_factory") + + assert isinstance(llm, LiteLLM) + assert llm.model_name == "mock_model" diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py b/packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py new file mode 100644 index 00000000..59c86483 --- /dev/null +++ b/packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py @@ -0,0 +1,20 @@ +from ragbits.core.config import core_config +from ragbits.core.llms.factory import has_default_llm + + +def test_has_default_llm(monkeypatch): + """ + Test the has_default_llm function when the default LLM factory is not set. + """ + monkeypatch.setattr(core_config, "default_llm_factory", None) + + assert has_default_llm() is False + + +def test_has_default_llm_false(monkeypatch): + """ + Test the has_default_llm function when the default LLM factory is set. + """ + monkeypatch.setattr(core_config, "default_llm_factory", "my_project.llms.get_llm") + + assert has_default_llm() is True diff --git a/uv.lock b/uv.lock index f685d643..6156c1af 100644 --- a/uv.lock +++ b/uv.lock @@ -617,7 +617,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version == '3.11'" }, + { name = "tomli", marker = "python_full_version <= '3.11'" }, ] [[package]] @@ -2038,7 +2038,6 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -2047,7 +2046,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -2056,7 +2054,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, - { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, ] [[package]] @@ -2065,7 +2062,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -2085,7 +2081,6 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -2094,7 +2089,6 @@ version = "10.3.2.106" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, - { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, ] [[package]] @@ -2108,7 +2102,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -2120,7 +2113,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -2138,7 +2130,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/11/8c/386018fdffdce2ff8d43fedf192ef7d14cab7501cbf78a106dd2e9f1fc1f/nvidia_nvjitlink_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:3bf10d85bb1801e9c894c6e197e44dd137d2a0a9e43f8450e9ad13f2df0dd52d", size = 19270432 }, { url = "https://files.pythonhosted.org/packages/fe/e4/486de766851d58699bcfeb3ba6a3beb4d89c3809f75b9d423b9508a8760f/nvidia_nvjitlink_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9ae346d16203ae4ea513be416495167a0101d33d2d14935aa9c1829a3fb45142", size = 19745114 }, - { url = "https://files.pythonhosted.org/packages/1a/aa/7b5d8e22d73e03f941293ae62c993642fa41e6525f3213292e007621aa8e/nvidia_nvjitlink_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:410718cd44962bed862a31dd0318620f6f9a8b28a6291967bcfcb446a6516771", size = 161917250 }, ] [[package]] @@ -2147,7 +2138,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, - { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] [[package]] @@ -3210,7 +3200,7 @@ dev = [ [[package]] name = "ragbits-workspace" version = "0.1.0" -source = { editable = "." } +source = { virtual = "." } dependencies = [ { name = "ragbits-cli" }, { name = "ragbits-core", extra = ["chromadb", "lab", "litellm", "local"] }, From f8771b39cb6b8a3c372a06a3646583db591dc9e0 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Wed, 16 Oct 2024 12:06:02 +0200 Subject: [PATCH 3/4] chore: Add project-level README (#103) --- CONTRIBUTING.md | 28 +++++++++++++++ README.md | 90 ++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 102 insertions(+), 16 deletions(-) create mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..5a4ce435 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# Installation + +## Build from source + +To build and run Ragbits from the source code: + +1. Requirements: [**uv**](https://docs.astral.sh/uv/getting-started/installation/) & [**python**](https://docs.astral.sh/uv/guides/install-python/) 3.10 or higher +2. Install dependencies and run venv in editable mode: + +```bash +$ source ./setup_dev_env.sh +``` + +## Install pre-commit + +To ensure code quality we use pre-commit hook with several checks. Setup it by: + +``` +pre-commit install +``` + +All updated files will be reformatted and linted before the commit. + +To reformat and lint all files in the project, use: + +`pre-commit run --all-files` + +The used linters are configured in `.pre-commit-config.yaml`. You can use `pre-commit autoupdate` to bump tools to the latest versions. diff --git a/README.md b/README.md index 2a2d8255..c47afb68 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,90 @@ -# Ragbits +
-Repository for internal experiment with our upcoming LLM framework. +

Ragbits

-# Installation +*Building blocks for rapid development of GenAI applications* -## Build from source +[Documentation](https://ragbits.deepsense.ai) | [Contact](https://deepsense.ai/contact/) -To build and run Ragbits from the source code: +[![PyPI - License](https://img.shields.io/pypi/l/ragbits)](https://pypi.org/project/ragbits) +[![PyPI - Version](https://img.shields.io/pypi/v/ragbits)](https://pypi.org/project/ragbits) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ragbits)](https://pypi.org/project/ragbits) -1. Requirements: [**uv**](https://docs.astral.sh/uv/getting-started/installation/) & [**python**](https://docs.astral.sh/uv/guides/install-python/) 3.10 or higher -2. Install dependencies and run venv in editable mode: +
-```bash -$ source ./setup_dev_env.sh +--- + +## What's Included? + +- [X] **[Core](packages/ragbits-core)** - Fundamental tools for working with prompts and LLMs. +- [X] **[Document Search](packages/ragbits-document-search)** - Handles vector search to retrieve relevant documents. +- [X] **[CLI](packages/ragbits-cli)** - The `ragbits` shell command, enabling tools such as GUI prompt management. +- [ ] **Flow Controls** - Manages multi-stage chat flows for performing advanced actions *(coming soon)*. +- [ ] **Structured Querying** - Queries structured data sources in a predictable manner *(coming soon)*. +- [ ] **Caching** - Adds a caching layer to reduce costs and response times *(coming soon)*. +- [ ] **Observability & Audit** - Tracks user queries and events for easier troubleshooting *(coming soon)*. +- [ ] **Guardrails** - Ensures response safety and relevance *(coming soon)*. + +## Installation + +To use the complete Ragbits stack, install the `ragbits` package: + +```sh +pip install ragbits +``` + +Alternatively, you can use individual components of the stack by installing their respective packages: `ragbits-core`, `ragbits-document-search`, `ragbits-cli`. + +## Quickstart + +First, create a prompt and a model for the data used in the prompt: + +```python +from pydantic import BaseModel +from ragbits.core.prompt import Prompt + +class Dog(BaseModel): + breed: str + age: int + temperament: str + +class DogNamePrompt(Prompt[Dog, str]): + system_prompt = """ + You are a dog name generator. You come up with funny names for dogs given the dog details. + """ + + user_prompt = """ + The dog is a {breed} breed, {age} years old, and has a {temperament} temperament. + """ ``` -## Install pre-commit +Next, create an instance of the LLM and the prompt: -To ensure code quality we use pre-commit hook with several checks. Setup it by: +```python +from ragbits.core.llms.litellm import LiteLLM +llm = LiteLLM("gpt-4o") +example_dog = Dog(breed="Golden Retriever", age=3, temperament="friendly") +prompt = DogNamePrompt(example_dog) ``` -pre-commit install + +Finally, generate a response from the LLM using the prompt: + +```python +response = await llm.generate(prompt) +print(f"Generated dog name: {response}") ``` -All updated files will be reformatted and linted before the commit. + + +## License -To reformat and lint all files in the project, use: +Ragbits is licensed under the [MIT License](LICENSE). -`pre-commit run --all-files` +## Contributing -The used linters are configured in `.pre-commit-config.yaml`. You can use `pre-commit autoupdate` to bump tools to the latest versions. +We welcome contributions! Please read [CONTRIBUTING.md](CONTRIBUTING.md) for more information. From d61725fc14961607e254657904bc3cf2f165fd75 Mon Sep 17 00:00:00 2001 From: akotyla <79326805+akotyla@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:06:25 +0200 Subject: [PATCH 4/4] feat(document-search): determine document type automatically (#99) --- .../src/ragbits/document_search/_main.py | 14 ++++++-- .../document_search/documents/document.py | 18 +++++++++++ .../tests/unit/test_document_search.py | 32 ++++++++++++++++++- 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index fe3f826c..dd92caee 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -6,6 +6,7 @@ from ragbits.core.vector_store import VectorStore, get_vector_store from ragbits.document_search.documents.document import Document, DocumentMeta from ragbits.document_search.documents.element import Element +from ragbits.document_search.documents.sources import GCSSource, LocalFileSource, Source from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter from ragbits.document_search.ingestion.providers.base import BaseProvider from ragbits.document_search.retrieval.rephrasers import get_rephraser @@ -104,7 +105,9 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig()) return self.reranker.rerank(elements) async def ingest_document( - self, document: Union[DocumentMeta, Document], document_processor: Optional[BaseProvider] = None + self, + document: Union[DocumentMeta, Document, Union[LocalFileSource, GCSSource]], + document_processor: Optional[BaseProvider] = None, ) -> None: """ Ingest a document. @@ -114,7 +117,14 @@ async def ingest_document( document_processor: The document processor to use. If not provided, the document processor will be determined based on the document metadata. """ - document_meta = document if isinstance(document, DocumentMeta) else document.metadata + + if isinstance(document, Source): + document_meta = await DocumentMeta.from_source(document) + elif isinstance(document, DocumentMeta): + document_meta = document + else: + document_meta = document.metadata + if document_processor is None: document_processor = self.document_processor_router.get_provider(document_meta) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py index 2ca2ec9a..0d43df91 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/document.py @@ -97,6 +97,24 @@ def from_local_path(cls, local_path: Path) -> "DocumentMeta": source=LocalFileSource(path=local_path), ) + @classmethod + async def from_source(cls, source: Union[LocalFileSource, GCSSource]) -> "DocumentMeta": + """ + Create a document metadata from a source. + + Args: + source: The source from which the document is fetched. + + Returns: + The document metadata. + """ + path = await source.fetch() + + return cls( + document_type=DocumentType(path.suffix[1:]), + source=source, + ) + class Document(BaseModel): """ diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index 76743c89..ba8695b1 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -1,3 +1,4 @@ +import tempfile from pathlib import Path from typing import Union from unittest.mock import AsyncMock @@ -7,8 +8,10 @@ from ragbits.core.vector_store.in_memory import InMemoryVectorStore from ragbits.document_search import DocumentSearch from ragbits.document_search._main import SearchConfig -from ragbits.document_search.documents.document import Document, DocumentMeta +from ragbits.document_search.documents.document import Document, DocumentMeta, DocumentType from ragbits.document_search.documents.element import TextElement +from ragbits.document_search.documents.sources import LocalFileSource +from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter from ragbits.document_search.ingestion.providers.dummy import DummyProvider CONFIG = { @@ -46,6 +49,33 @@ async def test_document_search_from_config(document, expected): assert first_result.content == expected +async def test_document_search_ingest_document_from_source(): + embeddings_mock = AsyncMock() + embeddings_mock.embed_text.return_value = [[0.1, 0.1]] + + providers = {DocumentType.TXT: DummyProvider()} + router = DocumentProcessorRouter.from_config(providers) + + document_search = DocumentSearch( + embedder=embeddings_mock, vector_store=InMemoryVectorStore(), document_processor_router=router + ) + + with tempfile.NamedTemporaryFile(suffix=".txt") as f: + f.write(b"Name of Peppa's brother is George") + f.seek(0) + + source = LocalFileSource(path=Path(f.name)) + + await document_search.ingest_document(source) + + results = await document_search.search("Peppa's brother") + + first_result = results[0] + + assert isinstance(first_result, TextElement) + assert first_result.content == "Name of Peppa's brother is George" + + @pytest.mark.parametrize( "document", [