diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 87970123..5bd310a0 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -4,7 +4,6 @@ from dbally.audit.event_tracker import EventTracker from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template from dbally.llm_client.base import LLMClient, LLMOptions -from dbally.prompts.prompt_builder import PromptBuilder from dbally.views.exposed_functions import ExposedFunction @@ -28,19 +27,16 @@ def __init__( self, llm_client: LLMClient, prompt_template: Optional[IQLPromptTemplate] = None, - prompt_builder: Optional[PromptBuilder] = None, promptify_view: Optional[Callable] = None, ) -> None: """ Args: llm_client: LLM client used to generate IQL prompt_template: If not provided by the users is set to `default_iql_template` - prompt_builder: PromptBuilder used to insert arguments into the prompt and adjust style per model. promptify_view: Function formatting filters for prompt """ self._llm_client = llm_client self._prompt_template = prompt_template or copy.deepcopy(default_iql_template) - self._prompt_builder = prompt_builder or PromptBuilder() self._promptify_view = promptify_view or _promptify_filters async def generate_iql( diff --git a/src/dbally/llm_client/base.py b/src/dbally/llm_client/base.py index 51110066..896e3eed 100644 --- a/src/dbally/llm_client/base.py +++ b/src/dbally/llm_client/base.py @@ -4,6 +4,7 @@ import abc from abc import ABC from dataclasses import asdict, dataclass +from functools import cached_property from typing import Any, ClassVar, Dict, Generic, Optional, Type, TypeVar, Union from dbally.audit.event_tracker import EventTracker @@ -67,12 +68,18 @@ class LLMClient(Generic[LLMClientOptions], ABC): def __init__(self, model_name: str, default_options: Optional[LLMClientOptions] = None) -> None: self.model_name = model_name self.default_options = default_options or self._options_cls() - self._prompt_builder = PromptBuilder(self.model_name) def __init_subclass__(cls) -> None: if not hasattr(cls, "_options_cls"): raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute") + @cached_property + def _prompt_builder(self) -> PromptBuilder: + """ + Prompt builder used to construct final prompts for the LLM. + """ + return PromptBuilder() + async def text_generation( # pylint: disable=R0913 self, template: PromptTemplate, diff --git a/src/dbally/prompts/prompt_builder.py b/src/dbally/prompts/prompt_builder.py index a6e5596c..6ab00852 100644 --- a/src/dbally/prompts/prompt_builder.py +++ b/src/dbally/prompts/prompt_builder.py @@ -13,19 +13,19 @@ class PromptBuilder: def __init__(self, model_name: Optional[str] = None) -> None: """ Args: - model_name: Name of the model to load a tokenizer for. - Tokenizer is used to append special tokens to the prompt. If empty, no tokens will be added. + model_name: name of the tokenizer model to use. If provided, the tokenizer will convert the prompt to the + format expected by the model. The model_name should be a model available on huggingface.co/models. Raises: OSError: If model_name is not found in huggingface.co/models """ self._tokenizer: Optional["PreTrainedTokenizer"] = None - if model_name is not None and not model_name.startswith("gpt"): + if model_name is not None: try: from transformers import AutoTokenizer # pylint: disable=import-outside-toplevel except ImportError as exc: - raise ImportError("You need to install transformers package to use huggingface models.") from exc + raise ImportError("You need to install transformers package to use huggingface tokenizers") from exc self._tokenizer = AutoTokenizer.from_pretrained(model_name) diff --git a/src/dbally/view_selection/llm_view_selector.py b/src/dbally/view_selection/llm_view_selector.py index b5a6eb9e..477f763f 100644 --- a/src/dbally/view_selection/llm_view_selector.py +++ b/src/dbally/view_selection/llm_view_selector.py @@ -4,7 +4,6 @@ from dbally.audit.event_tracker import EventTracker from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate from dbally.llm_client.base import LLMClient, LLMOptions -from dbally.prompts import PromptBuilder from dbally.view_selection.base import ViewSelector from dbally.view_selection.view_selector_prompt_template import default_view_selector_template @@ -24,20 +23,17 @@ def __init__( self, llm_client: LLMClient, prompt_template: Optional[IQLPromptTemplate] = None, - prompt_builder: Optional[PromptBuilder] = None, promptify_views: Optional[Callable[[Dict[str, str]], str]] = None, ) -> None: """ Args: llm_client: LLM client used to generate IQL prompt_template: template for the prompt used for the view selection - prompt_builder: PromptBuilder used to insert arguments into the prompt and adjust style per model promptify_views: Function formatting filters for prompt. By default names and descriptions of\ all views are concatenated """ self._llm_client = llm_client self._prompt_template = prompt_template or copy.deepcopy(default_view_selector_template) - self._prompt_builder = prompt_builder or PromptBuilder() self._promptify_views = promptify_views or _promptify_views async def select_view( diff --git a/tests/integration/test_llm_options.py b/tests/integration/test_llm_options.py index a2b42771..35fd1c8b 100644 --- a/tests/integration/test_llm_options.py +++ b/tests/integration/test_llm_options.py @@ -1,10 +1,9 @@ -from unittest.mock import ANY, AsyncMock, Mock, call +from unittest.mock import ANY, AsyncMock, call import pytest from dbally import create_collection -from dbally.llm_client.base import LLMClient -from tests.unit.mocks import MockLLMOptions, MockViewBase +from tests.unit.mocks import MockLLMClient, MockLLMOptions, MockViewBase class MockView1(MockViewBase): @@ -15,23 +14,6 @@ class MockView2(MockViewBase): ... -class MockLLMClient(LLMClient[MockLLMOptions]): - _options_cls = MockLLMOptions - - # TODO: Start calling super().__init__ and remove the pyling comment below - # as soon as the base class is refactored to not have PromptBuilder initialization - # hardcoded in its constructor. - # See: DBALLY-105 - # pylint: disable=super-init-not-called - def __init__(self, default_options: MockLLMOptions) -> None: - self.model_name = "gpt-4" - self.default_options = default_options - self._prompt_builder = Mock() - - async def call(self, *_, **__) -> str: - ... - - @pytest.mark.asyncio async def test_llm_options_propagation(): default_options = MockLLMOptions(mock_property1=1, mock_property2="default mock") diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index f43f1f9e..2e8ceafa 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -5,7 +5,7 @@ """ from dataclasses import dataclass -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union from unittest.mock import create_autospec from dbally import NOT_GIVEN, NotGiven @@ -71,16 +71,12 @@ class MockLLMOptions(LLMOptions): class MockLLMClient(LLMClient[MockLLMOptions]): _options_cls = MockLLMOptions - # TODO: Start calling super().__init__ and remove the pyling comment below - # as soon as the base class is refactored to not have PromptBuilder initialization - # hardcoded in its constructor. - # See: DBALLY-105 - # pylint: disable=super-init-not-called - def __init__(self, *_, **__) -> None: - self.model_name = "mock model" - - async def text_generation(self, *_, **__) -> str: - return "mock response" + def __init__( + self, + model_name: str = "gpt-4-mock", + default_options: Optional[MockLLMOptions] = None, + ) -> None: + super().__init__(model_name, default_options) async def call(self, *_, **__) -> str: return "mock response"