-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(llms): option to set a default LLM factory
- Loading branch information
1 parent
cc33d4c
commit 3a6d5ac
Showing
7 changed files
with
139 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import importlib | ||
|
||
from ragbits.core.config import core_config | ||
from ragbits.core.llms.base import LLM | ||
from ragbits.core.llms.litellm import LiteLLM | ||
|
||
|
||
def get_llm_from_factory(factory_path: str) -> LLM: | ||
""" | ||
Get an instance of an LLM using a factory function specified by the user. | ||
Args: | ||
factory_path (str): The path to the factory function. | ||
Returns: | ||
LLM: An instance of the LLM. | ||
""" | ||
module_name, function_name = factory_path.rsplit(".", 1) | ||
module = importlib.import_module(module_name) | ||
function = getattr(module, function_name) | ||
return function() | ||
|
||
|
||
def has_default_llm() -> bool: | ||
""" | ||
Check if the default LLM factory is set in the configuration. | ||
Returns: | ||
bool: Whether the default LLM factory is set. | ||
""" | ||
return core_config.default_llm_factory is not None | ||
|
||
|
||
def get_default_llm() -> LLM: | ||
""" | ||
Get an instance of the default LLM using the factory function | ||
specified in the configuration. | ||
Returns: | ||
LLM: An instance of the default LLM. | ||
Raises: | ||
ValueError: If the default LLM factory is not set. | ||
""" | ||
factory = core_config.default_llm_factory | ||
if factory is None: | ||
raise ValueError("Default LLM factory is not set") | ||
|
||
return get_llm_from_factory(factory) | ||
|
||
|
||
def simple_litellm_factory() -> LLM: | ||
""" | ||
A basic LLM factory that creates an LiteLLM instance with the default model, | ||
default options, and assumes that the API key is set in the environment. | ||
Returns: | ||
LLM: An instance of the LiteLLM. | ||
""" | ||
return LiteLLM() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
16 changes: 16 additions & 0 deletions
16
packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from ragbits.core.config import core_config | ||
from ragbits.core.llms.factory import get_default_llm | ||
from ragbits.core.llms.litellm import LiteLLM | ||
|
||
|
||
def test_get_default_llm(monkeypatch): | ||
""" | ||
Test the get_llm_from_factory function. | ||
""" | ||
monkeypatch.setattr( | ||
core_config, "default_llm_factory", "tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory" | ||
) | ||
|
||
llm = get_default_llm() | ||
Check failure on line 14 in packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py
|
||
assert isinstance(llm, LiteLLM) | ||
assert llm.model_name == "mock_model" |
22 changes: 22 additions & 0 deletions
22
packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from ragbits.core.llms.factory import get_llm_from_factory | ||
from ragbits.core.llms.litellm import LiteLLM | ||
|
||
|
||
def mock_llm_factory() -> LiteLLM: | ||
""" | ||
A mock LLM factory that creates a LiteLLM instance with a mock model name. | ||
Returns: | ||
LiteLLM: An instance of the LiteLLM. | ||
""" | ||
return LiteLLM(model_name="mock_model") | ||
|
||
|
||
def test_get_llm_from_factory(): | ||
""" | ||
Test the get_llm_from_factory function. | ||
""" | ||
llm = get_llm_from_factory("tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory") | ||
Check failure on line 19 in packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py
|
||
|
||
assert isinstance(llm, LiteLLM) | ||
assert llm.model_name == "mock_model" |
20 changes: 20 additions & 0 deletions
20
packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from ragbits.core.config import core_config | ||
from ragbits.core.llms.factory import has_default_llm | ||
|
||
|
||
def test_has_default_llm(monkeypatch): | ||
""" | ||
Test the has_default_llm function when the default LLM factory is not set. | ||
""" | ||
monkeypatch.setattr(core_config, "default_llm_factory", None) | ||
|
||
assert has_default_llm() is False | ||
|
||
|
||
def test_has_default_llm_false(monkeypatch): | ||
""" | ||
Test the has_default_llm function when the default LLM factory is set. | ||
""" | ||
monkeypatch.setattr(core_config, "default_llm_factory", "my_project.llms.get_llm") | ||
|
||
assert has_default_llm() is True |