Skip to content

Commit

Permalink
feat(prompts): Load prompt pattern config from pyproject.toml (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer authored Oct 10, 2024
1 parent a774147 commit d73e9d0
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 7 deletions.
1 change: 1 addition & 0 deletions packages/ragbits-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"jinja2>=3.1.4",
"pydantic>=2.9.1",
"typer~=0.12.5",
"tomli~=2.0.2",
]

[project.optional-dependencies]
Expand Down
15 changes: 15 additions & 0 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pydantic import BaseModel

from ragbits.core.utils._pyproject import get_config_instance


class CoreConfig(BaseModel):
"""
Configuration for the ragbits-core package, loaded from downstream projects' pyproject.toml files.
"""

# Pattern used to search for prompt files
prompt_path_pattern: str = "**/prompt_*.py"


core_config = get_config_instance(CoreConfig, subproject="core")
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from pathlib import Path
from typing import Any, get_origin

from ragbits.core.config import core_config
from ragbits.core.prompt import Prompt

DEFAULT_FILE_PATTERN = "**/prompt_*.py"


class PromptDiscovery:
"""
Expand All @@ -18,7 +17,7 @@ class PromptDiscovery:
root_path (Path): The root path to search for Prompt objects. Defaults to the directory where the script is run.
"""

def __init__(self, file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd()):
def __init__(self, file_pattern: str = core_config.prompt_path_pattern, root_path: Path = Path.cwd()):
self.file_pattern = file_pattern
self.root_path = root_path

Expand Down
5 changes: 3 additions & 2 deletions packages/ragbits-core/src/ragbits/core/prompt/lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from pydantic import BaseModel
from rich.console import Console

from ragbits.core.config import core_config
from ragbits.core.llms import LiteLLM
from ragbits.core.llms.clients import LiteLLMOptions
from ragbits.core.prompt import Prompt
from ragbits.core.prompt.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery
from ragbits.core.prompt.discovery import PromptDiscovery


@dataclass(frozen=True)
Expand Down Expand Up @@ -135,7 +136,7 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]:


def lab_app( # pylint: disable=missing-param-doc
file_pattern: str = DEFAULT_FILE_PATTERN, llm_model: str | None = None, llm_api_key: str | None = None
file_pattern: str = core_config.prompt_path_pattern, llm_model: str | None = None, llm_api_key: str | None = None
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
Expand Down
6 changes: 4 additions & 2 deletions packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import yaml
from rich.console import Console

from ragbits.core.config import core_config
from ragbits.core.prompt.discovery import PromptDiscovery
from ragbits.core.prompt.discovery.prompt_discovery import DEFAULT_FILE_PATTERN


def generate_configs(
file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd(), target_path: Path = Path("promptfooconfigs")
file_pattern: str = core_config.prompt_path_pattern,
root_path: Path = Path.cwd(),
target_path: Path = Path("promptfooconfigs"),
) -> None:
"""
Generates promptfoo configuration files for all discovered prompts.
Expand Down
79 changes: 79 additions & 0 deletions packages/ragbits-core/src/ragbits/core/utils/_pyproject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pathlib import Path
from typing import Any, TypeVar

import tomli
from pydantic import BaseModel


def find_pyproject(current_dir: Path = Path.cwd()) -> Path:
"""
Find the pyproject.toml file in the current directory or any of its parents.
Args:
current_dir (Path, optional): The directory to start searching from. Defaults to the
current working directory.
Returns:
Path: The path to the found pyproject.toml file.
Raises:
FileNotFoundError: If the pyproject.toml file is not found.
"""
possible_dirs = [current_dir, *current_dir.parents]
for possible_dir in possible_dirs:
pyproject = possible_dir / "pyproject.toml"
if pyproject.exists():
return pyproject
raise FileNotFoundError("pyproject.toml not found")


def get_ragbits_config(current_dir: Path = Path.cwd()) -> dict[str, Any]:
"""
Get the ragbits configuration from the project's pyproject.toml file.
Only configuration from the [tool.ragbits] section is returned.
If the project doesn't include any ragbits configuration, an empty dictionary is returned.
Args:
current_dir (Path, optional): The directory to start searching for the pyproject.toml file. Defaults to the
current working directory.
Returns:
dict: The ragbits configuration.
"""
try:
pyproject = find_pyproject(current_dir)
except FileNotFoundError:
# Projects are not required to use pyproject.toml
# No file just means no configuration
return {}

with pyproject.open("rb") as f:
pyproject_data = tomli.load(f)
return pyproject_data.get("tool", {}).get("ragbits", {})


ConfigModelT = TypeVar("ConfigModelT", bound=BaseModel)


def get_config_instance(
model: type[ConfigModelT], subproject: str | None = None, current_dir: Path = Path.cwd()
) -> ConfigModelT:
"""
Creates an instace of pydantic model loaded with the configuration from pyproject.toml.
Args:
model (Type[BaseModel]): The pydantic model to instantiate.
subproject (str, optional): The subproject to get the configuration for, defaults to giving entire
ragbits configuration.
current_dir (Path, optional): The directory to start searching for the pyproject.toml file. Defaults to the
current working directory
Returns:
ConfigModelT: The model instance loaded with the configuration
"""
config = get_ragbits_config(current_dir)
print(config)
if subproject:
config = config.get(subproject, {})
return model(**config)
25 changes: 25 additions & 0 deletions packages/ragbits-core/tests/unit/utils/pyproject/test_find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path

import pytest

from ragbits.core.utils._pyproject import find_pyproject

projects_dir = Path(__file__).parent / "testprojects"


def test_find_in_current_dir():
"""Test finding a pyproject.toml file in the current directory."""
found = find_pyproject(projects_dir / "happy_project")
assert found == projects_dir / "happy_project" / "pyproject.toml"


def test_find_in_parent_dir():
"""Test finding a pyproject.toml file in a parent directory."""
found = find_pyproject(projects_dir / "happy_project" / "subdirectory")
assert found == projects_dir / "happy_project" / "pyproject.toml"


def test_find_not_found():
"""Test that it raises FileNotFoundError if the file is not found."""
with pytest.raises(FileNotFoundError):
find_pyproject(Path("/"))
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pathlib import Path

from ragbits.core.utils._pyproject import get_ragbits_config

projects_dir = Path(__file__).parent / "testprojects"


def test_get_config():
"""Test getting config from pyproject.toml file."""
config = get_ragbits_config(projects_dir / "happy_project")

assert config == {
"lorem": "ipsum",
"happy-project": {
"foo": "bar",
"is_happy": True,
"happiness_level": 100,
},
}


def test_get_config_no_file():
"""Test getting config when the pyproject.toml file is not found."""
config = get_ragbits_config(Path("/"))

assert config == {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from pathlib import Path

from pydantic import BaseModel

from ragbits.core.utils._pyproject import get_config_instance

projects_dir = Path(__file__).parent / "testprojects"


class HappyProjectConfig(BaseModel):
foo: str
is_happy: bool
happiness_level: int


class PartialHappyProjectConfig(BaseModel):
foo: str
is_happy: bool


class OptionalHappyProjectConfig(BaseModel):
foo: str = "bar"
is_happy: bool = True
happiness_level: int = 100


def test_get_config_instance():
"""Test getting Pydantic model instance from pyproject.toml file."""
config = get_config_instance(
HappyProjectConfig,
subproject="happy-project",
current_dir=projects_dir / "happy_project",
)

assert config == HappyProjectConfig(foo="bar", is_happy=True, happiness_level=100)


def test_get_config_instance_additional_fields():
"""Test that unknown fields are ignored."""
config = get_config_instance(
PartialHappyProjectConfig,
subproject="happy-project",
current_dir=projects_dir / "happy_project",
)

assert config == PartialHappyProjectConfig(foo="bar", is_happy=True)


def test_get_config_instance_optional_fields():
"""Test that optional fields are filled with default values if not present in the file."""
config = get_config_instance(
OptionalHappyProjectConfig,
subproject="happy-project",
current_dir=projects_dir / "happy_project",
)

assert config == OptionalHappyProjectConfig(foo="bar", is_happy=True, happiness_level=100)


def test_get_config_instance_no_file():
"""Test getting config when the pyproject.toml file is not found (wich no required fields)."""
config = get_config_instance(
OptionalHappyProjectConfig,
subproject="happy-project",
current_dir=Path("/"),
)

assert config == OptionalHappyProjectConfig()
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[project]
name = "happy-project"

[tool.ragbits]
lorem = "ipsum"

[tool.ragbits.happy-project]
foo = "bar"
is_happy = true
happiness_level = 100
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d73e9d0

Please sign in to comment.