From d73e9d01752d34b842676c5f98d77d53b63371e8 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 10 Oct 2024 13:38:40 +0200 Subject: [PATCH] feat(prompts): Load prompt pattern config from pyproject.toml (#77) --- packages/ragbits-core/pyproject.toml | 1 + .../ragbits-core/src/ragbits/core/config.py | 15 ++++ .../core/prompt/discovery/prompt_discovery.py | 5 +- .../src/ragbits/core/prompt/lab/app.py | 5 +- .../src/ragbits/core/prompt/promptfoo.py | 6 +- .../src/ragbits/core/utils/_pyproject.py | 79 +++++++++++++++++++ .../tests/unit/utils/pyproject/test_find.py | 25 ++++++ .../unit/utils/pyproject/test_get_config.py | 26 ++++++ .../unit/utils/pyproject/test_get_instace.py | 68 ++++++++++++++++ .../testprojects/happy_project/pyproject.toml | 10 +++ uv.lock | 2 + 11 files changed, 235 insertions(+), 7 deletions(-) create mode 100644 packages/ragbits-core/src/ragbits/core/config.py create mode 100644 packages/ragbits-core/src/ragbits/core/utils/_pyproject.py create mode 100644 packages/ragbits-core/tests/unit/utils/pyproject/test_find.py create mode 100644 packages/ragbits-core/tests/unit/utils/pyproject/test_get_config.py create mode 100644 packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py create mode 100644 packages/ragbits-core/tests/unit/utils/pyproject/testprojects/happy_project/pyproject.toml diff --git a/packages/ragbits-core/pyproject.toml b/packages/ragbits-core/pyproject.toml index b61e83d7f..22efa58ae 100644 --- a/packages/ragbits-core/pyproject.toml +++ b/packages/ragbits-core/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "jinja2>=3.1.4", "pydantic>=2.9.1", "typer~=0.12.5", + "tomli~=2.0.2", ] [project.optional-dependencies] diff --git a/packages/ragbits-core/src/ragbits/core/config.py b/packages/ragbits-core/src/ragbits/core/config.py new file mode 100644 index 000000000..49c0b12b8 --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/config.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + +from ragbits.core.utils._pyproject import get_config_instance + + +class CoreConfig(BaseModel): + """ + Configuration for the ragbits-core package, loaded from downstream projects' pyproject.toml files. + """ + + # Pattern used to search for prompt files + prompt_path_pattern: str = "**/prompt_*.py" + + +core_config = get_config_instance(CoreConfig, subproject="core") diff --git a/packages/ragbits-core/src/ragbits/core/prompt/discovery/prompt_discovery.py b/packages/ragbits-core/src/ragbits/core/prompt/discovery/prompt_discovery.py index 0b5aad9e1..121ae3f0f 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/discovery/prompt_discovery.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/discovery/prompt_discovery.py @@ -4,10 +4,9 @@ from pathlib import Path from typing import Any, get_origin +from ragbits.core.config import core_config from ragbits.core.prompt import Prompt -DEFAULT_FILE_PATTERN = "**/prompt_*.py" - class PromptDiscovery: """ @@ -18,7 +17,7 @@ class PromptDiscovery: root_path (Path): The root path to search for Prompt objects. Defaults to the directory where the script is run. """ - def __init__(self, file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd()): + def __init__(self, file_pattern: str = core_config.prompt_path_pattern, root_path: Path = Path.cwd()): self.file_pattern = file_pattern self.root_path = root_path diff --git a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py index 2f9498f07..9b6a7dc49 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py @@ -13,10 +13,11 @@ from pydantic import BaseModel from rich.console import Console +from ragbits.core.config import core_config from ragbits.core.llms import LiteLLM from ragbits.core.llms.clients import LiteLLMOptions from ragbits.core.prompt import Prompt -from ragbits.core.prompt.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery +from ragbits.core.prompt.discovery import PromptDiscovery @dataclass(frozen=True) @@ -135,7 +136,7 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]: def lab_app( # pylint: disable=missing-param-doc - file_pattern: str = DEFAULT_FILE_PATTERN, llm_model: str | None = None, llm_api_key: str | None = None + file_pattern: str = core_config.prompt_path_pattern, llm_model: str | None = None, llm_api_key: str | None = None ) -> None: """ Launches the interactive application for listing, rendering, and testing prompts diff --git a/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py b/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py index 450a2ef3d..c650edd6d 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py @@ -4,12 +4,14 @@ import yaml from rich.console import Console +from ragbits.core.config import core_config from ragbits.core.prompt.discovery import PromptDiscovery -from ragbits.core.prompt.discovery.prompt_discovery import DEFAULT_FILE_PATTERN def generate_configs( - file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd(), target_path: Path = Path("promptfooconfigs") + file_pattern: str = core_config.prompt_path_pattern, + root_path: Path = Path.cwd(), + target_path: Path = Path("promptfooconfigs"), ) -> None: """ Generates promptfoo configuration files for all discovered prompts. diff --git a/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py b/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py new file mode 100644 index 000000000..f29a55f7f --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/utils/_pyproject.py @@ -0,0 +1,79 @@ +from pathlib import Path +from typing import Any, TypeVar + +import tomli +from pydantic import BaseModel + + +def find_pyproject(current_dir: Path = Path.cwd()) -> Path: + """ + Find the pyproject.toml file in the current directory or any of its parents. + + Args: + current_dir (Path, optional): The directory to start searching from. Defaults to the + current working directory. + + Returns: + Path: The path to the found pyproject.toml file. + + Raises: + FileNotFoundError: If the pyproject.toml file is not found. + """ + possible_dirs = [current_dir, *current_dir.parents] + for possible_dir in possible_dirs: + pyproject = possible_dir / "pyproject.toml" + if pyproject.exists(): + return pyproject + raise FileNotFoundError("pyproject.toml not found") + + +def get_ragbits_config(current_dir: Path = Path.cwd()) -> dict[str, Any]: + """ + Get the ragbits configuration from the project's pyproject.toml file. + + Only configuration from the [tool.ragbits] section is returned. + If the project doesn't include any ragbits configuration, an empty dictionary is returned. + + Args: + current_dir (Path, optional): The directory to start searching for the pyproject.toml file. Defaults to the + current working directory. + + Returns: + dict: The ragbits configuration. + """ + try: + pyproject = find_pyproject(current_dir) + except FileNotFoundError: + # Projects are not required to use pyproject.toml + # No file just means no configuration + return {} + + with pyproject.open("rb") as f: + pyproject_data = tomli.load(f) + return pyproject_data.get("tool", {}).get("ragbits", {}) + + +ConfigModelT = TypeVar("ConfigModelT", bound=BaseModel) + + +def get_config_instance( + model: type[ConfigModelT], subproject: str | None = None, current_dir: Path = Path.cwd() +) -> ConfigModelT: + """ + Creates an instace of pydantic model loaded with the configuration from pyproject.toml. + + Args: + model (Type[BaseModel]): The pydantic model to instantiate. + subproject (str, optional): The subproject to get the configuration for, defaults to giving entire + ragbits configuration. + current_dir (Path, optional): The directory to start searching for the pyproject.toml file. Defaults to the + current working directory + + Returns: + ConfigModelT: The model instance loaded with the configuration + """ + config = get_ragbits_config(current_dir) + print(config) + if subproject: + config = config.get(subproject, {}) + return model(**config) diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/test_find.py b/packages/ragbits-core/tests/unit/utils/pyproject/test_find.py new file mode 100644 index 000000000..2694721ad --- /dev/null +++ b/packages/ragbits-core/tests/unit/utils/pyproject/test_find.py @@ -0,0 +1,25 @@ +from pathlib import Path + +import pytest + +from ragbits.core.utils._pyproject import find_pyproject + +projects_dir = Path(__file__).parent / "testprojects" + + +def test_find_in_current_dir(): + """Test finding a pyproject.toml file in the current directory.""" + found = find_pyproject(projects_dir / "happy_project") + assert found == projects_dir / "happy_project" / "pyproject.toml" + + +def test_find_in_parent_dir(): + """Test finding a pyproject.toml file in a parent directory.""" + found = find_pyproject(projects_dir / "happy_project" / "subdirectory") + assert found == projects_dir / "happy_project" / "pyproject.toml" + + +def test_find_not_found(): + """Test that it raises FileNotFoundError if the file is not found.""" + with pytest.raises(FileNotFoundError): + find_pyproject(Path("/")) diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/test_get_config.py b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_config.py new file mode 100644 index 000000000..2c12dabde --- /dev/null +++ b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_config.py @@ -0,0 +1,26 @@ +from pathlib import Path + +from ragbits.core.utils._pyproject import get_ragbits_config + +projects_dir = Path(__file__).parent / "testprojects" + + +def test_get_config(): + """Test getting config from pyproject.toml file.""" + config = get_ragbits_config(projects_dir / "happy_project") + + assert config == { + "lorem": "ipsum", + "happy-project": { + "foo": "bar", + "is_happy": True, + "happiness_level": 100, + }, + } + + +def test_get_config_no_file(): + """Test getting config when the pyproject.toml file is not found.""" + config = get_ragbits_config(Path("/")) + + assert config == {} diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py new file mode 100644 index 000000000..9c0af1391 --- /dev/null +++ b/packages/ragbits-core/tests/unit/utils/pyproject/test_get_instace.py @@ -0,0 +1,68 @@ +from pathlib import Path + +from pydantic import BaseModel + +from ragbits.core.utils._pyproject import get_config_instance + +projects_dir = Path(__file__).parent / "testprojects" + + +class HappyProjectConfig(BaseModel): + foo: str + is_happy: bool + happiness_level: int + + +class PartialHappyProjectConfig(BaseModel): + foo: str + is_happy: bool + + +class OptionalHappyProjectConfig(BaseModel): + foo: str = "bar" + is_happy: bool = True + happiness_level: int = 100 + + +def test_get_config_instance(): + """Test getting Pydantic model instance from pyproject.toml file.""" + config = get_config_instance( + HappyProjectConfig, + subproject="happy-project", + current_dir=projects_dir / "happy_project", + ) + + assert config == HappyProjectConfig(foo="bar", is_happy=True, happiness_level=100) + + +def test_get_config_instance_additional_fields(): + """Test that unknown fields are ignored.""" + config = get_config_instance( + PartialHappyProjectConfig, + subproject="happy-project", + current_dir=projects_dir / "happy_project", + ) + + assert config == PartialHappyProjectConfig(foo="bar", is_happy=True) + + +def test_get_config_instance_optional_fields(): + """Test that optional fields are filled with default values if not present in the file.""" + config = get_config_instance( + OptionalHappyProjectConfig, + subproject="happy-project", + current_dir=projects_dir / "happy_project", + ) + + assert config == OptionalHappyProjectConfig(foo="bar", is_happy=True, happiness_level=100) + + +def test_get_config_instance_no_file(): + """Test getting config when the pyproject.toml file is not found (wich no required fields).""" + config = get_config_instance( + OptionalHappyProjectConfig, + subproject="happy-project", + current_dir=Path("/"), + ) + + assert config == OptionalHappyProjectConfig() diff --git a/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/happy_project/pyproject.toml b/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/happy_project/pyproject.toml new file mode 100644 index 000000000..813467863 --- /dev/null +++ b/packages/ragbits-core/tests/unit/utils/pyproject/testprojects/happy_project/pyproject.toml @@ -0,0 +1,10 @@ +[project] +name = "happy-project" + +[tool.ragbits] +lorem = "ipsum" + +[tool.ragbits.happy-project] +foo = "bar" +is_happy = true +happiness_level = 100 diff --git a/uv.lock b/uv.lock index 606485be3..561fda344 100644 --- a/uv.lock +++ b/uv.lock @@ -2872,6 +2872,7 @@ source = { editable = "packages/ragbits-core" } dependencies = [ { name = "jinja2" }, { name = "pydantic" }, + { name = "tomli" }, { name = "typer" }, ] @@ -2908,6 +2909,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = "~=1.46.0" }, { name = "numpy", marker = "extra == 'local'", specifier = "~=1.24.0" }, { name = "pydantic", specifier = ">=2.9.1" }, + { name = "tomli", specifier = "~=2.0.2" }, { name = "torch", marker = "extra == 'local'", specifier = "~=2.2.1" }, { name = "transformers", marker = "extra == 'local'", specifier = "~=4.44.2" }, { name = "typer", specifier = "~=0.12.5" },