Skip to content

Commit

Permalink
feat(prompts): Load prompt pattern config from pyproject.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Oct 8, 2024
1 parent 2e95436 commit 7042460
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 7 deletions.
1 change: 1 addition & 0 deletions packages/ragbits-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"jinja2>=3.1.4",
"pydantic>=2.9.1",
"typer~=0.12.5",
"tomli~=2.0.2",
]

[project.optional-dependencies]
Expand Down
15 changes: 15 additions & 0 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pydantic import BaseModel

from ragbits.core.utils._pyproject import get_config_instance


class CoreConfig(BaseModel):
"""
Configuration for the ragbits-core package, loaded from downstream projects' pyproject.toml files.
"""

# Pattern used to search for prompt files
prompt_path_pattern: str = "**/prompt_*.py"


core_config = get_config_instance(CoreConfig)
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from pathlib import Path
from typing import Any, get_origin

from ragbits.core.config import core_config
from ragbits.core.prompt import Prompt

DEFAULT_FILE_PATTERN = "**/prompt_*.py"


class PromptDiscovery:
"""
Expand All @@ -18,7 +17,7 @@ class PromptDiscovery:
root_path (Path): The root path to search for Prompt objects. Defaults to the directory where the script is run.
"""

def __init__(self, file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd()):
def __init__(self, file_pattern: str = core_config.prompt_path_pattern, root_path: Path = Path.cwd()):
self.file_pattern = file_pattern
self.root_path = root_path

Expand Down
5 changes: 3 additions & 2 deletions packages/ragbits-core/src/ragbits/core/prompt/lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from pydantic import BaseModel
from rich.console import Console

from ragbits.core.config import core_config
from ragbits.core.llms import LiteLLM
from ragbits.core.llms.clients import LiteLLMOptions
from ragbits.core.prompt import Prompt
from ragbits.core.prompt.discovery.prompt_discovery import DEFAULT_FILE_PATTERN, PromptDiscovery
from ragbits.core.prompt.discovery import PromptDiscovery


@dataclass(frozen=True)
Expand Down Expand Up @@ -135,7 +136,7 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]:


def lab_app( # pylint: disable=missing-param-doc
file_pattern: str = DEFAULT_FILE_PATTERN, llm_model: str | None = None, llm_api_key: str | None = None
file_pattern: str = core_config.prompt_path_pattern, llm_model: str | None = None, llm_api_key: str | None = None
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
Expand Down
6 changes: 4 additions & 2 deletions packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import yaml
from rich.console import Console

from ragbits.core.config import core_config
from ragbits.core.prompt.discovery import PromptDiscovery
from ragbits.core.prompt.discovery.prompt_discovery import DEFAULT_FILE_PATTERN


def generate_configs(
file_pattern: str = DEFAULT_FILE_PATTERN, root_path: Path = Path.cwd(), target_path: Path = Path("promptfooconfigs")
file_pattern: str = core_config.prompt_path_pattern,
root_path: Path = Path.cwd(),
target_path: Path = Path("promptfooconfigs"),
) -> None:
"""
Generates promptfoo configuration files for all discovered prompts.
Expand Down
68 changes: 68 additions & 0 deletions packages/ragbits-core/src/ragbits/core/utils/_pyproject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from pathlib import Path
from typing import Any, TypeVar

import tomli
from pydantic import BaseModel


def find_pyproject() -> Path:
"""
Find the pyproject.toml file in the current directory or any of its parents.
Returns:
Path: The path to the found pyproject.toml file.
Raises:
FileNotFoundError: If the pyproject.toml file is not found.
"""
current_dir = Path.cwd()
possible_dirs = [current_dir, *current_dir.parents]
for possible_dir in possible_dirs:
pyproject = possible_dir / "pyproject.toml"
if pyproject.exists():
return pyproject
raise FileNotFoundError("pyproject.toml not found")


def get_ragbits_config() -> dict[str, Any]:
"""
Get the ragbits configuration from the project's pyproject.toml file.
Only configuration from the [tool.ragbits] section is returned.
If the project doesn't include any ragbits configuration, an empty dictionary is returned.
Returns:
dict: The ragbits configuration.
"""
try:
pyproject = find_pyproject()
except FileNotFoundError:
# Projects are not required to use pyproject.toml
# No file just means no configuration
return {}

with pyproject.open("rb") as f:
pyproject_data = tomli.load(f)
return pyproject_data.get("tool", {}).get("ragbits", {})


ConfigModelT = TypeVar("ConfigModelT", bound=BaseModel)


def get_config_instance(model: type[ConfigModelT], subproject: str | None = None) -> ConfigModelT:
"""
Creates an instace of pydantic model loaded with the configuration from pyproject.toml.
Args:
model (Type[BaseModel]): The pydantic model to instantiate.
subproject (str, optional): The subproject to get the configuration for, defaults to giving entire
ragbits configuration.
Returns:
ConfigModelT: The model instance loaded with the configuration
"""
config = get_ragbits_config()
print(config)
if subproject:
config = config.get(subproject, {})
return model(**config)
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7042460

Please sign in to comment.