-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(prompts): Load prompt pattern config from pyproject.toml
- Loading branch information
1 parent
2e95436
commit 7042460
Showing
7 changed files
with
95 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from pydantic import BaseModel | ||
|
||
from ragbits.core.utils._pyproject import get_config_instance | ||
|
||
|
||
class CoreConfig(BaseModel): | ||
""" | ||
Configuration for the ragbits-core package, loaded from downstream projects' pyproject.toml files. | ||
""" | ||
|
||
# Pattern used to search for prompt files | ||
prompt_path_pattern: str = "**/prompt_*.py" | ||
|
||
|
||
core_config = get_config_instance(CoreConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
packages/ragbits-core/src/ragbits/core/utils/_pyproject.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from pathlib import Path | ||
from typing import Any, TypeVar | ||
|
||
import tomli | ||
from pydantic import BaseModel | ||
|
||
|
||
def find_pyproject() -> Path: | ||
""" | ||
Find the pyproject.toml file in the current directory or any of its parents. | ||
Returns: | ||
Path: The path to the found pyproject.toml file. | ||
Raises: | ||
FileNotFoundError: If the pyproject.toml file is not found. | ||
""" | ||
current_dir = Path.cwd() | ||
possible_dirs = [current_dir, *current_dir.parents] | ||
for possible_dir in possible_dirs: | ||
pyproject = possible_dir / "pyproject.toml" | ||
if pyproject.exists(): | ||
return pyproject | ||
raise FileNotFoundError("pyproject.toml not found") | ||
|
||
|
||
def get_ragbits_config() -> dict[str, Any]: | ||
""" | ||
Get the ragbits configuration from the project's pyproject.toml file. | ||
Only configuration from the [tool.ragbits] section is returned. | ||
If the project doesn't include any ragbits configuration, an empty dictionary is returned. | ||
Returns: | ||
dict: The ragbits configuration. | ||
""" | ||
try: | ||
pyproject = find_pyproject() | ||
except FileNotFoundError: | ||
# Projects are not required to use pyproject.toml | ||
# No file just means no configuration | ||
return {} | ||
|
||
with pyproject.open("rb") as f: | ||
pyproject_data = tomli.load(f) | ||
return pyproject_data.get("tool", {}).get("ragbits", {}) | ||
|
||
|
||
ConfigModelT = TypeVar("ConfigModelT", bound=BaseModel) | ||
|
||
|
||
def get_config_instance(model: type[ConfigModelT], subproject: str | None = None) -> ConfigModelT: | ||
""" | ||
Creates an instace of pydantic model loaded with the configuration from pyproject.toml. | ||
Args: | ||
model (Type[BaseModel]): The pydantic model to instantiate. | ||
subproject (str, optional): The subproject to get the configuration for, defaults to giving entire | ||
ragbits configuration. | ||
Returns: | ||
ConfigModelT: The model instance loaded with the configuration | ||
""" | ||
config = get_ragbits_config() | ||
print(config) | ||
if subproject: | ||
config = config.get(subproject, {}) | ||
return model(**config) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.