Skip to content

Commit

Permalink
fix(cli): fix cli failing on optional deps and get rid of unnecesssar…
Browse files Browse the repository at this point in the history
…y module loading (#111)
  • Loading branch information
akonarski-ds authored Oct 16, 2024
1 parent d61725f commit 17d7b3d
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 15 deletions.
9 changes: 6 additions & 3 deletions packages/ragbits-cli/src/ragbits/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import importlib.util
import pkgutil
import sys
from pathlib import Path

from typer import Typer

Expand All @@ -18,14 +20,15 @@ def main() -> None:
- if found it imports the `register` function from the `cli` module and calls it with the `app` object
- register function should add the CLI commands to the `app` object
"""
help_only = len(sys.argv) == 1 or sys.argv[1] == "--help"

cli_enabled_modules = [
module
for module in pkgutil.iter_modules(ragbits.__path__)
if module.ispkg and module.name != "cli" and importlib.util.find_spec(f"ragbits.{module.name}.cli")
for i, module in enumerate(pkgutil.iter_modules(ragbits.__path__))
if module.ispkg and module.name != "cli" and (Path(ragbits.__path__[i]) / module.name / "cli.py").exists()
]
for module in cli_enabled_modules:
register_func = importlib.import_module(f"ragbits.{module.name}.cli").register
register_func(app)
register_func(app, help_only)

app()
3 changes: 3 additions & 0 deletions packages/ragbits-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ local = [
lab = [
"gradio~=4.44.0",
]
promptfoo = [
"PyYAML~=6.0.2",
]

[tool.uv]
dev-dependencies = [
Expand Down
15 changes: 9 additions & 6 deletions packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# pylint: disable=import-outside-toplevel
import typer

from .prompt.lab.app import lab_app
from .prompt.promptfoo import generate_configs

prompts_app = typer.Typer(no_args_is_help=True)


def register(app: typer.Typer) -> None:
def register(app: typer.Typer, help_only: bool) -> None:
"""
Register the CLI commands for the package.
Args:
app: The Typer object to register the commands with.
help_only: A boolean indicating whether it is a help-only run.
"""
prompts_app.command(name="lab")(lab_app)
prompts_app.command(name="generate-promptfoo-configs")(generate_configs)
if not help_only:
from .prompt.lab.app import lab_app
from .prompt.promptfoo import generate_configs

prompts_app.command(name="lab")(lab_app)
prompts_app.command(name="generate-promptfoo-configs")(generate_configs)
app.add_typer(prompts_app, name="prompts", help="Commands for managing prompts")
15 changes: 14 additions & 1 deletion packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os
from pathlib import Path

import yaml
try:
import yaml

HAS_PYYAML = True
except ImportError:
HAS_PYYAML = False

from rich.console import Console

from ragbits.core.config import core_config
Expand All @@ -21,6 +27,13 @@ def generate_configs(
root_path: The root path to search for Prompt objects. Defaults to the directory where the script is run.
target_path: The path to save the promptfoo configuration files. Defaults to "promptfooconfigs".
"""
if not HAS_PYYAML:
Console(stderr=True).print(
"To generate configs for promptfoo, you need the PyYAML library. Please install it using the following"
" command:\n[b]pip install ragbits-core\\[promptfoo][/b]"
)
return

prompts = PromptDiscovery(file_pattern=file_pattern, root_path=root_path).discover()
Console().print(
f"Discovered {len(prompts)} prompts."
Expand Down
1 change: 0 additions & 1 deletion packages/ragbits-core/src/ragbits/core/utils/_pyproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def get_config_instance(
ConfigModelT: The model instance loaded with the configuration
"""
config = get_ragbits_config(current_dir)
print(config)
if subproject:
config = config.get(subproject, {})
return model(**config)
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class ChromaDBStore(VectorStore):
def __init__(
self,
index_name: str,
chroma_client: chromadb.ClientAPI,
embedding_function: Union[Embeddings, chromadb.EmbeddingFunction],
chroma_client: "chromadb.ClientAPI",
embedding_function: Union[Embeddings, "chromadb.EmbeddingFunction"],
max_distance: Optional[float] = None,
distance_method: Literal["l2", "ip", "cosine"] = "l2",
):
Expand Down Expand Up @@ -72,7 +72,7 @@ def from_config(cls, config: dict) -> "ChromaDBStore":
distance_method=config.get("distance_method", "l2"),
)

def _get_chroma_collection(self) -> chromadb.Collection:
def _get_chroma_collection(self) -> "chromadb.Collection":
"""
Based on the selected embedding_function, chooses how to retrieve the ChromaDB collection.
If the collection doesn't exist, it creates one.
Expand Down Expand Up @@ -116,7 +116,7 @@ def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], dic
return doc_id, embedding, metadata

@property
def embedding_function(self) -> Union[Embeddings, chromadb.EmbeddingFunction]:
def embedding_function(self) -> Union[Embeddings, "chromadb.EmbeddingFunction"]:
"""
Returns the embedding function.
Expand Down

0 comments on commit 17d7b3d

Please sign in to comment.