Skip to content

Commit

Permalink
refactor: add LLMOptions merge (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst authored May 14, 2024
1 parent 3f5c4bf commit 7191c3f
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/dbally/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .__version__ import __version__
from ._main import create_collection
from ._types import NOT_GIVEN, NotGiven
from .collection import Collection

__all__ = [
Expand All @@ -21,4 +22,6 @@
"BaseStructuredView",
"DataFrameBaseView",
"ExecutionResult",
"NotGiven",
"NOT_GIVEN",
]
31 changes: 31 additions & 0 deletions src/dbally/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing_extensions import Literal, override


# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""

def __bool__(self) -> Literal[False]:
return False

@override
def __repr__(self) -> str:
return "NOT_GIVEN"


NOT_GIVEN = NotGiven()
5 changes: 3 additions & 2 deletions src/dbally/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ async def ask(
"What job offers for Data Scientists do we have?"
dry_run: if True, only generate the query without executing it
return_natural_response: if True (and dry_run is False as natural response requires query results),
the natural response will be included in the answer
llm_options: options to use for the LLM client.
the natural response will be included in the answer
llm_options: options to use for the LLM client. If provided, these options will be merged with the default
options provided to the LLM client, prioritizing option values other than NOT_GIVEN
Returns:
ExecutionResult object representing the result of the query execution.
Expand Down
39 changes: 36 additions & 3 deletions src/dbally/llm_client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import abc
from abc import ABC
from dataclasses import asdict, dataclass
from typing import Dict, Generic, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Dict, Generic, Optional, Type, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
from dbally.data_models.audit import LLMEvent
from dbally.prompts import ChatFormat, PromptBuilder, PromptTemplate

from .._types import NotGiven

LLMOptionsNotGiven = TypeVar("LLMOptionsNotGiven")
LLMClientOptions = TypeVar("LLMClientOptions")


Expand All @@ -19,7 +22,37 @@ class LLMOptions(ABC):
Abstract dataclass that represents all available LLM call options.
"""

dict = asdict
_not_given: ClassVar[Optional[LLMOptionsNotGiven]] = None

def __or__(self, other: "LLMOptions") -> "LLMOptions":
"""
Merges two LLMOptions, prioritizing non-NOT_GIVEN values from the 'other' object.
"""
self_dict = asdict(self)
other_dict = asdict(other)

updated_dict = {
key: other_dict.get(key, self_dict[key])
if not isinstance(other_dict.get(key), NotGiven)
else self_dict[key]
for key in self_dict
}

return self.__class__(**updated_dict)

def dict(self) -> Dict[str, Any]:
"""
Creates a dictionary representation of the LLMOptions instance.
If a value is None, it will be replaced with a provider-specific not-given sentinel.
Returns:
A dictionary representation of the LLMOptions instance.
"""
options = asdict(self)
return {
key: self._not_given if value is None or isinstance(value, NotGiven) else value
for key, value in options.items()
}


class LLMClient(Generic[LLMClientOptions], ABC):
Expand Down Expand Up @@ -61,7 +94,7 @@ async def text_generation( # pylint: disable=R0913
Returns:
Text response from LLM.
"""
options = options if options else self.default_options
options = (self.default_options | options) if options else self.default_options

prompt = self._prompt_builder.build(template, fmt)

Expand Down
8 changes: 6 additions & 2 deletions src/dbally/llm_client/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import ClassVar, Dict, List, Optional, Union

from openai import NOT_GIVEN, NotGiven
from openai import NOT_GIVEN as OPENAI_NOT_GIVEN
from openai import NotGiven as OpenAINotGiven

from dbally.data_models.audit import LLMEvent
from dbally.llm_client.base import LLMClient
from dbally.prompts import ChatFormat

from .._types import NOT_GIVEN, NotGiven
from .base import LLMOptions


Expand All @@ -17,6 +19,8 @@ class OpenAIOptions(LLMOptions):
described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create.)
"""

_not_given: ClassVar[Optional[OpenAINotGiven]] = OPENAI_NOT_GIVEN

frequency_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN
max_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN
n: Union[Optional[int], NotGiven] = NOT_GIVEN
Expand Down
12 changes: 7 additions & 5 deletions tests/integration/test_llm_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ async def call(self, *_, **__) -> str:

@pytest.mark.asyncio
async def test_llm_options_propagation():
default_options = MockLLMOptions(mock_property=1)
custom_options = MockLLMOptions(mock_property=2)
default_options = MockLLMOptions(mock_property1=1, mock_property2="default mock")
custom_options = MockLLMOptions(mock_property1=2)
expected_options = MockLLMOptions(mock_property1=2, mock_property2="default mock")

llm_client = MockLLMClient(default_options=default_options)

collection = create_collection(
Expand Down Expand Up @@ -63,19 +65,19 @@ async def test_llm_options_propagation():
prompt=ANY,
response_format=ANY,
event=ANY,
options=custom_options,
options=expected_options,
),
call(
prompt=ANY,
response_format=ANY,
event=ANY,
options=custom_options,
options=expected_options,
),
call(
prompt=ANY,
response_format=ANY,
event=ANY,
options=custom_options,
options=expected_options,
),
]
)
6 changes: 4 additions & 2 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
"""

from dataclasses import dataclass
from typing import List, Tuple
from typing import List, Tuple, Union
from unittest.mock import create_autospec

from dbally import NOT_GIVEN, NotGiven
from dbally.iql import IQLQuery
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
Expand Down Expand Up @@ -63,7 +64,8 @@ async def similar(self, text: str) -> str:

@dataclass
class MockLLMOptions(LLMOptions):
mock_property: int
mock_property1: Union[int, NotGiven] = NOT_GIVEN
mock_property2: Union[str, NotGiven] = NOT_GIVEN


class MockLLMClient(LLMClient[MockLLMOptions]):
Expand Down

0 comments on commit 7191c3f

Please sign in to comment.