From 3a6d5ace7c93347f83bd9c40041e529cb16eb846 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Mon, 14 Oct 2024 17:53:27 +0200 Subject: [PATCH 1/2] feat(llms): option to set a default LLM factory --- .../ragbits-core/src/ragbits/core/config.py | 3 + .../src/ragbits/core/llms/factory.py | 60 +++++++++++++++++++ .../src/ragbits/core/prompt/lab/app.py | 35 +++++------ .../tests/unit/llms/factory/__init__.py | 0 .../unit/llms/factory/test_get_default_llm.py | 16 +++++ .../llms/factory/test_get_llm_from_factory.py | 22 +++++++ .../unit/llms/factory/test_has_default_llm.py | 20 +++++++ 7 files changed, 139 insertions(+), 17 deletions(-) create mode 100644 packages/ragbits-core/src/ragbits/core/llms/factory.py create mode 100644 packages/ragbits-core/tests/unit/llms/factory/__init__.py create mode 100644 packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py create mode 100644 packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py create mode 100644 packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py diff --git a/packages/ragbits-core/src/ragbits/core/config.py b/packages/ragbits-core/src/ragbits/core/config.py index 49c0b12b..830dfc9f 100644 --- a/packages/ragbits-core/src/ragbits/core/config.py +++ b/packages/ragbits-core/src/ragbits/core/config.py @@ -11,5 +11,8 @@ class CoreConfig(BaseModel): # Pattern used to search for prompt files prompt_path_pattern: str = "**/prompt_*.py" + # Path to a function that returns an LLM object, e.g. "my_project.llms.get_llm" + default_llm_factory: str | None = None + core_config = get_config_instance(CoreConfig, subproject="core") diff --git a/packages/ragbits-core/src/ragbits/core/llms/factory.py b/packages/ragbits-core/src/ragbits/core/llms/factory.py new file mode 100644 index 00000000..02bd4704 --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/llms/factory.py @@ -0,0 +1,60 @@ +import importlib + +from ragbits.core.config import core_config +from ragbits.core.llms.base import LLM +from ragbits.core.llms.litellm import LiteLLM + + +def get_llm_from_factory(factory_path: str) -> LLM: + """ + Get an instance of an LLM using a factory function specified by the user. + + Args: + factory_path (str): The path to the factory function. + + Returns: + LLM: An instance of the LLM. + """ + module_name, function_name = factory_path.rsplit(".", 1) + module = importlib.import_module(module_name) + function = getattr(module, function_name) + return function() + + +def has_default_llm() -> bool: + """ + Check if the default LLM factory is set in the configuration. + + Returns: + bool: Whether the default LLM factory is set. + """ + return core_config.default_llm_factory is not None + + +def get_default_llm() -> LLM: + """ + Get an instance of the default LLM using the factory function + specified in the configuration. + + Returns: + LLM: An instance of the default LLM. + + Raises: + ValueError: If the default LLM factory is not set. + """ + factory = core_config.default_llm_factory + if factory is None: + raise ValueError("Default LLM factory is not set") + + return get_llm_from_factory(factory) + + +def simple_litellm_factory() -> LLM: + """ + A basic LLM factory that creates an LiteLLM instance with the default model, + default options, and assumes that the API key is set in the environment. + + Returns: + LLM: An instance of the LiteLLM. + """ + return LiteLLM() diff --git a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py index 9b6a7dc4..05648b01 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py @@ -14,8 +14,8 @@ from rich.console import Console from ragbits.core.config import core_config -from ragbits.core.llms import LiteLLM -from ragbits.core.llms.clients import LiteLLMOptions +from ragbits.core.llms import LLM +from ragbits.core.llms.factory import get_llm_from_factory from ragbits.core.prompt import Prompt from ragbits.core.prompt.discovery import PromptDiscovery @@ -30,14 +30,12 @@ class PromptState: Attributes: prompts (list): A list containing discovered prompts. rendered_prompt (Prompt): The most recently rendered Prompt instance. - llm_model_name (str): The name of the selected LLM model. - llm_api_key (str | None): The API key for the chosen LLM model. + llm (LLM): The LLM instance to be used for generating responses. """ prompts: list = field(default_factory=list) rendered_prompt: Prompt | None = None - llm_model_name: str | None = None - llm_api_key: str | None = None + llm: LLM | None = None def render_prompt(index: int, system_prompt: str, user_prompt: str, state: PromptState, *args: Any) -> PromptState: @@ -99,15 +97,17 @@ def send_prompt_to_llm(state: PromptState) -> str: Returns: str: The response generated by the LLM. - """ - assert state.llm_model_name is not None, "LLM model name is not set." - llm_client = LiteLLM(model_name=state.llm_model_name, api_key=state.llm_api_key) + Raises: + ValueError: If the LLM model is not configured. + """ assert state.rendered_prompt is not None, "Prompt has not been rendered yet." + + if state.llm is None: + raise ValueError("LLM model is not configured.") + try: - response = asyncio.run( - llm_client.client.call(conversation=state.rendered_prompt.chat, options=LiteLLMOptions()) - ) + response = asyncio.run(state.llm.generate_raw(prompt=state.rendered_prompt)) except Exception as e: # pylint: disable=broad-except response = str(e) @@ -136,7 +136,8 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]: def lab_app( # pylint: disable=missing-param-doc - file_pattern: str = core_config.prompt_path_pattern, llm_model: str | None = None, llm_api_key: str | None = None + file_pattern: str = core_config.prompt_path_pattern, + llm_factory: str | None = core_config.default_llm_factory, ) -> None: """ Launches the interactive application for listing, rendering, and testing prompts @@ -163,9 +164,8 @@ def lab_app( # pylint: disable=missing-param-doc with gr.Blocks() as gr_app: prompts_state = gr.State( PromptState( - llm_model_name=llm_model, - llm_api_key=llm_api_key, prompts=list(prompts), + llm=get_llm_from_factory(llm_factory) if llm_factory else None, ) ) @@ -220,14 +220,15 @@ def show_split(index: int, state: gr.State) -> None: ) gr.Textbox(label="Rendered User Prompt", value=rendered_user_prompt, interactive=False) - llm_enabled = state.llm_model_name is not None + llm_enabled = state.llm is not None prompt_ready = state.rendered_prompt is not None llm_request_button = gr.Button( value="Send to LLM", interactive=llm_enabled and prompt_ready, ) gr.Markdown( - "To enable this button, select an LLM model when starting the app in CLI.", visible=not llm_enabled + "To enable this button set an LLM factory function in CLI options or your pyproject.toml", + visible=not llm_enabled, ) gr.Markdown("To enable this button, render a prompt first.", visible=llm_enabled and not prompt_ready) llm_prompt_response = gr.Textbox(lines=10, label="LLM response") diff --git a/packages/ragbits-core/tests/unit/llms/factory/__init__.py b/packages/ragbits-core/tests/unit/llms/factory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py b/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py new file mode 100644 index 00000000..9a6c9885 --- /dev/null +++ b/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py @@ -0,0 +1,16 @@ +from ragbits.core.config import core_config +from ragbits.core.llms.factory import get_default_llm +from ragbits.core.llms.litellm import LiteLLM + + +def test_get_default_llm(monkeypatch): + """ + Test the get_llm_from_factory function. + """ + monkeypatch.setattr( + core_config, "default_llm_factory", "tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory" + ) + + llm = get_default_llm() + assert isinstance(llm, LiteLLM) + assert llm.model_name == "mock_model" diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py b/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py new file mode 100644 index 00000000..53188bb3 --- /dev/null +++ b/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py @@ -0,0 +1,22 @@ +from ragbits.core.llms.factory import get_llm_from_factory +from ragbits.core.llms.litellm import LiteLLM + + +def mock_llm_factory() -> LiteLLM: + """ + A mock LLM factory that creates a LiteLLM instance with a mock model name. + + Returns: + LiteLLM: An instance of the LiteLLM. + """ + return LiteLLM(model_name="mock_model") + + +def test_get_llm_from_factory(): + """ + Test the get_llm_from_factory function. + """ + llm = get_llm_from_factory("tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory") + + assert isinstance(llm, LiteLLM) + assert llm.model_name == "mock_model" diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py b/packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py new file mode 100644 index 00000000..59c86483 --- /dev/null +++ b/packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py @@ -0,0 +1,20 @@ +from ragbits.core.config import core_config +from ragbits.core.llms.factory import has_default_llm + + +def test_has_default_llm(monkeypatch): + """ + Test the has_default_llm function when the default LLM factory is not set. + """ + monkeypatch.setattr(core_config, "default_llm_factory", None) + + assert has_default_llm() is False + + +def test_has_default_llm_false(monkeypatch): + """ + Test the has_default_llm function when the default LLM factory is set. + """ + monkeypatch.setattr(core_config, "default_llm_factory", "my_project.llms.get_llm") + + assert has_default_llm() is True From d5b4955c4606d04fc7a95bc985a016a3b9f07eab Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Tue, 15 Oct 2024 10:44:21 +0200 Subject: [PATCH 2/2] Fix tests --- .../tests/unit/llms/factory/__init__.py | 5 +++++ .../unit/llms/factory/test_get_default_llm.py | 5 ++--- .../unit/llms/factory/test_get_llm_from_factory.py | 2 +- uv.lock | 14 ++------------ 4 files changed, 10 insertions(+), 16 deletions(-) diff --git a/packages/ragbits-core/tests/unit/llms/factory/__init__.py b/packages/ragbits-core/tests/unit/llms/factory/__init__.py index e69de29b..a3559f0c 100644 --- a/packages/ragbits-core/tests/unit/llms/factory/__init__.py +++ b/packages/ragbits-core/tests/unit/llms/factory/__init__.py @@ -0,0 +1,5 @@ +import sys +from pathlib import Path + +# Add "llms" to sys.path +sys.path.append(str(Path(__file__).parent.parent)) diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py b/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py index 9a6c9885..c07272fb 100644 --- a/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py +++ b/packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py @@ -7,9 +7,8 @@ def test_get_default_llm(monkeypatch): """ Test the get_llm_from_factory function. """ - monkeypatch.setattr( - core_config, "default_llm_factory", "tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory" - ) + + monkeypatch.setattr(core_config, "default_llm_factory", "factory.test_get_llm_from_factory.mock_llm_factory") llm = get_default_llm() assert isinstance(llm, LiteLLM) diff --git a/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py b/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py index 53188bb3..8d2a948c 100644 --- a/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py +++ b/packages/ragbits-core/tests/unit/llms/factory/test_get_llm_from_factory.py @@ -16,7 +16,7 @@ def test_get_llm_from_factory(): """ Test the get_llm_from_factory function. """ - llm = get_llm_from_factory("tests.unit.llms.factory.test_get_llm_from_factory.mock_llm_factory") + llm = get_llm_from_factory("factory.test_get_llm_from_factory.mock_llm_factory") assert isinstance(llm, LiteLLM) assert llm.model_name == "mock_model" diff --git a/uv.lock b/uv.lock index f685d643..6156c1af 100644 --- a/uv.lock +++ b/uv.lock @@ -617,7 +617,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version == '3.11'" }, + { name = "tomli", marker = "python_full_version <= '3.11'" }, ] [[package]] @@ -2038,7 +2038,6 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -2047,7 +2046,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -2056,7 +2054,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, - { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, ] [[package]] @@ -2065,7 +2062,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -2085,7 +2081,6 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -2094,7 +2089,6 @@ version = "10.3.2.106" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, - { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, ] [[package]] @@ -2108,7 +2102,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -2120,7 +2113,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -2138,7 +2130,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/11/8c/386018fdffdce2ff8d43fedf192ef7d14cab7501cbf78a106dd2e9f1fc1f/nvidia_nvjitlink_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:3bf10d85bb1801e9c894c6e197e44dd137d2a0a9e43f8450e9ad13f2df0dd52d", size = 19270432 }, { url = "https://files.pythonhosted.org/packages/fe/e4/486de766851d58699bcfeb3ba6a3beb4d89c3809f75b9d423b9508a8760f/nvidia_nvjitlink_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9ae346d16203ae4ea513be416495167a0101d33d2d14935aa9c1829a3fb45142", size = 19745114 }, - { url = "https://files.pythonhosted.org/packages/1a/aa/7b5d8e22d73e03f941293ae62c993642fa41e6525f3213292e007621aa8e/nvidia_nvjitlink_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:410718cd44962bed862a31dd0318620f6f9a8b28a6291967bcfcb446a6516771", size = 161917250 }, ] [[package]] @@ -2147,7 +2138,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, - { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] [[package]] @@ -3210,7 +3200,7 @@ dev = [ [[package]] name = "ragbits-workspace" version = "0.1.0" -source = { editable = "." } +source = { virtual = "." } dependencies = [ { name = "ragbits-cli" }, { name = "ragbits-core", extra = ["chromadb", "lab", "litellm", "local"] },