From 73992af5a79335a8cf8ecaf890f115bd44af3e9e Mon Sep 17 00:00:00 2001 From: akotyla <79326805+akotyla@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:14:36 +0200 Subject: [PATCH 01/10] chore: add script to update package version (#29) --- scripts/update_ragbits_package.py | 162 ++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) create mode 100644 scripts/update_ragbits_package.py diff --git a/scripts/update_ragbits_package.py b/scripts/update_ragbits_package.py new file mode 100644 index 00000000..85c9758f --- /dev/null +++ b/scripts/update_ragbits_package.py @@ -0,0 +1,162 @@ +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "tomlkit", +# "inquirer", +# "rich", +# "typing-extensions", +# "typer" +# ] +# /// +# To run this script and update a package, run the following command: +# +# uv run scripts/update_ragbits_package.py +# + +from copy import deepcopy +from enum import Enum +from pathlib import Path +from typing import Optional + +import tomlkit +import typer +from inquirer.shortcuts import confirm, list_input, text +from rich import print as pprint + +PACKAGES_DIR = Path(__file__).parent.parent / "packages" + + +class UpdateType(Enum): + """ + Enum representing the type of version update: major, minor, or patch. + """ + + MAJOR = "major" + MINOR = "minor" + PATCH = "patch" + + +def _update_type_to_enum(update_type: Optional[str] = None) -> Optional[UpdateType]: + if update_type is not None: + return UpdateType(update_type) + return None + + +def _version_to_list(version_string): + return [int(part) for part in version_string.split(".")] + + +def _check_update_type(version: str, new_version: str) -> Optional[UpdateType]: + version_list = _version_to_list(version) + new_version_list = _version_to_list(new_version) + + if version_list[0] != new_version_list[0]: + return UpdateType.MAJOR + if version_list[1] != new_version_list[1]: + return UpdateType.MINOR + if version_list[2] != new_version_list[2]: + return UpdateType.PATCH + return None + + +def _get_updated_version(version: str, update_type: UpdateType) -> str: + version_list = _version_to_list(version) + + if update_type == UpdateType.MAJOR: + new_version_list = [version_list[0] + 1, 0, 0] + elif update_type == UpdateType.MINOR: + new_version_list = [version_list[0], version_list[1] + 1, 0] + else: + new_version_list = deepcopy(version_list) + new_version_list[2] = new_version_list[2] + 1 + + return ".".join([str(n) for n in new_version_list]) + + +def _update_pkg_version( + pkg_name: str, + pkg_pyproject: Optional[tomlkit.TOMLDocument] = None, + new_version: Optional[str] = None, + update_type: Optional[UpdateType] = None, +) -> tuple[str, str]: + if not pkg_pyproject: + pkg_pyproject = tomlkit.parse((PACKAGES_DIR / pkg_name / "pyproject.toml").read_text()) + + version = pkg_pyproject["project"]["version"] + + if not new_version: + if update_type is not None: + new_version = _get_updated_version(version, update_type=update_type) + else: + pprint(f"Current version of the [bold]{pkg_name}[/bold] package is: [bold]{version}[/bold]") + new_version = text("Enter the new version", default=_get_updated_version(version, UpdateType.PATCH)) + + pkg_pyproject["project"]["version"] = new_version + (PACKAGES_DIR / pkg_name / "pyproject.toml").write_text(tomlkit.dumps(pkg_pyproject)) + + assert isinstance(new_version, str) + pprint(f"[green]The {pkg_name} package was successfully updated from {version} to {new_version}.[/green]") + + return version, new_version + + +def run(pkg_name: Optional[str] = typer.Argument(None), update_type: Optional[str] = typer.Argument(None)) -> None: + """ + Main entry point for the package version updater. Updates package versions based on user input. + + Based on the provided package name and update type, this function updates the version of a + specific package. If the package is "ragbits-core", all other packages that depend on it + will also be updated accordingly. + + If no package name or update type is provided, the user will be prompted to select a package + and version update type interactively. For "ragbits-core", the user is asked for confirmation + before proceeding with a global update. + + Args: + pkg_name: Name of the package to update. If not provided, the user is prompted. + update_type: Type of version update to apply (major, minor or patch). If not provided, + the user is prompted for this input. + + Raises: + ValueError: If the provided `pkg_name` is not found in the available packages. + """ + + packages: list[str] = [obj.name for obj in PACKAGES_DIR.iterdir() if obj.is_dir()] + + if pkg_name is not None: + if pkg_name not in packages: + raise ValueError(f"Package '{pkg_name}' not found in available packages.") + else: + pkg_name = list_input("Enter the package name", choices=packages) + + casted_update_type = _update_type_to_enum(update_type) + + user_prompt_required = pkg_name is None or casted_update_type is None + + if pkg_name == "ragbits-core": + if user_prompt_required: + print("When upgrading the ragbits-core package it is also necessary to upgrade the other packages.") + is_continue = confirm(message="Do you want to continue?") + else: + is_continue = True + + if is_continue: + ragbits_version, new_ragbits_version = _update_pkg_version(pkg_name, update_type=casted_update_type) + casted_update_type = _check_update_type(ragbits_version, new_ragbits_version) + + for pkg in [pkg for pkg in packages if pkg != "ragbits-core"]: + pkg_pyproject = tomlkit.parse((PACKAGES_DIR / pkg / "pyproject.toml").read_text()) + pkg_pyproject["project"]["dependencies"] = [ + dep for dep in pkg_pyproject["project"]["dependencies"] if "ragbits" not in dep + ] + pkg_pyproject["project"]["dependencies"].append(f"ragbits=={new_ragbits_version}") + _update_pkg_version(pkg, pkg_pyproject, update_type=casted_update_type) + + else: + pprint("[red]The ragbits-core package was not successfully updated.[/red]") + else: + _update_pkg_version(pkg_name, update_type=casted_update_type) + + +if __name__ == "__main__": + typer.run(run) From c4647cca7125268063cb3e3cda3d255ce92cd815 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= <26008518+mhordynski@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:01:11 +0200 Subject: [PATCH 02/10] feat(document-search): add chunking in unstructured provider (#48) --- .../document_search/ingestion/providers/unstructured.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py index 037bed34..2e81b8ab 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/unstructured.py @@ -2,6 +2,7 @@ from io import BytesIO from typing import Optional +from unstructured.chunking.basic import chunk_elements from unstructured.documents.elements import Element as UnstructuredElement from unstructured.partition.api import partition_via_api @@ -17,6 +18,8 @@ "split_pdf_concurrency_level": 15, } +DEFAULT_CHUNKING_KWARGS: dict = {} + UNSTRUCTURED_API_KEY_ENV = "UNSTRUCTURED_API_KEY" UNSTRUCTURED_API_URL_ENV = "UNSTRUCTURED_API_URL" @@ -47,7 +50,7 @@ class UnstructuredProvider(BaseProvider): DocumentType.XML, } - def __init__(self, partition_kwargs: Optional[dict] = None): + def __init__(self, partition_kwargs: Optional[dict] = None, chunking_kwargs: Optional[dict] = None): """Initialize the UnstructuredProvider. Args: @@ -55,6 +58,7 @@ def __init__(self, partition_kwargs: Optional[dict] = None): for the available options: https://docs.unstructured.io/api-reference/api-services/api-parameters """ self.partition_kwargs = partition_kwargs or DEFAULT_PARTITION_KWARGS + self.chunking_kwargs = chunking_kwargs or DEFAULT_CHUNKING_KWARGS async def process(self, document_meta: DocumentMeta) -> list[Element]: """Process the document using the Unstructured API. @@ -86,6 +90,7 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]: api_url=api_url, **self.partition_kwargs, ) + elements = chunk_elements(elements, **self.chunking_kwargs) return [_to_text_element(element, document_meta) for element in elements] From 937235149fee1acdf0003d58fedf096a4906db24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= <26008518+mhordynski@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:01:27 +0200 Subject: [PATCH 03/10] fix: not existing dir for nested gcs objects (#47) --- .../src/ragbits/document_search/documents/sources.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py index 3c7240de..fc5a93a8 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py @@ -106,7 +106,7 @@ async def fetch(self) -> Path: raise ImportError("You need to install the 'gcloud-aio-storage' package to use Google Cloud Storage") if (local_dir_env := os.getenv(LOCAL_STORAGE_DIR_ENV)) is None: - local_dir = Path(tempfile.gettempdir()) + local_dir = Path(tempfile.gettempdir()) / "ragbits" else: local_dir = Path(local_dir_env) @@ -117,6 +117,7 @@ async def fetch(self) -> Path: if not path.is_file(): async with Storage() as client: content = await client.download(self.bucket, self.object_name) + Path(bucket_local_dir / self.object_name).parent.mkdir(parents=True, exist_ok=True) with open(path, mode="wb+") as file_object: file_object.write(content) From d390703c7301c7b12a84e52878ba2009f56b1514 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Mon, 30 Sep 2024 12:04:47 +0200 Subject: [PATCH 04/10] feat(prompt-discovery): Look for prompts in arbitrary files using patterns (#38) --- .../src/ragbits/dev_kit/prompt_lab/app.py | 163 ++++++++---------- .../prompt_lab/discovery/prompt_discovery.py | 99 +++++------ .../tests/unit/discovery/__init__.py | 0 .../discovery/prompt_classes_for_tests.py | 5 + .../unit/discovery/test_prompt_discovery.py | 42 ++--- 5 files changed, 132 insertions(+), 177 deletions(-) create mode 100644 packages/ragbits-dev-kit/tests/unit/discovery/__init__.py diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py index 879d7963..a3a6d25c 100644 --- a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py @@ -8,7 +8,8 @@ from ragbits.core.llms import LiteLLM from ragbits.core.llms.clients import LiteLLMOptions -from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import PromptDiscovery +from ragbits.core.prompt import Prompt +from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery class PromptState: @@ -18,8 +19,7 @@ class PromptState: This class holds various data structures used throughout the application's lifecycle. Attributes: - prompts_state (dict): A dictionary containing discovered prompts. Keys are prompt names, - values are details about the corresponding Prompt-based class. + prompts (list): A list containing discovered prompts. variable_values (dict): A dictionary to store values entered by the user for prompt input fields. dynamic_tb (dict): A dictionary containing dynamically created textboxes based on prompt input fields. current_prompt (Prompt): The currently processed Prompt object. This is created upon clicking the @@ -30,64 +30,39 @@ class PromptState: temp_field_name (str): Temporary field name used internally. """ - prompts_state: dict = {} + prompts: list = [] variable_values: dict = {} dynamic_tb: dict = {} - current_prompt = None - llm_model_name: str = "" - llm_api_key: str | None = "" + current_prompt: Prompt | None = None + llm_model_name: str | None = None + llm_api_key: str | None = None temp_field_name: str = "" -def get_prompts_list(path: str, state: gr.State) -> gr.State: +def load_prompts_list(pattern: str, state: gr.State) -> gr.State: """ - Fetches a list of prompts based on provided paths. + Fetches a list of prompts based on provided paths and updates the application state. - This function takes a comma-separated string of paths to prompt definition files and uses the + This function takes a path-pattern for discovering prompt definition files and uses the PromptDiscovery class to discover prompts within those files. The discovered prompts are then stored in the application state object. Args: - path (str): A comma-separated string of paths to prompt definition files. + pattern (str): A pattern for looking up prompt files. state (gr.State): The Gradio state object to update with discovered prompts. Returns: gr.State: The updated Gradio state object containing the list of discovered prompts. """ - prompt_paths: list[str] = path.split(",") - - obj = PromptDiscovery(file_paths=prompt_paths) - discovered_prompts = obj.discover() - state.value.prompts_state.update(discovered_prompts) + obj = PromptDiscovery(file_pattern=pattern) + discovered_prompts = list(obj.discover()) + state.value.prompts = discovered_prompts return state -def display_data(key: str, state: PromptState) -> tuple[gr.Textbox, gr.Textbox]: - """ - Displays system and user prompts for a given key from the prompts state. - - This function retrieves the system and user prompt text associated with a particular prompt key - from the application state and creates Gradio Textbox elements for them. - - Args: - key (str): The key of the prompt to display data for in the prompts state. - state (PromptState): The application state object containing prompt data. - - Returns: - tuple[gr.Textbox, gr.Textbox]: A tuple containing two Gradio Textbox elements, - one for the system prompt and one for the user prompt. - """ - data = state.prompts_state[key] - - system_prompt_text = gr.Textbox(label="System Prompt", value=data["system_prompt"]) - user_prompt_text = gr.Textbox(label="User Prompt", value=data["user_prompt"]) - - return system_prompt_text, user_prompt_text - - def render_prompt( - key: str, system_prompt: str, user_prompt: str, state: gr.State, *args: Any + index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any ) -> tuple[str, str, gr.State]: """ Renders a prompt based on the provided key, system prompt, user prompt, and input variables. @@ -96,7 +71,7 @@ def render_prompt( associated with the given key. It then updates the current prompt in the application state. Args: - key (str): The key of the prompt to render. + index (int): The index of the prompt to render in the prompts state. system_prompt (str): The system prompt template for the prompt. user_prompt (str): The user prompt template for the prompt. state (PromptState): The application state object. @@ -108,13 +83,13 @@ def render_prompt( """ variables = dict(zip(state.dynamic_tb.keys(), args)) - prompt_constructor = state.prompts_state[key]["object"] - prompt_constructor.system_prompt_template = jinja2.Template(system_prompt) - prompt_constructor.user_prompt_template = jinja2.Template(user_prompt) - - input_constructor = state.prompts_state[key]["input_type"] + prompt_class = state.prompts[index] + prompt_class.system_prompt_template = jinja2.Template(system_prompt) + prompt_class.user_prompt_template = jinja2.Template(user_prompt) - prompt_object = prompt_constructor(input_data=input_constructor(**variables)) + input_type = prompt_class.input_type + input_data = input_type(**variables) if input_type is not None else None + prompt_object = prompt_class(input_data=input_data) state.current_prompt = prompt_object chat_dict = {entry["role"]: entry["content"] for entry in prompt_object.chat} @@ -122,6 +97,22 @@ def render_prompt( return chat_dict["system"], chat_dict["user"], state +def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]: + """ + Returns a list of prompt choices based on the discovered prompts. + + This function generates a list of tuples containing the names of discovered prompts and their + corresponding indices. + + Args: + state (gr.State): The application state object. + + Returns: + list[tuple[str, int]]: A list of tuples containing prompt names and their indices. + """ + return [(prompt.__name__, idx) for idx, prompt in enumerate(state.value.prompts)] + + def send_prompt_to_llm(state: gr.State) -> str: """ Sends the current prompt to the LLM and returns the response. @@ -147,7 +138,7 @@ def send_prompt_to_llm(state: gr.State) -> str: return response -def get_input_type_fields(obj: BaseModel) -> list[dict]: +def get_input_type_fields(obj: BaseModel | None) -> list[dict]: """ Retrieves the field names and default values from the input type of a prompt. @@ -160,6 +151,8 @@ def get_input_type_fields(obj: BaseModel) -> list[dict]: Returns: list[dict]: A list of dictionaries, each containing a field name and its default value. """ + if obj is None: + return [] return [ {"field_name": k, "field_default_value": v["schema"].get("default", None)} for (k, v) in obj.__pydantic_core_schema__["schema"]["fields"].items() @@ -170,7 +163,9 @@ def get_input_type_fields(obj: BaseModel) -> list[dict]: @typer_app.command() -def run_app(prompts_paths: str, llm_model: str, llm_api_key: str | None = None) -> None: +def run_app( + file_pattern: str = DEFAULT_FILE_PATTERN, llm_model: str | None = None, llm_api_key: str | None = None +) -> None: """ Launches the interactive application for working with Large Language Models (LLMs). @@ -178,7 +173,7 @@ def run_app(prompts_paths: str, llm_model: str, llm_api_key: str | None = None) 1. Initializes the application state using the PromptState class. 2. Sets the LLM model name and API key based on user-provided arguments. - 3. Fetches a list of prompts from the specified paths using the get_prompts_list function. + 3. Fetches a list of prompts from the specified paths using the load_prompts_list function. 4. Creates a Gradio interface with various UI elements: - A dropdown menu for selecting prompts. - Textboxes for displaying and potentially modifying system and user prompts. @@ -186,10 +181,9 @@ def run_app(prompts_paths: str, llm_model: str, llm_api_key: str | None = None) - Buttons for rendering prompts, sending prompts to the LLM, and displaying the response. Args: - prompts_paths (str): A comma-separated string of paths to prompt definition files. + file_pattern (str): A pattern for looking up prompt files. llm_model (str): The name of the LLM model to use. llm_api_key (str): The API key for the chosen LLM model. - """ with gr.Blocks() as gr_app: prompt_state_obj = PromptState() @@ -197,43 +191,32 @@ def run_app(prompts_paths: str, llm_model: str, llm_api_key: str | None = None) prompt_state_obj.llm_api_key = llm_api_key prompts_state = gr.State(value=prompt_state_obj) - - prompts_state = get_prompts_list(path=prompts_paths, state=prompts_state) - - # Load some values on initialization of the project - # 1) List of available prompts (based on paths provided by user - # 2) Select first prompt as loaded by default - # 3) Set init values of "rendered" fields - - prompts_list = list(prompts_state.value.prompts_state.keys()) - initial_prompt_key = prompts_list[0] - system_prompt_value = prompts_state.value.prompts_state[initial_prompt_key]["system_prompt"] - user_prompt_value = prompts_state.value.prompts_state[initial_prompt_key]["user_prompt"] - - prompt_selection_dropdown = gr.Dropdown(choices=prompts_list, value=initial_prompt_key, label="Select Prompt") - - list_of_vars = [] + prompts_state = load_prompts_list(pattern=file_pattern, state=prompts_state) + prompt_selection_dropdown = gr.Dropdown( + choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt" + ) @gr.render(inputs=[prompt_selection_dropdown, prompts_state]) - def show_split(key: str, state: gr.State) -> None: + def show_split(index: int, state: gr.State) -> None: + prompt = state.prompts[index] + list_of_vars = [] with gr.Row(): with gr.Column(scale=1): with gr.Tab("Inputs"): - if key != "": - input_fields: list = get_input_type_fields(state.prompts_state[key]["input_type"]) - tb_dict = {} - for entry in input_fields: - with gr.Row(): - var = gr.Textbox( - label=entry["field_name"], - value=entry["field_default_value"], - interactive=True, - ) - list_of_vars.append(var) - - tb_dict[entry["field_name"]] = var - - state.dynamic_tb = tb_dict + input_fields: list = get_input_type_fields(prompt.input_type) + tb_dict = {} + for entry in input_fields: + with gr.Row(): + var = gr.Textbox( + label=entry["field_name"], + value=entry["field_default_value"], + interactive=True, + ) + list_of_vars.append(var) + + tb_dict[entry["field_name"]] = var + + state.dynamic_tb = tb_dict render_prompt_button = gr.Button(value="Render prompts") @@ -242,7 +225,7 @@ def show_split(key: str, state: gr.State) -> None: with gr.Row(): with gr.Column(): prompt_details_system_prompt = gr.Textbox( - label="System Prompt", value=system_prompt_value, interactive=True + label="System Prompt", value=prompt.system_prompt, interactive=True ) with gr.Column(): @@ -253,7 +236,7 @@ def show_split(key: str, state: gr.State) -> None: with gr.Row(): with gr.Column(): prompt_details_user_prompt = gr.Textbox( - label="User Prompt", value=user_prompt_value, interactive=True + label="User Prompt", value=prompt.user_prompt, interactive=True ) with gr.Column(): @@ -261,14 +244,10 @@ def show_split(key: str, state: gr.State) -> None: label="Rendered User Prompt", value="", interactive=False ) - llm_request_button = gr.Button(value="Send to LLM") + llm_enabled = state.llm_model_name is not None + llm_request_button = gr.Button(value="Send to LLM", interactive=llm_enabled) llm_prompt_response = gr.Textbox(lines=10, label="LLM response") - prompt_selection_dropdown.change( - display_data, - [prompt_selection_dropdown, prompts_state], - [prompt_details_system_prompt, prompt_details_user_prompt], - ) render_prompt_button.click( render_prompt, [ diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/discovery/prompt_discovery.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/discovery/prompt_discovery.py index 3bc1f736..0b5aad9e 100644 --- a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/discovery/prompt_discovery.py +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/discovery/prompt_discovery.py @@ -1,12 +1,12 @@ -import importlib +import importlib.util import inspect import os -from collections import namedtuple -from typing import Any +from pathlib import Path +from typing import Any, get_origin from ragbits.core.prompt import Prompt -PromptDetails = namedtuple("PromptDetails", ["system_prompt", "user_prompt", "input_type", "object"]) +DEFAULT_FILE_PATTERN = "**/prompt_*.py" class PromptDiscovery: @@ -14,81 +14,60 @@ class PromptDiscovery: Discovers Prompt objects within Python modules. Args: - file_paths (list[str]): List of file paths containing Prompt objects. + file_pattern (str): The file pattern to search for Prompt objects. Defaults to "**/prompt_*.py" + root_path (Path): The root path to search for Prompt objects. Defaults to the directory where the script is run. """ - def __init__(self, file_paths: list[str]): - self.file_paths = file_paths + def __init__(self, file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd()): + self.file_pattern = file_pattern + self.root_path = root_path - def is_submodule(self, module: Any, sub_module: Any) -> bool: + @staticmethod + def is_prompt_subclass(obj: Any) -> bool: """ - Checks if a module is a submodule of another. + Checks if an object is a class that is a subclass of Prompt (but not Prompt itself). Args: - module (module): The parent module. - sub_module (module): The potential submodule. + obj (any): The object to check. Returns: - bool: True if `sub_module` is a submodule of `module`, False otherwise. + bool: True if `obj` is a subclass of Prompt, False otherwise. """ + # See https://bugs.python.org/issue44293 for the reason why we need to check for get_origin(obj) + # in order to avoid generic type aliases (which `isclass` sees as classes, but `issubclass` don't). + return inspect.isclass(obj) and not get_origin(obj) and issubclass(obj, Prompt) and obj != Prompt - try: - value = module.__spec__.submodule_search_locations[0] in sub_module.__spec__.submodule_search_locations[0] - return value - except TypeError: - return False - - def process_module(self, module: Any, main_module: Any) -> dict: + def discover(self) -> set[type[Prompt]]: """ - Processes a module to find Prompt objects. - - Args: - module (module): The module to process. - main_module (module): The main module. + Discovers Prompt objects within the specified file paths. Returns: - dict: A dictionary mapping Prompt names to their corresponding PromptDetails objects. + set[Prompt]: The discovered Prompt objects. """ - result_dict = {} - for key, value in inspect.getmembers(module): - if inspect.isclass(value) and key != "Prompt" and issubclass(value, Prompt): - result_dict[key] = value + result_set: set[type[Prompt]] = set() - elif inspect.ismodule(value) and not key.startswith("_") and self.is_submodule(main_module, value): - temp_dict = self.process_module(value, main_module) + for file_path in self.root_path.glob(self.file_pattern): + # remove file extenson and remove directory separators with dots + module_name = str(file_path).rsplit(".", 1)[0].replace(os.sep, ".") - if len(temp_dict.keys()) == 0: - continue + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + print(f"Skipping {file_path}, not a Python module") + continue - result_dict = {**result_dict, **temp_dict} + module = importlib.util.module_from_spec(spec) - return result_dict + assert spec.loader is not None - def discover(self) -> dict: - """ - Discovers Prompt objects within the specified file paths. + try: + spec.loader.exec_module(module) + except Exception as e: # pylint: disable=broad-except + print(f"Skipping {file_path}, loading failed: {e}") + continue - Returns: - dict: A dictionary mapping Prompt names to their corresponding PromptDetails objects. - """ + for _, obj in inspect.getmembers(module): + if self.is_prompt_subclass(obj): + result_set.add(obj) - result_dict = {} - for prompt_path_str in self.file_paths: - if prompt_path_str.endswith(".py"): - temp_module = importlib.import_module(os.path.basename(prompt_path_str[:-3])) - else: - temp_module = importlib.import_module(os.path.basename(prompt_path_str)) - - temp_results = self.process_module(temp_module, temp_module) - - for key, test_obj in temp_results.items(): - if key not in result_dict: - result_dict[key] = PromptDetails( - system_prompt=test_obj.system_prompt, - user_prompt=test_obj.user_prompt, - input_type=test_obj.input_type, - object=test_obj, - )._asdict() - - return result_dict + return result_set diff --git a/packages/ragbits-dev-kit/tests/unit/discovery/__init__.py b/packages/ragbits-dev-kit/tests/unit/discovery/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragbits-dev-kit/tests/unit/discovery/prompt_classes_for_tests.py b/packages/ragbits-dev-kit/tests/unit/discovery/prompt_classes_for_tests.py index 2453dddd..b0fc1a54 100644 --- a/packages/ragbits-dev-kit/tests/unit/discovery/prompt_classes_for_tests.py +++ b/packages/ragbits-dev-kit/tests/unit/discovery/prompt_classes_for_tests.py @@ -59,3 +59,8 @@ class MyBasePrompt(Prompt, ABC): class MyPromptWithBase(MyBasePrompt): user_prompt = "custom user prompt" + + +class PromptWithoutInput(Prompt): + system_prompt = "fake system prompt without typing" + user_prompt = "fake user prompt without typing" diff --git a/packages/ragbits-dev-kit/tests/unit/discovery/test_prompt_discovery.py b/packages/ragbits-dev-kit/tests/unit/discovery/test_prompt_discovery.py index b51b33a4..cc650d86 100644 --- a/packages/ragbits-dev-kit/tests/unit/discovery/test_prompt_discovery.py +++ b/packages/ragbits-dev-kit/tests/unit/discovery/test_prompt_discovery.py @@ -1,38 +1,30 @@ -import sys from pathlib import Path from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import PromptDiscovery +current_dir = Path(__file__).parent -def test_prompt_discovery_from_file(): - test_paths = ["prompt_classes_for_tests.py"] - - discovery_result = PromptDiscovery(test_paths).discover() - - assert len(discovery_result.keys()) == 4 - assert "PromptForTest" in discovery_result.keys() - assert discovery_result["PromptForTest"]["user_prompt"] == "fake user prompt" - assert discovery_result["PromptForTest"]["system_prompt"] == "fake system prompt" - assert len(discovery_result["PromptForTest"]["input_type"].model_fields) == 6 +def test_prompt_discovery_from_file(): + discovery_results = PromptDiscovery(root_path=current_dir).discover() + print(discovery_results) - assert "PromptForTest2" in discovery_result.keys() - assert discovery_result["PromptForTest2"]["user_prompt"] == "fake user prompt2" - assert discovery_result["PromptForTest2"]["system_prompt"] == "fake system prompt2" - assert len(discovery_result["PromptForTest2"]["input_type"].model_fields) == 1 + assert len(discovery_results) == 5 - assert "MyPromptWithBase" in discovery_result.keys() - assert discovery_result["MyPromptWithBase"]["user_prompt"] == "custom user prompt" - assert discovery_result["MyPromptWithBase"]["system_prompt"] == "my base system prompt" + class_names = [cls.__name__ for cls in discovery_results] + assert "PromptForTest" in class_names + assert "PromptForTest2" in class_names + assert "PromptWithoutInput" in class_names + assert "PromptForTestInput" not in class_names def test_prompt_discovery_from_package(): - sys.path.append(str(Path(__file__).parent)) - test_paths = ["ragbits_tests_pkg_with_prompts"] - - discovery_result = PromptDiscovery(test_paths).discover() + discovery_results = PromptDiscovery( + root_path=current_dir, file_pattern="ragbits_tests_pkg_with_prompts/**/*.py" + ).discover() - assert len(discovery_result.keys()) == 2 + assert len(discovery_results) == 2 - assert "PromptForTestA" in discovery_result - assert "PromptForTestB" in discovery_result + class_names = [cls.__name__ for cls in discovery_results] + assert "PromptForTestA" in class_names + assert "PromptForTestB" in class_names From fdc5790ac106d37fa8b6e792c2949b2d5bb82aed Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Mon, 30 Sep 2024 12:25:41 +0200 Subject: [PATCH 05/10] feat(prompt-lab): Register a CLI command for Prompt Lab (#52) --- .../ragbits-cli/src/ragbits/cli/__init__.py | 2 +- packages/ragbits-core/src/ragbits/core/cli.py | 19 ------------ .../src/ragbits/dev_kit/cli.py | 16 ++++++++++ .../src/ragbits/dev_kit/prompt_lab/app.py | 30 ++----------------- 4 files changed, 20 insertions(+), 47 deletions(-) delete mode 100644 packages/ragbits-core/src/ragbits/core/cli.py create mode 100644 packages/ragbits-dev-kit/src/ragbits/dev_kit/cli.py diff --git a/packages/ragbits-cli/src/ragbits/cli/__init__.py b/packages/ragbits-cli/src/ragbits/cli/__init__.py index 21fc531e..45d30fbe 100644 --- a/packages/ragbits-cli/src/ragbits/cli/__init__.py +++ b/packages/ragbits-cli/src/ragbits/cli/__init__.py @@ -5,7 +5,7 @@ import ragbits -app = Typer() +app = Typer(no_args_is_help=True) def main() -> None: diff --git a/packages/ragbits-core/src/ragbits/core/cli.py b/packages/ragbits-core/src/ragbits/core/cli.py deleted file mode 100644 index 46945b99..00000000 --- a/packages/ragbits-core/src/ragbits/core/cli.py +++ /dev/null @@ -1,19 +0,0 @@ -from typer import Typer - -prompts_app = Typer() - - -@prompts_app.command() -def placeholder() -> None: - """Placeholder command""" - print("foo") - - -def register(app: Typer) -> None: - """ - Register the CLI commands for the ragbits-core package. - - Args: - app: The Typer object to register the commands with. - """ - app.add_typer(prompts_app, name="prompts") diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/cli.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/cli.py new file mode 100644 index 00000000..22475bb8 --- /dev/null +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/cli.py @@ -0,0 +1,16 @@ +import typer + +from .prompt_lab.app import lab_app + +prompts_app = typer.Typer(no_args_is_help=True) + + +def register(app: typer.Typer) -> None: + """ + Register the CLI commands for the package. + + Args: + app: The Typer object to register the commands with. + """ + prompts_app.command(name="lab")(lab_app) + app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts") diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py index a3a6d25c..9e9d5776 100644 --- a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py @@ -3,7 +3,6 @@ import gradio as gr import jinja2 -import typer from pydantic import BaseModel from ragbits.core.llms import LiteLLM @@ -159,31 +158,12 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]: ] -typer_app = typer.Typer(no_args_is_help=True) - - -@typer_app.command() -def run_app( +def lab_app( # pylint: disable=missing-param-doc file_pattern: str = DEFAULT_FILE_PATTERN, llm_model: str | None = None, llm_api_key: str | None = None ) -> None: """ - Launches the interactive application for working with Large Language Models (LLMs). - - This function serves as the entry point for the application. It performs several key tasks: - - 1. Initializes the application state using the PromptState class. - 2. Sets the LLM model name and API key based on user-provided arguments. - 3. Fetches a list of prompts from the specified paths using the load_prompts_list function. - 4. Creates a Gradio interface with various UI elements: - - A dropdown menu for selecting prompts. - - Textboxes for displaying and potentially modifying system and user prompts. - - Textboxes for entering input values based on the selected prompt. - - Buttons for rendering prompts, sending prompts to the LLM, and displaying the response. - - Args: - file_pattern (str): A pattern for looking up prompt files. - llm_model (str): The name of the LLM model to use. - llm_api_key (str): The API key for the chosen LLM model. + Launches the interactive application for listing, rendering, and testing prompts + defined within the current project. """ with gr.Blocks() as gr_app: prompt_state_obj = PromptState() @@ -262,7 +242,3 @@ def show_split(index: int, state: gr.State) -> None: llm_request_button.click(send_prompt_to_llm, prompts_state, llm_prompt_response) gr_app.launch() - - -if __name__ == "__main__": - typer_app() From 877b35f7ce415c4467dbd1c5c7ca2b6d568a3b76 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Wed, 2 Oct 2024 11:42:01 +0200 Subject: [PATCH 06/10] fix(prompt-lab): app shouldn't crash when no prompts found (#53) --- .../src/ragbits/dev_kit/prompt_lab/app.py | 40 +++++++------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py index 9e9d5776..a6db0b90 100644 --- a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py @@ -4,6 +4,7 @@ import gradio as gr import jinja2 from pydantic import BaseModel +from rich.console import Console from ragbits.core.llms import LiteLLM from ragbits.core.llms.clients import LiteLLMOptions @@ -38,28 +39,6 @@ class PromptState: temp_field_name: str = "" -def load_prompts_list(pattern: str, state: gr.State) -> gr.State: - """ - Fetches a list of prompts based on provided paths and updates the application state. - - This function takes a path-pattern for discovering prompt definition files and uses the - PromptDiscovery class to discover prompts within those files. The discovered prompts are then - stored in the application state object. - - Args: - pattern (str): A pattern for looking up prompt files. - state (gr.State): The Gradio state object to update with discovered prompts. - - Returns: - gr.State: The updated Gradio state object containing the list of discovered prompts. - """ - obj = PromptDiscovery(file_pattern=pattern) - discovered_prompts = list(obj.discover()) - state.value.prompts = discovered_prompts - - return state - - def render_prompt( index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any ) -> tuple[str, str, gr.State]: @@ -91,9 +70,7 @@ def render_prompt( prompt_object = prompt_class(input_data=input_data) state.current_prompt = prompt_object - chat_dict = {entry["role"]: entry["content"] for entry in prompt_object.chat} - - return chat_dict["system"], chat_dict["user"], state + return prompt_object.system_message, prompt_object.user_message, state def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]: @@ -165,13 +142,24 @@ def lab_app( # pylint: disable=missing-param-doc Launches the interactive application for listing, rendering, and testing prompts defined within the current project. """ + prompts = PromptDiscovery(file_pattern=file_pattern).discover() + + if not prompts: + Console(stderr=True).print( + f"""No prompts were found for the given file pattern: [b]{file_pattern}[/b]. + +Please make sure that you are executing the command from the correct directory \ +or provide a custom file pattern using the [b]--file-pattern[/b] flag.""" + ) + return + with gr.Blocks() as gr_app: prompt_state_obj = PromptState() prompt_state_obj.llm_model_name = llm_model prompt_state_obj.llm_api_key = llm_api_key + prompt_state_obj.prompts = list(prompts) prompts_state = gr.State(value=prompt_state_obj) - prompts_state = load_prompts_list(pattern=file_pattern, state=prompts_state) prompt_selection_dropdown = gr.Dropdown( choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt" ) From 0c808cd21683783ca3f4cafb85c4a1675de3f495 Mon Sep 17 00:00:00 2001 From: Alan Konarski <129968242+akonarski-ds@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:34:14 +0200 Subject: [PATCH 07/10] feat(prompts): integration with promptfoo (#54) --- .pre-commit-config.yaml | 2 +- .../src/ragbits/core/prompt/prompt.py | 26 +++++++-- .../tests/unit/prompts/test_prompt.py | 53 +++++++++++++++++++ packages/ragbits-dev-kit/README.md | 25 +++++++++ .../src/ragbits/dev_kit/cli.py | 2 + .../src/ragbits/dev_kit/promptfoo.py | 33 ++++++++++++ 6 files changed, 137 insertions(+), 4 deletions(-) create mode 100644 packages/ragbits-dev-kit/src/ragbits/dev_kit/promptfoo.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b8e6b3bd..7428c1b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,7 +44,7 @@ repos: - id: mypy # You can add additional plugins for mypy below # such as types-python-dateutil - additional_dependencies: [pydantic>=2.8.2] + additional_dependencies: [pydantic>=2.8.2, types-pyyaml>=6.0.12] exclude: (/test_|setup.py|/tests/|docs/) # Sort imports alphabetically, and automatically separated into sections and by type. diff --git a/packages/ragbits-core/src/ragbits/core/prompt/prompt.py b/packages/ragbits-core/src/ragbits/core/prompt/prompt.py index 69e87efa..75b3c093 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/prompt.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/prompt.py @@ -22,7 +22,7 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass= system_prompt: Optional[str] = None user_prompt: str - additional_messages: ChatFormat = [] + additional_messages: Optional[ChatFormat] = None # function that parses the response from the LLM to specific output type # if not provided, the class tries to set it automatically based on the output type @@ -125,10 +125,13 @@ def chat(self) -> ChatFormat: Returns: ChatFormat: A list of dictionaries, each containing the role and content of a message. """ - return [ + chat = [ *([{"role": "system", "content": self.system_message}] if self.system_message is not None else []), {"role": "user", "content": self.user_message}, - ] + self.additional_messages + ] + if self.additional_messages: + chat.extend(self.additional_messages) + return chat def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]": """ @@ -140,6 +143,8 @@ def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]": Returns: Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining. """ + if self.additional_messages is None: + self.additional_messages = [] self.additional_messages.append({"role": "user", "content": message}) return self @@ -153,6 +158,8 @@ def add_assistant_message(self, message: str) -> "Prompt[InputT, OutputT]": Returns: Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining. """ + if self.additional_messages is None: + self.additional_messages = [] self.additional_messages.append({"role": "assistant", "content": message}) return self @@ -190,3 +197,16 @@ def parse_response(self, response: str) -> OutputT: ResponseParsingError: If the response cannot be parsed. """ return self.response_parser(response) + + @classmethod + def to_promptfoo(cls, config: dict[str, Any]) -> ChatFormat: + """ + Generate a prompt in the promptfoo format from a promptfoo test configuration. + + Args: + config: The promptfoo test configuration. + + Returns: + ChatFormat: The prompt in the format used by promptfoo. + """ + return cls(cls.input_type.model_validate(config["vars"])).chat # type: ignore diff --git a/packages/ragbits-core/tests/unit/prompts/test_prompt.py b/packages/ragbits-core/tests/unit/prompts/test_prompt.py index da0352cd..557849f0 100644 --- a/packages/ragbits-core/tests/unit/prompts/test_prompt.py +++ b/packages/ragbits-core/tests/unit/prompts/test_prompt.py @@ -198,3 +198,56 @@ class TestPrompt(Prompt[_PromptInput, str]): prompt = TestPrompt(_PromptInput(name="John", age=15, theme="pop")) assert prompt.output_schema() is None + + +def test_to_promptfoo(): + """Test that a prompt can be converted to a promptfoo prompt.""" + promptfoo_test_config = { + "vars": {"name": "John", "age": 25, "theme": "pop"}, + } + + class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable + """A test prompt""" + + system_prompt = """ + You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}. + """ + user_prompt = "Theme for the song is {{ theme }}." + + assert TestPrompt.to_promptfoo(promptfoo_test_config) == [ + {"role": "system", "content": "You are a song generator for a adult named John."}, + {"role": "user", "content": "Theme for the song is pop."}, + ] + + +def test_two_instances_do_not_share_additional_messages(): + """ + Test that two instances of a prompt do not share additional messages. + """ + + class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable + """A test prompt""" + + system_prompt = """ + You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}. + """ + user_prompt = "Theme for the song is {{ theme }}." + + prompt1 = TestPrompt(_PromptInput(name="John", age=15, theme="pop")) + prompt1.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.") + + prompt2 = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock")) + prompt2.add_assistant_message("It's a nice tune.") + + assert prompt1.chat == [ + {"role": "system", "content": "You are a song generator for a child named John."}, + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + {"role": "user", "content": "I like it."}, + ] + + assert prompt2.chat == [ + {"role": "system", "content": "You are a song generator for a adult named Alice."}, + {"role": "user", "content": "Theme for the song is rock."}, + {"role": "assistant", "content": "It's a nice tune."}, + ] diff --git a/packages/ragbits-dev-kit/README.md b/packages/ragbits-dev-kit/README.md index f27a2678..099abe68 100644 --- a/packages/ragbits-dev-kit/README.md +++ b/packages/ragbits-dev-kit/README.md @@ -1 +1,26 @@ # Ragbits Development Kit + +## Promptfoo Integration + +Ragbits' `Prompt` abstraction can be seamlessly integrated with the `promptfoo` tool. After installing `promptfoo` as +specified in the [promptfoo documentation](https://www.promptfoo.dev/docs/installation/), you can generate promptfoo +configuration files for all the prompts discovered by our autodiscover mechanism by running the following command: + +```bash +rbts prompts generate-promptfoo-configs +``` + +This command will generate a YAML files in the directory specified by `--target-path` (`promptfooconfigs` by +default). The generated file should look like this: + +```yaml +prompts: + - file:///path/to/your/prompt:PromptClass.to_promptfoo +``` + +You can then edit the generated file to add your custom `promptfoo` configurations. Once your `promptfoo` configuration +file is ready, you can run `promptfoo` with the following command: + +```bash +promptfoo -c /path/to/generated/promptfoo-config.yaml eval +``` \ No newline at end of file diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/cli.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/cli.py index 22475bb8..73f5267f 100644 --- a/packages/ragbits-dev-kit/src/ragbits/dev_kit/cli.py +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/cli.py @@ -1,6 +1,7 @@ import typer from .prompt_lab.app import lab_app +from .promptfoo import generate_configs prompts_app = typer.Typer(no_args_is_help=True) @@ -13,4 +14,5 @@ def register(app: typer.Typer) -> None: app: The Typer object to register the commands with. """ prompts_app.command(name="lab")(lab_app) + prompts_app.command(name="generate-promptfoo-configs")(generate_configs) app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts") diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/promptfoo.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/promptfoo.py new file mode 100644 index 00000000..d7357621 --- /dev/null +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/promptfoo.py @@ -0,0 +1,33 @@ +import os +from pathlib import Path + +import yaml +from rich.console import Console + +from ragbits.dev_kit.prompt_lab.discovery import PromptDiscovery +from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import DEFAULT_FILE_PATTERN + + +def generate_configs( + file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd(), target_path: Path = Path("promptfooconfigs") +) -> None: + """ + Generates promptfoo configuration files for all discovered prompts. + + Args: + file_pattern: The file pattern to search for Prompt objects. Defaults to "**/prompt_*.py" + root_path: The root path to search for Prompt objects. Defaults to the directory where the script is run. + target_path: The path to save the promptfoo configuration files. Defaults to "promptfooconfigs". + """ + prompts = PromptDiscovery(file_pattern=file_pattern, root_path=root_path).discover() + Console().print( + f"Discovered {len(prompts)} prompts." + f" Saving promptfoo configuration files to [bold green]{target_path}[/] folder ..." + ) + + if not target_path.exists(): + target_path.mkdir() + for prompt in prompts: + with open(target_path / f"{prompt.__qualname__}.yaml", "w", encoding="utf-8") as f: + prompt_path = f'file://{prompt.__module__.replace(".", os.sep)}.py:{prompt.__qualname__}.to_promptfoo' + yaml.dump({"prompts": [prompt_path]}, f) From 2857978dff681cd583a782a13002aafe0dca4f85 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 3 Oct 2024 11:18:41 +0200 Subject: [PATCH 08/10] fix(prompt-lab): prevent crash when sending to LLM before rendering (#57) --- .../src/ragbits/dev_kit/prompt_lab/app.py | 76 +++++++++---------- 1 file changed, 35 insertions(+), 41 deletions(-) diff --git a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py index a6db0b90..8b494ae4 100644 --- a/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py +++ b/packages/ragbits-dev-kit/src/ragbits/dev_kit/prompt_lab/app.py @@ -1,4 +1,5 @@ import asyncio +from dataclasses import dataclass, field, replace from typing import Any import gradio as gr @@ -12,6 +13,7 @@ from ragbits.dev_kit.prompt_lab.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery +@dataclass(frozen=True) class PromptState: """ Class to store the current state of the application. @@ -20,28 +22,18 @@ class PromptState: Attributes: prompts (list): A list containing discovered prompts. - variable_values (dict): A dictionary to store values entered by the user for prompt input fields. - dynamic_tb (dict): A dictionary containing dynamically created textboxes based on prompt input fields. - current_prompt (Prompt): The currently processed Prompt object. This is created upon clicking the - "Render Prompt" button and reflects in the "Rendered Prompt" field. - It is used for communication with the LLM. + rendered_prompt (Prompt): The most recently rendered Prompt instance. llm_model_name (str): The name of the selected LLM model. llm_api_key (str | None): The API key for the chosen LLM model. - temp_field_name (str): Temporary field name used internally. """ - prompts: list = [] - variable_values: dict = {} - dynamic_tb: dict = {} - current_prompt: Prompt | None = None + prompts: list = field(default_factory=list) + rendered_prompt: Prompt | None = None llm_model_name: str | None = None llm_api_key: str | None = None - temp_field_name: str = "" -def render_prompt( - index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any -) -> tuple[str, str, gr.State]: +def render_prompt(index: int, system_prompt: str, user_prompt: str, state: gr.State, *args: Any) -> gr.State: """ Renders a prompt based on the provided key, system prompt, user prompt, and input variables. @@ -56,21 +48,20 @@ def render_prompt( args (tuple): A tuple of input values for the prompt. Returns: - tuple[str, str, PromptState]: A tuple containing the rendered system prompt, rendered user prompt, - and the updated application state. + gr.State: The updated application state object. """ - variables = dict(zip(state.dynamic_tb.keys(), args)) - prompt_class = state.prompts[index] prompt_class.system_prompt_template = jinja2.Template(system_prompt) prompt_class.user_prompt_template = jinja2.Template(user_prompt) input_type = prompt_class.input_type + input_fields = get_input_type_fields(input_type) + variables = {field["field_name"]: value for field, value in zip(input_fields, args)} input_data = input_type(**variables) if input_type is not None else None prompt_object = prompt_class(input_data=input_data) - state.current_prompt = prompt_object + state = replace(state, rendered_prompt=prompt_object) - return prompt_object.system_message, prompt_object.user_message, state + return state def list_prompt_choices(state: gr.State) -> list[tuple[str, int]]: @@ -102,12 +93,12 @@ def send_prompt_to_llm(state: gr.State) -> str: Returns: str: The response generated by the LLM. """ - current_prompt = state.current_prompt - llm_client = LiteLLM(model_name=state.llm_model_name, api_key=state.llm_api_key) try: - response = asyncio.run(llm_client.client.call(conversation=current_prompt.chat, options=LiteLLMOptions())) + response = asyncio.run( + llm_client.client.call(conversation=state.rendered_prompt.chat, options=LiteLLMOptions()) + ) except Exception as e: # pylint: disable=broad-except response = str(e) @@ -154,12 +145,14 @@ def lab_app( # pylint: disable=missing-param-doc return with gr.Blocks() as gr_app: - prompt_state_obj = PromptState() - prompt_state_obj.llm_model_name = llm_model - prompt_state_obj.llm_api_key = llm_api_key - prompt_state_obj.prompts = list(prompts) + prompts_state = gr.State( + PromptState( + llm_model_name=llm_model, + llm_api_key=llm_api_key, + prompts=list(prompts), + ) + ) - prompts_state = gr.State(value=prompt_state_obj) prompt_selection_dropdown = gr.Dropdown( choices=list_prompt_choices(prompts_state), value=0, label="Select Prompt" ) @@ -172,7 +165,6 @@ def show_split(index: int, state: gr.State) -> None: with gr.Column(scale=1): with gr.Tab("Inputs"): input_fields: list = get_input_type_fields(prompt.input_type) - tb_dict = {} for entry in input_fields: with gr.Row(): var = gr.Textbox( @@ -182,10 +174,6 @@ def show_split(index: int, state: gr.State) -> None: ) list_of_vars.append(var) - tb_dict[entry["field_name"]] = var - - state.dynamic_tb = tb_dict - render_prompt_button = gr.Button(value="Render prompts") with gr.Column(scale=4): @@ -197,9 +185,8 @@ def show_split(index: int, state: gr.State) -> None: ) with gr.Column(): - prompt_details_system_prompt_rendered = gr.Textbox( - label="Rendered System Prompt", value="", interactive=False - ) + system_message = state.rendered_prompt.system_message if state.rendered_prompt else "" + gr.Textbox(label="Rendered System Prompt", value=system_message, interactive=False) with gr.Row(): with gr.Column(): @@ -208,12 +195,19 @@ def show_split(index: int, state: gr.State) -> None: ) with gr.Column(): - prompt_details_user_prompt_rendered = gr.Textbox( - label="Rendered User Prompt", value="", interactive=False - ) + user_message = state.rendered_prompt.user_message if state.rendered_prompt else "" + gr.Textbox(label="Rendered User Prompt", value=user_message, interactive=False) llm_enabled = state.llm_model_name is not None - llm_request_button = gr.Button(value="Send to LLM", interactive=llm_enabled) + prompt_ready = state.rendered_prompt is not None + llm_request_button = gr.Button( + value="Send to LLM", + interactive=llm_enabled and prompt_ready, + ) + gr.Markdown( + "To enable this button, select an LLM model when starting the app in CLI.", visible=not llm_enabled + ) + gr.Markdown("To enable this button, render a prompt first.", visible=llm_enabled and not prompt_ready) llm_prompt_response = gr.Textbox(lines=10, label="LLM response") render_prompt_button.click( @@ -225,7 +219,7 @@ def show_split(index: int, state: gr.State) -> None: prompts_state, *list_of_vars, ], - [prompt_details_system_prompt_rendered, prompt_details_user_prompt_rendered, prompts_state], + [prompts_state], ) llm_request_button.click(send_prompt_to_llm, prompts_state, llm_prompt_response) From 08acb63884a70cd7d23eb00ba6fb789dfe84f0b3 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 3 Oct 2024 14:33:08 +0200 Subject: [PATCH 09/10] feat(prompts): Make the `Prompt` interface more clear in regard to messages (#59) --- .../ragbits-core/examples/prompt_example.py | 4 +- .../src/ragbits/core/prompt/prompt.py | 51 +++++------ .../tests/unit/prompts/test_prompt.py | 84 +++++++++++++++---- 3 files changed, 94 insertions(+), 45 deletions(-) diff --git a/packages/ragbits-core/examples/prompt_example.py b/packages/ragbits-core/examples/prompt_example.py index 51ce05ea..64c37c36 100644 --- a/packages/ragbits-core/examples/prompt_example.py +++ b/packages/ragbits-core/examples/prompt_example.py @@ -43,8 +43,8 @@ class LoremPrompt(Prompt[LoremPromptInput, LoremPromptOutput]): if __name__ == "__main__": - lorem_prompt = LoremPrompt(LoremPromptInput(theme="business")) - lorem_prompt.add_assistant_message("Lorem Ipsum biznessum dolor copy machinum yearly reportum") + lorem_prompt = LoremPrompt(LoremPromptInput(theme="animals")) + lorem_prompt.add_few_shot("theme: business", "Lorem Ipsum biznessum dolor copy machinum yearly reportum") print("CHAT:") print(lorem_prompt.chat) print() diff --git a/packages/ragbits-core/src/ragbits/core/prompt/prompt.py b/packages/ragbits-core/src/ragbits/core/prompt/prompt.py index 75b3c093..876a68a4 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/prompt.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/prompt.py @@ -22,7 +22,9 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass= system_prompt: Optional[str] = None user_prompt: str - additional_messages: Optional[ChatFormat] = None + + # Additional messages to be added to the conversation after the system prompt + few_shots: ChatFormat = [] # function that parses the response from the LLM to specific output type # if not provided, the class tries to set it automatically based on the output type @@ -111,10 +113,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: if self.input_type and input_data is None: raise ValueError("Input data must be provided") - self.system_message = ( + self.rendered_system_prompt = ( self._render_template(self.system_prompt_template, input_data) if self.system_prompt_template else None ) - self.user_message = self._render_template(self.user_prompt_template, input_data) + self.rendered_user_prompt = self._render_template(self.user_prompt_template, input_data) + + # Additional few shot examples that can be added dynamically using methods + # (in opposite to the static `few_shots` attribute which is defined in the class) + self._instace_few_shots: ChatFormat = [] super().__init__() @property @@ -126,41 +132,30 @@ def chat(self) -> ChatFormat: ChatFormat: A list of dictionaries, each containing the role and content of a message. """ chat = [ - *([{"role": "system", "content": self.system_message}] if self.system_message is not None else []), - {"role": "user", "content": self.user_message}, + *( + [{"role": "system", "content": self.rendered_system_prompt}] + if self.rendered_system_prompt is not None + else [] + ), + *self.few_shots, + *self._instace_few_shots, + {"role": "user", "content": self.rendered_user_prompt}, ] - if self.additional_messages: - chat.extend(self.additional_messages) return chat - def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]": - """ - Add a message from the user to the conversation. - - Args: - message (str): The message to add. - - Returns: - Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining. - """ - if self.additional_messages is None: - self.additional_messages = [] - self.additional_messages.append({"role": "user", "content": message}) - return self - - def add_assistant_message(self, message: str) -> "Prompt[InputT, OutputT]": + def add_few_shot(self, user_message: str, assistant_message: str) -> "Prompt[InputT, OutputT]": """ - Add a message from the assistant to the conversation. + Add a few-shot example to the conversation. Args: - message (str): The message to add. + user_message (str): The message from the user. + assistant_message (str): The message from the assistant. Returns: Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining. """ - if self.additional_messages is None: - self.additional_messages = [] - self.additional_messages.append({"role": "assistant", "content": message}) + self._instace_few_shots.append({"role": "user", "content": user_message}) + self._instace_few_shots.append({"role": "assistant", "content": assistant_message}) return self def output_schema(self) -> Optional[Dict | Type[BaseModel]]: diff --git a/packages/ragbits-core/tests/unit/prompts/test_prompt.py b/packages/ragbits-core/tests/unit/prompts/test_prompt.py index 557849f0..3af149bc 100644 --- a/packages/ragbits-core/tests/unit/prompts/test_prompt.py +++ b/packages/ragbits-core/tests/unit/prompts/test_prompt.py @@ -105,7 +105,7 @@ class TestPrompt(Prompt): # pylint: disable=unused-variable user_prompt = "Hello" prompt = TestPrompt() - assert prompt.user_message == "Hello" + assert prompt.rendered_user_prompt == "Hello" assert prompt.chat == [{"role": "user", "content": "Hello"}] @@ -121,8 +121,8 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable user_prompt = "Theme for the song is {{ theme }}." prompt = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock")) - assert prompt.system_message == "You are a song generator for a adult named Alice." - assert prompt.user_message == "Theme for the song is rock." + assert prompt.rendered_system_prompt == "You are a song generator for a adult named Alice." + assert prompt.rendered_user_prompt == "Theme for the song is rock." assert prompt.chat == [ {"role": "system", "content": "You are a song generator for a adult named Alice."}, {"role": "user", "content": "Theme for the song is rock."}, @@ -139,8 +139,8 @@ class TestPrompt(Prompt[str, str]): # type: ignore # pylint: disable=unused-var user_prompt = "Hello" -def test_adding_messages(): - """Test that messages can be added to the conversation.""" +def test_defining_few_shots(): + """Test that few shots can be defined for the prompt.""" class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable """A test prompt""" @@ -149,15 +149,68 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}. """ user_prompt = "Theme for the song is {{ theme }}." + few_shots = [ + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + ] - prompt = TestPrompt(_PromptInput(name="John", age=15, theme="pop")) - prompt.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.") + prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock")) + + assert prompt.chat == [ + {"role": "system", "content": "You are a song generator for a child named John."}, + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + {"role": "user", "content": "Theme for the song is rock."}, + ] + + +def test_adding_few_shots(): + """Test that few shots can be added to the conversation.""" + + class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable + """A test prompt""" + + system_prompt = """ + You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}. + """ + user_prompt = "Theme for the song is {{ theme }}." + + prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock")) + prompt.add_few_shot("Theme for the song is pop.", "It's a really catchy tune.") assert prompt.chat == [ {"role": "system", "content": "You are a song generator for a child named John."}, {"role": "user", "content": "Theme for the song is pop."}, {"role": "assistant", "content": "It's a really catchy tune."}, - {"role": "user", "content": "I like it."}, + {"role": "user", "content": "Theme for the song is rock."}, + ] + + +def test_defining_and_adding_few_shots(): + """Test that few shots can be defined and added to the conversation.""" + + class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable + """A test prompt""" + + system_prompt = """ + You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}. + """ + user_prompt = "Theme for the song is {{ theme }}." + few_shots = [ + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + ] + + prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock")) + prompt.add_few_shot("Theme for the song is experimental underground jazz.", "It's quite hard to dance to.") + + assert prompt.chat == [ + {"role": "system", "content": "You are a song generator for a child named John."}, + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + {"role": "user", "content": "Theme for the song is experimental underground jazz."}, + {"role": "assistant", "content": "It's quite hard to dance to."}, + {"role": "user", "content": "Theme for the song is rock."}, ] @@ -173,7 +226,7 @@ class TestPrompt(Prompt): # pylint: disable=unused-variable """ prompt = TestPrompt() - assert prompt.user_message == "Hello\nWorld" + assert prompt.rendered_user_prompt == "Hello\nWorld" def test_output_format(): @@ -220,7 +273,7 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable ] -def test_two_instances_do_not_share_additional_messages(): +def test_two_instances_do_not_share_few_shots(): """ Test that two instances of a prompt do not share additional messages. """ @@ -234,20 +287,21 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable user_prompt = "Theme for the song is {{ theme }}." prompt1 = TestPrompt(_PromptInput(name="John", age=15, theme="pop")) - prompt1.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.") + prompt1.add_few_shot("Theme for the song is 80s disco.", "I can't stop dancing.") prompt2 = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock")) - prompt2.add_assistant_message("It's a nice tune.") + prompt2.add_few_shot("Theme for the song is 90s pop.", "Why do I know all the words?") assert prompt1.chat == [ {"role": "system", "content": "You are a song generator for a child named John."}, + {"role": "user", "content": "Theme for the song is 80s disco."}, + {"role": "assistant", "content": "I can't stop dancing."}, {"role": "user", "content": "Theme for the song is pop."}, - {"role": "assistant", "content": "It's a really catchy tune."}, - {"role": "user", "content": "I like it."}, ] assert prompt2.chat == [ {"role": "system", "content": "You are a song generator for a adult named Alice."}, + {"role": "user", "content": "Theme for the song is 90s pop."}, + {"role": "assistant", "content": "Why do I know all the words?"}, {"role": "user", "content": "Theme for the song is rock."}, - {"role": "assistant", "content": "It's a nice tune."}, ] From 704eef2445895092f3f1ef05a0b00c24c0325780 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Mon, 7 Oct 2024 09:04:27 +0200 Subject: [PATCH 10/10] fix(document-search): avoid metadata mutation (#63) --- .../vector_store/chromadb_store.py | 80 +++---------------- .../tests/unit/test_chromadb_store.py | 56 ++----------- 2 files changed, 17 insertions(+), 119 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py b/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py index 685950be..6d4d4bc4 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/vector_store/chromadb_store.py @@ -1,5 +1,4 @@ import json -from copy import deepcopy from hashlib import sha256 from typing import List, Literal, Optional, Union @@ -79,48 +78,16 @@ def _return_best_match(self, retrieved: dict) -> Optional[str]: return None - def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], str, dict]: + def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], dict]: doc_id = sha256(entry.key.encode("utf-8")).hexdigest() embedding = entry.vector - text = entry.metadata["content"] - metadata = deepcopy(entry.metadata) - metadata["document"]["source"]["path"] = str(metadata["document"]["source"]["path"]) - metadata["key"] = entry.key - metadata = {key: json.dumps(val) if isinstance(val, dict) else val for key, val in metadata.items()} + metadata = { + "__key": entry.key, + "__metadata": json.dumps(entry.metadata, default=str), + } - return doc_id, embedding, text, metadata - - def _process_metadata(self, metadata: dict) -> dict[str, Union[str, int, float, bool]]: - """ - Processes the metadata dictionary by parsing JSON strings if applicable. - - Args: - metadata: A dictionary containing metadata where values may be JSON strings. - - Returns: - A dictionary with the same keys as the input, where JSON strings are parsed - into their respective Python data types. - """ - return {key: json.loads(val) if self._is_json(val) else val for key, val in metadata.items()} - - def _is_json(self, myjson: str) -> bool: - """ - Check if the provided string is a valid JSON. - - Args: - myjson: The string to be checked. - - Returns: - True if the string is a valid JSON, False otherwise. - """ - try: - if isinstance(myjson, str): - json.loads(myjson) - return True - return False - except ValueError: - return False + return doc_id, embedding, metadata @property def embedding_function(self) -> Union[Embeddings, chromadb.EmbeddingFunction]: @@ -139,12 +106,10 @@ async def store(self, entries: List[VectorDBEntry]) -> None: Args: entries: The entries to store. """ - collection = self._get_chroma_collection() - entries_processed = list(map(self._process_db_entry, entries)) - ids, embeddings, texts, metadatas = map(list, zip(*entries_processed)) + ids, embeddings, metadatas = map(list, zip(*entries_processed)) - collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas) + self._collection.add(ids=ids, embeddings=embeddings, metadatas=metadatas) async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]: """ @@ -157,43 +122,20 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry] Returns: The retrieved entries. """ - collection = self._get_chroma_collection() - query_result = collection.query(query_embeddings=[vector], n_results=k) + query_result = self._collection.query(query_embeddings=[vector], n_results=k) db_entries = [] for meta in query_result.get("metadatas"): db_entry = VectorDBEntry( - key=meta[0].get("key"), + key=meta[0]["__key"], vector=vector, - metadata=self._process_metadata(meta[0]), + metadata=json.loads(meta[0]["__metadata"]), ) db_entries.append(db_entry) return db_entries - async def find_similar(self, text: str) -> Optional[str]: - """ - Finds the most similar text in the chroma collection or returns None if the most similar text - has distance bigger than `self.max_distance`. - - Args: - text: The text to find similar to. - - Returns: - The most similar text or None if no similar text is found. - """ - - collection = self._get_chroma_collection() - - if isinstance(self._embedding_function, Embeddings): - embedding = await self._embedding_function.embed_text([text]) - retrieved = collection.query(query_embeddings=embedding, n_results=1) - else: - retrieved = collection.query(query_texts=[text], n_results=1) - - return self._return_best_match(retrieved) - def __repr__(self) -> str: """ Returns the string representation of the object. diff --git a/packages/ragbits-document-search/tests/unit/test_chromadb_store.py b/packages/ragbits-document-search/tests/unit/test_chromadb_store.py index fd86dde0..9d45bdc1 100644 --- a/packages/ragbits-document-search/tests/unit/test_chromadb_store.py +++ b/packages/ragbits-document-search/tests/unit/test_chromadb_store.py @@ -71,15 +71,6 @@ def test_get_chroma_collection(mock_chromadb_store): assert mock_chromadb_store._chroma_client.get_or_create_collection.called -def test_get_chroma_collection_with_custom_embedding_function( - custom_embedding_function, mock_chromadb_store_with_custom_embedding_function, mock_chroma_client -): - mock_chroma_client.get_or_create_collection.assert_called_once_with( - name="test_index", - metadata={"hnsw:space": "l2"}, - ) - - async def test_stores_entries_correctly(mock_chromadb_store): data = [ VectorDBEntry( @@ -96,17 +87,15 @@ async def test_stores_entries_correctly(mock_chromadb_store): def test_process_db_entry(mock_chromadb_store, mock_vector_db_entry): - id, embedding, text, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry) - print(f"metadata: {metadata}, type: {type(metadata)}") + id, embedding, metadata = mock_chromadb_store._process_db_entry(mock_vector_db_entry) assert id == sha256(b"test_key").hexdigest() assert embedding == [0.1, 0.2, 0.3] - assert text == "test content" assert ( - metadata["document"] - == '{"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}' + metadata["__metadata"] + == '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}' ) - assert metadata["key"] == "test_key" + assert metadata["__key"] == "test_key" async def test_store(mock_chromadb_store, mock_vector_db_entry): @@ -122,9 +111,8 @@ async def test_retrieves_entries_correctly(mock_chromadb_store): "metadatas": [ [ { - "key": "test_key", - "content": "test content", - "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}, + "__key": "test_key", + "__metadata": '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}', } ] ], @@ -143,27 +131,6 @@ async def test_handles_empty_retrieve(mock_chromadb_store): assert len(entries) == 0 -async def test_find_similar(mock_chromadb_store, mock_embedding_function): - mock_embedding_function.embed_text.return_value = [[0.1, 0.2, 0.3]] - mock_chromadb_store._embedding_function = mock_embedding_function - mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = { - "documents": [["test content"]], - "distances": [[0.1]], - } - result = await mock_chromadb_store.find_similar("test text") - assert result == "test content" - - -async def test_find_similar_with_custom_embeddings(mock_chromadb_store, custom_embedding_function): - mock_chromadb_store._embedding_function = custom_embedding_function - mock_chromadb_store._chroma_client.get_or_create_collection().query.return_value = { - "documents": [["test content"]], - "distances": [[0.1]], - } - result = await mock_chromadb_store.find_similar("test text") - assert result == "test content" - - def test_repr(mock_chromadb_store): assert repr(mock_chromadb_store) == "ChromaDBStore(index_name=test_index)" @@ -180,14 +147,3 @@ def test_return_best_match(mock_chromadb_store, retrieved, max_distance, expecte mock_chromadb_store._max_distance = max_distance result = mock_chromadb_store._return_best_match(retrieved) assert result == expected - - -def test_is_json_valid_string(mock_chromadb_store): - # Arrange - valid_json_string = '{"key": "value"}' - - # Act - result = mock_chromadb_store._is_json(valid_json_string) - - # Assert - assert result is True