Skip to content

Commit

Permalink
Add dynamic loading for modules dependent on optional deps
Browse files Browse the repository at this point in the history
  • Loading branch information
akonarski-ds committed Oct 24, 2024
1 parent d79ef6d commit 57d7141
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/document-search/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# ///
import asyncio

from ragbits.core.embeddings import LiteLLMEmbeddings
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_store import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
Expand Down
2 changes: 1 addition & 1 deletion examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import chromadb

from ragbits.core.embeddings import LiteLLMEmbeddings
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.vector_store.chromadb_store import ChromaDBStore
from ragbits.document_search import DocumentSearch, SearchConfig
from ragbits.document_search.documents.document import DocumentMeta
Expand Down
6 changes: 3 additions & 3 deletions examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
]

config = {
"embedder": {"type": "LiteLLMEmbeddings"},
"embedder": {"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings"},
"vector_store": {
"type": "ChromaDBStore",
"type": "ragbits.core.vector_store.chromadb_store:ChromaDBStore",
"config": {
"chroma_client": {"type": "PersistentClient", "config": {"path": "chroma"}},
"embedding_function": {"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings"},
Expand All @@ -36,7 +36,7 @@
"type": "LLMQueryRephraser",
"config": {
"llm": {
"type": "LiteLLM",
"type": "ragbits.core.llms.litellm:LiteLLM",
"config": {
"model_name": "gpt-4-turbo",
},
Expand Down
9 changes: 5 additions & 4 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import Embeddings
from .litellm import LiteLLMEmbeddings
from .local import LocalEmbeddings
from .noop import NoopEmbeddings

__all__ = ["Embeddings", "LiteLLMEmbeddings", "LocalEmbeddings", "NoopEmbeddings"]
__all__ = ["Embeddings", "NoopEmbeddings"]

module = sys.modules[__name__]

Expand All @@ -24,4 +24,5 @@ def get_embeddings(embedder_config: dict) -> Embeddings:
embeddings_type = embedder_config["type"]
config = embedder_config.get("config", {})

return getattr(module, embeddings_type)(**config)
embbedings = get_cls_from_config(embeddings_type, module)
return embbedings(**config)
10 changes: 4 additions & 6 deletions packages/ragbits-core/src/ragbits/core/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import sys

from .base import LLM
from .litellm import LiteLLM
from .local import LocalLLM
from ragbits.core.utils.config_handling import get_cls_from_config

__all__ = ["LLM", "LiteLLM", "LocalLLM"]
from .base import LLM

__all__ = ["LLM"]

module = sys.modules[__name__]

Expand All @@ -28,8 +27,7 @@ def get_llm(config: dict) -> LLM:
llm_type = config["type"]
llm_config = config.get("config", {})
default_options = llm_config.pop("default_options", None)

llm_cls = getattr(module, llm_type)
llm_cls = get_cls_from_config(llm_type, module)

if not issubclass(llm_cls, LLM):
raise ValueError(f"Invalid LLM class: {llm_cls}")
Expand Down
20 changes: 16 additions & 4 deletions packages/ragbits-core/src/ragbits/core/utils/config_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
from typing import Any


class InvalidConfigError(Exception):
"""
An exception to be raised when an invalid configuration is provided.
"""


def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # noqa: ANN401
"""
Retrieves and returns a class based on the given type string. The class can be either in the
Expand All @@ -19,8 +25,14 @@ def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any: # no
Any: The object retrieved from the specified or default module.
"""
if ":" in cls_path:
module_stringified, object_stringified = cls_path.split(":")
module = import_module(module_stringified)
return getattr(module, object_stringified)
try:
module_stringified, object_stringified = cls_path.split(":")
module = import_module(module_stringified)
return getattr(module, object_stringified)
except AttributeError as err:
raise InvalidConfigError(f"Class {object_stringified} not found in module {module_stringified}") from err

return getattr(default_module, cls_path)
try:
return getattr(default_module, cls_path)
except AttributeError as err:
raise InvalidConfigError(f"Class {cls_path} not found in module {default_module}") from err
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from ..utils.config_handling import get_cls_from_config
from .base import VectorDBEntry, VectorStore, WhereQuery
from .chromadb_store import ChromaDBStore
from .in_memory import InMemoryVectorStore

__all__ = ["ChromaDBStore", "InMemoryVectorStore", "VectorDBEntry", "VectorStore", "WhereQuery"]
__all__ = ["InMemoryVectorStore", "VectorDBEntry", "VectorStore", "WhereQuery"]

module = sys.modules[__name__]

Expand All @@ -24,7 +23,7 @@ def get_vector_store(vector_store_config: dict) -> VectorStore:
vector_store_cls = get_cls_from_config(vector_store_config["type"], module)
config = vector_store_config.get("config", {})

if vector_store_config["type"] == "ChromaDBStore":
if vector_store_config["type"].endswith("ChromaDBStore"):
return vector_store_cls.from_config(config)

return vector_store_cls(**config)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

from ragbits.core.embeddings import Embeddings
from ragbits.core.vector_store import ChromaDBStore, VectorDBEntry
from ragbits.core.vector_store import VectorDBEntry
from ragbits.core.vector_store.chromadb_store import ChromaDBStore


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,8 @@

from .base import BaseProvider
from .dummy import DummyProvider
from .unstructured.default import UnstructuredDefaultProvider
from .unstructured.images import UnstructuredImageProvider
from .unstructured.pdf import UnstructuredPdfProvider

__all__ = [
"BaseProvider",
"DummyProvider",
"UnstructuredDefaultProvider",
"UnstructuredImageProvider",
"UnstructuredPdfProvider",
]

__all__ = ["BaseProvider", "DummyProvider", "get_provider"]

module = sys.modules[__name__]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .default import UnstructuredDefaultProvider
from .images import UnstructuredImageProvider
from .pdf import UnstructuredPdfProvider

__all__ = ["UnstructuredDefaultProvider", "UnstructuredImageProvider", "UnstructuredPdfProvider"]

0 comments on commit 57d7141

Please sign in to comment.