From 21506e6f35de5e40a185356e63d3e359cc0cd40c Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 21 Jun 2024 16:36:45 +0200 Subject: [PATCH 01/53] context logic subpackage; type-hint context extraction --- src/dbally/context/__init__.py | 0 src/dbally/context/context.py | 2 ++ src/dbally/context/utils.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 31 insertions(+) create mode 100644 src/dbally/context/__init__.py create mode 100644 src/dbally/context/context.py create mode 100644 src/dbally/context/utils.py diff --git a/src/dbally/context/__init__.py b/src/dbally/context/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py new file mode 100644 index 00000000..a45fb4a1 --- /dev/null +++ b/src/dbally/context/context.py @@ -0,0 +1,2 @@ +class BaseCallerContext: + pass diff --git a/src/dbally/context/utils.py b/src/dbally/context/utils.py new file mode 100644 index 00000000..4c927a59 --- /dev/null +++ b/src/dbally/context/utils.py @@ -0,0 +1,29 @@ +import typing + +from .context import BaseCallerContext + + +def _extract_filter_context(filter_: typing.Callable) -> typing.Optional[typing.Type[BaseCallerContext]]: + """ + Extracts a SINGLE caller's context given a StructuredView filter. + + Args: + filter_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator) + + Returns: + A class inheriting from BaseCallerContext. If not context is given among type hints, value None. + """ + + for type_ in typing.get_type_hints(filter_).values(): + if not isinstance(type_, typing.Union) or not hasattr(type_, '__args__'): + # in the 1st condition here we assume that no user you cannot make filter value context-dependent without providing a base type for it + # the 2nd condition catches situations when user provides a type hint like typing.Union or typing.Union[int] (single part-clas); in both those cases it should be typing.Union to throw an error, although Python for now accepts this syntax,,, + continue + + for union_part_type in type_.__args__: # pyright: ignore + # TODO add custom error for the situation when user provides more than two contexts for a single filter + # for now we extract only the first context + if issubclass(union_part_type, BaseCallerContext): + return union_part_type + + return None From a87e8e23e1d28ca1fbf6faa061555bab15666f80 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 24 Jun 2024 19:03:02 +0200 Subject: [PATCH 02/53] reworked type hint info extraction; extended functionality to also return List[MethodParamWithTyping] required to construct ExposedFunction dataclass --- src/dbally/context/utils.py | 58 +++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/src/dbally/context/utils.py b/src/dbally/context/utils.py index 4c927a59..0fbf56b2 100644 --- a/src/dbally/context/utils.py +++ b/src/dbally/context/utils.py @@ -1,29 +1,51 @@ import typing -from .context import BaseCallerContext +from inspect import isclass +from dbally.context.context import BaseCallerContext +from dbally.views.exposed_functions import MethodParamWithTyping -def _extract_filter_context(filter_: typing.Callable) -> typing.Optional[typing.Type[BaseCallerContext]]: + +def _extract_params_and_context( + filter_method_: typing.Callable, hidden_args: typing.List[str] +) -> typing.Tuple[ + typing.List[MethodParamWithTyping], + typing.Optional[typing.Type[BaseCallerContext]] +]: """ - Extracts a SINGLE caller's context given a StructuredView filter. + Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format. Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext class is returned. Args: - filter_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator) + filter_method_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator) Returns: - A class inheriting from BaseCallerContext. If not context is given among type hints, value None. + A tuple. The first field contains the list of arguments, each encapsulated as MethodParamWithTyping. The 2nd is the BaseCallerContext subclass provided for this filter, or None if no context specified. """ - for type_ in typing.get_type_hints(filter_).values(): - if not isinstance(type_, typing.Union) or not hasattr(type_, '__args__'): - # in the 1st condition here we assume that no user you cannot make filter value context-dependent without providing a base type for it - # the 2nd condition catches situations when user provides a type hint like typing.Union or typing.Union[int] (single part-clas); in both those cases it should be typing.Union to throw an error, although Python for now accepts this syntax,,, - continue - - for union_part_type in type_.__args__: # pyright: ignore - # TODO add custom error for the situation when user provides more than two contexts for a single filter - # for now we extract only the first context - if issubclass(union_part_type, BaseCallerContext): - return union_part_type - - return None + params = [] + context = None + # TODO confirm whether to use typing.get_type_hints(method) or method.__annotations__ + for name_, type_ in typing.get_type_hints(filter_method_).items(): + if isclass(type_) and issubclass(type_, BaseCallerContext): + # this is the case when user provides a context but no other type hint for a specifc arg + # TODO confirm whether this case should be supported + context_ = type_ + type_ = typing.Any + elif typing.get_origin(type_) is typing.Union: + union_subtypes: typing.List[typing.Type] = [] + for subtype_ in typing.get_args(type_): # type: ignore + # TODO add custom error for the situation when user provides more than two contexts for a single filter + # for now we extract only the first context + if isclass(subtype_) and issubclass(subtype_, BaseCallerContext): + if context is None: + context = subtype_ + else: + union_subtypes.append(subtype_) + if union_subtypes: + type_ = typing.Union[tuple(union_subtypes)] # type: ignore + else: + type_ = typing.Any # this ELSE handles the situation when the user provided an typing.Union bare type hint, without specyfing any args. In that case, typing.get_args() returns an empty tuple. Unfortunately, Python does not treat it as an error... + + params.append(MethodParamWithTyping(name_, type_)) + + return params, context From 3ad4ecdf3b8a60211d0230c86c07ecef3d79d40f Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 24 Jun 2024 19:05:26 +0200 Subject: [PATCH 03/53] hidden args handling enabled --- src/dbally/context/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dbally/context/utils.py b/src/dbally/context/utils.py index 0fbf56b2..333c813d 100644 --- a/src/dbally/context/utils.py +++ b/src/dbally/context/utils.py @@ -26,6 +26,9 @@ def _extract_params_and_context( context = None # TODO confirm whether to use typing.get_type_hints(method) or method.__annotations__ for name_, type_ in typing.get_type_hints(filter_method_).items(): + if name_ in hidden_args: + continue + if isclass(type_) and issubclass(type_, BaseCallerContext): # this is the case when user provides a context but no other type hint for a specifc arg # TODO confirm whether this case should be supported From b0cc0ae3425c42e53e65266f2d7c02a53e27e259 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 28 Jun 2024 13:16:20 +0100 Subject: [PATCH 04/53] improved type hints parsing and compatibility using package --- src/dbally/context/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/dbally/context/utils.py b/src/dbally/context/utils.py index 333c813d..94f600f1 100644 --- a/src/dbally/context/utils.py +++ b/src/dbally/context/utils.py @@ -1,4 +1,5 @@ import typing +import typing_extensions as type_ext from inspect import isclass @@ -34,9 +35,9 @@ def _extract_params_and_context( # TODO confirm whether this case should be supported context_ = type_ type_ = typing.Any - elif typing.get_origin(type_) is typing.Union: + elif type_ext.get_origin(type_) is typing.Union: union_subtypes: typing.List[typing.Type] = [] - for subtype_ in typing.get_args(type_): # type: ignore + for subtype_ in type_ext.get_args(type_): # type: ignore # TODO add custom error for the situation when user provides more than two contexts for a single filter # for now we extract only the first context if isclass(subtype_) and issubclass(subtype_, BaseCallerContext): From 4ff5f62efddfd906c3efc03de3965564400c87fe Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 28 Jun 2024 16:31:12 +0100 Subject: [PATCH 05/53] dedicated exceptions for contex-related operations --- src/dbally/context/exceptions.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 src/dbally/context/exceptions.py diff --git a/src/dbally/context/exceptions.py b/src/dbally/context/exceptions.py new file mode 100644 index 00000000..91482d5d --- /dev/null +++ b/src/dbally/context/exceptions.py @@ -0,0 +1,21 @@ +from abc import ABC + + +class BaseContextException(Exception, ABC): + """ + A base exception for all specification context-related exception. + """ + pass + + +class ContextNotAvailableError(Exception): + pass + + +class ContextualisationNotAllowed(Exception): + pass + + +# WORKAROUND - traditional inhertiance syntax is not working in context of abstract Exceptions +BaseContextException.register(ContextNotAvailableError) +BaseContextException.register(ContextualisationNotAllowed) From c479c5065b8beb716c5be4a110df2ed431e1638c Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 28 Jun 2024 16:34:35 +0100 Subject: [PATCH 06/53] useful classmethods for context-related operations --- src/dbally/context/__init__.py | 1 + src/dbally/context/context.py | 27 ++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/dbally/context/__init__.py b/src/dbally/context/__init__.py index e69de29b..46752cf9 100644 --- a/src/dbally/context/__init__.py +++ b/src/dbally/context/__init__.py @@ -0,0 +1 @@ +from .context import BaseCallerContext diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index a45fb4a1..9f7114ad 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -1,2 +1,27 @@ +import ast + +from typing import List, Optional +from typing_extensions import Self +from dataclasses import dataclass + +from dbally.context.exceptions import ContextNotAvailableError + + +@dataclass class BaseCallerContext: - pass + """ + Base class for contexts that are used to pass additional knowledge about the caller environment to the filters. It is not made abstract for the convinience of IQL parsing. + LLM will always return `BaseCallerContext()` when the context is required and this call will be later substitue by a proper subclass instance selected based on the filter method signature (type hints). + """ + + @classmethod + def select_context(cls, contexts: List[Self]) -> Self: + if not contexts: + raise ContextNotAvailableError("The LLM detected that the context is required to execute the query and the filter signature allows contextualization while the context was not provided.") + + # we assume here that each element of `contexts` represents a different subclass of BaseCallerContext + return next(filter(lambda obj: isinstance(obj, cls), contexts)) + + @classmethod + def is_context_call(cls, node: ast.expr) -> bool: + return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == cls.__name__ From e3bb12796bb32607626ee1a9e11e2b876b891791 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 28 Jun 2024 16:44:02 +0100 Subject: [PATCH 07/53] make whole context utils module protected; added IQL parsing helper; bugfix; fixed linting issues --- src/dbally/context/{utils.py => _utils.py} | 30 +++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) rename src/dbally/context/{utils.py => _utils.py} (61%) diff --git a/src/dbally/context/utils.py b/src/dbally/context/_utils.py similarity index 61% rename from src/dbally/context/utils.py rename to src/dbally/context/_utils.py index 94f600f1..f2149258 100644 --- a/src/dbally/context/utils.py +++ b/src/dbally/context/_utils.py @@ -14,13 +14,16 @@ def _extract_params_and_context( typing.Optional[typing.Type[BaseCallerContext]] ]: """ - Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format. Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext class is returned. + Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format. + Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext + class is returned. Args: filter_method_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator) Returns: - A tuple. The first field contains the list of arguments, each encapsulated as MethodParamWithTyping. The 2nd is the BaseCallerContext subclass provided for this filter, or None if no context specified. + A tuple. The first field contains the list of arguments, each encapsulated as MethodParamWithTyping. + The 2nd is the BaseCallerContext subclass provided for this filter, or None if no context specified. """ params = [] @@ -30,14 +33,16 @@ def _extract_params_and_context( if name_ in hidden_args: continue + # TODO make ExposedFunction preserve information whether the context was available for a certain argument + if isclass(type_) and issubclass(type_, BaseCallerContext): # this is the case when user provides a context but no other type hint for a specifc arg # TODO confirm whether this case should be supported - context_ = type_ + context = type_ type_ = typing.Any elif type_ext.get_origin(type_) is typing.Union: union_subtypes: typing.List[typing.Type] = [] - for subtype_ in type_ext.get_args(type_): # type: ignore + for subtype_ in type_ext.get_args(type_): # type: ignore # TODO add custom error for the situation when user provides more than two contexts for a single filter # for now we extract only the first context if isclass(subtype_) and issubclass(subtype_, BaseCallerContext): @@ -46,10 +51,23 @@ def _extract_params_and_context( else: union_subtypes.append(subtype_) if union_subtypes: - type_ = typing.Union[tuple(union_subtypes)] # type: ignore + type_ = typing.Union[tuple(union_subtypes)] # type: ignore else: - type_ = typing.Any # this ELSE handles the situation when the user provided an typing.Union bare type hint, without specyfing any args. In that case, typing.get_args() returns an empty tuple. Unfortunately, Python does not treat it as an error... + type_ = typing.Any # this ELSE handles the situation when the user provided an typing.Union bare + # type hint without specyfing any args. In that case, typing.get_args() + # returns an empty tuple. Unfortunately, Python does not treat it as an error... params.append(MethodParamWithTyping(name_, type_)) return params, context + + +def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool: + if not isinstance(arg.type, BaseCallerContext) and type_ext.get_origin(arg.type) is not typing.Union: + return False + + for subtype in type_ext.get_args(arg.type): + if issubclass(subtype, BaseCallerContext): + return True + + return False From de72c7c6a6c9e9887d35bfc0916a45a64856dc9b Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 28 Jun 2024 17:03:40 +0100 Subject: [PATCH 08/53] parsing type hints _extract_params_and_context() no longer excludes BaseCallerContext subclasses from the Union[] args; linting fix --- src/dbally/context/__init__.py | 2 ++ src/dbally/context/_utils.py | 15 +++++---------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/dbally/context/__init__.py b/src/dbally/context/__init__.py index 46752cf9..294fbe2f 100644 --- a/src/dbally/context/__init__.py +++ b/src/dbally/context/__init__.py @@ -1 +1,3 @@ from .context import BaseCallerContext + +__all__ = ["BaseCallerContext"] diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py index f2149258..83777f21 100644 --- a/src/dbally/context/_utils.py +++ b/src/dbally/context/_utils.py @@ -41,21 +41,16 @@ class is returned. context = type_ type_ = typing.Any elif type_ext.get_origin(type_) is typing.Union: - union_subtypes: typing.List[typing.Type] = [] - for subtype_ in type_ext.get_args(type_): # type: ignore + union_subtypes = type_ext.get_args(type_) + if not union_subtypes: + type_ = typing.Any + + for subtype_ in union_subtypes: # type: ignore # TODO add custom error for the situation when user provides more than two contexts for a single filter # for now we extract only the first context if isclass(subtype_) and issubclass(subtype_, BaseCallerContext): if context is None: context = subtype_ - else: - union_subtypes.append(subtype_) - if union_subtypes: - type_ = typing.Union[tuple(union_subtypes)] # type: ignore - else: - type_ = typing.Any # this ELSE handles the situation when the user provided an typing.Union bare - # type hint without specyfing any args. In that case, typing.get_args() - # returns an empty tuple. Unfortunately, Python does not treat it as an error... params.append(MethodParamWithTyping(name_, type_)) From d3958c0243ee72fe6a6916449644c2ea7f58fb56 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 28 Jun 2024 17:07:31 +0100 Subject: [PATCH 09/53] adjusted the existing code to be aware of contexts (promts yet untouched) --- src/dbally/collection/collection.py | 3 ++ src/dbally/iql/_processor.py | 58 ++++++++++++++++++++++----- src/dbally/iql/_query.py | 9 ++++- src/dbally/views/base.py | 2 + src/dbally/views/exposed_functions.py | 4 +- src/dbally/views/methods_base.py | 9 +++-- src/dbally/views/structured.py | 5 ++- 7 files changed, 73 insertions(+), 17 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index c207d95b..b5c2940d 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -16,6 +16,7 @@ from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation +from dbally.context.context import BaseCallerContext class Collection: @@ -156,6 +157,7 @@ async def ask( dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, + context: Optional[List[BaseCallerContext]] = None ) -> ExecutionResult: """ Ask question in a text form and retrieve the answer based on the available views. @@ -215,6 +217,7 @@ async def ask( n_retries=self.n_retries, dry_run=dry_run, llm_options=llm_options, + context=context ) end_time_view = time.monotonic() diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 8127ddfe..81865bc9 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,5 +1,7 @@ import ast -from typing import TYPE_CHECKING, Any, List, Optional, Union + +from typing import TYPE_CHECKING, Any, List, Optional, Union, Mapping, Dict +from typing_extensions import Callable from dbally.audit.event_tracker import EventTracker from dbally.iql import syntax @@ -11,23 +13,35 @@ IQLUnsupportedSyntaxError, ) from dbally.iql._type_validators import validate_arg_type - -if TYPE_CHECKING: - from dbally.views.structured import ExposedFunction +from dbally.context.context import BaseCallerContext +from dbally.context.exceptions import ContextNotAvailableError, ContextualisationNotAllowed +from dbally.context._utils import _extract_params_and_context, _does_arg_allow_context +from dbally.views.exposed_functions import MethodParamWithTyping, ExposedFunction class IQLProcessor: """ Parses IQL string to tree structure. """ + source: str + allowed_functions: Mapping[str, "ExposedFunction"] + contexts: List[BaseCallerContext] + _event_tracker: EventTracker + def __init__( - self, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None + self, + source: str, + allowed_functions: List["ExposedFunction"], + contexts: Optional[List[BaseCallerContext]] = None, + event_tracker: Optional[EventTracker] = None ) -> None: self.source = source self.allowed_functions = {func.name: func for func in allowed_functions} + self.contexts = contexts or [] self._event_tracker = event_tracker or EventTracker() + async def process(self) -> syntax.Node: """ Process IQL string to root IQL.Node. @@ -38,6 +52,7 @@ async def process(self) -> syntax.Node: Raises: IQLError: if parsing fails. """ + # TODO adjust this method to prevent making context class constructor calls lowercase self.source = self._to_lower_except_in_quotes(self.source, ["AND", "OR", "NOT"]) ast_tree = ast.parse(self.source) @@ -75,7 +90,7 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: if not isinstance(func, ast.Name): raise IQLUnsupportedSyntaxError(node, self.source, context="FunctionCall") - if func.id not in self.allowed_functions: + if func.id not in self.allowed_functions: # TODO add context class constructors to self.allowed_functions raise IQLFunctionNotExists(func, self.source) func_def = self.allowed_functions[func.id] @@ -84,8 +99,8 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: if len(func_def.parameters) != len(node.args): raise ValueError(f"The method {func.id} has incorrect number of arguments") - for arg, arg_def in zip(node.args, func_def.parameters): - arg_value = self._parse_arg(arg) + for i, (arg, arg_def) in enumerate(zip(node.args, func_def.parameters)): + arg_value = self._parse_arg(arg, arg_spec=func_def.parameters[i], parent_func_def=func_def) if arg_def.similarity_index: arg_value = await arg_def.similarity_index.similar(arg_value, event_tracker=self._event_tracker) @@ -99,12 +114,37 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: return syntax.FunctionCall(func.id, args) - def _parse_arg(self, arg: ast.expr) -> Any: + def _parse_arg( + self, + arg: ast.expr, + arg_spec: Optional[MethodParamWithTyping] = None, + parent_func_def: Optional[ExposedFunction] = None + ) -> Any: + if isinstance(arg, ast.List): return [self._parse_arg(x) for x in arg.elts] + if BaseCallerContext.is_context_call(arg): + if parent_func_def is None or arg_spec is None: + # not sure whether this line will be ever reached + raise IQLArgumentParsingError(arg, self.source) + + if parent_func_def.context_class is None: + raise ContextualisationNotAllowed("The LLM detected that the context is required to execute the query while the filter signature does not allow it at all.") + + if _does_arg_allow_context(arg_spec): + raise ContextualisationNotAllowed(f"The LLM detected that the context is required to execute the query while the filter signature does allow it for `{arg_spec.name}` argument.") + + context = parent_func_def.context_class.select_context(self.contexts) + + try: + return getattr(context, arg_spec.name) + except AttributeError: + raise ContextNotAvailableError(f"The LLM detected that the context is required to execute the query and the context object was provided but it is missing the `{arg_spec.name}` field.") + if not isinstance(arg, ast.Constant): raise IQLArgumentParsingError(arg, self.source) + return arg.value @staticmethod diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 7ad86490..10236e29 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -3,6 +3,7 @@ from ..audit.event_tracker import EventTracker from . import syntax from ._processor import IQLProcessor +from dbally.context.context import BaseCallerContext if TYPE_CHECKING: from dbally.views.structured import ExposedFunction @@ -20,7 +21,11 @@ def __init__(self, root: syntax.Node): @classmethod async def parse( - cls, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None + cls, + source: str, + allowed_functions: List["ExposedFunction"], + event_tracker: Optional[EventTracker] = None, + context: Optional[List[BaseCallerContext]] = None ) -> "IQLQuery": """ Parse IQL string to IQLQuery object. @@ -32,4 +37,4 @@ async def parse( Returns: IQLQuery object """ - return cls(await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process()) + return cls(await IQLProcessor(source, allowed_functions, context, event_tracker).process()) diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 7b6dfc81..8d9e0bb3 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -6,6 +6,7 @@ from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.similarity import AbstractSimilarityIndex +from dbally.context.context import BaseCallerContext IndexLocation = Tuple[str, str, str] @@ -25,6 +26,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, + context: Optional[List[BaseCallerContext]] = None ) -> ViewExecutionResult: """ Executes the query and returns the result. diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index 55254958..f133a973 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,9 +1,10 @@ import re from dataclasses import dataclass from typing import _GenericAlias # type: ignore -from typing import List, Optional, Union +from typing import List, Optional, Union, Type from dbally.similarity import AbstractSimilarityIndex +from dbally.context.context import BaseCallerContext def parse_param_type(param_type: Union[type, _GenericAlias]) -> str: @@ -57,6 +58,7 @@ class ExposedFunction: name: str description: str parameters: List[MethodParamWithTyping] + context_class: Optional[Type[BaseCallerContext]] = None def __str__(self) -> str: base_str = f"{self.name}({', '.join(str(param) for param in self.parameters)})" diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 8eeedfb0..25baa957 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -7,6 +7,7 @@ from dbally.views import decorators from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping from dbally.views.structured import BaseStructuredView +from dbally.context._utils import _extract_params_and_context class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta): @@ -35,14 +36,14 @@ def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction] hasattr(method, "_methodDecorator") and method._methodDecorator == decorator # pylint: disable=protected-access ): - annotations = method.__annotations__.items() + params, context_class = _extract_params_and_context(method, cls.HIDDEN_ARGUMENTS) + methods.append( ExposedFunction( name=method_name, description=textwrap.dedent(method.__doc__).strip() if method.__doc__ else "", - parameters=[ - MethodParamWithTyping(n, t) for n, t in annotations if n not in cls.HIDDEN_ARGUMENTS - ], + parameters=params, + context_class=context_class ) ) return methods diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c43e4c2b..195d9166 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,11 +4,13 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult +from dbally.context.context import BaseCallerContext from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.views.exposed_functions import ExposedFunction +from dbally.context.context import BaseCallerContext from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation @@ -40,6 +42,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, + context: Optional[List[BaseCallerContext]] = None ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the IQL query from the natural language query\ @@ -68,7 +71,7 @@ async def ask( for _ in range(n_retries): try: - filters = await IQLQuery.parse(iql_filters, filter_list, event_tracker=event_tracker) + filters = await IQLQuery.parse(iql_filters, filter_list, event_tracker=event_tracker, context=context) await self.apply_filters(filters) break except (IQLError, ValueError) as e: From be338bf37bc694d999aeb616739de975c19d8dea Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 2 Jul 2024 14:31:18 +0200 Subject: [PATCH 10/53] adjusted _type_validators.validate_arg_type() to handle typing.Union[] --- src/dbally/iql/_type_validators.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/dbally/iql/_type_validators.py b/src/dbally/iql/_type_validators.py index 7932cff7..848b17ac 100644 --- a/src/dbally/iql/_type_validators.py +++ b/src/dbally/iql/_type_validators.py @@ -1,3 +1,5 @@ +import typing_extensions as type_ext + from dataclasses import dataclass from typing import _GenericAlias # type: ignore from typing import Any, Callable, Dict, Literal, Optional, Type, Union @@ -46,7 +48,7 @@ def _check_bool(required_type: Type[bool], value: Any) -> _ValidationResult: return _ValidationResult(False, reason=f"{repr(value)} is not of type {required_type.__name__}") -TYPE_VALIDATOR: Dict[Any, Callable[[Any, Any], _ValidationResult]] = { +TYPE_VALIDATOR: Dict[Any, Callable[[Type, Any], _ValidationResult]] = { Literal: _check_literal, float: _check_float, int: _check_int, @@ -65,7 +67,17 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) -> Returns: _ValidationResult instance """ - actual_type = required_type.__origin__ if isinstance(required_type, _GenericAlias) else required_type + actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type # typing.Union is an instance of _GenericAlias + if actual_type is None: # workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary + actual_type = required_type.__origin__ + + if actual_type is Union: + for subtype in type_ext.get_args(required_type): + res = validate_arg_type(subtype, value) + if res.valid: + return _ValidationResult(True) + + return _ValidationResult(False, f"{repr(value)} is not of type {repr(required_type)}") # typing.Union does not have __name__ property custom_type_checker = TYPE_VALIDATOR.get(actual_type) From 78f15359e4ff16135b637c4efdb6d3edbcd33b00 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 2 Jul 2024 14:34:45 +0200 Subject: [PATCH 11/53] context._utils._does_arg_allow_context() fix --- src/dbally/context/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py index 83777f21..5d9739c8 100644 --- a/src/dbally/context/_utils.py +++ b/src/dbally/context/_utils.py @@ -58,7 +58,7 @@ class is returned. def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool: - if not isinstance(arg.type, BaseCallerContext) and type_ext.get_origin(arg.type) is not typing.Union: + if type_ext.get_origin(arg.type) is not typing.Union and not issubclass(arg.type, BaseCallerContext): return False for subtype in type_ext.get_args(arg.type): From 308e2e1ec2bcf26f101aa4bc79964e7af6899cb8 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 2 Jul 2024 14:37:39 +0200 Subject: [PATCH 12/53] context record is now based on pydantic.BaseModel rather than dataclass + type hint improvements: fixes, new generics and aliases --- src/dbally/context/context.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index 9f7114ad..f51fa147 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -1,21 +1,24 @@ import ast -from typing import List, Optional +from typing import List, Optional, Type, TypeVar from typing_extensions import Self -from dataclasses import dataclass +from pydantic import BaseModel from dbally.context.exceptions import ContextNotAvailableError -@dataclass -class BaseCallerContext: +T = TypeVar('T', bound='BaseCallerContext') +AllCallerContexts = Optional[List[T]] # TODO confirm the naming + + +class BaseCallerContext(BaseModel): """ Base class for contexts that are used to pass additional knowledge about the caller environment to the filters. It is not made abstract for the convinience of IQL parsing. LLM will always return `BaseCallerContext()` when the context is required and this call will be later substitue by a proper subclass instance selected based on the filter method signature (type hints). """ @classmethod - def select_context(cls, contexts: List[Self]) -> Self: + def select_context(cls, contexts: List[T]) -> T: if not contexts: raise ContextNotAvailableError("The LLM detected that the context is required to execute the query and the filter signature allows contextualization while the context was not provided.") From 73741d9883d471c808905cf1531b50c4d99f1f37 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 2 Jul 2024 15:33:03 +0200 Subject: [PATCH 13/53] type hint lifting --- src/dbally/collection/collection.py | 4 ++-- src/dbally/context/_utils.py | 19 ++++++++----------- src/dbally/context/context.py | 7 +++---- src/dbally/iql/_processor.py | 21 ++++++++------------- src/dbally/iql/_query.py | 6 +++--- src/dbally/views/base.py | 6 +++--- src/dbally/views/structured.py | 6 +++--- 7 files changed, 30 insertions(+), 39 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index b5c2940d..2256606e 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -16,7 +16,7 @@ from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation -from dbally.context.context import BaseCallerContext +from dbally.context.context import BaseCallerContext, CustomContextsList class Collection: @@ -157,7 +157,7 @@ async def ask( dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, - context: Optional[List[BaseCallerContext]] = None + context: Optional[CustomContextsList] = None ) -> ExecutionResult: """ Ask question in a text form and retrieve the answer based on the available views. diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py index 5d9739c8..76d46359 100644 --- a/src/dbally/context/_utils.py +++ b/src/dbally/context/_utils.py @@ -1,6 +1,6 @@ -import typing import typing_extensions as type_ext +from typing import Sequence, Tuple, Optional, Type, Any, Union from inspect import isclass from dbally.context.context import BaseCallerContext @@ -8,11 +8,8 @@ def _extract_params_and_context( - filter_method_: typing.Callable, hidden_args: typing.List[str] -) -> typing.Tuple[ - typing.List[MethodParamWithTyping], - typing.Optional[typing.Type[BaseCallerContext]] -]: + filter_method_: type_ext.Callable, hidden_args: Sequence[str] +) -> Tuple[Sequence[MethodParamWithTyping], Optional[Type[BaseCallerContext]]]: """ Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format. Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext @@ -29,7 +26,7 @@ class is returned. params = [] context = None # TODO confirm whether to use typing.get_type_hints(method) or method.__annotations__ - for name_, type_ in typing.get_type_hints(filter_method_).items(): + for name_, type_ in type_ext.get_type_hints(filter_method_).items(): if name_ in hidden_args: continue @@ -39,11 +36,11 @@ class is returned. # this is the case when user provides a context but no other type hint for a specifc arg # TODO confirm whether this case should be supported context = type_ - type_ = typing.Any - elif type_ext.get_origin(type_) is typing.Union: + type_ = Any + elif type_ext.get_origin(type_) is Union: union_subtypes = type_ext.get_args(type_) if not union_subtypes: - type_ = typing.Any + type_ = Any for subtype_ in union_subtypes: # type: ignore # TODO add custom error for the situation when user provides more than two contexts for a single filter @@ -58,7 +55,7 @@ class is returned. def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool: - if type_ext.get_origin(arg.type) is not typing.Union and not issubclass(arg.type, BaseCallerContext): + if type_ext.get_origin(arg.type) is not Union and not issubclass(arg.type, BaseCallerContext): return False for subtype in type_ext.get_args(arg.type): diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index f51fa147..3e991aae 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -1,14 +1,13 @@ import ast -from typing import List, Optional, Type, TypeVar +from typing import Optional, Type, Sequence from typing_extensions import Self from pydantic import BaseModel from dbally.context.exceptions import ContextNotAvailableError -T = TypeVar('T', bound='BaseCallerContext') -AllCallerContexts = Optional[List[T]] # TODO confirm the naming +CustomContextsList = Sequence[Type['BaseCallerContext']] # TODO confirm the naming class BaseCallerContext(BaseModel): @@ -18,7 +17,7 @@ class BaseCallerContext(BaseModel): """ @classmethod - def select_context(cls, contexts: List[T]) -> T: + def select_context(cls, contexts: Sequence[Type[Self]]) -> Type[Self]: if not contexts: raise ContextNotAvailableError("The LLM detected that the context is required to execute the query and the filter signature allows contextualization while the context was not provided.") diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 81865bc9..49ab1eb9 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,7 +1,6 @@ import ast -from typing import TYPE_CHECKING, Any, List, Optional, Union, Mapping, Dict -from typing_extensions import Callable +from typing import TYPE_CHECKING, Any, List, Optional, Union, Mapping, Type from dbally.audit.event_tracker import EventTracker from dbally.iql import syntax @@ -13,7 +12,7 @@ IQLUnsupportedSyntaxError, ) from dbally.iql._type_validators import validate_arg_type -from dbally.context.context import BaseCallerContext +from dbally.context.context import BaseCallerContext, CustomContextsList from dbally.context.exceptions import ContextNotAvailableError, ContextualisationNotAllowed from dbally.context._utils import _extract_params_and_context, _does_arg_allow_context from dbally.views.exposed_functions import MethodParamWithTyping, ExposedFunction @@ -25,7 +24,7 @@ class IQLProcessor: """ source: str allowed_functions: Mapping[str, "ExposedFunction"] - contexts: List[BaseCallerContext] + contexts: CustomContextsList _event_tracker: EventTracker @@ -33,7 +32,7 @@ def __init__( self, source: str, allowed_functions: List["ExposedFunction"], - contexts: Optional[List[BaseCallerContext]] = None, + contexts: Optional[CustomContextsList] = None, event_tracker: Optional[EventTracker] = None ) -> None: self.source = source @@ -52,7 +51,7 @@ async def process(self) -> syntax.Node: Raises: IQLError: if parsing fails. """ - # TODO adjust this method to prevent making context class constructor calls lowercase + self.source = self._to_lower_except_in_quotes(self.source, ["AND", "OR", "NOT"]) ast_tree = ast.parse(self.source) @@ -132,15 +131,11 @@ def _parse_arg( if parent_func_def.context_class is None: raise ContextualisationNotAllowed("The LLM detected that the context is required to execute the query while the filter signature does not allow it at all.") - if _does_arg_allow_context(arg_spec): + if not _does_arg_allow_context(arg_spec): + print(arg_spec) raise ContextualisationNotAllowed(f"The LLM detected that the context is required to execute the query while the filter signature does allow it for `{arg_spec.name}` argument.") - context = parent_func_def.context_class.select_context(self.contexts) - - try: - return getattr(context, arg_spec.name) - except AttributeError: - raise ContextNotAvailableError(f"The LLM detected that the context is required to execute the query and the context object was provided but it is missing the `{arg_spec.name}` field.") + return parent_func_def.context_class.select_context(self.contexts) if not isinstance(arg, ast.Constant): raise IQLArgumentParsingError(arg, self.source) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 10236e29..39274f54 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,9 +1,9 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Type from ..audit.event_tracker import EventTracker from . import syntax from ._processor import IQLProcessor -from dbally.context.context import BaseCallerContext +from dbally.context.context import BaseCallerContext, CustomContextsList if TYPE_CHECKING: from dbally.views.structured import ExposedFunction @@ -25,7 +25,7 @@ async def parse( source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None, - context: Optional[List[BaseCallerContext]] = None + context: Optional[CustomContextsList] = None ) -> "IQLQuery": """ Parse IQL string to IQLQuery object. diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 8d9e0bb3..104b8daf 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -1,12 +1,12 @@ import abc -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Type from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.similarity import AbstractSimilarityIndex -from dbally.context.context import BaseCallerContext +from dbally.context.context import BaseCallerContext, CustomContextsList IndexLocation = Tuple[str, str, str] @@ -26,7 +26,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - context: Optional[List[BaseCallerContext]] = None + context: Optional[CustomContextsList] = None ) -> ViewExecutionResult: """ Executes the query and returns the result. diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 195d9166..8a2cd3c1 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -1,6 +1,6 @@ import abc from collections import defaultdict -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Type from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult @@ -10,7 +10,7 @@ from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.views.exposed_functions import ExposedFunction -from dbally.context.context import BaseCallerContext +from dbally.context.context import BaseCallerContext, CustomContextsList from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation @@ -42,7 +42,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - context: Optional[List[BaseCallerContext]] = None + context: Optional[CustomContextsList] = None ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the IQL query from the natural language query\ From 902f5ffffe1b22833f6c37398461b388f45737c4 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 2 Jul 2024 15:36:23 +0200 Subject: [PATCH 14/53] IQL generating LLM prompt passes BaseCallerContext() as filter argument when it detectes the additional context is required --- src/dbally/iql_generator/iql_prompt_template.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py index 2da8abd2..bc419201 100644 --- a/src/dbally/iql_generator/iql_prompt_template.py +++ b/src/dbally/iql_generator/iql_prompt_template.py @@ -60,10 +60,14 @@ def _validate_iql_response(llm_response: str) -> str: "You MUST use only these methods:\n" "\n{filters}\n" "It is VERY IMPORTANT not to use methods other than those listed above." + "If a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." + 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' + "In that case, the part of the output will look like this:" + "filter4(BaseCallerContext())" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", }, {"role": "user", "content": "{question}"}, - ), + ), # type: ignore # TODO fix it llm_response_parser=_validate_iql_response, ) From 6309070e898bb6a65feff80f52eed15936cd7269 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 2 Jul 2024 15:42:16 +0200 Subject: [PATCH 15/53] comments cleanup --- src/dbally/context/_utils.py | 3 --- src/dbally/iql_generator/iql_prompt_template.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py index 76d46359..7fea4e9b 100644 --- a/src/dbally/context/_utils.py +++ b/src/dbally/context/_utils.py @@ -30,11 +30,8 @@ class is returned. if name_ in hidden_args: continue - # TODO make ExposedFunction preserve information whether the context was available for a certain argument - if isclass(type_) and issubclass(type_, BaseCallerContext): # this is the case when user provides a context but no other type hint for a specifc arg - # TODO confirm whether this case should be supported context = type_ type_ = Any elif type_ext.get_origin(type_) is Union: diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py index bc419201..9a4825ba 100644 --- a/src/dbally/iql_generator/iql_prompt_template.py +++ b/src/dbally/iql_generator/iql_prompt_template.py @@ -68,6 +68,6 @@ def _validate_iql_response(llm_response: str) -> str: "This is CRUCIAL, otherwise the system will crash. ", }, {"role": "user", "content": "{question}"}, - ), # type: ignore # TODO fix it + ), llm_response_parser=_validate_iql_response, ) From d523bf771e4899974326547554919e48475c51d9 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Wed, 3 Jul 2024 10:55:59 +0200 Subject: [PATCH 16/53] type hint fixes --- src/dbally/context/context.py | 9 +++++---- src/dbally/iql/_processor.py | 1 - src/dbally/views/exposed_functions.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index 3e991aae..d3af935d 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -1,13 +1,14 @@ import ast -from typing import Optional, Type, Sequence +from typing import Sequence, TypeVar from typing_extensions import Self from pydantic import BaseModel from dbally.context.exceptions import ContextNotAvailableError -CustomContextsList = Sequence[Type['BaseCallerContext']] # TODO confirm the naming +CustomContext = TypeVar('CustomContext', bound='BaseCallerContext', covariant=True) +CustomContextsList = Sequence[CustomContext] # TODO confirm the naming class BaseCallerContext(BaseModel): @@ -17,11 +18,11 @@ class BaseCallerContext(BaseModel): """ @classmethod - def select_context(cls, contexts: Sequence[Type[Self]]) -> Type[Self]: + def select_context(cls, contexts: CustomContextsList) -> Self: if not contexts: raise ContextNotAvailableError("The LLM detected that the context is required to execute the query and the filter signature allows contextualization while the context was not provided.") - # we assume here that each element of `contexts` represents a different subclass of BaseCallerContext + # this method is called from the subclass of BaseCallerContext pointing the right type of custom context return next(filter(lambda obj: isinstance(obj, cls), contexts)) @classmethod diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 49ab1eb9..b6a1648a 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -132,7 +132,6 @@ def _parse_arg( raise ContextualisationNotAllowed("The LLM detected that the context is required to execute the query while the filter signature does not allow it at all.") if not _does_arg_allow_context(arg_spec): - print(arg_spec) raise ContextualisationNotAllowed(f"The LLM detected that the context is required to execute the query while the filter signature does allow it for `{arg_spec.name}` argument.") return parent_func_def.context_class.select_context(self.contexts) diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index f133a973..c6d400d2 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass from typing import _GenericAlias # type: ignore -from typing import List, Optional, Union, Type +from typing import Sequence, Optional, Union, Type from dbally.similarity import AbstractSimilarityIndex from dbally.context.context import BaseCallerContext @@ -57,7 +57,7 @@ class ExposedFunction: name: str description: str - parameters: List[MethodParamWithTyping] + parameters: Sequence[MethodParamWithTyping] context_class: Optional[Type[BaseCallerContext]] = None def __str__(self) -> str: From 9ba89e5e43e365cf8fcc105d4eedbff095295d8a Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Wed, 3 Jul 2024 14:32:24 +0200 Subject: [PATCH 17/53] post-merge fixes + minor refactor --- src/dbally/collection/collection.py | 4 +- src/dbally/iql/_query.py | 4 +- src/dbally/iql_generator/iql_generator.py | 8 +- .../iql_generator/iql_prompt_template.py | 73 ------------------- src/dbally/iql_generator/prompt.py | 5 ++ src/dbally/views/base.py | 2 +- src/dbally/views/structured.py | 3 +- 7 files changed, 18 insertions(+), 81 deletions(-) delete mode 100644 src/dbally/iql_generator/iql_prompt_template.py diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 5e175cb5..45446cc5 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -157,7 +157,7 @@ async def ask( dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, - context: Optional[CustomContextsList] = None + contexts: Optional[CustomContextsList] = None ) -> ExecutionResult: """ Ask question in a text form and retrieve the answer based on the available views. @@ -217,7 +217,7 @@ async def ask( n_retries=self.n_retries, dry_run=dry_run, llm_options=llm_options, - context=context + contexts=contexts ) end_time_view = time.monotonic() diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 2099d39c..6a610c24 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -30,7 +30,7 @@ async def parse( source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None, - context: Optional[CustomContextsList] = None + contexts: Optional[CustomContextsList] = None ) -> Self: """ Parse IQL string to IQLQuery object. @@ -43,5 +43,5 @@ async def parse( IQLQuery object """ - root = await IQLProcessor(source, allowed_functions, context, event_tracker).process() + root = await IQLProcessor(source, allowed_functions, contexts, event_tracker).process() return cls(root=root, source=source) diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 7eeb9154..1946c258 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -8,6 +8,7 @@ from dbally.prompt.elements import FewShotExample from dbally.prompt.template import PromptTemplate from dbally.views.exposed_functions import ExposedFunction +from dbally.context.context import CustomContextsList ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" @@ -42,7 +43,8 @@ async def generate_iql( examples: Optional[List[FewShotExample]] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - ) -> IQLQuery: + contexts: Optional[CustomContextsList] = None + ) -> Optional[IQLQuery]: """ Generates IQL in text form using LLM. @@ -60,7 +62,7 @@ async def generate_iql( prompt_format = IQLGenerationPromptFormat( question=question, filters=filters, - examples=examples, + examples=examples or [], ) formatted_prompt = self._prompt_template.format_prompt(prompt_format) @@ -78,7 +80,9 @@ async def generate_iql( source=iql, allowed_functions=filters, event_tracker=event_tracker, + contexts=contexts ) except IQLError as exc: + # TODO handle the possibility of variable `response` being not initialized while runnning the following line formatted_prompt = formatted_prompt.add_assistant_message(response) formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py deleted file mode 100644 index 9a4825ba..00000000 --- a/src/dbally/iql_generator/iql_prompt_template.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Callable, Dict, Optional - -from dbally.exceptions import DbAllyError -from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables - - -class UnsupportedQueryError(DbAllyError): - """ - Error raised when IQL generator is unable to construct a query - with given filters. - """ - - -class IQLPromptTemplate(PromptTemplate): - """ - Class for prompt templates meant for the IQL - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ): - super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"filters", "question"}) - - -def _validate_iql_response(llm_response: str) -> str: - """ - Validates LLM response to IQL - - Args: - llm_response: LLM response - - Returns: - A string containing IQL for filters. - - Raises: - UnsuppotedQueryError: When IQL generator is unable to construct a query - with given filters. - """ - - if "unsupported query" in llm_response.lower(): - raise UnsupportedQueryError - return llm_response - - -default_iql_template = IQLPromptTemplate( - chat=( - { - "role": "system", - "content": "You have access to API that lets you query a database:\n" - "\n{filters}\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" - "Remember! Don't give any comments, just the function calls.\n" - "The output will look like this:\n" - 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' - "DO NOT INCLUDE arguments names in your response. Only the values.\n" - "You MUST use only these methods:\n" - "\n{filters}\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - "If a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." - 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' - "In that case, the part of the output will look like this:" - "filter4(BaseCallerContext())" - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ - "This is CRUCIAL, otherwise the system will crash. ", - }, - {"role": "user", "content": "{question}"}, - ), - llm_response_parser=_validate_iql_response, -) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 44bb2cd4..0cff10ea 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -74,6 +74,11 @@ def __init__( "You MUST use only these methods:\n" "\n{filters}\n" "It is VERY IMPORTANT not to use methods other than those listed above." + "If a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." + 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' + "In that case, the part of the output will look like this:" + "filter4(BaseCallerContext())" + "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. " ), diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 365e83d6..43b69dbd 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -27,7 +27,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - context: Optional[CustomContextsList] = None + contexts: Optional[CustomContextsList] = None ) -> ViewExecutionResult: """ Executes the query and returns the result. diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 5e9c4e1a..eda1a48c 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -42,7 +42,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - context: Optional[CustomContextsList] = None + contexts: Optional[CustomContextsList] = None ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the IQL query from the natural language query\ @@ -71,6 +71,7 @@ async def ask( event_tracker=event_tracker, llm_options=llm_options, n_retries=n_retries, + contexts=contexts ) await self.apply_filters(iql) From 5fd802f1dd18db60b71bb5a6d7f107bfbe191196 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Thu, 4 Jul 2024 13:21:45 +0200 Subject: [PATCH 18/53] added missing docstrings; fixed type hints; fixed issues detected by pylint; run pre-commit auto refactor --- src/dbally/collection/collection.py | 10 ++-- src/dbally/context/_utils.py | 23 +++++++-- src/dbally/context/context.py | 51 +++++++++++++++----- src/dbally/context/exceptions.py | 13 ++++-- src/dbally/iql/_processor.py | 54 +++++++++++++++------- src/dbally/iql/_query.py | 10 ++-- src/dbally/iql/_type_validators.py | 13 ++++-- src/dbally/iql_generator/iql_generator.py | 20 ++++---- src/dbally/views/base.py | 12 +++-- src/dbally/views/exposed_functions.py | 4 +- src/dbally/views/freeform/text2sql/view.py | 3 ++ src/dbally/views/methods_base.py | 6 +-- src/dbally/views/structured.py | 11 +++-- 13 files changed, 158 insertions(+), 72 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 45446cc5..6e10aafa 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -3,20 +3,20 @@ import textwrap import time from collections import defaultdict -from typing import Callable, Dict, List, Optional, Type, TypeVar +from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult +from dbally.context.context import CustomContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.base import BaseView, IndexLocation -from dbally.context.context import BaseCallerContext, CustomContextsList class Collection: @@ -157,7 +157,7 @@ async def ask( dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[CustomContextsList] = None + contexts: Optional[Iterable[CustomContext]] = None, ) -> ExecutionResult: """ Ask question in a text form and retrieve the answer based on the available views. @@ -177,6 +177,8 @@ async def ask( the natural response will be included in the answer llm_options: options to use for the LLM client. If provided, these options will be merged with the default options provided to the LLM client, prioritizing option values other than NOT_GIVEN + contexts: An iterable (typically a list) of context objects, each being an instance of + a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query. Returns: ExecutionResult object representing the result of the query execution. @@ -217,7 +219,7 @@ async def ask( n_retries=self.n_retries, dry_run=dry_run, llm_options=llm_options, - contexts=contexts + contexts=contexts, ) end_time_view = time.monotonic() diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py index 7fea4e9b..113b0ca2 100644 --- a/src/dbally/context/_utils.py +++ b/src/dbally/context/_utils.py @@ -1,15 +1,17 @@ -import typing_extensions as type_ext - -from typing import Sequence, Tuple, Optional, Type, Any, Union from inspect import isclass +from typing import Any, Optional, Sequence, Tuple, Type, Union + +import typing_extensions as type_ext from dbally.context.context import BaseCallerContext from dbally.views.exposed_functions import MethodParamWithTyping +ContextClass: type_ext.TypeAlias = Optional[Type[BaseCallerContext]] + def _extract_params_and_context( filter_method_: type_ext.Callable, hidden_args: Sequence[str] -) -> Tuple[Sequence[MethodParamWithTyping], Optional[Type[BaseCallerContext]]]: +) -> Tuple[Sequence[MethodParamWithTyping], ContextClass]: """ Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format. Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext @@ -17,9 +19,10 @@ class is returned. Args: filter_method_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator) + hidden_args: method arguments that should not be extracted Returns: - A tuple. The first field contains the list of arguments, each encapsulated as MethodParamWithTyping. + The first field contains the list of arguments, each encapsulated as MethodParamWithTyping. The 2nd is the BaseCallerContext subclass provided for this filter, or None if no context specified. """ @@ -52,6 +55,16 @@ class is returned. def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool: + """ + Verifies whether a method argument allows contextualization based on the type hints attached to a method signature. + + Args: + arg: MethodParamWithTyping container preserving information about the method argument + + Returns: + Verification result. + """ + if type_ext.get_origin(arg.type) is not Union and not issubclass(arg.type, BaseCallerContext): return False diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index d3af935d..baf86e93 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -1,30 +1,59 @@ import ast +from typing import Iterable -from typing import Sequence, TypeVar -from typing_extensions import Self from pydantic import BaseModel +from typing_extensions import Self, TypeAlias from dbally.context.exceptions import ContextNotAvailableError - -CustomContext = TypeVar('CustomContext', bound='BaseCallerContext', covariant=True) -CustomContextsList = Sequence[CustomContext] # TODO confirm the naming +# CustomContext = TypeVar('CustomContext', bound='BaseCallerContext', covariant=True) +CustomContext: TypeAlias = "BaseCallerContext" class BaseCallerContext(BaseModel): """ - Base class for contexts that are used to pass additional knowledge about the caller environment to the filters. It is not made abstract for the convinience of IQL parsing. - LLM will always return `BaseCallerContext()` when the context is required and this call will be later substitue by a proper subclass instance selected based on the filter method signature (type hints). + Pydantic-based record class. Base class for contexts that are used to pass additional knowledge about + the caller environment to the filters. It is not made abstract for the convinience of IQL parsing. + LLM will always return `BaseCallerContext()` when the context is required and this call will be + later substituted by a proper subclass instance selected based on the filter method signature (type hints). """ @classmethod - def select_context(cls, contexts: CustomContextsList) -> Self: + def select_context(cls, contexts: Iterable[CustomContext]) -> Self: + """ + Typically called from a subclass of BaseCallerContext, selects a member object from `contexts` being + an instance of the same class. Effectively provides a type dispatch mechanism, substituting the context + class by its right instance. + + Args: + contexts: A sequence of objects, each being an instance of a different BaseCallerContext subclass. + + Returns: + An instance of the same BaseCallerContext subclass this method is caller from. + + Raises: + ContextNotAvailableError: If the sequence of context objects passed as argument is empty. + """ + if not contexts: - raise ContextNotAvailableError("The LLM detected that the context is required to execute the query and the filter signature allows contextualization while the context was not provided.") + raise ContextNotAvailableError( + "The LLM detected that the context is required to execute the query +\ + and the filter signature allows contextualization while the context was not provided." + ) - # this method is called from the subclass of BaseCallerContext pointing the right type of custom context - return next(filter(lambda obj: isinstance(obj, cls), contexts)) + # TODO confirm whether it is possible to design a correct type hints here and skipping `type: ignore` + return next(filter(lambda obj: isinstance(obj, cls), contexts)) # type: ignore @classmethod def is_context_call(cls, node: ast.expr) -> bool: + """ + Verifies whether an AST node indicates context substitution. + + Args: + node: An AST node (expression) to verify: + + Returns: + Verification result. + """ + return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == cls.__name__ diff --git a/src/dbally/context/exceptions.py b/src/dbally/context/exceptions.py index 91482d5d..0efa1473 100644 --- a/src/dbally/context/exceptions.py +++ b/src/dbally/context/exceptions.py @@ -3,17 +3,22 @@ class BaseContextException(Exception, ABC): """ - A base exception for all specification context-related exception. + A base (abstract) exception for all specification context-related exception. """ - pass class ContextNotAvailableError(Exception): - pass + """ + An exception inheriting from BaseContextException pointining that no sufficient context information + was provided by the user while calling view.ask(). + """ class ContextualisationNotAllowed(Exception): - pass + """ + An exception inheriting from BaseContextException pointining that the filter method signature + does not allow to provide an additional context. + """ # WORKAROUND - traditional inhertiance syntax is not working in context of abstract Exceptions diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index b6a1648a..fb9c57be 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,8 +1,10 @@ import ast - -from typing import TYPE_CHECKING, Any, List, Optional, Union, Mapping, Type +from typing import Any, Iterable, List, Mapping, Optional, Union from dbally.audit.event_tracker import EventTracker +from dbally.context._utils import _does_arg_allow_context +from dbally.context.context import BaseCallerContext, CustomContext +from dbally.context.exceptions import ContextualisationNotAllowed from dbally.iql import syntax from dbally.iql._exceptions import ( IQLArgumentParsingError, @@ -12,35 +14,48 @@ IQLUnsupportedSyntaxError, ) from dbally.iql._type_validators import validate_arg_type -from dbally.context.context import BaseCallerContext, CustomContextsList -from dbally.context.exceptions import ContextNotAvailableError, ContextualisationNotAllowed -from dbally.context._utils import _extract_params_and_context, _does_arg_allow_context -from dbally.views.exposed_functions import MethodParamWithTyping, ExposedFunction +from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping class IQLProcessor: """ Parses IQL string to tree structure. + + Attributes: + source: Raw LLM response containing IQL filter calls. + allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View. + contexts: A sequence (typically a list) of context objects, each being an instance of + a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query. """ + source: str allowed_functions: Mapping[str, "ExposedFunction"] - contexts: CustomContextsList + contexts: Iterable[CustomContext] _event_tracker: EventTracker - def __init__( self, source: str, - allowed_functions: List["ExposedFunction"], - contexts: Optional[CustomContextsList] = None, - event_tracker: Optional[EventTracker] = None + allowed_functions: Iterable[ExposedFunction], + contexts: Optional[Iterable[CustomContext]] = None, + event_tracker: Optional[EventTracker] = None, ) -> None: + """ + IQLProcessor class constructor. + + Args: + source: Raw LLM response containing IQL filter calls. + allowed_functions: An interable (typically a list) of all filters implemented for a certain View. + contexts: An iterable (typically a list) of context objects, each being an instance of + a subclass of BaseCallerContext. + even_tracker: An EvenTracker instance. + """ + self.source = source self.allowed_functions = {func.name: func for func in allowed_functions} self.contexts = contexts or [] self._event_tracker = event_tracker or EventTracker() - async def process(self) -> syntax.Node: """ Process IQL string to root IQL.Node. @@ -89,7 +104,7 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: if not isinstance(func, ast.Name): raise IQLUnsupportedSyntaxError(node, self.source, context="FunctionCall") - if func.id not in self.allowed_functions: # TODO add context class constructors to self.allowed_functions + if func.id not in self.allowed_functions: raise IQLFunctionNotExists(func, self.source) func_def = self.allowed_functions[func.id] @@ -117,9 +132,8 @@ def _parse_arg( self, arg: ast.expr, arg_spec: Optional[MethodParamWithTyping] = None, - parent_func_def: Optional[ExposedFunction] = None + parent_func_def: Optional[ExposedFunction] = None, ) -> Any: - if isinstance(arg, ast.List): return [self._parse_arg(x) for x in arg.elts] @@ -129,10 +143,16 @@ def _parse_arg( raise IQLArgumentParsingError(arg, self.source) if parent_func_def.context_class is None: - raise ContextualisationNotAllowed("The LLM detected that the context is required to execute the query while the filter signature does not allow it at all.") + raise ContextualisationNotAllowed( + "The LLM detected that the context is required +\ + to execute the query while the filter signature does not allow it at all." + ) if not _does_arg_allow_context(arg_spec): - raise ContextualisationNotAllowed(f"The LLM detected that the context is required to execute the query while the filter signature does allow it for `{arg_spec.name}` argument.") + raise ContextualisationNotAllowed( + f"The LLM detected that the context is required +\ + to execute the query while the filter signature does allow it for `{arg_spec.name}` argument." + ) return parent_func_def.context_class.select_context(self.contexts) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 6a610c24..cc090ad6 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,10 +1,12 @@ -from typing import TYPE_CHECKING, List, Optional, Type +from typing import TYPE_CHECKING, Iterable, List, Optional + from typing_extensions import Self +from dbally.context.context import CustomContext + from ..audit.event_tracker import EventTracker from . import syntax from ._processor import IQLProcessor -from dbally.context.context import BaseCallerContext, CustomContextsList if TYPE_CHECKING: from dbally.views.structured import ExposedFunction @@ -30,7 +32,7 @@ async def parse( source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None, - contexts: Optional[CustomContextsList] = None + contexts: Optional[Iterable[CustomContext]] = None, ) -> Self: """ Parse IQL string to IQLQuery object. @@ -39,6 +41,8 @@ async def parse( source: IQL string that needs to be parsed allowed_functions: list of IQL functions that are allowed for this query event_tracker: EventTracker object to track events + contexts: An iterable (typically a list) of context objects, each being + an instance of a subclass of BaseCallerContext. Returns: IQLQuery object """ diff --git a/src/dbally/iql/_type_validators.py b/src/dbally/iql/_type_validators.py index 848b17ac..7b993ef5 100644 --- a/src/dbally/iql/_type_validators.py +++ b/src/dbally/iql/_type_validators.py @@ -1,9 +1,9 @@ -import typing_extensions as type_ext - from dataclasses import dataclass from typing import _GenericAlias # type: ignore from typing import Any, Callable, Dict, Literal, Optional, Type, Union +import typing_extensions as type_ext + @dataclass class _ValidationResult: @@ -67,8 +67,10 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) -> Returns: _ValidationResult instance """ - actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type # typing.Union is an instance of _GenericAlias - if actual_type is None: # workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary + actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type + # typing.Union is an instance of _GenericAlias + if actual_type is None: + # workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary actual_type = required_type.__origin__ if actual_type is Union: @@ -77,7 +79,8 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) -> if res.valid: return _ValidationResult(True) - return _ValidationResult(False, f"{repr(value)} is not of type {repr(required_type)}") # typing.Union does not have __name__ property + # typing.Union does not have __name__ property, thus using repr() is necessary + return _ValidationResult(False, f"{repr(value)} is not of type {repr(required_type)}") custom_type_checker = TYPE_VALIDATOR.get(actual_type) diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 1946c258..8018f6e1 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,6 +1,7 @@ -from typing import List, Optional +from typing import Iterable, List, Optional from dbally.audit.event_tracker import EventTracker +from dbally.context.context import CustomContext from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.llms.base import LLM @@ -8,7 +9,6 @@ from dbally.prompt.elements import FewShotExample from dbally.prompt.template import PromptTemplate from dbally.views.exposed_functions import ExposedFunction -from dbally.context.context import CustomContextsList ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" @@ -43,8 +43,8 @@ async def generate_iql( examples: Optional[List[FewShotExample]] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - contexts: Optional[CustomContextsList] = None - ) -> Optional[IQLQuery]: + contexts: Optional[Iterable[CustomContext]] = None, + ) -> IQLQuery: """ Generates IQL in text form using LLM. @@ -55,6 +55,8 @@ async def generate_iql( examples: List of examples to be injected into the conversation. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors. + contexts: An iterable (typically a list) of context objects, each being + an instance of a subclass of BaseCallerContext. Returns: Generated IQL query. @@ -77,12 +79,12 @@ async def generate_iql( iql = formatted_prompt.response_parser(response) # TODO: Move IQL query parsing to prompt response parser return await IQLQuery.parse( - source=iql, - allowed_functions=filters, - event_tracker=event_tracker, - contexts=contexts + source=iql, allowed_functions=filters, event_tracker=event_tracker, contexts=contexts ) except IQLError as exc: - # TODO handle the possibility of variable `response` being not initialized while runnning the following line + # TODO handle the possibility of variable `response` being not initialized + # while runnning the following line formatted_prompt = formatted_prompt.add_assistant_message(response) formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) + + # TODO handle the situation when all retries fails and the return defaults to None diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 43b69dbd..e2292b56 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -1,15 +1,17 @@ import abc -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, Iterable, List, Optional, Tuple + +from typing_extensions import TypeAlias from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult +from dbally.context.context import CustomContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.elements import FewShotExample from dbally.similarity import AbstractSimilarityIndex -from dbally.context.context import BaseCallerContext, CustomContextsList -IndexLocation = Tuple[str, str, str] +IndexLocation: TypeAlias = Tuple[str, str, str] class BaseView(metaclass=abc.ABCMeta): @@ -27,7 +29,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[CustomContextsList] = None + contexts: Optional[Iterable[CustomContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. @@ -39,6 +41,8 @@ async def ask( n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. llm_options: Options to use for the LLM. + contexts: An iterable (typically a list) of context objects, each being + an instance of a subclass of BaseCallerContext. Returns: The result of the query. diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index c6d400d2..481052f7 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,10 +1,10 @@ import re from dataclasses import dataclass from typing import _GenericAlias # type: ignore -from typing import Sequence, Optional, Union, Type +from typing import Optional, Sequence, Type, Union -from dbally.similarity import AbstractSimilarityIndex from dbally.context.context import BaseCallerContext +from dbally.similarity import AbstractSimilarityIndex def parse_param_type(param_type: Union[type, _GenericAlias]) -> str: diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 7f24f00e..4fbb4bef 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -8,6 +8,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult +from dbally.context.context import CustomContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.template import PromptTemplate @@ -103,6 +104,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, + contexts: Optional[Iterable[CustomContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the SQL query from the natural language query and @@ -115,6 +117,7 @@ async def ask( n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. llm_options: Options to use for the LLM. + contexts: Currently not used. Returns: The result of the query. diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 25baa957..a2a2bd9e 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -3,11 +3,11 @@ import textwrap from typing import Any, Callable, List, Tuple +from dbally.context._utils import _extract_params_and_context from dbally.iql import syntax from dbally.views import decorators -from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping +from dbally.views.exposed_functions import ExposedFunction from dbally.views.structured import BaseStructuredView -from dbally.context._utils import _extract_params_and_context class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta): @@ -43,7 +43,7 @@ def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction] name=method_name, description=textwrap.dedent(method.__doc__).strip() if method.__doc__ else "", parameters=params, - context_class=context_class + context_class=context_class, ) ) return methods diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index eda1a48c..99f6955a 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -1,16 +1,15 @@ import abc from collections import defaultdict -from typing import Dict, List, Optional, Type +from typing import Dict, Iterable, List, Optional from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context.context import CustomContext from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.views.exposed_functions import ExposedFunction -from dbally.context.context import BaseCallerContext, CustomContextsList from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation @@ -42,7 +41,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[CustomContextsList] = None + contexts: Optional[Iterable[CustomContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the IQL query from the natural language query\ @@ -55,6 +54,8 @@ async def ask( n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. llm_options: Options to use for the LLM. + contexts: An iterable (typically a list) of context objects, each being + an instance of a subclass of BaseCallerContext. Returns: The result of the query. @@ -71,7 +72,7 @@ async def ask( event_tracker=event_tracker, llm_options=llm_options, n_retries=n_retries, - contexts=contexts + contexts=contexts, ) await self.apply_filters(iql) From 09bac55b1b25bedc6e86176c1b1cb3e65844d794 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Thu, 4 Jul 2024 14:10:05 +0200 Subject: [PATCH 19/53] reworked parse_param_type() function to increase performance, generality and properly handle types: Union[Type1, Type2, ...], __main__.SomeCustomClass --- src/dbally/views/exposed_functions.py | 34 +++++++++++++++++++++------ 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index 481052f7..89c1ee12 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,26 +1,46 @@ -import re from dataclasses import dataclass +from inspect import isclass from typing import _GenericAlias # type: ignore from typing import Optional, Sequence, Type, Union +import typing_extensions as type_ext + from dbally.context.context import BaseCallerContext from dbally.similarity import AbstractSimilarityIndex -def parse_param_type(param_type: Union[type, _GenericAlias]) -> str: +def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str: """ Parses the type of a method parameter and returns a string representation of it. Args: - param_type: type of the parameter + param_type: Type of the parameter. Returns: - str: string representation of the type + A string representation of the type. """ - if param_type in {int, float, str, bool, list, dict, set, tuple}: + + # TODO consider using hasattr() to ensure correctness of the IF's below + if isclass(param_type): return param_type.__name__ - if param_type.__module__ == "typing": - return re.sub(r"\btyping\.", "", str(param_type)) + + # typing.Literal['aaa', 'bbb'] edge case handler + # the args are strings not types thus isclass('aaa') is False + # at the same type string has no __module__ property which causes an error + if isinstance(param_type, str): + return f"'{param_type}'" + + if param_type.__module__ == "typing" or param_type.__module__ == "typing_extensions": + type_args = type_ext.get_args(param_type) + if type_args: + param_name = param_type._name # pylint: disable=protected-access + if param_name is None: + # workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None` + # but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"` + param_name = type_ext.get_origin(param_type)._name # pylint: disable=protected-access + + args_str_repr = ", ".join(parse_param_type(arg) for arg in type_args) + return f"{param_name}[{args_str_repr}]" return str(param_type) From d42a369e17bc576e5a6032dddaaa9e968c745831 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Thu, 4 Jul 2024 23:17:44 +0200 Subject: [PATCH 20/53] fix: removed duplicated line from the prompt template --- src/dbally/iql_generator/prompt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 0cff10ea..4c3de405 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -74,11 +74,10 @@ def __init__( "You MUST use only these methods:\n" "\n{filters}\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "If a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." + "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' "In that case, the part of the output will look like this:" "filter4(BaseCallerContext())" - "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. " ), From c0b0522a33ae6786dc0a7b85831e8c07d8fca6c1 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Thu, 4 Jul 2024 23:20:35 +0200 Subject: [PATCH 21/53] adjusted existing unit tests to work with new contextualization logic --- tests/unit/test_iql_format.py | 8 ++++++++ tests/unit/test_iql_generator.py | 1 + 2 files changed, 9 insertions(+) diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index 8f583c4c..d5f61ae1 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -23,6 +23,10 @@ async def test_iql_prompt_format_default() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." + "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." + 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' + "In that case, the part of the output will look like this:" + "filter4(BaseCallerContext())" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, @@ -53,6 +57,10 @@ async def test_iql_prompt_format_few_shots_injected() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." + "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." + 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' + "In that case, the part of the output will look like this:" + "filter4(BaseCallerContext())" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index ce3f593d..3183e394 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -80,6 +80,7 @@ async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventT source="filter_by_id(1)", allowed_functions=filters, event_tracker=event_tracker, + contexts=None ) From 9b2e1314bca0ba83dc3f988fba4bb39b086ecb2d Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Thu, 4 Jul 2024 23:23:43 +0200 Subject: [PATCH 22/53] linter-recommended fixes --- src/dbally/views/exposed_functions.py | 2 +- tests/unit/test_iql_generator.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index 89c1ee12..83d1f131 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -30,7 +30,7 @@ def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str: if isinstance(param_type, str): return f"'{param_type}'" - if param_type.__module__ == "typing" or param_type.__module__ == "typing_extensions": + if param_type.__module__ in ["typing", "typing_extensions"]: type_args = type_ext.get_args(param_type) if type_args: param_name = param_type._name # pylint: disable=protected-access diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 3183e394..7fb0a379 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -77,10 +77,7 @@ async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventT options=None, ) mock_parse.assert_called_once_with( - source="filter_by_id(1)", - allowed_functions=filters, - event_tracker=event_tracker, - contexts=None + source="filter_by_id(1)", allowed_functions=filters, event_tracker=event_tracker, contexts=None ) From 2d0ef4bcf8330affe3f3ebf5c8adb5a86a453111 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 5 Jul 2024 11:06:23 +0200 Subject: [PATCH 23/53] contextualization mechanism - dedicated unit tests --- tests/unit/iql/test_iql_parser.py | 24 ++++++++++++++++++++---- tests/unit/views/test_sqlalchemy_base.py | 19 +++++++++++++++++-- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index 94b66e28..74c1bcd3 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -1,28 +1,44 @@ import re -from typing import List +from typing import List, Union import pytest +from dbally.context import BaseCallerContext from dbally.iql import IQLArgumentParsingError, IQLQuery, IQLUnsupportedSyntaxError, syntax from dbally.iql._exceptions import IQLArgumentValidationError, IQLFunctionNotExists from dbally.iql._processor import IQLProcessor from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping +class TestCustomContext(BaseCallerContext): + city: str + + +class AnotherTestCustomContext(BaseCallerContext): + some_field: str + + async def test_iql_parser(): + custom_context = TestCustomContext(city="cracow") + custom_context2 = AnotherTestCustomContext(some_field="aaa") + parsed = await IQLQuery.parse( - "not (filter_by_name(['John', 'Anne']) and filter_by_city('cracow') and filter_by_company('deepsense.ai'))", + "not (filter_by_name(['John', 'Anne']) and filter_by_city(BaseCallerContext()) and filter_by_company('deepsense.ai'))", allowed_functions=[ ExposedFunction( name="filter_by_name", description="", parameters=[MethodParamWithTyping(name="name", type=List[str])] ), ExposedFunction( - name="filter_by_city", description="", parameters=[MethodParamWithTyping(name="city", type=str)] + name="filter_by_city", + description="", + parameters=[MethodParamWithTyping(name="city", type=Union[str, TestCustomContext])], + context_class=TestCustomContext, ), ExposedFunction( name="filter_by_company", description="", parameters=[MethodParamWithTyping(name="company", type=str)] ), ], + contexts=[custom_context, custom_context2], ) not_op = parsed.root @@ -37,7 +53,7 @@ async def test_iql_parser(): assert name_filter.arguments[0] == ["John", "Anne"] assert isinstance(city_filter, syntax.FunctionCall) - assert city_filter.arguments[0] == "cracow" + assert city_filter.arguments[0] is custom_context assert isinstance(company_filter, syntax.FunctionCall) assert company_filter.arguments[0] == "deepsense.ai" diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 079a2135..339af441 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -1,14 +1,20 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name import re +from typing import Union import sqlalchemy +from dbally.context import BaseCallerContext from dbally.iql import IQLQuery from dbally.views.decorators import view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView +class SomeTestContext(BaseCallerContext): + age: int + + class MockSqlAlchemyView(SqlAlchemyBaseView): """ Mock class for testing the SqlAlchemyBaseView @@ -22,12 +28,20 @@ def method_foo(self, idx: int) -> sqlalchemy.ColumnElement: """ Some documentation string """ + return sqlalchemy.literal(idx) @view_filter() async def method_bar(self, city: str, year: int) -> sqlalchemy.ColumnElement: return sqlalchemy.literal(f"hello {city} in {year}") + @view_filter() + async def method_baz(self, age: Union[int, SomeTestContext]) -> sqlalchemy.ColumnElement: + if isinstance(age, SomeTestContext): + return sqlalchemy.literal(age.age) + + return sqlalchemy.literal(age) + def normalize_whitespace(s: str) -> str: """ @@ -44,9 +58,10 @@ async def test_filter_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) query = await IQLQuery.parse( - 'method_foo(1) and method_bar("London", 2020)', + 'method_foo(1) and method_bar("London", 2020) and method_baz(BaseCallerContext())', allowed_functions=mock_view.list_filters(), + contexts=[SomeTestContext(age=69)], ) await mock_view.apply_filters(query) sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) - assert sql == "SELECT 'test' AS foo WHERE 1 AND 'hello London in 2020'" + assert sql == "SELECT 'test' AS foo WHERE 1 AND 'hello London in 2020' AND 69" From 6466f611bea3abe609af0d0f6c8939ee705b2e3a Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 5 Jul 2024 11:38:07 +0200 Subject: [PATCH 24/53] cleaned up overengineered code remanining from the previous iteration of development --- src/dbally/iql/_processor.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index fb9c57be..7393aa87 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -113,13 +113,13 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: if len(func_def.parameters) != len(node.args): raise ValueError(f"The method {func.id} has incorrect number of arguments") - for i, (arg, arg_def) in enumerate(zip(node.args, func_def.parameters)): - arg_value = self._parse_arg(arg, arg_spec=func_def.parameters[i], parent_func_def=func_def) + for arg, arg_spec in zip(node.args, func_def.parameters): + arg_value = self._parse_arg(arg, arg_spec=arg_spec, parent_func_def=func_def) - if arg_def.similarity_index: - arg_value = await arg_def.similarity_index.similar(arg_value, event_tracker=self._event_tracker) + if arg_spec.similarity_index: + arg_value = await arg_spec.similarity_index.similar(arg_value, event_tracker=self._event_tracker) - check_result = validate_arg_type(arg_def.type, arg_value) + check_result = validate_arg_type(arg_spec.type, arg_value) if not check_result.valid: raise IQLArgumentValidationError(message=check_result.reason or "", node=arg, source=self.source) From 637f7fac76a296dcbba6bbb6ea8d11779a3bbfcc Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 8 Jul 2024 10:45:14 +0200 Subject: [PATCH 25/53] replaced pydantic.BaseModel by dataclasses.dataclass, pydantic no longer required --- src/dbally/context/context.py | 6 +++--- tests/unit/iql/test_iql_parser.py | 3 +++ tests/unit/views/test_sqlalchemy_base.py | 2 ++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index baf86e93..edc524fe 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -1,16 +1,16 @@ import ast +from dataclasses import dataclass from typing import Iterable -from pydantic import BaseModel from typing_extensions import Self, TypeAlias from dbally.context.exceptions import ContextNotAvailableError -# CustomContext = TypeVar('CustomContext', bound='BaseCallerContext', covariant=True) CustomContext: TypeAlias = "BaseCallerContext" -class BaseCallerContext(BaseModel): +@dataclass +class BaseCallerContext: """ Pydantic-based record class. Base class for contexts that are used to pass additional knowledge about the caller environment to the filters. It is not made abstract for the convinience of IQL parsing. diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index 74c1bcd3..2bb1f0a6 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -1,4 +1,5 @@ import re +from dataclasses import dataclass from typing import List, Union import pytest @@ -10,10 +11,12 @@ from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping +@dataclass class TestCustomContext(BaseCallerContext): city: str +@dataclass class AnotherTestCustomContext(BaseCallerContext): some_field: str diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 339af441..999c774e 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -1,6 +1,7 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name import re +from dataclasses import dataclass from typing import Union import sqlalchemy @@ -11,6 +12,7 @@ from dbally.views.sqlalchemy_base import SqlAlchemyBaseView +@dataclass class SomeTestContext(BaseCallerContext): age: int From f867e25c0e5ea06348f840d98e39ef00e7e4b962 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 8 Jul 2024 11:07:49 +0200 Subject: [PATCH 26/53] BaseCallerContext: dataclass w.o. fields -> interface (abstract class), supporting both dataclasses and pydantic --- src/dbally/context/context.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index edc524fe..acac30fa 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -1,5 +1,5 @@ import ast -from dataclasses import dataclass +from abc import ABC from typing import Iterable from typing_extensions import Self, TypeAlias @@ -9,13 +9,12 @@ CustomContext: TypeAlias = "BaseCallerContext" -@dataclass -class BaseCallerContext: +class BaseCallerContext(ABC): """ - Pydantic-based record class. Base class for contexts that are used to pass additional knowledge about - the caller environment to the filters. It is not made abstract for the convinience of IQL parsing. - LLM will always return `BaseCallerContext()` when the context is required and this call will be - later substituted by a proper subclass instance selected based on the filter method signature (type hints). + An interface for contexts that are used to pass additional knowledge about + the caller environment to the filters. LLM will always return `BaseCallerContext()` + when the context is required and this call will be later substituted by an instance of + a class implementing this interface, selected based on the filter method signature (type hints). """ @classmethod From 3423033dc26dec7855f522aa527b06afe16b49bc Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 8 Jul 2024 11:21:22 +0200 Subject: [PATCH 27/53] LLM now pastes Context() instead of BaseCallerContext() to indicate that a context is required --- src/dbally/context/context.py | 10 ++++++++-- src/dbally/iql_generator/prompt.py | 4 ++-- tests/unit/iql/test_iql_parser.py | 2 +- tests/unit/test_iql_format.py | 8 ++++---- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index acac30fa..8676b797 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -12,11 +12,13 @@ class BaseCallerContext(ABC): """ An interface for contexts that are used to pass additional knowledge about - the caller environment to the filters. LLM will always return `BaseCallerContext()` + the caller environment to the filters. LLM will always return `Context()` when the context is required and this call will be later substituted by an instance of a class implementing this interface, selected based on the filter method signature (type hints). """ + _alias: str = "Context" + @classmethod def select_context(cls, contexts: Iterable[CustomContext]) -> Self: """ @@ -55,4 +57,8 @@ def is_context_call(cls, node: ast.expr) -> bool: Verification result. """ - return isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == cls.__name__ + return ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Name) + and node.func.id in [cls._alias, cls.__name__] + ) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 4c3de405..99653147 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -74,10 +74,10 @@ def __init__( "You MUST use only these methods:\n" "\n{filters}\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." + "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: Context()." 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' "In that case, the part of the output will look like this:" - "filter4(BaseCallerContext())" + "filter4(Context())" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. " ), diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index 2bb1f0a6..7cde0d5e 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -26,7 +26,7 @@ async def test_iql_parser(): custom_context2 = AnotherTestCustomContext(some_field="aaa") parsed = await IQLQuery.parse( - "not (filter_by_name(['John', 'Anne']) and filter_by_city(BaseCallerContext()) and filter_by_company('deepsense.ai'))", + "not (filter_by_name(['John', 'Anne']) and filter_by_city(Context()) and filter_by_company('deepsense.ai'))", allowed_functions=[ ExposedFunction( name="filter_by_name", description="", parameters=[MethodParamWithTyping(name="name", type=List[str])] diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index d5f61ae1..91ceb644 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -23,10 +23,10 @@ async def test_iql_prompt_format_default() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." + "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: Context()." 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' "In that case, the part of the output will look like this:" - "filter4(BaseCallerContext())" + "filter4(Context())" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, @@ -57,10 +57,10 @@ async def test_iql_prompt_format_few_shots_injected() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: BaseCallerContext()." + "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: Context()." 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' "In that case, the part of the output will look like this:" - "filter4(BaseCallerContext())" + "filter4(Context())" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, From 0d8cd1edfd34fc4eb2915d213199d6154dc49c0b Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 9 Jul 2024 11:13:40 +0200 Subject: [PATCH 28/53] docstring typo fixes; more precise return type hint --- src/dbally/context/_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py index 113b0ca2..a13af521 100644 --- a/src/dbally/context/_utils.py +++ b/src/dbally/context/_utils.py @@ -1,5 +1,5 @@ from inspect import isclass -from typing import Any, Optional, Sequence, Tuple, Type, Union +from typing import Any, Optional, Sequence, List, Tuple, Type, Union import typing_extensions as type_ext @@ -11,11 +11,11 @@ def _extract_params_and_context( filter_method_: type_ext.Callable, hidden_args: Sequence[str] -) -> Tuple[Sequence[MethodParamWithTyping], ContextClass]: +) -> Tuple[List[MethodParamWithTyping], ContextClass]: """ - Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format. - Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext - class is returned. + Processes the MethodsBaseView filter method signature to extract the argument names and type hint + in the form of MethodParamWithTyping list. Additionally, the first type hint, pointing to the subclass + of BaseCallerContext is returned. Args: filter_method_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator) From c97ba15224b281f6822f57be516aac7508054e7c Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 9 Jul 2024 15:51:05 +0200 Subject: [PATCH 29/53] renamed Context() -> AskerContext(); added more detailed detailed examples to the prompt --- src/dbally/context/_utils.py | 2 +- src/dbally/context/context.py | 2 +- src/dbally/iql_generator/prompt.py | 7 ++++--- tests/unit/iql/test_iql_parser.py | 2 +- tests/unit/test_iql_format.py | 14 ++++++++------ tests/unit/views/test_sqlalchemy_base.py | 2 +- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py index a13af521..e8e3b0eb 100644 --- a/src/dbally/context/_utils.py +++ b/src/dbally/context/_utils.py @@ -1,5 +1,5 @@ from inspect import isclass -from typing import Any, Optional, Sequence, List, Tuple, Type, Union +from typing import Any, List, Optional, Sequence, Tuple, Type, Union import typing_extensions as type_ext diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index 8676b797..a66a7525 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -17,7 +17,7 @@ class BaseCallerContext(ABC): a class implementing this interface, selected based on the filter method signature (type hints). """ - _alias: str = "Context" + _alias: str = "AskerContext" @classmethod def select_context(cls, contexts: Iterable[CustomContext]) -> Self: diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 99653147..ec10de9c 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -74,10 +74,11 @@ def __init__( "You MUST use only these methods:\n" "\n{filters}\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: Context()." - 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' + "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value by: AskerContext()." + 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..' + 'For example: "my position name", "my company valuation", "current day", "the ongoing project".' "In that case, the part of the output will look like this:" - "filter4(Context())" + "filter4(AskerContext())" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. " ), diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index 7cde0d5e..d28604e5 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -26,7 +26,7 @@ async def test_iql_parser(): custom_context2 = AnotherTestCustomContext(some_field="aaa") parsed = await IQLQuery.parse( - "not (filter_by_name(['John', 'Anne']) and filter_by_city(Context()) and filter_by_company('deepsense.ai'))", + "not (filter_by_name(['John', 'Anne']) and filter_by_city(AskerContext()) and filter_by_company('deepsense.ai'))", allowed_functions=[ ExposedFunction( name="filter_by_name", description="", parameters=[MethodParamWithTyping(name="name", type=List[str])] diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index 91ceb644..618d17c9 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -23,10 +23,11 @@ async def test_iql_prompt_format_default() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: Context()." - 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' + "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value by: AskerContext()." + 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..' + 'For example: "my position name", "my company valuation", "current day", "the ongoing project".' "In that case, the part of the output will look like this:" - "filter4(Context())" + "filter4(AskerContext())" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, @@ -57,10 +58,11 @@ async def test_iql_prompt_format_few_shots_injected() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires knowledge of some additional context, than substitute that argument value by: Context()." - 'The typical input phrase referencing some additional context contains the word "my" or similar phrasing, e.g. "my position name", "my company valuation".' + "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value by: AskerContext()." + 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..' + 'For example: "my position name", "my company valuation", "current day", "the ongoing project".' "In that case, the part of the output will look like this:" - "filter4(Context())" + "filter4(AskerContext())" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 999c774e..5b3cbf59 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -60,7 +60,7 @@ async def test_filter_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) query = await IQLQuery.parse( - 'method_foo(1) and method_bar("London", 2020) and method_baz(BaseCallerContext())', + 'method_foo(1) and method_bar("London", 2020) and method_baz(AskerContext())', allowed_functions=mock_view.list_filters(), contexts=[SomeTestContext(age=69)], ) From 1294a9ca94e69fab36182ed943db275c7ff0d97d Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 9 Jul 2024 17:30:54 +0200 Subject: [PATCH 30/53] type hint parsing changes: SomeCustomContext -> AskerContext; Union[a, b] -> a | b; removed typing & typing_extensions module name prefixes --- src/dbally/views/exposed_functions.py | 31 ++++++++++++++++++--------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index 83d1f131..009ac0d3 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from inspect import isclass from typing import _GenericAlias # type: ignore -from typing import Optional, Sequence, Type, Union +from typing import Optional, Sequence, Type, Union, Generator import typing_extensions as type_ext @@ -22,6 +22,11 @@ def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str: # TODO consider using hasattr() to ensure correctness of the IF's below if isclass(param_type): + if issubclass(param_type, BaseCallerContext): + # this mechanism ensures the LLM will be able to notice the relation between + # the keyword-call specified in the prompt and the filter method signatures + return BaseCallerContext._alias + return param_type.__name__ # typing.Literal['aaa', 'bbb'] edge case handler @@ -31,16 +36,22 @@ def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str: return f"'{param_type}'" if param_type.__module__ in ["typing", "typing_extensions"]: + param_name = param_type._name # pylint: disable=protected-access + if param_name is None: + # workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None` + # but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"` + param_name = type_ext.get_origin(param_type)._name # pylint: disable=protected-access + type_args = type_ext.get_args(param_type) - if type_args: - param_name = param_type._name # pylint: disable=protected-access - if param_name is None: - # workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None` - # but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"` - param_name = type_ext.get_origin(param_type)._name # pylint: disable=protected-access - - args_str_repr = ", ".join(parse_param_type(arg) for arg in type_args) - return f"{param_name}[{args_str_repr}]" + if not type_args: + return param_name + + parsed_args: Generator[str] = (parse_param_type(arg) for arg in type_args) + if type_ext.get_origin(param_type) is Union: + return " | ".join(parsed_args) + + parsed_args_concatanated = ", ".join(parsed_args) + return f"{param_name}[{parsed_args_concatanated}]" return str(param_type) From 999759b2c85496cb62b7f70454b4917b96404a39 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 12 Jul 2024 11:08:56 +0200 Subject: [PATCH 31/53] refactor: collection.results.[ViewExecutionResult, ExecutionResult]."context" -> "metadata" (to ommit confusing naming overlap with newly added contextualisation functionality); views.exposed_functions pylint warnings resolved --- .../audit/event_handlers/cli_event_handler.py | 4 +- .../event_handlers/langsmith_event_handler.py | 2 +- src/dbally/collection/collection.py | 2 +- src/dbally/collection/results.py | 4 +- src/dbally/context/context.py | 6 +- src/dbally/gradio/gradio_interface.py | 2 +- src/dbally/iql_generator/prompt.py | 2 +- src/dbally/nl_responder/nl_responder.py | 2 +- src/dbally/nl_responder/prompts.py | 8 +-- src/dbally/prompt/template.py | 4 +- src/dbally/views/exposed_functions.py | 66 +++++++++++++------ src/dbally/views/freeform/text2sql/view.py | 4 +- src/dbally/views/pandas_base.py | 2 +- src/dbally/views/sqlalchemy_base.py | 2 +- src/dbally/views/structured.py | 2 +- tests/unit/mocks.py | 2 +- .../similarity/sample_module/submodule.py | 4 +- tests/unit/test_collection.py | 10 +-- tests/unit/test_iql_format.py | 4 +- tests/unit/test_nl_responder.py | 2 +- tests/unit/views/test_methods_base.py | 2 +- tests/unit/views/test_pandas_base.py | 6 +- tests/unit/views/test_sqlalchemy_base.py | 2 +- tests/unit/views/text2sql/test_view.py | 2 +- 24 files changed, 86 insertions(+), 60 deletions(-) diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index aa48e049..7326a6f5 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -127,5 +127,5 @@ async def request_end(self, output: RequestEnd, request_context: Optional[dict] self._print_syntax("[green bold]REQUEST OUTPUT:") self._print_syntax(f"Number of rows: {len(output.result.results)}") - if "sql" in output.result.context: - self._print_syntax(f"{output.result.context['sql']}", "psql") + if "sql" in output.result.metadata: + self._print_syntax(f"{output.result.metadata['sql']}", "psql") diff --git a/src/dbally/audit/event_handlers/langsmith_event_handler.py b/src/dbally/audit/event_handlers/langsmith_event_handler.py index c0b619c2..060fb030 100644 --- a/src/dbally/audit/event_handlers/langsmith_event_handler.py +++ b/src/dbally/audit/event_handlers/langsmith_event_handler.py @@ -101,5 +101,5 @@ async def request_end(self, output: RequestEnd, request_context: RunTree) -> Non output: The output of the request. In this case - PSQL query. request_context: Optional context passed from request_start method """ - request_context.end(outputs={"sql": output.result.context["sql"]}) + request_context.end(outputs={"sql": output.result.metadata["sql"]}) request_context.post(exclude_child_runs=False) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 6e10aafa..8ff62599 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -234,7 +234,7 @@ async def ask( result = ExecutionResult( results=view_result.results, - context=view_result.context, + metadata=view_result.metadata, execution_time=time.monotonic() - start_time, execution_time_view=end_time_view - start_time_view, view_name=selected_view, diff --git a/src/dbally/collection/results.py b/src/dbally/collection/results.py index b33cf5e3..65421a34 100644 --- a/src/dbally/collection/results.py +++ b/src/dbally/collection/results.py @@ -14,7 +14,7 @@ class ViewExecutionResult: """ results: List[Dict[str, Any]] - context: Dict[str, Any] + metadata: Dict[str, Any] @dataclass @@ -37,7 +37,7 @@ class ExecutionResult: """ results: List[Dict[str, Any]] - context: Dict[str, Any] + metadata: Dict[str, Any] execution_time: float execution_time_view: float view_name: str diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index a66a7525..bef1ef23 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -17,7 +17,7 @@ class BaseCallerContext(ABC): a class implementing this interface, selected based on the filter method signature (type hints). """ - _alias: str = "AskerContext" + alias: str = "AskerContext" @classmethod def select_context(cls, contexts: Iterable[CustomContext]) -> Self: @@ -58,7 +58,5 @@ def is_context_call(cls, node: ast.expr) -> bool: """ return ( - isinstance(node, ast.Call) - and isinstance(node.func, ast.Name) - and node.func.id in [cls._alias, cls.__name__] + isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id in [cls.alias, cls.__name__] ) diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 761b0dd2..89c61f62 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -114,7 +114,7 @@ async def _ui_ask_query( execution_result = await self.collection.ask( question=question_query, return_natural_response=natural_language_flag ) - generated_query = str(execution_result.context) + generated_query = str(execution_result.metadata) data = self._load_results_into_dataframe(execution_result.results) textual_response = str(execution_result.textual_response) if natural_language_flag else textual_response except UnsupportedQueryError: diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index ec10de9c..92427c88 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -74,7 +74,7 @@ def __init__( "You MUST use only these methods:\n" "\n{filters}\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value by: AskerContext()." + "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value with: AskerContext()." 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..' 'For example: "my position name", "my company valuation", "current day", "the ongoing project".' "In that case, the part of the output will look like this:" diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py index 7a8f98e4..693efe59 100644 --- a/src/dbally/nl_responder/nl_responder.py +++ b/src/dbally/nl_responder/nl_responder.py @@ -70,7 +70,7 @@ async def generate_response( if tokens_count > self._max_tokens_count: prompt_format = QueryExplanationPromptFormat( question=question, - context=result.context, + metadata=result.metadata, results=result.results, ) formatted_prompt = self._explainer_prompt_template.format_prompt(prompt_format) diff --git a/src/dbally/nl_responder/prompts.py b/src/dbally/nl_responder/prompts.py index f99a8a6c..90365623 100644 --- a/src/dbally/nl_responder/prompts.py +++ b/src/dbally/nl_responder/prompts.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import pandas as pd @@ -40,9 +40,9 @@ def __init__( self, *, question: str, - context: Dict[str, Any], + metadata: Dict[str, Any], results: List[Dict[str, Any]], - examples: List[FewShotExample] = None, + examples: Optional[List[FewShotExample]] = None, ) -> None: """ Constructs a new QueryExplanationPromptFormat instance. @@ -55,7 +55,7 @@ def __init__( """ super().__init__(examples) self.question = question - self.query = next((context.get(key) for key in ("iql", "sql", "query") if context.get(key)), question) + self.query = next((metadata.get(key) for key in ("iql", "sql", "query") if metadata.get(key)), question) self.number_of_results = len(results) diff --git a/src/dbally/prompt/template.py b/src/dbally/prompt/template.py index 124a3e1c..b4ef650d 100644 --- a/src/dbally/prompt/template.py +++ b/src/dbally/prompt/template.py @@ -1,6 +1,6 @@ import copy import re -from typing import Callable, Dict, Generic, List, TypeVar +from typing import Callable, Dict, Generic, List, Optional, TypeVar from typing_extensions import Self @@ -55,7 +55,7 @@ class PromptFormat: Generic format for prompts allowing to inject few shot examples into the conversation. """ - def __init__(self, examples: List[FewShotExample] = None) -> None: + def __init__(self, examples: Optional[List[FewShotExample]] = None) -> None: """ Constructs a new PromptFormat instance. diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index 009ac0d3..9237fd92 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from inspect import isclass from typing import _GenericAlias # type: ignore -from typing import Optional, Sequence, Type, Union, Generator +from typing import Generator, Optional, Sequence, Type, Union import typing_extensions as type_ext @@ -9,7 +9,50 @@ from dbally.similarity import AbstractSimilarityIndex -def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str: +class TypeParsingError(ValueError): + """ + Custo error raised when parsing a data type using `parse_param_type()` fails. + """ + + +def _parse_complex_type(param_type: Union[type_ext.Type, _GenericAlias]) -> str: + """ + Generates string representation of a complex type from `typing` or `typing_extensions` module. + + Args: + param_type: type or type alias. + + Returns: + A string representation of the type. + """ + + # delegating large chunk of parsing code to this separate function prevents + # pylint from raising R0911: too-many-return-statements + + param_name = param_type._name # pylint: disable=protected-access + if param_name is None: + # workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None` + # but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"` + param_origin_type = type_ext.get_origin(param_type) + if param_origin_type is None: + # probably unnecessary hack ensuring + raise TypeParsingError(f"Unable to parse: {str(param_type)}") + + param_name = param_origin_type._name # pylint: disable=protected-access + + type_args = type_ext.get_args(param_type) + if not type_args: + return param_name + + parsed_args: Generator[str] = (parse_param_type(arg) for arg in type_args) + if type_ext.get_origin(param_type) is Union: + return " | ".join(parsed_args) + + parsed_args_concatanated = ", ".join(parsed_args) + return f"{param_name}[{parsed_args_concatanated}]" + + +def parse_param_type(param_type: Union[type_ext.Type, _GenericAlias, str]) -> str: """ Parses the type of a method parameter and returns a string representation of it. @@ -25,7 +68,7 @@ def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str: if issubclass(param_type, BaseCallerContext): # this mechanism ensures the LLM will be able to notice the relation between # the keyword-call specified in the prompt and the filter method signatures - return BaseCallerContext._alias + return BaseCallerContext.alias return param_type.__name__ @@ -36,22 +79,7 @@ def parse_param_type(param_type: Union[type, _GenericAlias, str]) -> str: return f"'{param_type}'" if param_type.__module__ in ["typing", "typing_extensions"]: - param_name = param_type._name # pylint: disable=protected-access - if param_name is None: - # workaround for typing.Literal, because: `typing.Literal['aaa', 'bbb']._name is None` - # but at the same time: `type_ext.get_origin(typing.Literal['aaa', 'bbb'])._name == "Literal"` - param_name = type_ext.get_origin(param_type)._name # pylint: disable=protected-access - - type_args = type_ext.get_args(param_type) - if not type_args: - return param_name - - parsed_args: Generator[str] = (parse_param_type(arg) for arg in type_args) - if type_ext.get_origin(param_type) is Union: - return " | ".join(parsed_args) - - parsed_args_concatanated = ", ".join(parsed_args) - return f"{param_name}[{parsed_args_concatanated}]" + return _parse_complex_type(param_type) return str(param_type) diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 4fbb4bef..27596d7e 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -151,7 +151,7 @@ async def ask( ) if dry_run: - return ViewExecutionResult(results=[], context={"sql": sql}) + return ViewExecutionResult(results=[], metadata={"sql": sql}) rows = await self._execute_sql(sql, parameters, event_tracker=event_tracker) break @@ -167,7 +167,7 @@ async def ask( # pylint: disable=protected-access return ViewExecutionResult( results=[dict(row._mapping) for row in rows], - context={ + metadata={ "sql": sql, }, ) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index 3d3831f7..9baad3ab 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -83,7 +83,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult( results=filtered_data.to_dict(orient="records"), - context={ + metadata={ "filter_mask": self._filter_mask, }, ) diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index b1783558..8454ddaf 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -88,5 +88,5 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult( results=results, - context={"sql": sql}, + metadata={"sql": sql}, ) diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 99f6955a..cfe2d6ba 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -78,7 +78,7 @@ async def ask( await self.apply_filters(iql) result = self.execute(dry_run=dry_run) - result.context["iql"] = f"{iql}" + result.metadata["iql"] = f"{iql}" return result diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 75cc914b..df02581b 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -30,7 +30,7 @@ async def apply_filters(self, filters: IQLQuery) -> None: ... def execute(self, dry_run=False) -> ViewExecutionResult: - return ViewExecutionResult(results=[], context={}) + return ViewExecutionResult(results=[], metadata={}) class MockIQLGenerator(IQLGenerator): diff --git a/tests/unit/similarity/sample_module/submodule.py b/tests/unit/similarity/sample_module/submodule.py index 42e05c0a..0e222479 100644 --- a/tests/unit/similarity/sample_module/submodule.py +++ b/tests/unit/similarity/sample_module/submodule.py @@ -24,7 +24,7 @@ async def apply_filters(self, filters: IQLQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: - return ViewExecutionResult(results=[], context={}) + return ViewExecutionResult(results=[], metadata={}) class BarView(MethodsBaseView): @@ -43,4 +43,4 @@ async def apply_filters(self, filters: IQLQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: - return ViewExecutionResult(results=[], context={}) + return ViewExecutionResult(results=[], metadata={}) diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 38ec3e99..57fe0e6f 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -54,7 +54,7 @@ class MockViewWithResults(MockViewBase): """ def execute(self, dry_run=False) -> ViewExecutionResult: - return ViewExecutionResult(results=[{"foo": "bar"}], context={"baz": "qux"}) + return ViewExecutionResult(results=[{"foo": "bar"}], metadata={"baz": "qux"}) def list_filters(self) -> List[ExposedFunction]: return [ExposedFunction("test_filter", "", [])] @@ -79,7 +79,7 @@ class MockViewWithSimilarity(MockViewBase): """ def execute(self, dry_run=False) -> ViewExecutionResult: - return ViewExecutionResult(results=[{"foo": "bar"}], context={"baz": "qux"}) + return ViewExecutionResult(results=[{"foo": "bar"}], metadata={"baz": "qux"}) def list_filters(self) -> List[ExposedFunction]: return [ @@ -106,7 +106,7 @@ class MockViewWithSimilarity2(MockViewBase): """ def execute(self, dry_run=False) -> ViewExecutionResult: - return ViewExecutionResult(results=[{"foo": "bar"}], context={"baz": "qux"}) + return ViewExecutionResult(results=[{"foo": "bar"}], metadata={"baz": "qux"}) def list_filters(self) -> List[ExposedFunction]: return [ @@ -291,7 +291,7 @@ async def test_ask_view_selection_single_view() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": "test_filter()"} + assert result.metadata == {"baz": "qux", "iql": "test_filter()"} async def test_ask_view_selection_multiple_views() -> None: @@ -312,7 +312,7 @@ async def test_ask_view_selection_multiple_views() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": "test_filter()"} + assert result.metadata == {"baz": "qux", "iql": "test_filter()"} async def test_ask_view_selection_no_views() -> None: diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index 618d17c9..64085d9c 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -23,7 +23,7 @@ async def test_iql_prompt_format_default() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value by: AskerContext()." + "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value with: AskerContext()." 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..' 'For example: "my position name", "my company valuation", "current day", "the ongoing project".' "In that case, the part of the output will look like this:" @@ -58,7 +58,7 @@ async def test_iql_prompt_format_few_shots_injected() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value by: AskerContext()." + "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value with: AskerContext()." 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..' 'For example: "my position name", "my company valuation", "current day", "the ongoing project".' "In that case, the part of the output will look like this:" diff --git a/tests/unit/test_nl_responder.py b/tests/unit/test_nl_responder.py index e23fe3d1..3f5815a0 100644 --- a/tests/unit/test_nl_responder.py +++ b/tests/unit/test_nl_responder.py @@ -22,7 +22,7 @@ def event_tracker() -> EventTracker: @pytest.fixture def answer() -> ViewExecutionResult: - return ViewExecutionResult(results=[{"id": 1, "name": "Mock name"}], context={"sql": "Mock SQL"}) + return ViewExecutionResult(results=[{"id": 1, "name": "Mock name"}], metadata={"sql": "Mock SQL"}) @pytest.mark.asyncio diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 58959a64..bbd1aee5 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -29,7 +29,7 @@ async def apply_filters(self, filters: IQLQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: - return ViewExecutionResult(results=[], context={}) + return ViewExecutionResult(results=[], metadata={}) def test_list_filters() -> None: diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 51eea791..4e45f29f 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -66,7 +66,7 @@ async def test_filter_or() -> None: await mock_view.apply_filters(query) result = mock_view.execute() assert result.results == MOCK_DATA_BERLIN_OR_LONDON - assert result.context["filter_mask"].tolist() == [True, False, True, False, True] + assert result.metadata["filter_mask"].tolist() == [True, False, True, False, True] async def test_filter_and() -> None: @@ -81,7 +81,7 @@ async def test_filter_and() -> None: await mock_view.apply_filters(query) result = mock_view.execute() assert result.results == MOCK_DATA_PARIS_2020 - assert result.context["filter_mask"].tolist() == [False, True, False, False, False] + assert result.metadata["filter_mask"].tolist() == [False, True, False, False, False] async def test_filter_not() -> None: @@ -96,4 +96,4 @@ async def test_filter_not() -> None: await mock_view.apply_filters(query) result = mock_view.execute() assert result.results == MOCK_DATA_NOT_PARIS_2020 - assert result.context["filter_mask"].tolist() == [True, False, True, True, True] + assert result.metadata["filter_mask"].tolist() == [True, False, True, True, True] diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 5b3cbf59..20f11e72 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -65,5 +65,5 @@ async def test_filter_sql_generation() -> None: contexts=[SomeTestContext(age=69)], ) await mock_view.apply_filters(query) - sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) + sql = normalize_whitespace(mock_view.execute(dry_run=True).metadata["sql"]) assert sql == "SELECT 'test' AS foo WHERE 1 AND 'hello London in 2020' AND 69" diff --git a/tests/unit/views/text2sql/test_view.py b/tests/unit/views/text2sql/test_view.py index 91b7b50f..f5c51247 100644 --- a/tests/unit/views/text2sql/test_view.py +++ b/tests/unit/views/text2sql/test_view.py @@ -61,7 +61,7 @@ async def test_text2sql_view(sample_db: Engine): response = await collection.ask("Show me customers from New York") - assert response.context["sql"] == llm_response["sql"] + assert response.metadata["sql"] == llm_response["sql"] assert response.results == [ {"id": 1, "name": "Alice", "city": "New York"}, {"id": 3, "name": "Charlie", "city": "New York"}, From 2e1005a169f8d461dcfa5fe3ec807e1e872773ab Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 12 Jul 2024 11:47:04 +0200 Subject: [PATCH 32/53] param type parsing: correctly handling builtins types with args (e.g. list[int]) in Python 3.9+ --- src/dbally/views/exposed_functions.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index 9237fd92..9ee307ec 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -15,9 +15,11 @@ class TypeParsingError(ValueError): """ -def _parse_complex_type(param_type: Union[type_ext.Type, _GenericAlias]) -> str: +def _parse_standard_type(param_type: Union[type_ext.Type, _GenericAlias]) -> str: """ - Generates string representation of a complex type from `typing` or `typing_extensions` module. + Generates string representation of a data type (or alias) being neither custom class or string. + This function is primarily intended to parse types from consecutive modules: + `builtins` (>= Python 3.9), `typing` and `typing_extensions`. Args: param_type: type or type alias. @@ -78,9 +80,16 @@ def parse_param_type(param_type: Union[type_ext.Type, _GenericAlias, str]) -> st if isinstance(param_type, str): return f"'{param_type}'" - if param_type.__module__ in ["typing", "typing_extensions"]: - return _parse_complex_type(param_type) + if param_type.__module__ in ["builtins", "typing", "typing_extensions"]: + return _parse_standard_type(param_type) + # TODO add explicit support Generic types, + # although they should be already handled okay by _parse_standard_type() + + # TODO test on various different Python versions > 3.8 + + # fallback, at the moment we expect this to be called only for type aliases + # note that in Python 3.12+ there exists typing.TypeAliasType that can be checked with isinstance() return str(param_type) From 820066d585c6bb8857ea1713c498bd9c092e7826 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 12 Jul 2024 11:47:43 +0200 Subject: [PATCH 33/53] type hint fix: explcitly marked BaseCallerContext.alias as typing.ClassVar --- src/dbally/context/context.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index bef1ef23..b01cad2e 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -1,6 +1,6 @@ import ast from abc import ABC -from typing import Iterable +from typing import ClassVar, Iterable from typing_extensions import Self, TypeAlias @@ -15,9 +15,12 @@ class BaseCallerContext(ABC): the caller environment to the filters. LLM will always return `Context()` when the context is required and this call will be later substituted by an instance of a class implementing this interface, selected based on the filter method signature (type hints). + + Attributes: + alias: Class variable defining an alias which is defined in the prompt for the LLM to reference context. """ - alias: str = "AskerContext" + alias: ClassVar[str] = "AskerContext" @classmethod def select_context(cls, contexts: Iterable[CustomContext]) -> Self: @@ -38,8 +41,8 @@ class by its right instance. if not contexts: raise ContextNotAvailableError( - "The LLM detected that the context is required to execute the query +\ - and the filter signature allows contextualization while the context was not provided." + "The LLM detected that the context is required to execute the query" + "and the filter signature allows contextualization while the context was not provided." ) # TODO confirm whether it is possible to design a correct type hints here and skipping `type: ignore` From 25fbfa64ce6c21b85bae918726513b5b97b78bb1 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 15 Jul 2024 12:20:47 +0200 Subject: [PATCH 34/53] docs + benchmarks adjusted to meet new naming [ExecutionResult, ViewExecutionResult]."context" -> "metadata" --- benchmark/dbally_benchmark/e2e_benchmark.py | 2 +- docs/how-to/use_elastic_vector_store_code.py | 2 +- docs/how-to/use_elasticsearch_store_code.py | 2 +- docs/how-to/views/custom_views_code.py | 2 +- docs/quickstart/quickstart2_code.py | 2 +- docs/quickstart/quickstart_code.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py index aa686727..f2d86b58 100644 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ b/benchmark/dbally_benchmark/e2e_benchmark.py @@ -31,7 +31,7 @@ async def _run_dbally_for_single_example(example: BIRDExample, collection: Collection) -> Text2SQLResult: try: result = await collection.ask(example.question, dry_run=True) - sql = result.context["sql"] + sql = result.metadata["sql"] except UnsupportedQueryError: sql = "UnsupportedQueryError" except NoViewFoundError: diff --git a/docs/how-to/use_elastic_vector_store_code.py b/docs/how-to/use_elastic_vector_store_code.py index 4817fcf4..48c3309f 100644 --- a/docs/how-to/use_elastic_vector_store_code.py +++ b/docs/how-to/use_elastic_vector_store_code.py @@ -91,7 +91,7 @@ async def main(country="United States", years_of_experience="2"): f"Find someone from the {country} with more than {years_of_experience} years of experience." ) - print(f"The generated SQL query is: {result.context.get('sql')}") + print(f"The generated SQL query is: {result.metadata.get('sql')}") print() print(f"Retrieved {len(result.results)} candidates:") for candidate in result.results: diff --git a/docs/how-to/use_elasticsearch_store_code.py b/docs/how-to/use_elasticsearch_store_code.py index 1f690c35..181a3c89 100644 --- a/docs/how-to/use_elasticsearch_store_code.py +++ b/docs/how-to/use_elasticsearch_store_code.py @@ -95,7 +95,7 @@ async def main(country="United States", years_of_experience="2"): f"Find someone from the {country} with more than {years_of_experience} years of experience." ) - print(f"The generated SQL query is: {result.context.get('sql')}") + print(f"The generated SQL query is: {result.metadata.get('sql')}") print() print(f"Retrieved {len(result.results)} candidates:") for candidate in result.results: diff --git a/docs/how-to/views/custom_views_code.py b/docs/how-to/views/custom_views_code.py index 33c954c7..573eddc7 100644 --- a/docs/how-to/views/custom_views_code.py +++ b/docs/how-to/views/custom_views_code.py @@ -64,7 +64,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: print(self._filter) filtered_data = list(filter(self._filter, self.get_data())) - return ViewExecutionResult(results=filtered_data, context={}) + return ViewExecutionResult(results=filtered_data, metadata={}) class CandidateView(FilteredIterableBaseView): def get_data(self) -> Iterable: diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index 593e7b4a..fd6c6b8c 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -85,7 +85,7 @@ async def main(): result = await collection.ask("Find someone from the United States with more than 2 years of experience.") - print(f"The generated SQL query is: {result.context.get('sql')}") + print(f"The generated SQL query is: {result.metadata.get('sql')}") print() print(f"Retrieved {len(result.results)} candidates:") for candidate in result.results: diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 34ee9765..0e302537 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -63,7 +63,7 @@ async def main(): result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") - print(f"The generated SQL query is: {result.context.get('sql')}") + print(f"The generated SQL query is: {result.metadata.get('sql')}") print() print(f"Retrieved {len(result.results)} candidates:") for candidate in result.results: From a1545773cb52636bab2e9b7ffb91816378a05321 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 15 Jul 2024 12:32:44 +0200 Subject: [PATCH 35/53] redesigned context-not-available error to follow the same principles as other IQL errors, inherting from IQLError, thus enabled its handling by self-reflection mechanism --- src/dbally/context/exceptions.py | 21 ---------------- src/dbally/iql/_exceptions.py | 42 +++++++++++++++++++++++++++++--- src/dbally/iql/_processor.py | 12 +++------ 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/src/dbally/context/exceptions.py b/src/dbally/context/exceptions.py index 0efa1473..15c1d303 100644 --- a/src/dbally/context/exceptions.py +++ b/src/dbally/context/exceptions.py @@ -1,26 +1,5 @@ -from abc import ABC - - -class BaseContextException(Exception, ABC): - """ - A base (abstract) exception for all specification context-related exception. - """ - - class ContextNotAvailableError(Exception): """ An exception inheriting from BaseContextException pointining that no sufficient context information was provided by the user while calling view.ask(). """ - - -class ContextualisationNotAllowed(Exception): - """ - An exception inheriting from BaseContextException pointining that the filter method signature - does not allow to provide an additional context. - """ - - -# WORKAROUND - traditional inhertiance syntax is not working in context of abstract Exceptions -BaseContextException.register(ContextNotAvailableError) -BaseContextException.register(ContextualisationNotAllowed) diff --git a/src/dbally/iql/_exceptions.py b/src/dbally/iql/_exceptions.py index 7df08187..797a7824 100644 --- a/src/dbally/iql/_exceptions.py +++ b/src/dbally/iql/_exceptions.py @@ -1,13 +1,26 @@ import ast from typing import Optional, Union +from typing_extensions import TypeAlias + from dbally.exceptions import DbAllyError +IQLNode: TypeAlias = Union[ast.stmt, ast.expr] + class IQLError(DbAllyError): - """Base exception for all IQL parsing related exceptions.""" + """ + Base exception for all IQL parsing related exceptions. + + Attributes: + node: An IQL Node (AST Exprresion) during which processing an error was encountered. + source: Raw LLM response containing IQL filter calls. + """ + + node: IQLNode + source: str - def __init__(self, message: str, node: Union[ast.stmt, ast.expr], source: str) -> None: + def __init__(self, message: str, node: IQLNode, source: str) -> None: message = message + ": " + source[node.col_offset : node.end_col_offset] super().__init__(message) @@ -18,7 +31,7 @@ def __init__(self, message: str, node: Union[ast.stmt, ast.expr], source: str) - class IQLArgumentParsingError(IQLError): """Raised when an argument cannot be parsed into a valid IQL.""" - def __init__(self, node: Union[ast.stmt, ast.expr], source: str) -> None: + def __init__(self, node: IQLNode, source: str) -> None: message = "Not a valid IQL argument" super().__init__(message, node, source) @@ -26,7 +39,7 @@ def __init__(self, node: Union[ast.stmt, ast.expr], source: str) -> None: class IQLUnsupportedSyntaxError(IQLError): """Raised when trying to parse an unsupported syntax.""" - def __init__(self, node: Union[ast.stmt, ast.expr], source: str, context: Optional[str] = None) -> None: + def __init__(self, node: IQLNode, source: str, context: Optional[str] = None) -> None: node_name = node.__class__.__name__ message = f"{node_name} syntax is not supported in IQL" @@ -47,3 +60,24 @@ def __init__(self, node: ast.Name, source: str) -> None: class IQLArgumentValidationError(IQLError): """Raised when argument is not valid for a given method.""" + + +class IQLContextNotAllowedError(IQLError): + """ + Raised when a context call/keyword has been passed as an argument to the filter + which does not support contextualization for this specific parameter. + """ + + def __init__(self, node: IQLNode, source: str, arg_name: Optional[str] = None) -> None: + if arg_name is None: + message = ( + "The LLM detected that the context is required to execute the query" + "while the filter signature does not allow it at all." + ) + else: + message = ( + "The LLM detected that the context is required to execute the query" + f"while the filter signature does allow it for `{arg_name}` argument." + ) + + super().__init__(message, node, source) diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 7393aa87..5e18a480 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -4,11 +4,11 @@ from dbally.audit.event_tracker import EventTracker from dbally.context._utils import _does_arg_allow_context from dbally.context.context import BaseCallerContext, CustomContext -from dbally.context.exceptions import ContextualisationNotAllowed from dbally.iql import syntax from dbally.iql._exceptions import ( IQLArgumentParsingError, IQLArgumentValidationError, + IQLContextNotAllowedError, IQLError, IQLFunctionNotExists, IQLUnsupportedSyntaxError, @@ -143,16 +143,10 @@ def _parse_arg( raise IQLArgumentParsingError(arg, self.source) if parent_func_def.context_class is None: - raise ContextualisationNotAllowed( - "The LLM detected that the context is required +\ - to execute the query while the filter signature does not allow it at all." - ) + raise IQLContextNotAllowedError(arg, self.source) if not _does_arg_allow_context(arg_spec): - raise ContextualisationNotAllowed( - f"The LLM detected that the context is required +\ - to execute the query while the filter signature does allow it for `{arg_spec.name}` argument." - ) + raise IQLContextNotAllowedError(arg, self.source, arg_name=arg_spec.name) return parent_func_def.context_class.select_context(self.contexts) From 623effdd7b1f2ce4bcd5f50fcd5cdca886a51c3c Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 15 Jul 2024 17:53:45 +0200 Subject: [PATCH 36/53] EXPERIMENTAL: reworked context injection such it is handled immediately in 'structured_view.ask()' and than stored in 'ExposedFunction' instances --- src/dbally/collection/collection.py | 4 ++-- src/dbally/context/context.py | 22 +++++++++----------- src/dbally/context/exceptions.py | 24 +++++++++++++++++++--- src/dbally/iql/_processor.py | 13 +++--------- src/dbally/iql/_query.py | 15 ++++---------- src/dbally/iql/_type_validators.py | 2 +- src/dbally/iql_generator/iql_generator.py | 10 ++------- src/dbally/views/base.py | 8 ++++---- src/dbally/views/exposed_functions.py | 23 ++++++++++++++++++++- src/dbally/views/freeform/text2sql/view.py | 4 ++-- src/dbally/views/structured.py | 23 ++++++++++++++++++--- tests/unit/iql/test_iql_parser.py | 8 +------- tests/unit/test_iql_generator.py | 2 +- tests/unit/views/test_sqlalchemy_base.py | 7 ++++--- 14 files changed, 97 insertions(+), 68 deletions(-) diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index 8ff62599..a0fccde9 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -10,7 +10,7 @@ from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult -from dbally.context.context import CustomContext +from dbally.context.context import BaseCallerContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder @@ -157,7 +157,7 @@ async def ask( dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[CustomContext]] = None, + contexts: Optional[Iterable[BaseCallerContext]] = None, ) -> ExecutionResult: """ Ask question in a text form and retrieve the answer based on the available views. diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py index b01cad2e..1f5a32d6 100644 --- a/src/dbally/context/context.py +++ b/src/dbally/context/context.py @@ -2,11 +2,9 @@ from abc import ABC from typing import ClassVar, Iterable -from typing_extensions import Self, TypeAlias +from typing_extensions import Self -from dbally.context.exceptions import ContextNotAvailableError - -CustomContext: TypeAlias = "BaseCallerContext" +from dbally.context.exceptions import BaseContextError class BaseCallerContext(ABC): @@ -23,7 +21,7 @@ class BaseCallerContext(ABC): alias: ClassVar[str] = "AskerContext" @classmethod - def select_context(cls, contexts: Iterable[CustomContext]) -> Self: + def select_context(cls, contexts: Iterable["BaseCallerContext"]) -> Self: """ Typically called from a subclass of BaseCallerContext, selects a member object from `contexts` being an instance of the same class. Effectively provides a type dispatch mechanism, substituting the context @@ -36,17 +34,17 @@ class by its right instance. An instance of the same BaseCallerContext subclass this method is caller from. Raises: - ContextNotAvailableError: If the sequence of context objects passed as argument is empty. + BaseContextError: If no element in `contexts` matches `cls` class. """ - if not contexts: - raise ContextNotAvailableError( - "The LLM detected that the context is required to execute the query" - "and the filter signature allows contextualization while the context was not provided." - ) + try: + selected_context = next(filter(lambda obj: isinstance(obj, cls), contexts)) + except StopIteration as e: + # this custom exception provides more clear message what have just gone wrong + raise BaseContextError() from e # TODO confirm whether it is possible to design a correct type hints here and skipping `type: ignore` - return next(filter(lambda obj: isinstance(obj, cls), contexts)) # type: ignore + return selected_context # type: ignore @classmethod def is_context_call(cls, node: ast.expr) -> bool: diff --git a/src/dbally/context/exceptions.py b/src/dbally/context/exceptions.py index 15c1d303..c538ee18 100644 --- a/src/dbally/context/exceptions.py +++ b/src/dbally/context/exceptions.py @@ -1,5 +1,23 @@ -class ContextNotAvailableError(Exception): +class BaseContextError(Exception): """ - An exception inheriting from BaseContextException pointining that no sufficient context information - was provided by the user while calling view.ask(). + A base error for context handling logic. """ + + +class SuitableContextNotProvidedError(BaseContextError): + """ + Raised when method argument type hint points that a contextualization is available + but not suitable context was provided. + """ + + def __init__(self, filter_fun_signature: str, context_class_name: str) -> None: + # this syntax 'or BaseCallerContext' is just to prevent type checkers + # from raising a warning, as filter_.context_class can be None. It's essenially a fallback that should never + # be reached, unless somebody will use this Exception against its purpose. + # TODO consider raising a warning/error when this happens. + + message = ( + f"No context of class {context_class_name} was provided" + f"while the filter {filter_fun_signature} requires it." + ) + super().__init__(message) diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 5e18a480..37c08ee0 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -3,7 +3,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.context._utils import _does_arg_allow_context -from dbally.context.context import BaseCallerContext, CustomContext +from dbally.context.context import BaseCallerContext from dbally.iql import syntax from dbally.iql._exceptions import ( IQLArgumentParsingError, @@ -23,21 +23,17 @@ class IQLProcessor: Attributes: source: Raw LLM response containing IQL filter calls. - allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View. - contexts: A sequence (typically a list) of context objects, each being an instance of - a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query. + allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View.= """ source: str allowed_functions: Mapping[str, "ExposedFunction"] - contexts: Iterable[CustomContext] _event_tracker: EventTracker def __init__( self, source: str, allowed_functions: Iterable[ExposedFunction], - contexts: Optional[Iterable[CustomContext]] = None, event_tracker: Optional[EventTracker] = None, ) -> None: """ @@ -46,14 +42,11 @@ def __init__( Args: source: Raw LLM response containing IQL filter calls. allowed_functions: An interable (typically a list) of all filters implemented for a certain View. - contexts: An iterable (typically a list) of context objects, each being an instance of - a subclass of BaseCallerContext. even_tracker: An EvenTracker instance. """ self.source = source self.allowed_functions = {func.name: func for func in allowed_functions} - self.contexts = contexts or [] self._event_tracker = event_tracker or EventTracker() async def process(self) -> syntax.Node: @@ -148,7 +141,7 @@ def _parse_arg( if not _does_arg_allow_context(arg_spec): raise IQLContextNotAllowedError(arg, self.source, arg_name=arg_spec.name) - return parent_func_def.context_class.select_context(self.contexts) + return parent_func_def.context if not isinstance(arg, ast.Constant): raise IQLArgumentParsingError(arg, self.source) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index cc090ad6..a9080a49 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,9 +1,7 @@ -from typing import TYPE_CHECKING, Iterable, List, Optional +from typing import TYPE_CHECKING, List, Optional from typing_extensions import Self -from dbally.context.context import CustomContext - from ..audit.event_tracker import EventTracker from . import syntax from ._processor import IQLProcessor @@ -28,11 +26,7 @@ def __str__(self) -> str: @classmethod async def parse( - cls, - source: str, - allowed_functions: List["ExposedFunction"], - event_tracker: Optional[EventTracker] = None, - contexts: Optional[Iterable[CustomContext]] = None, + cls, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None ) -> Self: """ Parse IQL string to IQLQuery object. @@ -41,11 +35,10 @@ async def parse( source: IQL string that needs to be parsed allowed_functions: list of IQL functions that are allowed for this query event_tracker: EventTracker object to track events - contexts: An iterable (typically a list) of context objects, each being - an instance of a subclass of BaseCallerContext. + Returns: IQLQuery object """ - root = await IQLProcessor(source, allowed_functions, contexts, event_tracker).process() + root = await IQLProcessor(source, allowed_functions, event_tracker).process() return cls(root=root, source=source) diff --git a/src/dbally/iql/_type_validators.py b/src/dbally/iql/_type_validators.py index 7b993ef5..b06f8305 100644 --- a/src/dbally/iql/_type_validators.py +++ b/src/dbally/iql/_type_validators.py @@ -70,7 +70,7 @@ def validate_arg_type(required_type: Union[Type, _GenericAlias], value: Any) -> actual_type = type_ext.get_origin(required_type) if isinstance(required_type, _GenericAlias) else required_type # typing.Union is an instance of _GenericAlias if actual_type is None: - # workaround to prevent type warning in line `if isisntanc(value, actual_type):`, TODO check whether necessary + # workaround to prevent type warning in line `if isisntance(value, actual_type):`, TODO check whether necessary actual_type = required_type.__origin__ if actual_type is Union: diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 8018f6e1..c6aeec31 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,7 +1,6 @@ -from typing import Iterable, List, Optional +from typing import List, Optional from dbally.audit.event_tracker import EventTracker -from dbally.context.context import CustomContext from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.llms.base import LLM @@ -43,7 +42,6 @@ async def generate_iql( examples: Optional[List[FewShotExample]] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - contexts: Optional[Iterable[CustomContext]] = None, ) -> IQLQuery: """ Generates IQL in text form using LLM. @@ -55,8 +53,6 @@ async def generate_iql( examples: List of examples to be injected into the conversation. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors. - contexts: An iterable (typically a list) of context objects, each being - an instance of a subclass of BaseCallerContext. Returns: Generated IQL query. @@ -78,9 +74,7 @@ async def generate_iql( # TODO: Move response parsing to llm generate_text method iql = formatted_prompt.response_parser(response) # TODO: Move IQL query parsing to prompt response parser - return await IQLQuery.parse( - source=iql, allowed_functions=filters, event_tracker=event_tracker, contexts=contexts - ) + return await IQLQuery.parse(source=iql, allowed_functions=filters, event_tracker=event_tracker) except IQLError as exc: # TODO handle the possibility of variable `response` being not initialized # while runnning the following line diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index e2292b56..a83b961d 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -5,7 +5,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import CustomContext +from dbally.context.context import BaseCallerContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.elements import FewShotExample @@ -29,7 +29,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[CustomContext]] = None, + contexts: Optional[Iterable[BaseCallerContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. @@ -59,9 +59,9 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc def list_few_shots(self) -> List[FewShotExample]: """ - List all examples to be injected into few-shot prompt. + Lists all examples to be injected into few-shot prompt. Returns: - List of few-shot examples + List of few-shot examples. """ return [] diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index 9ee307ec..07b88005 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -1,11 +1,12 @@ from dataclasses import dataclass from inspect import isclass from typing import _GenericAlias # type: ignore -from typing import Generator, Optional, Sequence, Type, Union +from typing import Generator, Iterable, Optional, Sequence, Type, Union import typing_extensions as type_ext from dbally.context.context import BaseCallerContext +from dbally.context.exceptions import BaseContextError, SuitableContextNotProvidedError from dbally.similarity import AbstractSimilarityIndex @@ -127,6 +128,7 @@ class ExposedFunction: description: str parameters: Sequence[MethodParamWithTyping] context_class: Optional[Type[BaseCallerContext]] = None + context: Optional[BaseCallerContext] = None def __str__(self) -> str: base_str = f"{self.name}({', '.join(str(param) for param in self.parameters)})" @@ -135,3 +137,22 @@ def __str__(self) -> str: return f"{base_str} - {self.description}" return base_str + + def inject_context(self, contexts: Iterable[BaseCallerContext]) -> None: + """ + Inserts reference to the member of `contexts` of the proper class in self.context. + + Args: + contexts: An iterable of user-provided context objects. + + Raises: + SuitableContextNotProvidedError: Ff no element in `contexts` matches `self.context_class`. + """ + + if self.context_class is None: + return + + try: + self.context = self.context_class.select_context(contexts) + except BaseContextError as e: + raise SuitableContextNotProvidedError(str(self), self.context_class.__name__) from e diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 27596d7e..31f4c041 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -8,7 +8,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import CustomContext +from dbally.context.context import BaseCallerContext from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.template import PromptTemplate @@ -104,7 +104,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[CustomContext]] = None, + contexts: Optional[Iterable[BaseCallerContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the SQL query from the natural language query and diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index cfe2d6ba..0b5e1f27 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,7 +4,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import CustomContext +from dbally.context.context import BaseCallerContext from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM @@ -33,6 +33,22 @@ def get_iql_generator(self, llm: LLM) -> IQLGenerator: """ return IQLGenerator(llm=llm) + @classmethod + def contextualize_filters( + cls, filters: Iterable[ExposedFunction], contexts: Optional[Iterable[BaseCallerContext]] + ) -> None: + """ + Updates a list of filters packed as ExposedFunction's by ingesting the matching context objects. + + Args: + filters: An iterable of filters. + contexts: An iterable of context objects. + """ + + contexts = contexts or [] + for filter_ in filters: + filter_.inject_context(contexts) + async def ask( self, query: str, @@ -41,7 +57,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[CustomContext]] = None, + contexts: Optional[Iterable[BaseCallerContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the IQL query from the natural language query\ @@ -65,6 +81,8 @@ async def ask( filters = self.list_filters() examples = self.list_few_shots() + self.contextualize_filters(filters, contexts) + iql = await iql_generator.generate_iql( question=query, filters=filters, @@ -72,7 +90,6 @@ async def ask( event_tracker=event_tracker, llm_options=llm_options, n_retries=n_retries, - contexts=contexts, ) await self.apply_filters(iql) diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index d28604e5..691ab39d 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -16,14 +16,8 @@ class TestCustomContext(BaseCallerContext): city: str -@dataclass -class AnotherTestCustomContext(BaseCallerContext): - some_field: str - - async def test_iql_parser(): custom_context = TestCustomContext(city="cracow") - custom_context2 = AnotherTestCustomContext(some_field="aaa") parsed = await IQLQuery.parse( "not (filter_by_name(['John', 'Anne']) and filter_by_city(AskerContext()) and filter_by_company('deepsense.ai'))", @@ -36,12 +30,12 @@ async def test_iql_parser(): description="", parameters=[MethodParamWithTyping(name="city", type=Union[str, TestCustomContext])], context_class=TestCustomContext, + context=custom_context, ), ExposedFunction( name="filter_by_company", description="", parameters=[MethodParamWithTyping(name="company", type=str)] ), ], - contexts=[custom_context, custom_context2], ) not_op = parsed.root diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 7fb0a379..08189943 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -77,7 +77,7 @@ async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventT options=None, ) mock_parse.assert_called_once_with( - source="filter_by_id(1)", allowed_functions=filters, event_tracker=event_tracker, contexts=None + source="filter_by_id(1)", allowed_functions=filters, event_tracker=event_tracker ) diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 20f11e72..cd115410 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -59,10 +59,11 @@ async def test_filter_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) + filters = mock_view.list_filters() + mock_view.contextualize_filters(filters, [SomeTestContext(age=69)]) + query = await IQLQuery.parse( - 'method_foo(1) and method_bar("London", 2020) and method_baz(AskerContext())', - allowed_functions=mock_view.list_filters(), - contexts=[SomeTestContext(age=69)], + 'method_foo(1) and method_bar("London", 2020) and method_baz(AskerContext())', allowed_functions=filters ) await mock_view.apply_filters(query) sql = normalize_whitespace(mock_view.execute(dry_run=True).metadata["sql"]) From afacf5ba384bed4f48a94fc2ed47eb9e1e79ff0a Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Fri, 19 Jul 2024 16:13:09 +0200 Subject: [PATCH 37/53] additional unit tests for the new contextualization mechanism --- tests/unit/views/test_methods_base.py | 42 +++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index bbd1aee5..987e8e8d 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -1,13 +1,28 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name - -from typing import List, Literal, Tuple +import asyncio +from dataclasses import dataclass +from typing import List, Literal, Tuple, Union, Optional from dbally.collection.results import ViewExecutionResult from dbally.iql import IQLQuery from dbally.views.decorators import view_filter -from dbally.views.exposed_functions import MethodParamWithTyping +from dbally.views.exposed_functions import MethodParamWithTyping, ExposedFunction from dbally.views.methods_base import MethodsBaseView +from dbally.context import BaseCallerContext +from dbally.iql_generator.iql_generator import IQLGenerator +from dbally.audit.event_tracker import EventTracker +from dbally.prompt.elements import FewShotExample +from dbally.llms.clients.base import LLMOptions +from dbally.llms.base import LLM + + +@dataclass +class TestCallerContext(BaseCallerContext): + """ + Mock class for testing context. + """ + current_year: Literal['2023', '2024'] class MockMethodsBase(MethodsBaseView): @@ -22,7 +37,7 @@ def method_foo(self, idx: int) -> None: """ @view_filter() - def method_bar(self, cities: List[str], year: Literal["2023", "2024"], pairs: List[Tuple[str, int]]) -> str: + def method_bar(self, cities: List[str], year: Union[Literal["2023", "2024"], TestCallerContext], pairs: List[Tuple[str, int]]) -> str: return f"hello {cities} in {year} of {pairs}" async def apply_filters(self, filters: IQLQuery) -> None: @@ -47,9 +62,24 @@ def test_list_filters() -> None: assert method_bar.description == "" assert method_bar.parameters == [ MethodParamWithTyping("cities", List[str]), - MethodParamWithTyping("year", Literal["2023", "2024"]), + MethodParamWithTyping("year", Union[Literal["2023", "2024"], TestCallerContext]), MethodParamWithTyping("pairs", List[Tuple[str, int]]), ] assert ( - str(method_bar) == "method_bar(cities: List[str], year: Literal['2023', '2024'], pairs: List[Tuple[str, int]])" + str(method_bar) == "method_bar(cities: List[str], year: Literal['2023', '2024'] | AskerContext, pairs: List[Tuple[str, int]])" ) + + +async def test_contextualization() -> None: + mock_view = MockMethodsBase() + filters = mock_view.list_filters() + test_context = TestCallerContext("2024") + mock_view.contextualize_filters(filters, [test_context]) + + method_foo = [f for f in filters if f.name == "method_foo"][0] + assert method_foo.context_class is None + assert method_foo.context is None + + method_bar = [f for f in filters if f.name == "method_bar"][0] + assert method_bar.context_class is TestCallerContext + assert method_bar.context is test_context From dd8b339f2a3481f0c74ee049560bba68b0c1f3d1 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 22 Jul 2024 11:18:12 +0200 Subject: [PATCH 38/53] context benchmark script and data --- .../dbally_benchmark/context_benchmark.py | 235 ++++++++++++++++++ .../dataset/context_dataset.json | 62 +++++ 2 files changed, 297 insertions(+) create mode 100644 benchmark/dbally_benchmark/context_benchmark.py create mode 100644 benchmark/dbally_benchmark/dataset/context_dataset.json diff --git a/benchmark/dbally_benchmark/context_benchmark.py b/benchmark/dbally_benchmark/context_benchmark.py new file mode 100644 index 00000000..387e61be --- /dev/null +++ b/benchmark/dbally_benchmark/context_benchmark.py @@ -0,0 +1,235 @@ +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring +import dbally +import asyncio +import typing +import json +import traceback +import os + +import tqdm.asyncio +import sqlalchemy +import pydantic +from typing_extensions import TypeAlias +from copy import deepcopy +from sqlalchemy import create_engine +from sqlalchemy.ext.automap import automap_base, AutomapBase +from dataclasses import dataclass, field + +from dbally import decorators, SqlAlchemyBaseView +from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler +from dbally.llms.litellm import LiteLLM +from dbally.context import BaseCallerContext + + +SQLITE_DB_FILE_REL_PATH = "../../examples/recruiting/data/candidates.db" +engine = create_engine(f"sqlite:///{os.path.abspath(SQLITE_DB_FILE_REL_PATH)}") + +Base: AutomapBase = automap_base() +Base.prepare(autoload_with=engine) + +Candidate = Base.classes.candidates + + +class MyData(BaseCallerContext, pydantic.BaseModel): + first_name: str + surname: str + position: str + years_of_experience: int + university: str + skills: typing.List[str] + country: str + + +class OpenPosition(BaseCallerContext, pydantic.BaseModel): + position: str + min_years_of_experience: int + graduated_from_university: str + required_skills: typing.List[str] + + +class CandidateView(SqlAlchemyBaseView): + """ + A view for retrieving candidates from the database. + """ + + def get_select(self) -> sqlalchemy.Select: + """ + Creates the initial SqlAlchemy select object, which will be used to build the query. + """ + return sqlalchemy.select(Candidate) + + @decorators.view_filter() + def at_least_experience(self, years: typing.Union[int, OpenPosition]) -> sqlalchemy.ColumnElement: + """ + Filters candidates with at least `years` of experience. + """ + if isinstance(years, OpenPosition): + years = years.min_years_of_experience + + return Candidate.years_of_experience >= years + + @decorators.view_filter() + def at_most_experience(self, years: typing.Union[int, MyData]) -> sqlalchemy.ColumnElement: + if isinstance(years, MyData): + years = years.years_of_experience + + return Candidate.years_of_experience <= years + + @decorators.view_filter() + def has_position(self, position: typing.Union[str, OpenPosition]) -> sqlalchemy.ColumnElement: + if isinstance(position, OpenPosition): + position = position.position + + return Candidate.position == position + + @decorators.view_filter() + def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement: + """ + Filters candidates that can be considered for a senior data scientist position. + """ + return sqlalchemy.and_( + Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]), + Candidate.years_of_experience >= 3, + ) + + @decorators.view_filter() + def from_country(self, country: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement: + """ + Filters candidates from a specific country. + """ + if isinstance(country, MyData): + return Candidate.country == country.country + + return Candidate.country == country + + @decorators.view_filter() + def graduated_from_university(self, university: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement: + if isinstance(university, MyData): + university = university.university + + return Candidate.university == university + + @decorators.view_filter() + def has_skill(self, skill: str) -> sqlalchemy.ColumnElement: + return Candidate.skills.like(f"%{skill}%") + + @decorators.view_filter() + def knows_data_analysis(self) -> sqlalchemy.ColumnElement: + return Candidate.tags.like("%Data Analysis%") + + @decorators.view_filter() + def knows_python(self) -> sqlalchemy.ColumnElement: + return Candidate.skills.like("%Python%") + + @decorators.view_filter() + def first_name_is(self, first_name: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement: + if isinstance(first_name, MyData): + first_name = first_name.first_name + + return Candidate.name.startswith(first_name) + + +OpenAILLMName: TypeAlias = typing.Literal['gpt-3.5-turbo', 'gpt-4-turbo', 'gpt-4o'] + + +def setup_collection(model_name: OpenAILLMName) -> dbally.Collection: + llm = LiteLLM(model_name=model_name) + + collection = dbally.create_collection("recruitment", llm) + collection.add(CandidateView, lambda: CandidateView(engine)) + + return collection + + +async def generate_iql_from_question( + collection: dbally.Collection, + model_name: OpenAILLMName, + question: str, + contexts: typing.Optional[typing.List[BaseCallerContext]] +) -> typing.Tuple[str, OpenAILLMName, typing.Optional[str]]: + + try: + result = await collection.ask( + question, + contexts=contexts, + dry_run=True + ) + except Exception as e: + exc_pretty = traceback.format_exception_only(e.__class__, e)[0] + return question, model_name, f"FAILED: {exc_pretty}" + + out = result.metadata.get("iql") + if out is None: + return question, model_name, None + + return question, model_name, out.replace('"', '\'') + + +@dataclass +class BenchmarkConfig: + dataset_path: str + out_path: str + n_repeats: int = 5 + llms: typing.List[OpenAILLMName] = field(default_factory=lambda: ['gpt-3.5-turbo', 'gpt-4-turbo', 'gpt-4o']) + + +async def main(config: BenchmarkConfig): + test_set = None + with open(config.dataset_path, 'r') as file: + test_set = json.load(file) + + contexts = [ + MyData( + first_name="John", + surname="Smith", + years_of_experience=4, + position="Data Engineer", + university="University of Toronto", + skills=["Python"], + country="United Kingdom" + ), + OpenPosition( + position="Machine Learning Engineer", + graduated_from_university="Stanford Univeristy", + min_years_of_experience=1, + required_skills=["Python", "SQL"] + ) + ] + + tasks: typing.List[asyncio.Task] = [] + for model_name in config.llms: + collection = setup_collection(model_name) + for test_case in test_set: + answers = [] + for _ in range(config.n_repeats): + task = asyncio.create_task(generate_iql_from_question(collection, model_name, + test_case["question"], contexts=contexts)) + tasks.append(task) + + output_data = { + test_case["question"]:test_case + for test_case in test_set + } + empty_answers = {str(llm_name): [] for llm_name in config.llms} + + total_iter = len(config.llms) * len(test_set) * config.n_repeats + for task in tqdm.asyncio.tqdm.as_completed(tasks, total=total_iter): + question, llm_name, answer = await task + if "answers" not in output_data[question]: + output_data[question]["answers"] = deepcopy(empty_answers) + + output_data[question]["answers"][llm_name].append(answer) + + output_data_list = list(output_data.values()) + + with open(config.out_path, 'w') as file: + file.write(json.dumps(test_set, indent=2)) + + +if __name__ == "__main__": + config = BenchmarkConfig( + dataset_path="dataset/context_dataset.json", + out_path="../../context_benchmark_output.json" + ) + + asyncio.run(main(config)) diff --git a/benchmark/dbally_benchmark/dataset/context_dataset.json b/benchmark/dbally_benchmark/dataset/context_dataset.json new file mode 100644 index 00000000..a0df66e6 --- /dev/null +++ b/benchmark/dbally_benchmark/dataset/context_dataset.json @@ -0,0 +1,62 @@ +[ + { + "question": "Find me French candidates suitable for my position with at least 1 year of experience.", + "correct_answer": "from_country('France') AND has_position(AskerContext()) AND at_least_experience(1)", + "context": false + }, + { + "question": "Please find me candidates from my country who have at most 4 years of experience.", + "correct_answer": "from_country(AskerContext()) AND at_most_experience(4)", + "context": true + }, + { + "question": "Find me candidates who graduated from Stanford University and work as Software Engineers.", + "correct_answer": "graduated_from_university('Stanford University') AND has_position('Software Engineer')", + "context": false + }, + { + "question": "Find me candidates who graduated from my university", + "correct_answer": "graduated_from_university(AskerContext())", + "context": true + }, + { + "question": "Could you find me candidates with at most as experience who also know Python?", + "correct_answer": "at_most_experience(AskerContext()) AND know_python()", + "context": true + }, + { + "question": "Please find me candidates who know Data Analysis and Python", + "correct_answer": "know_python() AND know_data_analysis()", + "context": false + }, + { + "question": "Find me candidates with at least minimal required experience for the currently open position.", + "correct_answer": "at_least_experience(AskerContext())", + "context": true + }, + { + "question": "List candidates with between 2 and 6 years of experience.", + "correct_answer": "at_least_experience(2) AND at_most_experience(6)", + "context": false + }, + { + "question": "Find me candidates who currently have the same position as we look for in our company?", + "correct_answer": "has_position(AskerContext())", + "context": true + }, + { + "question": "Please find me senior data scientist candidates who know Data Analysis and come from my country", + "correct_answer": "senior_data_scientist_position() AND has_skill('Data Analysis') AND from_country(AskerContext())", + "context": true + }, + { + "question": "Find me candidates that have the same first name as me", + "correct_answer": "first_name_is(AskerContext())", + "context": true + }, + { + "question": "List candidates named Mohammed from India", + "correct_answer": "first_name_is('Mohammed') AND from_country('India')", + "context": false + } +] From 6bb0816dc6decc5776c239a9732fceca6d5707ff Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 22 Jul 2024 11:19:39 +0200 Subject: [PATCH 39/53] refactored main prompt (too long lines), missing end-of-line characters --- src/dbally/iql_generator/prompt.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 92427c88..aea3e301 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -73,12 +73,13 @@ def __init__( "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" "\n{filters}\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value with: AskerContext()." - 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..' - 'For example: "my position name", "my company valuation", "current day", "the ongoing project".' - "In that case, the part of the output will look like this:" - "filter4(AskerContext())" + "It is VERY IMPORTANT not to use methods other than those listed above.\n" + "Finally, if a called function argument value is not directly specified in the query but instead requires " + "some additional execution context, than substitute that argument value with: AskerContext().\n" + 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..\n' + 'For example: "my position name", "my company valuation", "current day", "the ongoing project".\n' + "In that case, the part of the output will look like this:\n" + "filter4(AskerContext())\n" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. " ), From f388f92ce9ea6d84d369f0b00c4adc6cda9e58ad Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Mon, 22 Jul 2024 11:33:12 +0200 Subject: [PATCH 40/53] better error handling --- benchmark/dbally_benchmark/context_benchmark.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmark/dbally_benchmark/context_benchmark.py b/benchmark/dbally_benchmark/context_benchmark.py index 387e61be..abea5fb6 100644 --- a/benchmark/dbally_benchmark/context_benchmark.py +++ b/benchmark/dbally_benchmark/context_benchmark.py @@ -19,6 +19,7 @@ from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.llms.litellm import LiteLLM from dbally.context import BaseCallerContext +from dbally.iql import IQLError SQLITE_DB_FILE_REL_PATH = "../../examples/recruiting/data/candidates.db" @@ -154,6 +155,9 @@ async def generate_iql_from_question( contexts=contexts, dry_run=True ) + except IQLError as e: + exc_pretty = traceback.format_exception_only(e.__class__, e)[0] + return question, model_name, f"FAILED: {exc_pretty}({e.source})" except Exception as e: exc_pretty = traceback.format_exception_only(e.__class__, e)[0] return question, model_name, f"FAILED: {exc_pretty}" From fbecc5171e80652ce55a9a69dde229a33317d949 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 23 Jul 2024 09:33:14 +0200 Subject: [PATCH 41/53] context benchmark dataset fix --- benchmark/dbally_benchmark/dataset/context_dataset.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/dbally_benchmark/dataset/context_dataset.json b/benchmark/dbally_benchmark/dataset/context_dataset.json index a0df66e6..c37f38d1 100644 --- a/benchmark/dbally_benchmark/dataset/context_dataset.json +++ b/benchmark/dbally_benchmark/dataset/context_dataset.json @@ -20,7 +20,7 @@ "context": true }, { - "question": "Could you find me candidates with at most as experience who also know Python?", + "question": "Could you find me candidates with at most as experience as I have who also know Python?", "correct_answer": "at_most_experience(AskerContext()) AND know_python()", "context": true }, From 5d4ff64385a58c360c44c80a83797a6ec08ee893 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 23 Jul 2024 09:33:43 +0200 Subject: [PATCH 42/53] added polars-based accuracy summary to the benchmark --- .../dbally_benchmark/context_benchmark.py | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/benchmark/dbally_benchmark/context_benchmark.py b/benchmark/dbally_benchmark/context_benchmark.py index abea5fb6..430caf84 100644 --- a/benchmark/dbally_benchmark/context_benchmark.py +++ b/benchmark/dbally_benchmark/context_benchmark.py @@ -1,14 +1,14 @@ # pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring +import polars as pl + import dbally import asyncio import typing import json import traceback import os - import tqdm.asyncio import sqlalchemy -import pydantic from typing_extensions import TypeAlias from copy import deepcopy from sqlalchemy import create_engine @@ -31,7 +31,8 @@ Candidate = Base.classes.candidates -class MyData(BaseCallerContext, pydantic.BaseModel): +@dataclass +class MyData(BaseCallerContext): first_name: str surname: str position: str @@ -41,7 +42,8 @@ class MyData(BaseCallerContext, pydantic.BaseModel): country: str -class OpenPosition(BaseCallerContext, pydantic.BaseModel): +@dataclass +class OpenPosition(BaseCallerContext): position: str min_years_of_experience: int graduated_from_university: str @@ -130,7 +132,7 @@ def first_name_is(self, first_name: typing.Union[str, MyData]) -> sqlalchemy.Col return Candidate.name.startswith(first_name) -OpenAILLMName: TypeAlias = typing.Literal['gpt-3.5-turbo', 'gpt-4-turbo', 'gpt-4o'] +OpenAILLMName: TypeAlias = typing.Literal['gpt-3.5-turbo', 'gpt-3.5-turbo-instruct', 'gpt-4-turbo', 'gpt-4o'] def setup_collection(model_name: OpenAILLMName) -> dbally.Collection: @@ -224,10 +226,30 @@ async def main(config: BenchmarkConfig): output_data[question]["answers"][llm_name].append(answer) - output_data_list = list(output_data.values()) + df_out_raw = pl.DataFrame(list(output_data.values())) + + df_out = ( + df_out_raw + .unnest("answers") + .unpivot( + on=pl.selectors.starts_with("gpt"), + index=["question", "correct_answer", "context"], + variable_name="model", + value_name="answer" + ) + .explode("answer") + .group_by(["context", "model"]) + .agg([ + (pl.col("correct_answer") == pl.col("answer")).mean().alias("frac_hits"), + (pl.col("correct_answer") == pl.col("answer")).sum().alias("n_hits"), + ]) + .sort(["model", "context"]) + ) + + print(df_out) with open(config.out_path, 'w') as file: - file.write(json.dumps(test_set, indent=2)) + file.write(json.dumps(df_out_raw.to_dicts(), indent=2)) if __name__ == "__main__": From e7e88268322db3f093ae5c71cb6ca26ad6ffcae7 Mon Sep 17 00:00:00 2001 From: Jakub Cierocki Date: Tue, 23 Jul 2024 09:34:04 +0200 Subject: [PATCH 43/53] adjusted prompt to reduce halucinations: nested filter/context calls and putting filter args in quotation marks --- src/dbally/iql_generator/prompt.py | 8 ++++++- tests/unit/test_iql_format.py | 38 ++++++++++++++++++++---------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index aea3e301..4cab3fc4 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -76,10 +76,16 @@ def __init__( "It is VERY IMPORTANT not to use methods other than those listed above.\n" "Finally, if a called function argument value is not directly specified in the query but instead requires " "some additional execution context, than substitute that argument value with: AskerContext().\n" - 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..\n' + "The typical input phrase suggesting that the additional execution context need to be referenced \n" + 'contains words like: "I", "my", "mine", "current", "the" etc..\n' 'For example: "my position name", "my company valuation", "current day", "the ongoing project".\n' "In that case, the part of the output will look like this:\n" "filter4(AskerContext())\n" + "Outside this situation DO NOT combine filters like this:\n" + "filter4(filter2())\n" + "And NEVER quote the filter argument unless you're sure it represents the string/literal datatype, \n" + "Especially do not quote AskerContext() calls like this:\n" + "filter2('AskerContext()')\n" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. " ), diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index 64085d9c..b82942d7 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -22,12 +22,19 @@ async def test_iql_prompt_format_default() -> None: "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" "\n\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value with: AskerContext()." - 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..' - 'For example: "my position name", "my company valuation", "current day", "the ongoing project".' - "In that case, the part of the output will look like this:" - "filter4(AskerContext())" + "It is VERY IMPORTANT not to use methods other than those listed above.\n" + "Finally, if a called function argument value is not directly specified in the query but instead requires " + "some additional execution context, than substitute that argument value with: AskerContext().\n" + "The typical input phrase suggesting that the additional execution context need to be referenced \n" + 'contains words like: "I", "my", "mine", "current", "the" etc..\n' + 'For example: "my position name", "my company valuation", "current day", "the ongoing project".\n' + "In that case, the part of the output will look like this:\n" + "filter4(AskerContext())\n" + "Outside this situation DO NOT combine filters like this:\n" + "filter4(filter2())\n" + "And NEVER quote the filter argument unless you're sure it represents the string/literal datatype, \n" + "Especially do not quote AskerContext() calls like this:\n" + "filter2('AskerContext()')\n" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, @@ -57,12 +64,19 @@ async def test_iql_prompt_format_few_shots_injected() -> None: "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" "\n\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - "Finally, if a called function argument value is not directly specified in the query but instead requires some additional execution context, than substitute that argument value with: AskerContext()." - 'The typical input phrase suggesting that the additional execution context need to be referenced contains words like: "I", "my", "mine", "current", "the" etc..' - 'For example: "my position name", "my company valuation", "current day", "the ongoing project".' - "In that case, the part of the output will look like this:" - "filter4(AskerContext())" + "It is VERY IMPORTANT not to use methods other than those listed above.\n" + "Finally, if a called function argument value is not directly specified in the query but instead requires " + "some additional execution context, than substitute that argument value with: AskerContext().\n" + "The typical input phrase suggesting that the additional execution context need to be referenced \n" + 'contains words like: "I", "my", "mine", "current", "the" etc..\n' + 'For example: "my position name", "my company valuation", "current day", "the ongoing project".\n' + "In that case, the part of the output will look like this:\n" + "filter4(AskerContext())\n" + "Outside this situation DO NOT combine filters like this:\n" + "filter4(filter2())\n" + "And NEVER quote the filter argument unless you're sure it represents the string/literal datatype, \n" + "Especially do not quote AskerContext() calls like this:\n" + "filter2('AskerContext()')\n" """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, From 8eefd9b4a125bee414e80e932a927a76c50482c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 23 Sep 2024 09:30:11 +0100 Subject: [PATCH 44/53] fix linters --- src/dbally/iql/_type_validators.py | 3 ++- src/dbally/views/freeform/text2sql/view.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/dbally/iql/_type_validators.py b/src/dbally/iql/_type_validators.py index 1b28c310..dfcab015 100644 --- a/src/dbally/iql/_type_validators.py +++ b/src/dbally/iql/_type_validators.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from typing import ( # type: ignore - Annotated, Any, Callable, Dict, @@ -13,6 +12,8 @@ get_origin, ) +from typing_extensions import Annotated + @dataclass class _ValidationResult: diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 6c9a0413..147573f5 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -100,11 +100,11 @@ async def ask( self, query: str, llm: LLM, + contexts: Optional[List[BaseCallerContext]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[BaseCallerContext]] = None, ) -> ViewExecutionResult: """ Executes the query and returns the result. It generates the SQL query from the natural language query and @@ -113,11 +113,11 @@ async def ask( Args: query: The natural language query to execute. llm: The LLM used to execute the query. + contexts: Currently not used. event_tracker: The event tracker used to audit the query execution. n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. llm_options: Options to use for the LLM. - contexts: Currently not used. Returns: The result of the query. From c28091fc711824433025eaaa0bb56237cd83f41b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 23 Sep 2024 10:09:04 +0100 Subject: [PATCH 45/53] fix tests --- tests/unit/views/test_methods_base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 7c076aaf..51c1548e 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -12,7 +12,7 @@ @dataclass -class TestCallerContext(BaseCallerContext): +class CallerContext(BaseCallerContext): """ Mock class for testing context. """ @@ -33,7 +33,7 @@ def method_foo(self, idx: int) -> None: @view_filter() def method_bar( - self, cities: List[str], year: Union[Literal["2023", "2024"], TestCallerContext], pairs: List[Tuple[str, int]] + self, cities: List[str], year: Union[Literal["2023", "2024"], CallerContext], pairs: List[Tuple[str, int]] ) -> str: return f"hello {cities} in {year} of {pairs}" @@ -45,7 +45,7 @@ def method_baz(self) -> None: @view_aggregation() def method_qux( - self, ages: List[int], years: Union[Literal["2023", "2024"], TestCallerContext], names: List[str] + self, ages: List[int], years: Union[Literal["2023", "2024"], CallerContext], names: List[str] ) -> str: return f"hello {ages} and {names}" @@ -74,7 +74,7 @@ def test_list_filters() -> None: assert method_bar.description == "" assert method_bar.parameters == [ MethodParamWithTyping("cities", List[str]), - MethodParamWithTyping("year", Union[Literal["2023", "2024"], TestCallerContext]), + MethodParamWithTyping("year", Union[Literal["2023", "2024"], CallerContext]), MethodParamWithTyping("pairs", List[Tuple[str, int]]), ] assert ( @@ -98,7 +98,7 @@ def test_list_aggregations() -> None: assert method_qux.description == "" assert method_qux.parameters == [ MethodParamWithTyping("ages", List[int]), - MethodParamWithTyping("years", Union[Literal["2023", "2024"], TestCallerContext]), + MethodParamWithTyping("years", Union[Literal["2023", "2024"], CallerContext]), MethodParamWithTyping("names", List[str]), ] assert str(method_qux) == "method_qux(ages: List[int], years: Literal['2023', '2024'] | Context, names: List[str])" From 69a8d588504fe89a2a9b9463c928468f9d37932c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 23 Sep 2024 10:23:50 +0100 Subject: [PATCH 46/53] fix tests --- tests/unit/iql/test_iql_parser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index babfa4ed..70b71f2e 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -1,8 +1,9 @@ import re from dataclasses import dataclass -from typing import Annotated, List, Union +from typing import List, Union import pytest +from typing_extensions import Annotated from dbally.context.context import BaseCallerContext from dbally.iql import IQLArgumentParsingError, IQLUnsupportedSyntaxError, syntax From d6c8fc6038c0d25496628d7678370bd51aedb478 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 23 Sep 2024 10:36:53 +0100 Subject: [PATCH 47/53] fix tests --- src/dbally/iql/_type_validators.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/dbally/iql/_type_validators.py b/src/dbally/iql/_type_validators.py index dfcab015..7f0bce48 100644 --- a/src/dbally/iql/_type_validators.py +++ b/src/dbally/iql/_type_validators.py @@ -1,18 +1,7 @@ from dataclasses import dataclass -from typing import ( # type: ignore - Any, - Callable, - Dict, - Literal, - Optional, - Type, - Union, - _GenericAlias, - get_args, - get_origin, -) - -from typing_extensions import Annotated +from typing import Any, Callable, Dict, Literal, Optional, Type, Union, _GenericAlias # type: ignore + +from typing_extensions import Annotated, get_args, get_origin @dataclass From d7026d446f72c79612c421a2e428932a1cfda466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 23 Sep 2024 10:59:39 +0100 Subject: [PATCH 48/53] rm old benchmarks --- .../dbally_benchmark/context_benchmark.py | 252 ------------------ .../dataset/context_dataset.json | 62 ----- benchmark/dbally_benchmark/e2e_benchmark.py | 153 ----------- 3 files changed, 467 deletions(-) delete mode 100644 benchmark/dbally_benchmark/context_benchmark.py delete mode 100644 benchmark/dbally_benchmark/dataset/context_dataset.json delete mode 100644 benchmark/dbally_benchmark/e2e_benchmark.py diff --git a/benchmark/dbally_benchmark/context_benchmark.py b/benchmark/dbally_benchmark/context_benchmark.py deleted file mode 100644 index e9108759..00000000 --- a/benchmark/dbally_benchmark/context_benchmark.py +++ /dev/null @@ -1,252 +0,0 @@ -# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring, missing-class-docstring, broad-exception-caught -import asyncio -import json -import os -import traceback -import typing -from copy import deepcopy -from dataclasses import dataclass, field - -import polars as pl -import sqlalchemy -import tqdm.asyncio -from sqlalchemy import create_engine -from sqlalchemy.ext.automap import AutomapBase, automap_base -from typing_extensions import TypeAlias - -import dbally -from dbally import SqlAlchemyBaseView, decorators -from dbally.collection import Collection -from dbally.context import BaseCallerContext -from dbally.iql import IQLError -from dbally.llms.litellm import LiteLLM - -SQLITE_DB_FILE_REL_PATH = "../../examples/recruiting/data/candidates.db" -engine = create_engine(f"sqlite:///{os.path.abspath(SQLITE_DB_FILE_REL_PATH)}") - -Base: AutomapBase = automap_base() -Base.prepare(autoload_with=engine) - -Candidate = Base.classes.candidates - - -@dataclass -class MyData(BaseCallerContext): - first_name: str - surname: str - position: str - years_of_experience: int - university: str - skills: typing.List[str] - country: str - - -@dataclass -class OpenPosition(BaseCallerContext): - position: str - min_years_of_experience: int - graduated_from_university: str - required_skills: typing.List[str] - - -class CandidateView(SqlAlchemyBaseView): - """ - A view for retrieving candidates from the database. - """ - - def get_select(self) -> sqlalchemy.Select: - """ - Creates the initial SqlAlchemy select object, which will be used to build the query. - """ - return sqlalchemy.select(Candidate) - - @decorators.view_filter() - def at_least_experience(self, years: typing.Union[int, OpenPosition]) -> sqlalchemy.ColumnElement: - """ - Filters candidates with at least `years` of experience. - """ - if isinstance(years, OpenPosition): - years = years.min_years_of_experience - - return Candidate.years_of_experience >= years - - @decorators.view_filter() - def at_most_experience(self, years: typing.Union[int, MyData]) -> sqlalchemy.ColumnElement: - if isinstance(years, MyData): - years = years.years_of_experience - - return Candidate.years_of_experience <= years - - @decorators.view_filter() - def has_position(self, position: typing.Union[str, OpenPosition]) -> sqlalchemy.ColumnElement: - if isinstance(position, OpenPosition): - position = position.position - - return Candidate.position == position - - @decorators.view_filter() - def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement: - """ - Filters candidates that can be considered for a senior data scientist position. - """ - return sqlalchemy.and_( - Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]), - Candidate.years_of_experience >= 3, - ) - - @decorators.view_filter() - def from_country(self, country: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement: - """ - Filters candidates from a specific country. - """ - if isinstance(country, MyData): - return Candidate.country == country.country - - return Candidate.country == country - - @decorators.view_filter() - def graduated_from_university(self, university: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement: - if isinstance(university, MyData): - university = university.university - - return Candidate.university == university - - @decorators.view_filter() - def has_skill(self, skill: str) -> sqlalchemy.ColumnElement: - return Candidate.skills.like(f"%{skill}%") - - @decorators.view_filter() - def knows_data_analysis(self) -> sqlalchemy.ColumnElement: - return Candidate.tags.like("%Data Analysis%") - - @decorators.view_filter() - def knows_python(self) -> sqlalchemy.ColumnElement: - return Candidate.skills.like("%Python%") - - @decorators.view_filter() - def first_name_is(self, first_name: typing.Union[str, MyData]) -> sqlalchemy.ColumnElement: - if isinstance(first_name, MyData): - first_name = first_name.first_name - - return Candidate.name.startswith(first_name) - - -OpenAILLMName: TypeAlias = typing.Literal["gpt-3.5-turbo", "gpt-3.5-turbo-instruct", "gpt-4-turbo", "gpt-4o"] - - -def setup_collection(model_name: OpenAILLMName) -> Collection: - llm = LiteLLM(model_name=model_name) - - collection = dbally.create_collection("recruitment", llm) - collection.add(CandidateView, lambda: CandidateView(engine)) - - return collection - - -async def generate_iql_from_question( - collection: Collection, - model_name: OpenAILLMName, - question: str, - contexts: typing.Optional[typing.List[BaseCallerContext]], -) -> typing.Tuple[str, OpenAILLMName, typing.Optional[str]]: - try: - result = await collection.ask(question, contexts=contexts, dry_run=True) - except IQLError as e: - exc_pretty = traceback.format_exception_only(e.__class__, e)[0] - return question, model_name, f"FAILED: {exc_pretty}({e.source})" - except Exception as e: - exc_pretty = traceback.format_exception_only(e.__class__, e)[0] - return question, model_name, f"FAILED: {exc_pretty}" - - out = result.metadata.get("iql") - if out is None: - return question, model_name, None - - return question, model_name, out.replace('"', "'") - - -@dataclass -class BenchmarkConfig: - dataset_path: str - out_path: str - n_repeats: int = 5 - llms: typing.List[OpenAILLMName] = field(default_factory=lambda: ["gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o"]) - - -async def main(config_: BenchmarkConfig): - test_set = None - with open(config_.dataset_path, encoding="utf-8") as file: - test_set = json.load(file) - - contexts = [ - MyData( - first_name="John", - surname="Smith", - years_of_experience=4, - position="Data Engineer", - university="University of Toronto", - skills=["Python"], - country="United Kingdom", - ), - OpenPosition( - position="Machine Learning Engineer", - graduated_from_university="Stanford Univeristy", - min_years_of_experience=1, - required_skills=["Python", "SQL"], - ), - ] - - tasks: typing.List[asyncio.Task] = [] - for model_name in config.llms: - collection = setup_collection(model_name) - for test_case in test_set: - for _ in range(config.n_repeats): - task = asyncio.create_task( - generate_iql_from_question(collection, model_name, test_case["question"], contexts=contexts) - ) - tasks.append(task) - - output_data = {test_case["question"]: test_case for test_case in test_set} - empty_answers = {str(llm_name): [] for llm_name in config.llms} - - total_iter = len(config.llms) * len(test_set) * config.n_repeats - for task in tqdm.asyncio.tqdm.as_completed(tasks, total=total_iter): - question, llm_name, answer = await task - if "answers" not in output_data[question]: - output_data[question]["answers"] = deepcopy(empty_answers) - - output_data[question]["answers"][llm_name].append(answer) - - df_out_raw = pl.DataFrame(list(output_data.values())) - - df_out = ( - df_out_raw.unnest("answers") - .unpivot( - on=pl.selectors.starts_with("gpt"), - index=["question", "correct_answer", "context"], - variable_name="model", - value_name="answer", - ) - .explode("answer") - .group_by(["context", "model"]) - .agg( - [ - (pl.col("correct_answer") == pl.col("answer")).mean().alias("frac_hits"), - (pl.col("correct_answer") == pl.col("answer")).sum().alias("n_hits"), - ] - ) - .sort(["model", "context"]) - ) - - print(df_out) - - with open(config.out_path, "w", encoding="utf-8") as file: - file.write(json.dumps(df_out_raw.to_dicts(), indent=2)) - - -if __name__ == "__main__": - config = BenchmarkConfig( - dataset_path="dataset/context_dataset.json", out_path="../../context_benchmark_output.json" - ) - - asyncio.run(main(config)) diff --git a/benchmark/dbally_benchmark/dataset/context_dataset.json b/benchmark/dbally_benchmark/dataset/context_dataset.json deleted file mode 100644 index c37f38d1..00000000 --- a/benchmark/dbally_benchmark/dataset/context_dataset.json +++ /dev/null @@ -1,62 +0,0 @@ -[ - { - "question": "Find me French candidates suitable for my position with at least 1 year of experience.", - "correct_answer": "from_country('France') AND has_position(AskerContext()) AND at_least_experience(1)", - "context": false - }, - { - "question": "Please find me candidates from my country who have at most 4 years of experience.", - "correct_answer": "from_country(AskerContext()) AND at_most_experience(4)", - "context": true - }, - { - "question": "Find me candidates who graduated from Stanford University and work as Software Engineers.", - "correct_answer": "graduated_from_university('Stanford University') AND has_position('Software Engineer')", - "context": false - }, - { - "question": "Find me candidates who graduated from my university", - "correct_answer": "graduated_from_university(AskerContext())", - "context": true - }, - { - "question": "Could you find me candidates with at most as experience as I have who also know Python?", - "correct_answer": "at_most_experience(AskerContext()) AND know_python()", - "context": true - }, - { - "question": "Please find me candidates who know Data Analysis and Python", - "correct_answer": "know_python() AND know_data_analysis()", - "context": false - }, - { - "question": "Find me candidates with at least minimal required experience for the currently open position.", - "correct_answer": "at_least_experience(AskerContext())", - "context": true - }, - { - "question": "List candidates with between 2 and 6 years of experience.", - "correct_answer": "at_least_experience(2) AND at_most_experience(6)", - "context": false - }, - { - "question": "Find me candidates who currently have the same position as we look for in our company?", - "correct_answer": "has_position(AskerContext())", - "context": true - }, - { - "question": "Please find me senior data scientist candidates who know Data Analysis and come from my country", - "correct_answer": "senior_data_scientist_position() AND has_skill('Data Analysis') AND from_country(AskerContext())", - "context": true - }, - { - "question": "Find me candidates that have the same first name as me", - "correct_answer": "first_name_is(AskerContext())", - "context": true - }, - { - "question": "List candidates named Mohammed from India", - "correct_answer": "first_name_is('Mohammed') AND from_country('India')", - "context": false - } -] diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py deleted file mode 100644 index f2d86b58..00000000 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ /dev/null @@ -1,153 +0,0 @@ -import asyncio -import json -import os -from functools import partial -from pathlib import Path -from typing import Any, List - -import hydra -import neptune -from dbally_benchmark.config import BenchmarkConfig -from dbally_benchmark.constants import VIEW_REGISTRY, EvaluationType, ViewName -from dbally_benchmark.dataset.bird_dataset import BIRDDataset, BIRDExample -from dbally_benchmark.paths import PATH_EXPERIMENTS -from dbally_benchmark.text2sql.metrics import calculate_dataset_metrics -from dbally_benchmark.text2sql.text2sql_result import Text2SQLResult -from dbally_benchmark.utils import batch, get_datetime_str, set_up_gitlab_metadata -from hydra.utils import instantiate -from loguru import logger -from neptune.utils import stringify_unsupported -from omegaconf import DictConfig -from sqlalchemy import create_engine - -import dbally -from dbally.collection import Collection -from dbally.collection.exceptions import NoViewFoundError -from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError -from dbally.llms.litellm import LiteLLM -from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE - - -async def _run_dbally_for_single_example(example: BIRDExample, collection: Collection) -> Text2SQLResult: - try: - result = await collection.ask(example.question, dry_run=True) - sql = result.metadata["sql"] - except UnsupportedQueryError: - sql = "UnsupportedQueryError" - except NoViewFoundError: - sql = "NoViewFoundError" - except Exception: # pylint: disable=broad-exception-caught - sql = "Error" - - return Text2SQLResult( - db_id=example.db_id, question=example.question, ground_truth_sql=example.SQL, predicted_sql=sql - ) - - -async def run_dbally_for_dataset(dataset: BIRDDataset, collection: Collection) -> List[Text2SQLResult]: - """ - Transforms questions into SQL queries using a IQL approach. - - Args: - dataset: The dataset containing questions to be transformed into SQL queries. - collection: Container for a set of views used by db-ally. - - Returns: - A list of Text2SQLResult objects representing the predictions. - """ - - results: List[Text2SQLResult] = [] - - for group in batch(dataset, 5): - current_results = await asyncio.gather( - *[_run_dbally_for_single_example(example, collection) for example in group] - ) - results = [*current_results, *results] - - return results - - -async def evaluate(cfg: DictConfig) -> Any: - """ - Runs db-ally evaluation for a single dataset defined in hydra config. - - Args: - cfg: hydra config, loads automatically from path passed on to the decorator - """ - - output_dir = PATH_EXPERIMENTS / cfg.output_path / get_datetime_str() - output_dir.mkdir(exist_ok=True, parents=True) - cfg = instantiate(cfg) - benchmark_cfg = BenchmarkConfig() - - engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}") - - llm = LiteLLM( - model_name="gpt-4", - api_key=benchmark_cfg.openai_api_key, - ) - - db = dbally.create_collection(cfg.db_name, llm) - - for view_name in cfg.view_names: - view = VIEW_REGISTRY[ViewName(view_name)] - db.add(view, partial(view, engine)) - - run = None - if cfg.neptune.log: - run = neptune.init_run( - project=benchmark_cfg.neptune_project, - api_token=benchmark_cfg.neptune_api_token, - ) - run["config"] = stringify_unsupported(cfg) - tags = list(cfg.neptune.get("tags", [])) + [EvaluationType.END2END.value, cfg.model_name, cfg.db_name] - run["sys/tags"].add(tags) - - if "CI_MERGE_REQUEST_IID" in os.environ: - run = set_up_gitlab_metadata(run) - - metrics_file_name, results_file_name = "metrics.json", "eval_results.json" - - logger.info(f"Running db-ally predictions for dataset {cfg.dataset_path}") - evaluation_dataset = BIRDDataset.from_json_file( - Path(cfg.dataset_path), difficulty_levels=cfg.get("difficulty_levels") - ) - dbally_results = await run_dbally_for_dataset(dataset=evaluation_dataset, collection=db) - - with open(output_dir / results_file_name, "w", encoding="utf-8") as outfile: - json.dump([result.model_dump() for result in dbally_results], outfile, indent=4) - - logger.info("Calculating metrics") - metrics = calculate_dataset_metrics(dbally_results, engine) - - with open(output_dir / metrics_file_name, "w", encoding="utf-8") as outfile: - json.dump(metrics, outfile, indent=4) - - logger.info(f"db-ally predictions saved under directory: {output_dir}") - - if run: - run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat) - run["config/view_selection_prompt_template"] = stringify_unsupported(VIEW_SELECTION_TEMPLATE.chat) - run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE) - run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix()) - run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix()) - run["evaluation/metrics"] = stringify_unsupported(metrics) - logger.info(f"Evaluation results logged to neptune at {run.get_url()}") - - -@hydra.main(version_base=None, config_path="experiment_config", config_name="evaluate_e2e_config") -def main(cfg: DictConfig): - """ - Runs db-ally evaluation for a single dataset defined in hydra config. - The following metrics are calculated during evaluation: exact match, valid SQL, - execution accuracy and valid efficiency score. - - Args: - cfg: hydra config, loads automatically from path passed on to the decorator. - """ - - asyncio.run(evaluate(cfg)) - - -if __name__ == "__main__": - main() # pylint: disable=E1120 From e8271ac2afc8690aa968ea104ab8b3b400b5a73e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 23 Sep 2024 21:01:11 +0100 Subject: [PATCH 49/53] some renames and stuff --- src/dbally/collection/collection.py | 21 +++--- src/dbally/context.py | 11 ++++ src/dbally/context/__init__.py | 3 - src/dbally/context/_utils.py | 75 ---------------------- src/dbally/context/context.py | 17 ----- src/dbally/context/exceptions.py | 23 ------- src/dbally/iql/_processor.py | 4 +- src/dbally/iql/_query.py | 4 +- src/dbally/iql_generator/iql_generator.py | 8 +-- src/dbally/iql_generator/prompt.py | 8 +-- src/dbally/views/base.py | 6 +- src/dbally/views/exposed_functions.py | 8 +-- src/dbally/views/freeform/text2sql/view.py | 4 +- src/dbally/views/structured.py | 4 +- tests/unit/iql/test_iql_parser.py | 14 ++-- tests/unit/test_fallback_collection.py | 4 +- tests/unit/views/test_methods_base.py | 4 +- tests/unit/views/test_sqlalchemy_base.py | 4 +- 18 files changed, 59 insertions(+), 163 deletions(-) create mode 100644 src/dbally/context.py delete mode 100644 src/dbally/context/__init__.py delete mode 100644 src/dbally/context/_utils.py delete mode 100644 src/dbally/context/context.py delete mode 100644 src/dbally/context/exceptions.py diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index b5d680f7..a01d7dc7 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -4,7 +4,7 @@ import textwrap import time from collections import defaultdict -from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar +from typing import Callable, Dict, List, Optional, Type, TypeVar import dbally from dbally.audit.event_handlers.base import EventHandler @@ -12,7 +12,7 @@ from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult, ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions @@ -228,7 +228,7 @@ async def _ask_view( event_tracker: EventTracker, llm_options: Optional[LLMOptions], dry_run: bool, - contexts: Iterable[BaseCallerContext], + contexts: List[Context], ) -> ViewExecutionResult: """ Ask the selected view to provide an answer to the question. @@ -247,11 +247,11 @@ async def _ask_view( view_result = await selected_view.ask( query=question, llm=self._llm, + contexts=contexts, event_tracker=event_tracker, n_retries=self.n_retries, dry_run=dry_run, llm_options=llm_options, - contexts=contexts, ) return view_result @@ -298,9 +298,11 @@ def get_all_event_handlers(self) -> List[EventHandler]: return self._event_handlers return list(set(self._event_handlers).union(self._fallback_collection.get_all_event_handlers())) + # pylint: disable=too-many-arguments async def _handle_fallback( self, question: str, + contexts: Optional[List[Context]], dry_run: bool, return_natural_response: bool, llm_options: Optional[LLMOptions], @@ -322,7 +324,6 @@ async def _handle_fallback( Returns: The result from the fallback collection. - """ if not self._fallback_collection: raise caught_exception @@ -337,6 +338,7 @@ async def _handle_fallback( async with event_tracker.track_event(fallback_event) as span: result = await self._fallback_collection.ask( question=question, + contexts=contexts, dry_run=dry_run, return_natural_response=return_natural_response, llm_options=llm_options, @@ -348,10 +350,10 @@ async def _handle_fallback( async def ask( self, question: str, + contexts: Optional[List[Context]] = None, dry_run: bool = False, return_natural_response: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[BaseCallerContext]] = None, event_tracker: Optional[EventTracker] = None, ) -> ExecutionResult: """ @@ -366,14 +368,14 @@ async def ask( Args: question: question posed using natural language representation e.g\ - "What job offers for Data Scientists do we have?" + "What job offers for Data Scientists do we have?" + contexts: list of context objects, each being an instance of + a subclass of Context. May contain contexts irrelevant for the currently processed query. dry_run: if True, only generate the query without executing it return_natural_response: if True (and dry_run is False as natural response requires query results), the natural response will be included in the answer llm_options: options to use for the LLM client. If provided, these options will be merged with the default options provided to the LLM client, prioritizing option values other than NOT_GIVEN - contexts: An iterable (typically a list) of context objects, each being an instance of - a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query. event_tracker: Event tracker object for given ask. Returns: @@ -433,6 +435,7 @@ async def ask( if self._fallback_collection: result = await self._handle_fallback( question=question, + contexts=contexts, dry_run=dry_run, return_natural_response=return_natural_response, llm_options=llm_options, diff --git a/src/dbally/context.py b/src/dbally/context.py new file mode 100644 index 00000000..46c65030 --- /dev/null +++ b/src/dbally/context.py @@ -0,0 +1,11 @@ +from abc import ABC +from typing import ClassVar + + +class Context(ABC): + """ + Base class for all contexts that are used to pass additional knowledge about the caller environment to the view. + """ + + type_name: ClassVar[str] = "Context" + alias_name: ClassVar[str] = "CONTEXT" diff --git a/src/dbally/context/__init__.py b/src/dbally/context/__init__.py deleted file mode 100644 index 294fbe2f..00000000 --- a/src/dbally/context/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .context import BaseCallerContext - -__all__ = ["BaseCallerContext"] diff --git a/src/dbally/context/_utils.py b/src/dbally/context/_utils.py deleted file mode 100644 index e8e3b0eb..00000000 --- a/src/dbally/context/_utils.py +++ /dev/null @@ -1,75 +0,0 @@ -from inspect import isclass -from typing import Any, List, Optional, Sequence, Tuple, Type, Union - -import typing_extensions as type_ext - -from dbally.context.context import BaseCallerContext -from dbally.views.exposed_functions import MethodParamWithTyping - -ContextClass: type_ext.TypeAlias = Optional[Type[BaseCallerContext]] - - -def _extract_params_and_context( - filter_method_: type_ext.Callable, hidden_args: Sequence[str] -) -> Tuple[List[MethodParamWithTyping], ContextClass]: - """ - Processes the MethodsBaseView filter method signature to extract the argument names and type hint - in the form of MethodParamWithTyping list. Additionally, the first type hint, pointing to the subclass - of BaseCallerContext is returned. - - Args: - filter_method_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator) - hidden_args: method arguments that should not be extracted - - Returns: - The first field contains the list of arguments, each encapsulated as MethodParamWithTyping. - The 2nd is the BaseCallerContext subclass provided for this filter, or None if no context specified. - """ - - params = [] - context = None - # TODO confirm whether to use typing.get_type_hints(method) or method.__annotations__ - for name_, type_ in type_ext.get_type_hints(filter_method_).items(): - if name_ in hidden_args: - continue - - if isclass(type_) and issubclass(type_, BaseCallerContext): - # this is the case when user provides a context but no other type hint for a specifc arg - context = type_ - type_ = Any - elif type_ext.get_origin(type_) is Union: - union_subtypes = type_ext.get_args(type_) - if not union_subtypes: - type_ = Any - - for subtype_ in union_subtypes: # type: ignore - # TODO add custom error for the situation when user provides more than two contexts for a single filter - # for now we extract only the first context - if isclass(subtype_) and issubclass(subtype_, BaseCallerContext): - if context is None: - context = subtype_ - - params.append(MethodParamWithTyping(name_, type_)) - - return params, context - - -def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool: - """ - Verifies whether a method argument allows contextualization based on the type hints attached to a method signature. - - Args: - arg: MethodParamWithTyping container preserving information about the method argument - - Returns: - Verification result. - """ - - if type_ext.get_origin(arg.type) is not Union and not issubclass(arg.type, BaseCallerContext): - return False - - for subtype in type_ext.get_args(arg.type): - if issubclass(subtype, BaseCallerContext): - return True - - return False diff --git a/src/dbally/context/context.py b/src/dbally/context/context.py deleted file mode 100644 index aea21d13..00000000 --- a/src/dbally/context/context.py +++ /dev/null @@ -1,17 +0,0 @@ -from abc import ABC -from typing import ClassVar - - -class BaseCallerContext(ABC): - """ - An interface for contexts that are used to pass additional knowledge about - the caller environment to the filters. LLM will always return `Context()` - when the context is required and this call will be later substituted by an instance of - a class implementing this interface, selected based on the filter method signature (type hints). - - Attributes: - alias: Class variable defining an alias which is defined in the prompt for the LLM to reference context. - """ - - type_name: ClassVar[str] = "Context" - alias_name: ClassVar[str] = "CONTEXT" diff --git a/src/dbally/context/exceptions.py b/src/dbally/context/exceptions.py deleted file mode 100644 index c538ee18..00000000 --- a/src/dbally/context/exceptions.py +++ /dev/null @@ -1,23 +0,0 @@ -class BaseContextError(Exception): - """ - A base error for context handling logic. - """ - - -class SuitableContextNotProvidedError(BaseContextError): - """ - Raised when method argument type hint points that a contextualization is available - but not suitable context was provided. - """ - - def __init__(self, filter_fun_signature: str, context_class_name: str) -> None: - # this syntax 'or BaseCallerContext' is just to prevent type checkers - # from raising a warning, as filter_.context_class can be None. It's essenially a fallback that should never - # be reached, unless somebody will use this Exception against its purpose. - # TODO consider raising a warning/error when this happens. - - message = ( - f"No context of class {context_class_name} was provided" - f"while the filter {filter_fun_signature} requires it." - ) - super().__init__(message) diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index 9befc28d..d66e6977 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -4,7 +4,7 @@ from typing import Any, Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql import syntax from dbally.iql._exceptions import ( IQLArgumentParsingError, @@ -34,7 +34,7 @@ def __init__( self, source: str, allowed_functions: List[ExposedFunction], - allowed_contexts: Optional[List[BaseCallerContext]] = None, + allowed_contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, ) -> None: self.source = source diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 03abc912..1d8e1ef2 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -8,7 +8,7 @@ from ._processor import IQLAggregationProcessor, IQLFiltersProcessor, IQLProcessor, RootT if TYPE_CHECKING: - from dbally.context.context import BaseCallerContext + from dbally.context import Context from dbally.views.exposed_functions import ExposedFunction @@ -33,7 +33,7 @@ async def parse( cls, source: str, allowed_functions: List["ExposedFunction"], - allowed_contexts: Optional[List["BaseCallerContext"]] = None, + allowed_contexts: Optional[List["Context"]] = None, event_tracker: Optional[EventTracker] = None, ) -> Self: """ diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 5b4087b7..c6700ef3 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -3,7 +3,7 @@ from typing import Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql import IQLError, IQLQuery from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.prompt import ( @@ -67,7 +67,7 @@ async def __call__( question: str, filters: List[ExposedFunction], aggregations: List[ExposedFunction], - contexts: List[BaseCallerContext], + contexts: List[Context], examples: List[FewShotExample], llm: LLM, event_tracker: Optional[EventTracker] = None, @@ -146,7 +146,7 @@ async def __call__( *, question: str, methods: List[ExposedFunction], - contexts: List[BaseCallerContext], + contexts: List[Context], examples: List[FewShotExample], llm: LLM, event_tracker: Optional[EventTracker] = None, @@ -265,7 +265,7 @@ async def __call__( *, question: str, methods: List[ExposedFunction], - contexts: List[BaseCallerContext], + contexts: List[Context], examples: List[FewShotExample], llm: LLM, llm_options: Optional[LLMOptions] = None, diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 8189d6fd..64c99d8a 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -3,7 +3,7 @@ from typing import List, Optional from dbally.audit.event_tracker import EventTracker -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.exceptions import DbAllyError from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.prompt.elements import FewShotExample @@ -21,7 +21,7 @@ class UnsupportedQueryError(DbAllyError): async def _iql_filters_parser( response: str, allowed_functions: List[ExposedFunction], - allowed_contexts: List[BaseCallerContext], + allowed_contexts: List[Context], event_tracker: Optional[EventTracker] = None, ) -> IQLFiltersQuery: """ @@ -53,7 +53,7 @@ async def _iql_filters_parser( async def _iql_aggregation_parser( response: str, allowed_functions: List[ExposedFunction], - allowed_contexts: List[BaseCallerContext], + allowed_contexts: List[Context], event_tracker: Optional[EventTracker] = None, ) -> IQLAggregationQuery: """ @@ -127,7 +127,7 @@ def __init__( *, question: str, methods: List[ExposedFunction], - contexts: List[BaseCallerContext], + contexts: List[Context], examples: Optional[List[FewShotExample]] = None, ) -> None: """ diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 31e8f363..d85edfa3 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -5,7 +5,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.elements import FewShotExample @@ -25,7 +25,7 @@ async def ask( self, query: str, llm: LLM, - contexts: Optional[List[BaseCallerContext]] = None, + contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, @@ -38,7 +38,7 @@ async def ask( query: The natural language query to execute. llm: The LLM used to execute the query. contexts: An iterable (typically a list) of context objects, each being - an instance of a subclass of BaseCallerContext. + an instance of a subclass of Context. event_tracker: The event tracker used to audit the query execution. n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. diff --git a/src/dbally/views/exposed_functions.py b/src/dbally/views/exposed_functions.py index d769c12c..bb03fe68 100644 --- a/src/dbally/views/exposed_functions.py +++ b/src/dbally/views/exposed_functions.py @@ -4,7 +4,7 @@ from typing_extensions import _AnnotatedAlias, get_origin -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.similarity import AbstractSimilarityIndex @@ -21,11 +21,11 @@ def __str__(self) -> str: return f"{self.name}: {self._parse_type()}" @property - def contexts(self) -> List[Type[BaseCallerContext]]: + def contexts(self) -> List[Type[Context]]: """ Returns the contexts if the type is annotated with them. """ - return [arg for arg in getattr(self.type, "__args__", []) if issubclass(arg, BaseCallerContext)] + return [arg for arg in getattr(self.type, "__args__", []) if issubclass(arg, Context)] @property def similarity_index(self) -> Optional[AbstractSimilarityIndex]: @@ -51,7 +51,7 @@ def _parse_type_inner(param_type: Union[type, _GenericAlias]) -> str: if param_type.__module__ == "typing": return re.sub(r"\btyping\.", "", str(param_type)) - if issubclass(param_type, BaseCallerContext): + if issubclass(param_type, Context): return param_type.type_name if hasattr(param_type, "__name__"): diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 147573f5..c53007c0 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -8,7 +8,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.template import PromptTemplate @@ -100,7 +100,7 @@ async def ask( self, query: str, llm: LLM, - contexts: Optional[List[BaseCallerContext]] = None, + contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index b9577bb0..2e477216 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,7 +4,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM @@ -35,7 +35,7 @@ async def ask( self, query: str, llm: LLM, - contexts: Optional[List[BaseCallerContext]] = None, + contexts: Optional[List[Context]] = None, event_tracker: Optional[EventTracker] = None, n_retries: int = 3, dry_run: bool = False, diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index 70b71f2e..c74e5eda 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -5,7 +5,7 @@ import pytest from typing_extensions import Annotated -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql import IQLArgumentParsingError, IQLUnsupportedSyntaxError, syntax from dbally.iql._exceptions import ( IQLArgumentValidationError, @@ -59,7 +59,7 @@ async def test_iql_filter_parser(): async def test_iql_filter_context_parser(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str test_context = TestCustomContext(city="cracow") @@ -109,7 +109,7 @@ class TestCustomContext(BaseCallerContext): async def test_iql_filter_context_not_allowed_error(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str with pytest.raises(IQLContextNotAllowedError) as exc_info: @@ -140,7 +140,7 @@ class TestCustomContext(BaseCallerContext): async def test_iql_filter_context_not_found_error(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str with pytest.raises(IQLContextNotFoundError) as exc_info: @@ -360,7 +360,7 @@ async def test_iql_aggregation_parser(): async def test_iql_aggregation_context_parser(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str test_context = TestCustomContext(city="cracow") @@ -387,7 +387,7 @@ class TestCustomContext(BaseCallerContext): async def test_iql_aggregation_context_not_allowed_error(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str with pytest.raises(IQLContextNotAllowedError) as exc_info: @@ -412,7 +412,7 @@ class TestCustomContext(BaseCallerContext): async def test_iql_aggregation_context_not_found_error(): @dataclass - class TestCustomContext(BaseCallerContext): + class TestCustomContext(Context): city: str with pytest.raises(IQLContextNotFoundError) as exc_info: diff --git a/tests/unit/test_fallback_collection.py b/tests/unit/test_fallback_collection.py index 9cfeb009..4a4dbb86 100644 --- a/tests/unit/test_fallback_collection.py +++ b/tests/unit/test_fallback_collection.py @@ -8,7 +8,7 @@ from dbally.audit import CLIEventHandler, EventTracker, OtelEventHandler from dbally.audit.event_handlers.buffer_event_handler import BufferEventHandler from dbally.collection import Collection, ViewExecutionResult -from dbally.context.context import BaseCallerContext +from dbally.context import Context from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms import LLM from dbally.llms.clients import LLMOptions @@ -42,7 +42,7 @@ async def ask( n_retries: int = 3, dry_run: bool = False, llm_options: Optional[LLMOptions] = None, - contexts: Optional[Iterable[BaseCallerContext]] = None, + contexts: Optional[Iterable[Context]] = None, ) -> ViewExecutionResult: return ViewExecutionResult( results=[{"mock_result": "fallback_result"}], metadata={"mock_context": "fallback_context"} diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 51c1548e..e870e2dc 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -4,7 +4,7 @@ from typing import List, Literal, Tuple, Union from dbally.collection.results import ViewExecutionResult -from dbally.context import BaseCallerContext +from dbally.context import Context from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.exposed_functions import MethodParamWithTyping @@ -12,7 +12,7 @@ @dataclass -class CallerContext(BaseCallerContext): +class CallerContext(Context): """ Mock class for testing context. """ diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 8924a835..3b621a54 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -6,14 +6,14 @@ import sqlalchemy -from dbally.context import BaseCallerContext +from dbally.context import Context from dbally.iql import IQLAggregationQuery, IQLFiltersQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @dataclass -class SomeTestContext(BaseCallerContext): +class SomeTestContext(Context): age: int From bdcc7b34b4324aebef2dc21ff0fe7e6c26131943 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 24 Sep 2024 00:33:42 +0100 Subject: [PATCH 50/53] fix benchmarks --- benchmarks/sql/bench/pipelines/collection.py | 6 +- benchmarks/sql/bench/pipelines/view.py | 8 +- .../sql/bench/views/structured/superhero.py | 139 ------------------ 3 files changed, 7 insertions(+), 146 deletions(-) diff --git a/benchmarks/sql/bench/pipelines/collection.py b/benchmarks/sql/bench/pipelines/collection.py index 19831b0d..4e3ee58a 100644 --- a/benchmarks/sql/bench/pipelines/collection.py +++ b/benchmarks/sql/bench/pipelines/collection.py @@ -85,10 +85,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: prediction = ExecutionResult( view_name=result.view_name, iql=IQLResult( - filters=IQL(source=result.context["iql"]["filters"]), - aggregation=IQL(source=result.context["iql"]["aggregation"]), + filters=IQL(source=result.metadata["iql"]["filters"]), + aggregation=IQL(source=result.metadata["iql"]["aggregation"]), ), - sql=result.context["sql"], + sql=result.metadata["sql"], ) reference = ExecutionResult( diff --git a/benchmarks/sql/bench/pipelines/view.py b/benchmarks/sql/bench/pipelines/view.py index be9d8263..a8379bea 100644 --- a/benchmarks/sql/bench/pipelines/view.py +++ b/benchmarks/sql/bench/pipelines/view.py @@ -104,10 +104,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL(source=result.context["iql"]["filters"]), - aggregation=IQL(source=result.context["iql"]["aggregation"]), + filters=IQL(source=result.metadata["iql"]["filters"]), + aggregation=IQL(source=result.metadata["iql"]["aggregation"]), ), - sql=result.context["sql"], + sql=result.metadata["sql"], ) reference = ExecutionResult( @@ -179,7 +179,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: else: prediction = ExecutionResult( view_name=view.__class__.__name__, - sql=result.context["sql"], + sql=result.metadata["sql"], ) reference = ExecutionResult( diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index 2a6a75a0..3c2f4169 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -76,12 +76,6 @@ class SuperheroFilterMixin: def filter_by_superhero_id(self, superhero_id: int) -> ColumnElement: """ Filters the view by the superhero id. - - Args: - superhero_id: The id of the superhero. - - Returns: - The filter condition. """ return Superhero.id == superhero_id @@ -89,12 +83,6 @@ def filter_by_superhero_id(self, superhero_id: int) -> ColumnElement: def filter_by_superhero_name(self, superhero_name: str) -> ColumnElement: """ Filters the view by the superhero nick or handle. - - Args: - superhero_name: The abstract nick or handle of the superhero. - - Returns: - The filter condition. """ return Superhero.superhero_name == superhero_name @@ -102,9 +90,6 @@ def filter_by_superhero_name(self, superhero_name: str) -> ColumnElement: def filter_by_missing_superhero_full_name(self) -> ColumnElement: """ Filters the view by the missing full name of the superhero. - - Returns: - The filter condition. """ return Superhero.full_name == None @@ -112,12 +97,6 @@ def filter_by_missing_superhero_full_name(self) -> ColumnElement: def filter_by_superhero_full_name(self, superhero_full_name: str) -> ColumnElement: """ Filters the view by the full name of the superhero. - - Args: - superhero_full_name: The human name of the superhero. - - Returns: - The filter condition. """ return Superhero.full_name == superhero_full_name @@ -125,12 +104,6 @@ def filter_by_superhero_full_name(self, superhero_full_name: str) -> ColumnEleme def filter_by_superhero_first_name(self, superhero_first_name: str) -> ColumnElement: """ Filters the view by the simmilar full name of the superhero. - - Args: - superhero_first_name: The first name of the superhero. - - Returns: - The filter condition. """ return Superhero.full_name.like(f"{superhero_first_name}%") @@ -138,12 +111,6 @@ def filter_by_superhero_first_name(self, superhero_first_name: str) -> ColumnEle def filter_by_height_cm(self, height_cm: float) -> ColumnElement: """ Filters the view by the height of the superhero. - - Args: - height_cm: The height of the superhero. - - Returns: - The filter condition. """ return Superhero.height_cm == height_cm @@ -151,12 +118,6 @@ def filter_by_height_cm(self, height_cm: float) -> ColumnElement: def filter_by_height_cm_less_than(self, height_cm: float) -> ColumnElement: """ Filters the view by the height of the superhero. - - Args: - height_cm: The height of the superhero. - - Returns: - The filter condition. """ return Superhero.height_cm < height_cm @@ -164,12 +125,6 @@ def filter_by_height_cm_less_than(self, height_cm: float) -> ColumnElement: def filter_by_height_cm_greater_than(self, height_cm: float) -> ColumnElement: """ Filters the view by the height of the superhero. - - Args: - height_cm: The height of the superhero. - - Returns: - The filter condition. """ return Superhero.height_cm > height_cm @@ -177,13 +132,6 @@ def filter_by_height_cm_greater_than(self, height_cm: float) -> ColumnElement: def filter_by_height_cm_between(self, begin_height_cm: float, end_height_cm: float) -> ColumnElement: """ Filters the view by the height of the superhero. - - Args: - begin_height_cm: The begin height of the superhero. - end_height_cm: The end height of the superhero. - - Returns: - The filter condition. """ return Superhero.height_cm.between(begin_height_cm, end_height_cm) @@ -191,9 +139,6 @@ def filter_by_height_cm_between(self, begin_height_cm: float, end_height_cm: flo def filter_by_the_tallest(self) -> ColumnElement: """ Filter the view by the tallest superhero. - - Returns: - The filter condition. """ return Superhero.height_cm == select(func.max(Superhero.height_cm)).scalar_subquery() @@ -201,9 +146,6 @@ def filter_by_the_tallest(self) -> ColumnElement: def filter_by_missing_weight(self) -> ColumnElement: """ Filters the view by the missing weight of the superhero. - - Returns: - The filter condition. """ return Superhero.weight_kg == 0 or Superhero.weight_kg == None @@ -211,12 +153,6 @@ def filter_by_missing_weight(self) -> ColumnElement: def filter_by_weight_kg(self, weight_kg: int) -> ColumnElement: """ Filters the view by the weight of the superhero. - - Args: - weight_kg: The weight of the superhero. - - Returns: - The filter condition. """ return Superhero.weight_kg == weight_kg @@ -224,12 +160,6 @@ def filter_by_weight_kg(self, weight_kg: int) -> ColumnElement: def filter_by_weight_kg_greater_than(self, weight_kg: int) -> ColumnElement: """ Filters the view by the weight of the superhero. - - Args: - weight_kg: The weight of the superhero. - - Returns: - The filter condition. """ return Superhero.weight_kg > weight_kg @@ -237,12 +167,6 @@ def filter_by_weight_kg_greater_than(self, weight_kg: int) -> ColumnElement: def filter_by_weight_kg_less_than(self, weight_kg: int) -> ColumnElement: """ Filters the view by the weight of the superhero. - - Args: - weight_kg: The weight of the superhero. - - Returns: - The filter condition. """ return Superhero.weight_kg < weight_kg @@ -250,12 +174,6 @@ def filter_by_weight_kg_less_than(self, weight_kg: int) -> ColumnElement: def filter_by_weight_greater_than_percentage_of_average(self, average_percentage: int) -> ColumnElement: """ Filters the view by the weight greater than the percentage of average of superheroes. - - Args: - average_percentage: The percentage of the average weight. - - Returns: - The filter condition. """ return Superhero.weight_kg * 100 > select(func.avg(Superhero.weight_kg)).scalar_subquery() * average_percentage @@ -263,9 +181,6 @@ def filter_by_weight_greater_than_percentage_of_average(self, average_percentage def filter_by_the_heaviest(self) -> ColumnElement: """ Filters the view by the heaviest superhero. - - Returns: - The filter condition. """ return Superhero.weight_kg == select(func.max(Superhero.weight_kg)).scalar_subquery() @@ -273,9 +188,6 @@ def filter_by_the_heaviest(self) -> ColumnElement: def filter_by_missing_publisher(self) -> ColumnElement: """ Filters the view by the missing publisher of the superhero. - - Returns: - The filter condition. """ return Superhero.publisher_id == None @@ -295,12 +207,6 @@ def __init__(self, *args, **kwargs) -> None: def filter_by_eye_colour(self, eye_colour: str) -> ColumnElement: """ Filters the view by the superhero eye colour. - - Args: - eye_colour: The eye colour of the superhero. - - Returns: - The filter condition. """ return self.eye_colour.colour == eye_colour @@ -308,12 +214,6 @@ def filter_by_eye_colour(self, eye_colour: str) -> ColumnElement: def filter_by_hair_colour(self, hair_colour: str) -> ColumnElement: """ Filters the view by the superhero hair colour. - - Args: - hair_colour: The hair colour of the superhero. - - Returns: - The filter condition. """ return self.hair_colour.colour == hair_colour @@ -321,12 +221,6 @@ def filter_by_hair_colour(self, hair_colour: str) -> ColumnElement: def filter_by_skin_colour(self, skin_colour: str) -> ColumnElement: """ Filters the view by the superhero skin colour. - - Args: - skin_colour: The skin colour of the superhero. - - Returns: - The filter condition. """ return self.skin_colour.colour == skin_colour @@ -334,9 +228,6 @@ def filter_by_skin_colour(self, skin_colour: str) -> ColumnElement: def filter_by_same_hair_and_eye_colour(self) -> ColumnElement: """ Filters the view by the superhero with the same hair and eye colour. - - Returns: - The filter condition. """ return self.eye_colour.colour == self.hair_colour.colour @@ -344,9 +235,6 @@ def filter_by_same_hair_and_eye_colour(self) -> ColumnElement: def filter_by_same_hair_and_skin_colour(self) -> ColumnElement: """ Filters the view by the superhero with the same hair and skin colour. - - Returns: - The filter condition. """ return self.hair_colour.colour == self.skin_colour.colour @@ -360,12 +248,6 @@ class PublisherFilterMixin: def filter_by_publisher_name(self, publisher_name: str) -> ColumnElement: """ Filters the view by the publisher name. - - Args: - publisher_name: The name of the publisher. - - Returns: - The filter condition. """ return Publisher.publisher_name == publisher_name @@ -379,12 +261,6 @@ class AlignmentFilterMixin: def filter_by_alignment(self, alignment: Literal["Good", "Bad", "Neutral", "N/A"]) -> ColumnElement: """ Filters the view by the superhero alignment. - - Args: - alignment: The alignment of the superhero. - - Returns: - The filter condition. """ return Alignment.alignment == alignment @@ -398,12 +274,6 @@ class GenderFilterMixin: def filter_by_gender(self, gender: Literal["Male", "Female", "N/A"]) -> ColumnElement: """ Filters the view by the object gender. - - Args: - gender: The gender of the object. - - Returns: - The filter condition. """ return Gender.gender == gender @@ -417,12 +287,6 @@ class RaceFilterMixin: def filter_by_race(self, race: str) -> ColumnElement: """ Filters the view by the object race. - - Args: - race: The race of the object. - - Returns: - The filter condition. """ return Race.race == race @@ -436,9 +300,6 @@ class SuperheroAggregationMixin: def count_superheroes(self) -> Select: """ Counts the number of superheros. - - Returns: - The superheros count. """ return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) From c82e579b35a0606710c11d77231d04787cfd804a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 25 Sep 2024 09:14:24 +0200 Subject: [PATCH 51/53] rm chroma file --- chroma/chroma.sqlite3 | Bin 147456 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 chroma/chroma.sqlite3 diff --git a/chroma/chroma.sqlite3 b/chroma/chroma.sqlite3 deleted file mode 100644 index 9664a835c715ab5d903760f1fa276264545f33c4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 147456 zcmeI*OKcm*0R~{MNNG(<7VS6*BP*`Qc2bFnsBrlZB^aqI%e2c>4_l(*xGg{|xl5Ta zMatZjY^5hjc7XJ8Dw;#DEdumbpqHQsiuM+uXiq()2vYP~6le|&iUPTGW_M>Fd`WHO zKo)=661h9GyEEU+&OToRg5dimAwHk)9Q}ET{*2Hczcn48 z@BAh34_)5R`6g?>>8BBd@E0uF*UYyrOCSgW5P$##AOHafKmY;|fB*y_0D(gmcrom{ zHXp^c`62USbr6661Rwwb2tWV=5P$##AOL|QA#hd<(R7>$zvm18BmArI&%^J9{}%pp z_@nTLM{rKr9$k*Ndv6RBCHl;r^z&sh(txr;aryG6w+& zKmY;|fB*y_009U<00I!`c>!Gi_k0xq00bZa0SG_<0uX=z1Rwwb2=sse8~=~~{~jo4 zWDx=ofB*y_009U<00Izz00eqk0R8_xUHr&A1Rwwb2tWV=5P$##AOHaf^oRhi|9hmM zkxd9d00Izz00bZa0SG_<0ubnF0gV6e>EcJ`ApijgKmY;|fB*y_009Upg5P$##AOHafKmY;|fB*y_ z0D+?>a9s$j#-efS6=UyBzjzqG-dAK@F3Ot9r``SkAN#@|AGI7JJrIBZ1Rwwb2tWV= z5P$##AOHaf>?bhhUk&m*08KXrj7c75j{pDE7yfj=H6k1W5P$##AOHafKmY;|fB*y_ z0D&VZa4`^#HI4V@z5jeP0RH~pkt{hR3IY&-00bZa0SG_<0uX=z1R!vz0{HuXhsueS zK>z{}fB*y_009U<00Izz00fSt0IvU!WU(Mo5P$##AOHafKmY;|fB*y_0D(gl!1e#3 za$;o=fB*y_009U<00Izz00bZafg>q^@BbgkVnL!H009U<00Izz00bZa0SG_<0*5Mq z{{KVe#L6H50SG_<0uX=z1Rwwb2tWV=M_HhcChi>lId8p>9_7R$Z4iI}1Rwwb2tWV= z5P$##AOHaf{BHtm{J$v7_=evf`p@8G=okI(^nF4n@!C&d@4P6E2rG}H3+t*bE5*Jz zeI;V*dQnxBN^LD>%$9X|QoCQR=qh`Q#mLI7c{)Lqx~h>{T_@V+#zwuN6WLgZlgIZFlH}=15`gB$K`~aDJ7>AIT0^E9v=`#f-CR7Zj}%#AgOx+%`InV^V#MEX=FS^=&PdeZIheJ zxrN!~JLKEhI}=8jrmn54HN8ONI|IS)HFZL!G>}57*TDY{{$hOIMqgG)h!e z3bIb}xrOXXesl;RzFE`CLE6&=qak@HltM}CHb_SZ>pf%7cLuO>gu23{jxtg=f)w;|B zXrWk+E_OmUI6>&P+GJu~eps+b&BQwZX3AXaF*nz)0Ngll=9crTv-8b0qfp+g(%tl- zPF7cPi#JGF*V0D1+?l1uiDIf$PA5vKVlgq5nNkyyGBrIFms6?uOnEvz^@GuVF(Tah z$;%7cwx$#Vtg5DBx>}QKx)$Sqt5tI~C7wddT{o&k_O0AXeuY?zIhEF^kS$r}m7{Gn z@UGELK$mqA+FoqdD);FQx>l_h<&*Q(p}m)SfP zWlgmjL$X%cQftNv+!exXv%E~SeXOb_)+A~LeSPDJVDX|bx#}C0QWLHbpj2pMq_MWy zYAj6{zHnDy9?96F_akPm6UGsRk1$&P%9Ma$2gC zD^aZQsXZ_1rea_RtQER*%;o8E%g#v4GGjcgUGqV%FE1^y<3F$8mqXWW@See947_bG+F7~e@Qhf)b8kx=64OiD>gN=lk3CFHo8DNiNRrJ1BG z@7?gzcKXUYVKXV#IiGcO>DWB;@GcF4xjVCytTty48a;b!8=%|!zqx-qCgy5f&88E? zTF%y%j2|KGbFcOXPIpVpv(>FB?tCtAh$jL(?d}tT#Eg_I6{qBMJQ+`?<(V`q!&GWI zp3JDzGg4yj)jly2%syT*OOn+o$yKL1mZjt9IxB2{fNgW+>)vYf6~lF8lMf*1?V zJ&G*YXVO}Mo@#8WX2CggT}^=J6v>HUE@$nF)_I?Glx#h%|5B&lbgylgS4e#Q-m3y< z_qEUH&Hb1cqBr^$J-4VgjEe*_np!K>*`=njvzxQ*-IrBKR)v&gCY@HwN+Ovks?+Lp zGA@-g2QT4zhk3+JEv;kLmS$YOV6N(f?|7!&)~6_UqH3T}X1Hbl2Uri;nCe&D|U`~IoE-}b%RcSiW3?=xS~_njkKF_16_KmY;| zfItrj?54wFoR?ht!+mp)0?Fpc9oq9Zmu5Sbkne3hD_J~T(_ z73595ZcG=<4tlL1Ss%jRJ1Is+M}?iRp=`~y-iAHG*32%xMrR=J^D{bSuM7=DR$`qJ z!@D{8xnDtQo^x{d#Gn|7#e}^%BYn-`&C^4k6gP+3U1Z*oZBI!T@N1!g$lSSoENRJF zxz+}Ez7Z57=gtYcOKv-jhkJ*~@wxjIuQ6`&E9%nEo*0PSNp{MytLN3JKy79QcSl4q zl1vJFcMOL%QQFKp;i*Y&B3Mwa5Z=l8H1$;P#_9MEx_>=7W7yaiTvlgv^K`(Y(j02< zU&h!^-bLIpPVXu2f=>1gL~dN@R4donb+b8QPI5l`ch>|la^ZsTc*JmaLYz;19$RO= zrOJjsHNTWfquubD=4fo1j*fN@pKI!$Z?{CbQl;Yi(({~J3xuk2ZD4?G3zQO2Td}I+*z?$t<$HQu%}!^cg2Vp zktAWy?5}n<&pW4D=ZKy}w1zu5}kB zGRnXeS7j7w7_s}RFsp~Trcww23 znc<6=kI|7%4@9!(JCzcTOUuaq_>6SwV2eohBB$~9{| Date: Mon, 30 Sep 2024 02:16:56 +0200 Subject: [PATCH 52/53] add contexts to benchmarks + fix types --- benchmarks/sql/bench.py | 3 ++- benchmarks/sql/bench/contexts/__init__.py | 10 +++++++ benchmarks/sql/bench/contexts/superhero.py | 21 +++++++++++++++ benchmarks/sql/bench/metrics/base.py | 7 ++--- benchmarks/sql/bench/pipelines/base.py | 27 ++++++++++++++++--- benchmarks/sql/bench/pipelines/collection.py | 7 +++-- benchmarks/sql/bench/pipelines/view.py | 16 ++++++----- .../sql/bench/views/structured/__init__.py | 0 benchmarks/sql/config/setup/collection.yaml | 2 ++ .../sql/config/setup/contexts/superhero.yaml | 4 +++ benchmarks/sql/config/setup/iql-view.yaml | 2 ++ benchmarks/sql/config/setup/sql-view.yaml | 2 ++ 12 files changed, 85 insertions(+), 16 deletions(-) create mode 100644 benchmarks/sql/bench/contexts/__init__.py create mode 100644 benchmarks/sql/bench/contexts/superhero.py create mode 100644 benchmarks/sql/bench/views/structured/__init__.py create mode 100644 benchmarks/sql/config/setup/contexts/superhero.yaml diff --git a/benchmarks/sql/bench.py b/benchmarks/sql/bench.py index 0257b52d..d8007be9 100644 --- a/benchmarks/sql/bench.py +++ b/benchmarks/sql/bench.py @@ -28,6 +28,7 @@ ) from bench.pipelines import CollectionEvaluationPipeline, IQLViewEvaluationPipeline, SQLViewEvaluationPipeline from bench.utils import save +from hydra.core.hydra_config import HydraConfig from neptune.utils import stringify_unsupported from omegaconf import DictConfig @@ -120,7 +121,7 @@ async def bench(config: DictConfig) -> None: log.info("Evaluation finished. Saving results...") - output_dir = Path(hydra.core.hydra_config.HydraConfig.get().runtime.output_dir) + output_dir = Path(HydraConfig.get().runtime.output_dir) metrics_file = output_dir / "metrics.json" results_file = output_dir / "results.json" diff --git a/benchmarks/sql/bench/contexts/__init__.py b/benchmarks/sql/bench/contexts/__init__.py new file mode 100644 index 00000000..0d1dd11c --- /dev/null +++ b/benchmarks/sql/bench/contexts/__init__.py @@ -0,0 +1,10 @@ +from typing import Dict, Type + +from dbally.context import Context + +from .superhero import SuperheroContext, UserContext + +CONTEXTS_REGISTRY: Dict[str, Type[Context]] = { + UserContext.__name__: UserContext, + SuperheroContext.__name__: SuperheroContext, +} diff --git a/benchmarks/sql/bench/contexts/superhero.py b/benchmarks/sql/bench/contexts/superhero.py new file mode 100644 index 00000000..edd28657 --- /dev/null +++ b/benchmarks/sql/bench/contexts/superhero.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from dbally.context import Context + + +@dataclass +class UserContext(Context): + """ + Current user data. + """ + + name: str = "John Doe" + + +@dataclass +class SuperheroContext(Context): + """ + Current user favourite superhero data. + """ + + name: str = "Batman" diff --git a/benchmarks/sql/bench/metrics/base.py b/benchmarks/sql/bench/metrics/base.py index d0e78072..8df7d1d3 100644 --- a/benchmarks/sql/bench/metrics/base.py +++ b/benchmarks/sql/bench/metrics/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Type +from omegaconf import DictConfig from typing_extensions import Self from ..pipelines import EvaluationResult @@ -11,7 +12,7 @@ class Metric(ABC): Base class for metrics. """ - def __init__(self, config: Optional[Dict] = None) -> None: + def __init__(self, config: Optional[DictConfig] = None) -> None: """ Initializes the metric. @@ -38,7 +39,7 @@ class MetricSet: Represents a set of metrics. """ - def __init__(self, *metrics: List[Type[Metric]]) -> None: + def __init__(self, *metrics: Type[Metric]) -> None: """ Initializes the metric set. @@ -48,7 +49,7 @@ def __init__(self, *metrics: List[Type[Metric]]) -> None: self._metrics = metrics self.metrics: List[Metric] = [] - def __call__(self, config: Dict) -> Self: + def __call__(self, config: DictConfig) -> Self: """ Initializes the metrics. diff --git a/benchmarks/sql/bench/pipelines/base.py b/benchmarks/sql/bench/pipelines/base.py index dc8d83ea..e19aef56 100644 --- a/benchmarks/sql/bench/pipelines/base.py +++ b/benchmarks/sql/bench/pipelines/base.py @@ -1,7 +1,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union +from omegaconf import DictConfig + +from dbally.context import Context from dbally.iql._exceptions import IQLError from dbally.iql._query import IQLQuery from dbally.iql_generator.prompt import UnsupportedQueryError @@ -10,6 +13,8 @@ from dbally.llms.litellm import LiteLLM from dbally.llms.local import LocalLLM +from ..contexts import CONTEXTS_REGISTRY + @dataclass class IQL: @@ -23,7 +28,7 @@ class IQL: generated: bool = True @classmethod - def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL": + def from_query(cls, query: Optional[Union[IQLQuery, BaseException]]) -> "IQL": """ Creates an IQL object from the query. @@ -81,7 +86,11 @@ class EvaluationPipeline(ABC): Collection evaluation pipeline. """ - def get_llm(self, config: Dict) -> LLM: + def __init__(self, config: DictConfig) -> None: + super().__init__() + self.contexts = self.get_contexts(config.setup) + + def get_llm(self, config: DictConfig) -> LLM: """ Returns the LLM based on the configuration. @@ -95,6 +104,18 @@ def get_llm(self, config: Dict) -> LLM: return LocalLLM(config.model_name.split("/", 1)[1]) return LiteLLM(config.model_name) + def get_contexts(self, config: DictConfig) -> List[Context]: + """ + Returns the contexts based on the configuration. + + Args: + config: The contexts configuration. + + Returns: + The contexts. + """ + return [CONTEXTS_REGISTRY[context]() for contexts in config.contexts.values() for context in contexts] + @abstractmethod async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: """ diff --git a/benchmarks/sql/bench/pipelines/collection.py b/benchmarks/sql/bench/pipelines/collection.py index 4e3ee58a..7ded27d1 100644 --- a/benchmarks/sql/bench/pipelines/collection.py +++ b/benchmarks/sql/bench/pipelines/collection.py @@ -1,5 +1,6 @@ from typing import Any, Dict +from omegaconf import DictConfig from sqlalchemy import create_engine import dbally @@ -17,16 +18,17 @@ class CollectionEvaluationPipeline(EvaluationPipeline): Collection evaluation pipeline. """ - def __init__(self, config: Dict) -> None: + def __init__(self, config: DictConfig) -> None: """ Constructs the pipeline for evaluating collection predictions. Args: config: The configuration for the pipeline. """ + super().__init__(config) self.collection = self.get_collection(config.setup) - def get_collection(self, config: Dict) -> Collection: + def get_collection(self, config: DictConfig) -> Collection: """ Sets up the collection based on the configuration. @@ -68,6 +70,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: try: result = await self.collection.ask( question=data["question"], + contexts=self.contexts, dry_run=True, return_natural_response=False, ) diff --git a/benchmarks/sql/bench/pipelines/view.py b/benchmarks/sql/bench/pipelines/view.py index a8379bea..fc502115 100644 --- a/benchmarks/sql/bench/pipelines/view.py +++ b/benchmarks/sql/bench/pipelines/view.py @@ -3,11 +3,11 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Type +from omegaconf import DictConfig from sqlalchemy import create_engine +from dbally.views.base import BaseView from dbally.views.exceptions import ViewExecutionError -from dbally.views.freeform.text2sql.view import BaseText2SQLView -from dbally.views.sqlalchemy_base import SqlAlchemyBaseView from ..views import VIEWS_REGISTRY from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult @@ -18,18 +18,19 @@ class ViewEvaluationPipeline(EvaluationPipeline, ABC): View evaluation pipeline. """ - def __init__(self, config: Dict) -> None: + def __init__(self, config: DictConfig) -> None: """ Constructs the pipeline for evaluating IQL predictions. Args: config: The configuration for the pipeline. """ + super().__init__(config) self.llm = self.get_llm(config.setup.llm) self.dbs = self.get_dbs(config.setup) self.views = self.get_views(config.setup) - def get_dbs(self, config: Dict) -> Dict: + def get_dbs(self, config: DictConfig) -> Dict: """ Returns the database object based on the database name. @@ -42,7 +43,7 @@ def get_dbs(self, config: Dict) -> Dict: return {db: create_engine(f"sqlite:///data/{db}.db") for db in config.views} @abstractmethod - def get_views(self, config: Dict) -> Dict[str, Type[SqlAlchemyBaseView]]: + def get_views(self, config: DictConfig) -> Dict[str, Type[BaseView]]: """ Creates the view classes mapping based on the configuration. @@ -59,7 +60,7 @@ class IQLViewEvaluationPipeline(ViewEvaluationPipeline): IQL view evaluation pipeline. """ - def get_views(self, config: Dict) -> Dict[str, Type[SqlAlchemyBaseView]]: + def get_views(self, config: DictConfig) -> Dict[str, Type[BaseView]]: """ Creates the view classes mapping based on the configuration. @@ -89,6 +90,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: result = await view.ask( query=data["question"], llm=self.llm, + contexts=self.contexts, dry_run=True, n_retries=0, ) @@ -140,7 +142,7 @@ class SQLViewEvaluationPipeline(ViewEvaluationPipeline): SQL view evaluation pipeline. """ - def get_views(self, config: Dict) -> Dict[str, Type[BaseText2SQLView]]: + def get_views(self, config: DictConfig) -> Dict[str, Type[BaseView]]: """ Creates the view classes mapping based on the configuration. diff --git a/benchmarks/sql/bench/views/structured/__init__.py b/benchmarks/sql/bench/views/structured/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmarks/sql/config/setup/collection.yaml b/benchmarks/sql/config/setup/collection.yaml index 2eafb34a..f1d3a13a 100644 --- a/benchmarks/sql/config/setup/collection.yaml +++ b/benchmarks/sql/config/setup/collection.yaml @@ -5,3 +5,5 @@ defaults: - llm@generator_llm: gpt-3.5-turbo - views/structured@views: - superhero + - contexts: + - superhero diff --git a/benchmarks/sql/config/setup/contexts/superhero.yaml b/benchmarks/sql/config/setup/contexts/superhero.yaml new file mode 100644 index 00000000..fcb6f70f --- /dev/null +++ b/benchmarks/sql/config/setup/contexts/superhero.yaml @@ -0,0 +1,4 @@ +superhero: [ + UserContext, + SuperheroContext, +] diff --git a/benchmarks/sql/config/setup/iql-view.yaml b/benchmarks/sql/config/setup/iql-view.yaml index e652bc3b..be482a85 100644 --- a/benchmarks/sql/config/setup/iql-view.yaml +++ b/benchmarks/sql/config/setup/iql-view.yaml @@ -4,3 +4,5 @@ defaults: - llm: gpt-3.5-turbo - views/structured@views: - superhero + - contexts: + - superhero diff --git a/benchmarks/sql/config/setup/sql-view.yaml b/benchmarks/sql/config/setup/sql-view.yaml index e4e1f7d9..5d9de669 100644 --- a/benchmarks/sql/config/setup/sql-view.yaml +++ b/benchmarks/sql/config/setup/sql-view.yaml @@ -4,3 +4,5 @@ defaults: - llm: gpt-3.5-turbo - views/freeform@views: - superhero + - contexts: + - superhero From fab9d3ff3c20b24b7feb1f886a68bfa4821c1431 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 1 Oct 2024 17:02:02 +0200 Subject: [PATCH 53/53] small refactor --- benchmarks/sql/bench/pipelines/base.py | 47 ++++++++--- benchmarks/sql/bench/pipelines/collection.py | 69 +++++++++------- benchmarks/sql/bench/pipelines/view.py | 85 +++++++------------- 3 files changed, 104 insertions(+), 97 deletions(-) diff --git a/benchmarks/sql/bench/pipelines/base.py b/benchmarks/sql/bench/pipelines/base.py index e19aef56..acde042e 100644 --- a/benchmarks/sql/bench/pipelines/base.py +++ b/benchmarks/sql/bench/pipelines/base.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from functools import cached_property +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union from omegaconf import DictConfig +from sqlalchemy import Engine, create_engine from dbally.context import Context from dbally.iql._exceptions import IQLError @@ -12,9 +14,12 @@ from dbally.llms.clients.exceptions import LLMError from dbally.llms.litellm import LiteLLM from dbally.llms.local import LocalLLM +from dbally.views.base import BaseView from ..contexts import CONTEXTS_REGISTRY +ViewT = TypeVar("ViewT", bound=BaseView) + @dataclass class IQL: @@ -88,9 +93,10 @@ class EvaluationPipeline(ABC): def __init__(self, config: DictConfig) -> None: super().__init__() - self.contexts = self.get_contexts(config.setup) + self.config = config - def get_llm(self, config: DictConfig) -> LLM: + @staticmethod + def _get_llm(config: DictConfig) -> LLM: """ Returns the LLM based on the configuration. @@ -104,17 +110,12 @@ def get_llm(self, config: DictConfig) -> LLM: return LocalLLM(config.model_name.split("/", 1)[1]) return LiteLLM(config.model_name) - def get_contexts(self, config: DictConfig) -> List[Context]: + @cached_property + def dbs(self) -> Dict[str, Engine]: """ - Returns the contexts based on the configuration. - - Args: - config: The contexts configuration. - - Returns: - The contexts. + Returns the database engines based on the configuration. """ - return [CONTEXTS_REGISTRY[context]() for contexts in config.contexts.values() for context in contexts] + return {db: create_engine(f"sqlite:///data/{db}.db") for db in self.config.setup.views} @abstractmethod async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: @@ -127,3 +128,25 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: Returns: The evaluation result. """ + + +class ViewEvaluationMixin(Generic[ViewT]): + """ + View evaluation mixin. + """ + + @cached_property + def contexts(self) -> List[Context]: + """ + Returns the contexts based on the configuration. + """ + return [ + CONTEXTS_REGISTRY[context]() for contexts in self.config.setup.contexts.values() for context in contexts + ] + + @cached_property + @abstractmethod + def views(self) -> Dict[str, Type[ViewT]]: + """ + Returns the view classes mapping based on the configuration + """ diff --git a/benchmarks/sql/bench/pipelines/collection.py b/benchmarks/sql/bench/pipelines/collection.py index 7ded27d1..784086a7 100644 --- a/benchmarks/sql/bench/pipelines/collection.py +++ b/benchmarks/sql/bench/pipelines/collection.py @@ -1,59 +1,68 @@ -from typing import Any, Dict - -from omegaconf import DictConfig -from sqlalchemy import create_engine +from functools import cached_property +from typing import Any, Dict, Type, Union import dbally from dbally.collection.collection import Collection from dbally.collection.exceptions import NoViewFoundError +from dbally.llms.base import LLM from dbally.view_selection.llm_view_selector import LLMViewSelector from dbally.views.exceptions import ViewExecutionError +from dbally.views.freeform.text2sql.view import BaseText2SQLView +from dbally.views.sqlalchemy_base import SqlAlchemyBaseView from ..views import VIEWS_REGISTRY -from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult +from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult, ViewEvaluationMixin -class CollectionEvaluationPipeline(EvaluationPipeline): +class CollectionEvaluationPipeline( + EvaluationPipeline, ViewEvaluationMixin[Union[SqlAlchemyBaseView, BaseText2SQLView]] +): """ Collection evaluation pipeline. """ - def __init__(self, config: DictConfig) -> None: + @cached_property + def selector(self) -> LLM: """ - Constructs the pipeline for evaluating collection predictions. - - Args: - config: The configuration for the pipeline. + Returns the selector LLM. """ - super().__init__(config) - self.collection = self.get_collection(config.setup) + return self._get_llm(self.config.setup.selector_llm) - def get_collection(self, config: DictConfig) -> Collection: + @cached_property + def generator(self) -> LLM: """ - Sets up the collection based on the configuration. - - Args: - config: The collection configuration. + Returns the generator LLM. + """ + return self._get_llm(self.config.setup.generator_llm) - Returns: - The collection. + @cached_property + def views(self) -> Dict[str, Type[Union[SqlAlchemyBaseView, BaseText2SQLView]]]: + """ + Returns the view classes mapping based on the configuration. + """ + return { + db: cls + for db, views in self.config.setup.views.items() + for view in views + if issubclass(cls := VIEWS_REGISTRY[view], (SqlAlchemyBaseView, BaseText2SQLView)) + } + + @cached_property + def collection(self) -> Collection: + """ + Returns the collection used for evaluation. """ - generator_llm = self.get_llm(config.generator_llm) - selector_llm = self.get_llm(config.selector_llm) - view_selector = LLMViewSelector(selector_llm) + view_selector = LLMViewSelector(self.selector) collection = dbally.create_collection( - name=config.name, - llm=generator_llm, + name=self.config.setup.name, + llm=self.generator, view_selector=view_selector, ) collection.n_retries = 0 - for db_name, view_names in config.views.items(): - db = create_engine(f"sqlite:///data/{db_name}.db") - for view_name in view_names: - view_cls = VIEWS_REGISTRY[view_name] - collection.add(view_cls, lambda: view_cls(db)) # pylint: disable=cell-var-from-loop + for db, view in self.views.items(): + collection.add(view, lambda: view(self.dbs[db])) # pylint: disable=cell-var-from-loop return collection diff --git a/benchmarks/sql/bench/pipelines/view.py b/benchmarks/sql/bench/pipelines/view.py index fc502115..237f2858 100644 --- a/benchmarks/sql/bench/pipelines/view.py +++ b/benchmarks/sql/bench/pipelines/view.py @@ -1,77 +1,53 @@ # pylint: disable=duplicate-code from abc import ABC, abstractmethod +from functools import cached_property from typing import Any, Dict, Type -from omegaconf import DictConfig -from sqlalchemy import create_engine - -from dbally.views.base import BaseView +from dbally.llms.base import LLM from dbally.views.exceptions import ViewExecutionError +from dbally.views.freeform.text2sql.view import BaseText2SQLView +from dbally.views.sqlalchemy_base import SqlAlchemyBaseView from ..views import VIEWS_REGISTRY -from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult +from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult, ViewEvaluationMixin, ViewT -class ViewEvaluationPipeline(EvaluationPipeline, ABC): +class ViewEvaluationPipeline(EvaluationPipeline, ViewEvaluationMixin[ViewT], ABC): """ View evaluation pipeline. """ - def __init__(self, config: DictConfig) -> None: - """ - Constructs the pipeline for evaluating IQL predictions. - - Args: - config: The configuration for the pipeline. - """ - super().__init__(config) - self.llm = self.get_llm(config.setup.llm) - self.dbs = self.get_dbs(config.setup) - self.views = self.get_views(config.setup) - - def get_dbs(self, config: DictConfig) -> Dict: + @cached_property + def llm(self) -> LLM: """ - Returns the database object based on the database name. - - Args: - config: The database configuration. - - Returns: - The database object. + Returns the LLM based on the configuration. """ - return {db: create_engine(f"sqlite:///data/{db}.db") for db in config.views} + return self._get_llm(self.config.setup.llm) + @cached_property @abstractmethod - def get_views(self, config: DictConfig) -> Dict[str, Type[BaseView]]: + def views(self) -> Dict[str, Type[ViewT]]: """ - Creates the view classes mapping based on the configuration. - - Args: - config: The views configuration. - - Returns: - The view classes mapping. + Returns the view classes mapping based on the configuration """ -class IQLViewEvaluationPipeline(ViewEvaluationPipeline): +class IQLViewEvaluationPipeline(ViewEvaluationPipeline[SqlAlchemyBaseView]): """ IQL view evaluation pipeline. """ - def get_views(self, config: DictConfig) -> Dict[str, Type[BaseView]]: + @cached_property + def views(self) -> Dict[str, Type[SqlAlchemyBaseView]]: """ - Creates the view classes mapping based on the configuration. - - Args: - config: The views configuration. - - Returns: - The view classes mapping. + Returns the view classes mapping based on the configuration. """ return { - view_name: VIEWS_REGISTRY[view_name] for view_names in config.views.values() for view_name in view_names + view: cls + for views in self.config.setup.views.values() + for view in views + if issubclass(cls := VIEWS_REGISTRY[view], SqlAlchemyBaseView) } async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: @@ -137,22 +113,21 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: ) -class SQLViewEvaluationPipeline(ViewEvaluationPipeline): +class SQLViewEvaluationPipeline(ViewEvaluationPipeline[BaseText2SQLView]): """ SQL view evaluation pipeline. """ - def get_views(self, config: DictConfig) -> Dict[str, Type[BaseView]]: + @cached_property + def views(self) -> Dict[str, Type[BaseText2SQLView]]: """ - Creates the view classes mapping based on the configuration. - - Args: - config: The views configuration. - - Returns: - The view classes mapping. + Returns the view classes mapping based on the configuration. """ - return {db_id: VIEWS_REGISTRY[view_name] for db_id, view_name in config.views.items()} + return { + db: cls + for db, view in self.config.setup.views.items() + if issubclass(cls := VIEWS_REGISTRY[view], BaseText2SQLView) + } async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: """