diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index b5d680f7..a01d7dc7 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -4,7 +4,7 @@ import textwrap import time from collections import defaultdict -from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar +from typing import Callable, Dict, List, Optional, Type, TypeVar import dbally from dbally.audit.event_handlers.base import EventHandler @@ -12,7 +12,7 @@ from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult, ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions @@ -228,7 +228,7 @@ async def _ask_view( event_tracker: EventTracker, llm_options: Optional[LLMOptions], dry_run: bool, - contexts: Iterable[BaseCallerContext], + contexts: List[Context], ) -> ViewExecutionResult: """ Ask the selected view to provide an answer to the question. @@ -247,11 +247,11 @@ async def _ask_view( view_result = await selected_view.ask( query=question, llm=self._llm, + contexts=contexts, event_tracker=event_tracker, n_retries=self.n_retries, dry_run=dry_run, llm_options=llm_options, - contexts=contexts, ) return view_result @@ -298,9 +298,11 @@ def get_all_event_handlers(self) -> List[EventHandler]: return self._event_handlers return list(set(self._event_handlers).union(self._fallback_collection.get_all_event_handlers())) + # pylint: disable=too-many-arguments async def _handle_fallback( self, question: str, + contexts: Optional[List[Context]], dry_run: bool, return_natural_response: bool, llm_options: Optional[LLMOptions], @@ -322,7 +324,6 @@ async def _handle_fallback( Returns: The result from the fallback collection. - """ if not self._fallback_collection: raise caught_exception @@ -337,6 +338,7 @@ async def _handle_fallback( async with event_tracker.track_event(fallback_event) as span: result = await self._fallback_collection.ask( question=question, + contexts=contexts, dry_run=dry_run, return_natural_response=return_natural_response, llm_options=llm_options, @@ -348,10 +350,10 @@ async def _handle_fallback( async def ask( self, question: str, + contexts: Optional[List[Context]] = None, dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[BaseCallerContext]] = None, event_tracker: Optional[EventTracker] = None, ) -> ExecutionResult: """ @@ -366,14 +368,14 @@ async def ask( Args: question: question posed using natural language representation e.g\ - "What job offers for Data Scientists do we have?" + "What job offers for Data Scientists do we have?" + contexts: list of context objects, each being an instance of + a subclass of Context. May contain contexts irrelevant for the currently processed query. dry_run: if True, only generate the query without executing it return_natural_response: if True (and dry_run is False as natural response requires query results), the natural response will be included in the answer llm_options: options to use for the LLM client. If provided, these options will be merged with the default options provided to the LLM client, prioritizing option values other than NOT_GIVEN - contexts: An iterable (typically a list) of context objects, each being an instance of - a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query. event_tracker: Event tracker object for given ask. Returns: @@ -433,6 +435,7 @@ async def ask( if self._fallback_collection: result = await self._handle_fallback( question=question, + contexts=contexts, dry_run=dry_run, return_natural_response=return_natural_response, llm_options=llm_options, diff --git a/src/dbally/context.py b/src/dbally/context.py new file mode 100644 index 00000000..46c65030 --- /dev/null +++ b/src/dbally/context.py @@ -0,0 +1,11 @@ +from abc import ABC +from typing import ClassVar + + +class Context(ABC): + """ + Base class for all contexts that are used to pass additional knowledge about the caller environment to the view. + """ + + type_name: ClassVar[str] = "Context" + alias_name: ClassVar[str] = "CONTEXT" diff --git a/src/dbally/context/__init__.py b/src/dbally/context/__init__.py deleted file mode 100644 index 294fbe2f..00000000 --- a/src/dbally/context/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .context import BaseCallerContext - -__all__ = ["BaseCallerContext"] diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py deleted file mode 100644 index e8e3b0eb..00000000 --- a/src/dbally/context/_utils.py +++ /dev/null @@ -1,75 +0,0 @@ -from inspect import isclass -from typing import Any, List, Optional, Sequence, Tuple, Type, Union - -import typing_extensions as type_ext - -from dbally.context.context import BaseCallerContext -from dbally.views.exposed_functions import MethodParamWithTyping - -ContextClass: type_ext.TypeAlias = Optional[Type[BaseCallerContext]] - - -def _extract_params_and_context( - filter_method_: type_ext.Callable, hidden_args: Sequence[str] -) -> Tuple[List[MethodParamWithTyping], ContextClass]: - """ - Processes the MethodsBaseView filter method signature to extract the argument names and type hint - in the form of MethodParamWithTyping list. Additionally, the first type hint, pointing to the subclass - of BaseCallerContext is returned. - - Args: - filter_method_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator) - hidden_args: method arguments that should not be extracted - - Returns: - The first field contains the list of arguments, each encapsulated as MethodParamWithTyping. - The 2nd is the BaseCallerContext subclass provided for this filter, or None if no context specified. - """ - - params = [] - context = None - # TODO confirm whether to use typing.get_type_hints(method) or method.__annotations__ - for name_, type_ in type_ext.get_type_hints(filter_method_).items(): - if name_ in hidden_args: - continue - - if isclass(type_) and issubclass(type_, BaseCallerContext): - # this is the case when user provides a context but no other type hint for a specifc arg - context = type_ - type_ = Any - elif type_ext.get_origin(type_) is Union: - union_subtypes = type_ext.get_args(type_) - if not union_subtypes: - type_ = Any - - for subtype_ in union_subtypes: # type: ignore - # TODO add custom error for the situation when user provides more than two contexts for a single filter - # for now we extract only the first context - if isclass(subtype_) and issubclass(subtype_, BaseCallerContext): - if context is None: - context = subtype_ - - params.append(MethodParamWithTyping(name_, type_)) - - return params, context - - -def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool: - """ - Verifies whether a method argument allows contextualization based on the type hints attached to a method signature. - - Args: - arg: MethodParamWithTyping container preserving information about the method argument - - Returns: - Verification result. - """ - - if type_ext.get_origin(arg.type) is not Union and not issubclass(arg.type, BaseCallerContext): - return False - - for subtype in type_ext.get_args(arg.type): - if issubclass(subtype, BaseCallerContext): - return True - - return False diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py deleted file mode 100644 index aea21d13..00000000 --- a/src/dbally/context/context.py +++ /dev/null @@ -1,17 +0,0 @@ -from abc import ABC -from typing import ClassVar - - -class BaseCallerContext(ABC): - """ - An interface for contexts that are used to pass additional knowledge about - the caller environment to the filters. LLM will always return `Context()` - when the context is required and this call will be later substituted by an instance of - a class implementing this interface, selected based on the filter method signature (type hints). - - Attributes: - alias: Class variable defining an alias which is defined in the prompt for the LLM to reference context. - """ - - type_name: ClassVar[str] = "Context" - alias_name: ClassVar[str] = "CONTEXT" diff --git a/src/dbally/context/exceptions.py b/src/dbally/context/exceptions.py deleted file mode 100644 index c538ee18..00000000 --- a/src/dbally/context/exceptions.py +++ /dev/null @@ -1,23 +0,0 @@ -class BaseContextError(Exception): - """ - A base error for context handling logic. - """ - - -class SuitableContextNotProvidedError(BaseContextError): - """ - Raised when method argument type hint points that a contextualization is available - but not suitable context was provided. - """ - - def __init__(self, filter_fun_signature: str, context_class_name: str) -> None: - # this syntax 'or BaseCallerContext' is just to prevent type checkers - # from raising a warning, as filter_.context_class can be None. It's essenially a fallback that should never - # be reached, unless somebody will use this Exception against its purpose. - # TODO consider raising a warning/error when this happens. - - message = ( - f"No context of class {context_class_name} was provided" - f"while the filter {filter_fun_signature} requires it." - ) - super().__init__(message) diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 9befc28d..d66e6977 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -4,7 +4,7 @@ from typing import Any, Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql import syntax from dbally.iql._exceptions import ( IQLArgumentParsingError, @@ -34,7 +34,7 @@ def __init__( self, source: str, allowed_functions: List[ExposedFunction], - allowed_contexts: Optional[List[BaseCallerContext]] = None, + allowed_contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, ) -> None: self.source = source diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 03abc912..1d8e1ef2 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -8,7 +8,7 @@ from ._processor import IQLAggregationProcessor, IQLFiltersProcessor, IQLProcessor, RootT if TYPE_CHECKING: - from dbally.context.context import BaseCallerContext + from dbally.context import Context from dbally.views.exposed_functions import ExposedFunction @@ -33,7 +33,7 @@ async def parse( cls, source: str, allowed_functions: List["ExposedFunction"], - allowed_contexts: Optional[List["BaseCallerContext"]] = None, + allowed_contexts: Optional[List["Context"]] = None, event_tracker: Optional[EventTracker] = None, ) -> Self: """ diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 5b4087b7..c6700ef3 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -3,7 +3,7 @@ from typing import Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql import IQLError, IQLQuery from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.prompt import ( @@ -67,7 +67,7 @@ async def __call__( question: str, filters: List[ExposedFunction], aggregations: List[ExposedFunction], - contexts: List[BaseCallerContext], + contexts: List[Context], examples: List[FewShotExample], llm: LLM, event_tracker: Optional[EventTracker] = None, @@ -146,7 +146,7 @@ async def __call__( *, question: str, methods: List[ExposedFunction], - contexts: List[BaseCallerContext], + contexts: List[Context], examples: List[FewShotExample], llm: LLM, event_tracker: Optional[EventTracker] = None, @@ -265,7 +265,7 @@ async def __call__( *, question: str, methods: List[ExposedFunction], - contexts: List[BaseCallerContext], + contexts: List[Context], examples: List[FewShotExample], llm: LLM, llm_options: Optional[LLMOptions] = None, diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 8189d6fd..64c99d8a 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -3,7 +3,7 @@ from typing import List, Optional from dbally.audit.event_tracker import EventTracker -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.exceptions import DbAllyError from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.prompt.elements import FewShotExample @@ -21,7 +21,7 @@ class UnsupportedQueryError(DbAllyError): async def _iql_filters_parser( response: str, allowed_functions: List[ExposedFunction], - allowed_contexts: List[BaseCallerContext], + allowed_contexts: List[Context], event_tracker: Optional[EventTracker] = None, ) -> IQLFiltersQuery: """ @@ -53,7 +53,7 @@ async def _iql_filters_parser( async def _iql_aggregation_parser( response: str, allowed_functions: List[ExposedFunction], - allowed_contexts: List[BaseCallerContext], + allowed_contexts: List[Context], event_tracker: Optional[EventTracker] = None, ) -> IQLAggregationQuery: """ @@ -127,7 +127,7 @@ def __init__( *, question: str, methods: List[ExposedFunction], - contexts: List[BaseCallerContext], + contexts: List[Context], examples: Optional[List[FewShotExample]] = None, ) -> None: """ diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 31e8f363..d85edfa3 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -5,7 +5,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.elements import FewShotExample @@ -25,7 +25,7 @@ async def ask( self, query: str, llm: LLM, - contexts: Optional[List[BaseCallerContext]] = None, + contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, @@ -38,7 +38,7 @@ async def ask( query: The natural language query to execute. llm: The LLM used to execute the query. contexts: An iterable (typically a list) of context objects, each being - an instance of a subclass of BaseCallerContext. + an instance of a subclass of Context. event_tracker: The event tracker used to audit the query execution. n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index d769c12c..bb03fe68 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -4,7 +4,7 @@ from typing_extensions import _AnnotatedAlias, get_origin -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.similarity import AbstractSimilarityIndex @@ -21,11 +21,11 @@ def __str__(self) -> str: return f"{self.name}: {self._parse_type()}" @property - def contexts(self) -> List[Type[BaseCallerContext]]: + def contexts(self) -> List[Type[Context]]: """ Returns the contexts if the type is annotated with them. """ - return [arg for arg in getattr(self.type, "__args__", []) if issubclass(arg, BaseCallerContext)] + return [arg for arg in getattr(self.type, "__args__", []) if issubclass(arg, Context)] @property def similarity_index(self) -> Optional[AbstractSimilarityIndex]: @@ -51,7 +51,7 @@ def _parse_type_inner(param_type: Union[type, _GenericAlias]) -> str: if param_type.__module__ == "typing": return re.sub(r"\btyping\.", "", str(param_type)) - if issubclass(param_type, BaseCallerContext): + if issubclass(param_type, Context): return param_type.type_name if hasattr(param_type, "__name__"): diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 147573f5..c53007c0 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -8,7 +8,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.template import PromptTemplate @@ -100,7 +100,7 @@ async def ask( self, query: str, llm: LLM, - contexts: Optional[List[BaseCallerContext]] = None, + contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index b9577bb0..2e477216 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,7 +4,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM @@ -35,7 +35,7 @@ async def ask( self, query: str, llm: LLM, - contexts: Optional[List[BaseCallerContext]] = None, + contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index 70b71f2e..c74e5eda 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -5,7 +5,7 @@ import pytest from typing_extensions import Annotated -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql import IQLArgumentParsingError, IQLUnsupportedSyntaxError, syntax from dbally.iql._exceptions import ( IQLArgumentValidationError, @@ -59,7 +59,7 @@ async def test_iql_filter_parser(): async def test_iql_filter_context_parser(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str test_context = TestCustomContext(city="cracow") @@ -109,7 +109,7 @@ class TestCustomContext(BaseCallerContext): async def test_iql_filter_context_not_allowed_error(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str with pytest.raises(IQLContextNotAllowedError) as exc_info: @@ -140,7 +140,7 @@ class TestCustomContext(BaseCallerContext): async def test_iql_filter_context_not_found_error(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str with pytest.raises(IQLContextNotFoundError) as exc_info: @@ -360,7 +360,7 @@ async def test_iql_aggregation_parser(): async def test_iql_aggregation_context_parser(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str test_context = TestCustomContext(city="cracow") @@ -387,7 +387,7 @@ class TestCustomContext(BaseCallerContext): async def test_iql_aggregation_context_not_allowed_error(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str with pytest.raises(IQLContextNotAllowedError) as exc_info: @@ -412,7 +412,7 @@ class TestCustomContext(BaseCallerContext): async def test_iql_aggregation_context_not_found_error(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str with pytest.raises(IQLContextNotFoundError) as exc_info: diff --git a/tests/unit/test_fallback_collection.py b/tests/unit/test_fallback_collection.py index 9cfeb009..4a4dbb86 100644 --- a/tests/unit/test_fallback_collection.py +++ b/tests/unit/test_fallback_collection.py @@ -8,7 +8,7 @@ from dbally.audit import CLIEventHandler, EventTracker, OtelEventHandler from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.collection import Collection, ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms import LLM from dbally.llms.clients import LLMOptions @@ -42,7 +42,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[BaseCallerContext]] = None, + contexts: Optional[Iterable[Context]] = None, ) -> ViewExecutionResult: return ViewExecutionResult( results=[{"mock_result": "fallback_result"}], metadata={"mock_context": "fallback_context"} diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 51c1548e..e870e2dc 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -4,7 +4,7 @@ from typing import List, Literal, Tuple, Union from dbally.collection.results import ViewExecutionResult -from dbally.context import BaseCallerContext +from dbally.context import Context from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.exposed_functions import MethodParamWithTyping @@ -12,7 +12,7 @@ @dataclass -class CallerContext(BaseCallerContext): +class CallerContext(Context): """ Mock class for testing context. """ diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 8924a835..3b621a54 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -6,14 +6,14 @@ import sqlalchemy -from dbally.context import BaseCallerContext +from dbally.context import Context from dbally.iql import IQLAggregationQuery, IQLFiltersQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @dataclass -class SomeTestContext(BaseCallerContext): +class SomeTestContext(Context): age: int