From 704246039424b1526f263c6395bb654688c657c6 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Tue, 8 Oct 2024 15:03:39 +0200 Subject: [PATCH] feat(prompts): Load prompt pattern config from pyproject.toml --- packages/ragbits-core/pyproject.toml | 1 + .../ragbits-core/src/ragbits/core/config.py | 15 ++++ .../core/prompt/discovery/prompt_discovery.py | 5 +- .../src/ragbits/core/prompt/lab/app.py | 5 +- .../src/ragbits/core/prompt/promptfoo.py | 6 +- .../src/ragbits/core/utils/_pyproject.py | 68 +++++++++++++++++++ uv.lock | 2 + 7 files changed, 95 insertions(+), 7 deletions(-) create mode 100644 packages/ragbits-core/src/ragbits/core/config.py create mode 100644 packages/ragbits-core/src/ragbits/core/utils/_pyproject.py diff --git a/packages/ragbits-core/pyproject.toml b/packages/ragbits-core/pyproject.toml index b61e83d7f..22efa58ae 100644 --- a/packages/ragbits-core/pyproject.toml +++ b/packages/ragbits-core/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "jinja2>=3.1.4", "pydantic>=2.9.1", "typer~=0.12.5", + "tomli~=2.0.2", ] [project.optional-dependencies] diff --git a/packages/ragbits-core/src/ragbits/core/config.py b/packages/ragbits-core/src/ragbits/core/config.py new file mode 100644 index 000000000..b676048e9 --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/config.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + +from ragbits.core.utils._pyproject import get_config_instance + + +class CoreConfig(BaseModel): + """ + Configuration for the ragbits-core package, loaded from downstream projects' pyproject.toml files. + """ + + # Pattern used to search for prompt files + prompt_path_pattern: str = "**/prompt_*.py" + + +core_config = get_config_instance(CoreConfig) diff --git a/packages/ragbits-core/src/ragbits/core/prompt/discovery/prompt_discovery.py b/packages/ragbits-core/src/ragbits/core/prompt/discovery/prompt_discovery.py index 0b5aad9e1..121ae3f0f 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/discovery/prompt_discovery.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/discovery/prompt_discovery.py @@ -4,10 +4,9 @@ from pathlib import Path from typing import Any, get_origin +from ragbits.core.config import core_config from ragbits.core.prompt import Prompt -DEFAULT_FILE_PATTERN = "**/prompt_*.py" - class PromptDiscovery: """ @@ -18,7 +17,7 @@ class PromptDiscovery: root_path (Path): The root path to search for Prompt objects. Defaults to the directory where the script is run. """ - def __init__(self, file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd()): + def __init__(self, file_pattern: str = core_config.prompt_path_pattern, root_path: Path = Path.cwd()): self.file_pattern = file_pattern self.root_path = root_path diff --git a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py index 2f9498f07..9b6a7dc49 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py @@ -13,10 +13,11 @@ from pydantic import BaseModel from rich.console import Console +from ragbits.core.config import core_config from ragbits.core.llms import LiteLLM from ragbits.core.llms.clients import LiteLLMOptions from ragbits.core.prompt import Prompt -from ragbits.core.prompt.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery +from ragbits.core.prompt.discovery import PromptDiscovery @dataclass(frozen=True) @@ -135,7 +136,7 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]: def lab_app( # pylint: disable=missing-param-doc - file_pattern: str = DEFAULT_FILE_PATTERN, llm_model: str | None = None, llm_api_key: str | None = None + file_pattern: str = core_config.prompt_path_pattern, llm_model: str | None = None, llm_api_key: str | None = None ) -> None: """ Launches the interactive application for listing, rendering, and testing prompts diff --git a/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py b/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py index 450a2ef3d..c650edd6d 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py @@ -4,12 +4,14 @@ import yaml from rich.console import Console +from ragbits.core.config import core_config from ragbits.core.prompt.discovery import PromptDiscovery -from ragbits.core.prompt.discovery.prompt_discovery import DEFAULT_FILE_PATTERN def generate_configs( - file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd(), target_path: Path = Path("promptfooconfigs") + file_pattern: str = core_config.prompt_path_pattern, + root_path: Path = Path.cwd(), + target_path: Path = Path("promptfooconfigs"), ) -> None: """ Generates promptfoo configuration files for all discovered prompts. diff --git a/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py b/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py new file mode 100644 index 000000000..c3b70b75b --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py @@ -0,0 +1,68 @@ +from pathlib import Path +from typing import Any, TypeVar + +import tomli +from pydantic import BaseModel + + +def find_pyproject() -> Path: + """ + Find the pyproject.toml file in the current directory or any of its parents. + + Returns: + Path: The path to the found pyproject.toml file. + + Raises: + FileNotFoundError: If the pyproject.toml file is not found. + """ + current_dir = Path.cwd() + possible_dirs = [current_dir, *current_dir.parents] + for possible_dir in possible_dirs: + pyproject = possible_dir / "pyproject.toml" + if pyproject.exists(): + return pyproject + raise FileNotFoundError("pyproject.toml not found") + + +def get_ragbits_config() -> dict[str, Any]: + """ + Get the ragbits configuration from the project's pyproject.toml file. + + Only configuration from the [tool.ragbits] section is returned. + If the project doesn't include any ragbits configuration, an empty dictionary is returned. + + Returns: + dict: The ragbits configuration. + """ + try: + pyproject = find_pyproject() + except FileNotFoundError: + # Projects are not required to use pyproject.toml + # No file just means no configuration + return {} + + with pyproject.open("rb") as f: + pyproject_data = tomli.load(f) + return pyproject_data.get("tool", {}).get("ragbits", {}) + + +ConfigModelT = TypeVar("ConfigModelT", bound=BaseModel) + + +def get_config_instance(model: type[ConfigModelT], subproject: str | None = None) -> ConfigModelT: + """ + Creates an instace of pydantic model loaded with the configuration from pyproject.toml. + + Args: + model (Type[BaseModel]): The pydantic model to instantiate. + subproject (str, optional): The subproject to get the configuration for, defaults to giving entire + ragbits configuration. + + Returns: + ConfigModelT: The model instance loaded with the configuration + """ + config = get_ragbits_config() + print(config) + if subproject: + config = config.get(subproject, {}) + return model(**config) diff --git a/uv.lock b/uv.lock index 606485be3..561fda344 100644 --- a/uv.lock +++ b/uv.lock @@ -2872,6 +2872,7 @@ source = { editable = "packages/ragbits-core" } dependencies = [ { name = "jinja2" }, { name = "pydantic" }, + { name = "tomli" }, { name = "typer" }, ] @@ -2908,6 +2909,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = "~=1.46.0" }, { name = "numpy", marker = "extra == 'local'", specifier = "~=1.24.0" }, { name = "pydantic", specifier = ">=2.9.1" }, + { name = "tomli", specifier = "~=2.0.2" }, { name = "torch", marker = "extra == 'local'", specifier = "~=2.2.1" }, { name = "transformers", marker = "extra == 'local'", specifier = "~=4.44.2" }, { name = "typer", specifier = "~=0.12.5" },