From 139ab9be6514cb0777251efafc3e272cac00e6cd Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Thu, 29 Aug 2024 10:10:08 +0200 Subject: [PATCH] feat: aggregations in structured views (#62) --- benchmarks/sql/bench/pipelines/base.py | 26 +- benchmarks/sql/bench/pipelines/collection.py | 40 +-- benchmarks/sql/bench/pipelines/view.py | 35 +-- .../sql/bench/views/structured/superhero.py | 23 +- docs/how-to/views/custom_views_code.py | 1 + docs/quickstart/quickstart_code.py | 1 - src/dbally/iql/__init__.py | 12 +- src/dbally/iql/_processor.py | 77 ++++-- src/dbally/iql/_query.py | 37 ++- src/dbally/iql_generator/iql_generator.py | 232 ++++++++++++++---- src/dbally/iql_generator/prompt.py | 172 +++++++++---- src/dbally/prompt/template.py | 4 +- src/dbally/view_selection/prompt.py | 2 +- src/dbally/views/decorators.py | 15 ++ src/dbally/views/exceptions.py | 18 +- src/dbally/views/methods_base.py | 51 +++- src/dbally/views/pandas_base.py | 88 +++++-- src/dbally/views/sqlalchemy_base.py | 39 +-- src/dbally/views/structured.py | 97 ++++---- tests/integration/test_llm_options.py | 2 +- tests/unit/iql/test_iql_parser.py | 226 +++++++++++++++-- tests/unit/mocks.py | 24 +- .../similarity/sample_module/submodule.py | 12 +- tests/unit/test_collection.py | 19 +- tests/unit/test_iql_format.py | 28 +-- tests/unit/test_iql_generator.py | 188 ++++++++------ tests/unit/views/test_methods_base.py | 39 ++- tests/unit/views/test_pandas_base.py | 99 +++++++- tests/unit/views/test_sqlalchemy_base.py | 51 +++- 29 files changed, 1235 insertions(+), 423 deletions(-) diff --git a/benchmarks/sql/bench/pipelines/base.py b/benchmarks/sql/bench/pipelines/base.py index 38bcb304..dc8d83ea 100644 --- a/benchmarks/sql/bench/pipelines/base.py +++ b/benchmarks/sql/bench/pipelines/base.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union +from dbally.iql._exceptions import IQLError +from dbally.iql._query import IQLQuery +from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM +from dbally.llms.clients.exceptions import LLMError from dbally.llms.litellm import LiteLLM from dbally.llms.local import LocalLLM @@ -16,6 +20,25 @@ class IQL: source: Optional[str] = None unsupported: bool = False valid: bool = True + generated: bool = True + + @classmethod + def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL": + """ + Creates an IQL object from the query. + + Args: + query: The IQL query or exception. + + Returns: + The IQL object. + """ + return cls( + source=query.source if isinstance(query, (IQLQuery, IQLError)) else None, + unsupported=isinstance(query, UnsupportedQueryError), + valid=not isinstance(query, IQLError), + generated=not isinstance(query, LLMError), + ) @dataclass @@ -47,6 +70,7 @@ class EvaluationResult: """ db_id: str + question_id: str question: str reference: ExecutionResult prediction: ExecutionResult diff --git a/benchmarks/sql/bench/pipelines/collection.py b/benchmarks/sql/bench/pipelines/collection.py index dfc127cf..19831b0d 100644 --- a/benchmarks/sql/bench/pipelines/collection.py +++ b/benchmarks/sql/bench/pipelines/collection.py @@ -5,10 +5,8 @@ import dbally from dbally.collection.collection import Collection from dbally.collection.exceptions import NoViewFoundError -from dbally.iql._exceptions import IQLError -from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.view_selection.llm_view_selector import LLMViewSelector -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from ..views import VIEWS_REGISTRY from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult @@ -74,44 +72,23 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return_natural_response=False, ) except NoViewFoundError: - prediction = ExecutionResult( - view_name=None, - iql=None, - sql=None, - ) - except IQLGenerationError as exc: + prediction = ExecutionResult() + except ViewExecutionError as exc: prediction = ExecutionResult( view_name=exc.view_name, iql=IQLResult( - filters=IQL( - source=exc.filters, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), - aggregation=IQL( - source=exc.aggregation, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), + filters=IQL.from_query(exc.iql.filters), + aggregation=IQL.from_query(exc.iql.aggregation), ), - sql=None, ) else: prediction = ExecutionResult( view_name=result.view_name, iql=IQLResult( - filters=IQL( - source=result.context.get("iql"), - unsupported=False, - valid=True, - ), - aggregation=IQL( - source=None, - unsupported=False, - valid=True, - ), + filters=IQL(source=result.context["iql"]["filters"]), + aggregation=IQL(source=result.context["iql"]["aggregation"]), ), - sql=result.context.get("sql"), + sql=result.context["sql"], ) reference = ExecutionResult( @@ -134,6 +111,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/pipelines/view.py b/benchmarks/sql/bench/pipelines/view.py index d4ae8515..be9d8263 100644 --- a/benchmarks/sql/bench/pipelines/view.py +++ b/benchmarks/sql/bench/pipelines/view.py @@ -5,9 +5,7 @@ from sqlalchemy import create_engine -from dbally.iql._exceptions import IQLError -from dbally.iql_generator.prompt import UnsupportedQueryError -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from dbally.views.freeform.text2sql.view import BaseText2SQLView from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -94,37 +92,20 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: dry_run=True, n_retries=0, ) - except IQLGenerationError as exc: + except ViewExecutionError as exc: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL( - source=exc.filters, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), - aggregation=IQL( - source=exc.aggregation, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), + filters=IQL.from_query(exc.iql.filters), + aggregation=IQL.from_query(exc.iql.aggregation), ), - sql=None, ) else: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL( - source=result.context["iql"], - unsupported=False, - valid=True, - ), - aggregation=IQL( - source=None, - unsupported=False, - valid=True, - ), + filters=IQL(source=result.context["iql"]["filters"]), + aggregation=IQL(source=result.context["iql"]["aggregation"]), ), sql=result.context["sql"], ) @@ -135,12 +116,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: filters=IQL( source=data["iql_filters"], unsupported=data["iql_filters_unsupported"], - valid=True, ), aggregation=IQL( source=data["iql_aggregation"], unsupported=data["iql_aggregation_unsupported"], - valid=True, ), context=data["iql_context"], ), @@ -149,6 +128,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, @@ -209,6 +189,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index db57498e..2a6a75a0 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.declarative import DeferredReflection from sqlalchemy.orm import aliased, declarative_base -from dbally.views.decorators import view_filter +from dbally.views.decorators import view_aggregation, view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView Base = declarative_base(cls=DeferredReflection) @@ -285,8 +285,8 @@ class SuperheroColourFilterMixin: Mixin for filtering the view by the superhero colour attributes. """ - def __init__(self) -> None: - super().__init__() + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.eye_colour = aliased(Colour) self.hair_colour = aliased(Colour) self.skin_colour = aliased(Colour) @@ -427,10 +427,27 @@ def filter_by_race(self, race: str) -> ColumnElement: return Race.race == race +class SuperheroAggregationMixin: + """ + Mixin for aggregating the view by the superhero attributes. + """ + + @view_aggregation() + def count_superheroes(self) -> Select: + """ + Counts the number of superheros. + + Returns: + The superheros count. + """ + return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) + + class SuperheroView( DBInitMixin, SqlAlchemyBaseView, SuperheroFilterMixin, + SuperheroAggregationMixin, SuperheroColourFilterMixin, AlignmentFilterMixin, GenderFilterMixin, diff --git a/docs/how-to/views/custom_views_code.py b/docs/how-to/views/custom_views_code.py index 33c954c7..c64a2ffb 100644 --- a/docs/how-to/views/custom_views_code.py +++ b/docs/how-to/views/custom_views_code.py @@ -66,6 +66,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult(results=filtered_data, context={}) + class CandidateView(FilteredIterableBaseView): def get_data(self) -> Iterable: return [ diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index be2aab37..ef73cad0 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -6,7 +6,6 @@ from sqlalchemy import create_engine from sqlalchemy.ext.automap import automap_base -import dbally from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.llms.litellm import LiteLLM diff --git a/src/dbally/iql/__init__.py b/src/dbally/iql/__init__.py index 0df0a766..20bde9eb 100644 --- a/src/dbally/iql/__init__.py +++ b/src/dbally/iql/__init__.py @@ -1,5 +1,13 @@ from . import syntax from ._exceptions import IQLArgumentParsingError, IQLError, IQLUnsupportedSyntaxError -from ._query import IQLQuery +from ._query import IQLAggregationQuery, IQLFiltersQuery, IQLQuery -__all__ = ["IQLQuery", "syntax", "IQLError", "IQLArgumentParsingError", "IQLUnsupportedSyntaxError"] +__all__ = [ + "IQLQuery", + "IQLFiltersQuery", + "IQLAggregationQuery", + "syntax", + "IQLError", + "IQLArgumentParsingError", + "IQLUnsupportedSyntaxError", +] diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index f1adf64c..1bd72bcc 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,5 +1,6 @@ import ast -from typing import TYPE_CHECKING, Any, List, Optional, Union +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker from dbally.iql import syntax @@ -19,10 +20,12 @@ if TYPE_CHECKING: from dbally.views.structured import ExposedFunction +RootT = TypeVar("RootT", bound=syntax.Node) -class IQLProcessor: + +class IQLProcessor(Generic[RootT], ABC): """ - Parses IQL string to tree structure. + Base class for IQL processors. """ def __init__( @@ -32,9 +35,9 @@ def __init__( self.allowed_functions = {func.name: func for func in allowed_functions} self._event_tracker = event_tracker or EventTracker() - async def process(self) -> syntax.Node: + async def process(self) -> RootT: """ - Process IQL string to root IQL.Node. + Process IQL string to IQL root node. Returns: IQL node which is root of the tree representing IQL query. @@ -60,25 +63,17 @@ async def process(self) -> syntax.Node: return await self._parse_node(ast_tree.body[0].value) - async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: - if isinstance(node, ast.BoolOp): - return await self._parse_bool_op(node) - if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): - return syntax.Not(await self._parse_node(node.operand)) - if isinstance(node, ast.Call): - return await self._parse_call(node) - - raise IQLUnsupportedSyntaxError(node, self.source) + @abstractmethod + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT: + """ + Parses AST node to IQL node. - async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp: - if isinstance(node.op, ast.Not): - return syntax.Not(await self._parse_node(node.values[0])) - if isinstance(node.op, ast.And): - return syntax.And([await self._parse_node(x) for x in node.values]) - if isinstance(node.op, ast.Or): - return syntax.Or([await self._parse_node(x) for x in node.values]) + Args: + node: AST node to parse. - raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp") + Returns: + IQL node. + """ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: func = node.func @@ -153,3 +148,41 @@ def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str: converted_text = converted_text[: len(converted_text) - len(keyword)] + keyword.lower() return converted_text + + +class IQLFiltersProcessor(IQLProcessor[syntax.Node]): + """ + IQL processor for filters. + """ + + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: + if isinstance(node, ast.BoolOp): + return await self._parse_bool_op(node) + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return syntax.Not(await self._parse_node(node.operand)) + if isinstance(node, ast.Call): + return await self._parse_call(node) + + raise IQLUnsupportedSyntaxError(node, self.source) + + async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp: + if isinstance(node.op, ast.Not): + return syntax.Not(await self._parse_node(node.values[0])) + if isinstance(node.op, ast.And): + return syntax.And([await self._parse_node(x) for x in node.values]) + if isinstance(node.op, ast.Or): + return syntax.Or([await self._parse_node(x) for x in node.values]) + + raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp") + + +class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]): + """ + IQL processor for aggregation. + """ + + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall: + if isinstance(node, ast.Call): + return await self._parse_call(node) + + raise IQLUnsupportedSyntaxError(node, self.source) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index dd831a91..57b3b4ed 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,26 +1,29 @@ -from typing import TYPE_CHECKING, List, Optional +from abc import ABC +from typing import TYPE_CHECKING, Generic, List, Optional, Type from ..audit.event_tracker import EventTracker from . import syntax -from ._processor import IQLProcessor +from ._processor import IQLAggregationProcessor, IQLFiltersProcessor, IQLProcessor, RootT if TYPE_CHECKING: from dbally.views.structured import ExposedFunction -class IQLQuery: +class IQLQuery(Generic[RootT], ABC): """ IQLQuery container. It stores IQL as a syntax tree defined in `IQL` class. """ - root: syntax.Node + root: RootT + source: str + _processor: Type[IQLProcessor[RootT]] - def __init__(self, root: syntax.Node, source: str) -> None: + def __init__(self, root: RootT, source: str) -> None: self.root = root - self._source = source + self.source = source def __str__(self) -> str: - return self._source + return self.source @classmethod async def parse( @@ -28,7 +31,7 @@ async def parse( source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None, - ) -> "IQLQuery": + ) -> "IQLQuery[RootT]": """ Parse IQL string to IQLQuery object. @@ -43,5 +46,21 @@ async def parse( Raises: IQLError: If parsing fails. """ - root = await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process() + root = await cls._processor(source, allowed_functions, event_tracker=event_tracker).process() return cls(root=root, source=source) + + +class IQLFiltersQuery(IQLQuery[syntax.Node]): + """ + IQL filters query container. + """ + + _processor: Type[IQLFiltersProcessor] = IQLFiltersProcessor + + +class IQLAggregationQuery(IQLQuery[syntax.FunctionCall]): + """ + IQL aggregation query container. + """ + + _processor: Type[IQLAggregationProcessor] = IQLAggregationProcessor diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 27347734..4ea65340 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,11 +1,16 @@ -from typing import List, Optional +import asyncio +from dataclasses import dataclass +from typing import Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker from dbally.iql import IQLError, IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.prompt import ( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, + FILTERS_GENERATION_TEMPLATE, + DecisionPromptFormat, IQLGenerationPromptFormat, ) from dbally.llms.base import LLM @@ -15,57 +20,151 @@ from dbally.prompt.template import PromptTemplate from dbally.views.exposed_functions import ExposedFunction -ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ - generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" +IQLQueryT = TypeVar("IQLQueryT", bound=IQLQuery) -class IQLGenerator: +@dataclass +class IQLGeneratorState: + """ + State of the IQL generator. """ - Class used to generate IQL from natural language question. - In db-ally, LLM uses IQL (Intermediate Query Language) to express complex queries in a simplified way. - The class used to generate IQL from natural language query is `IQLGenerator`. + filters: Optional[Union[IQLFiltersQuery, Exception]] = None + aggregation: Optional[Union[IQLAggregationQuery, Exception]] = None + + @property + def failed(self) -> bool: + """ + Checks if the generation failed. - IQL generation is done using the method `self.generate_iql`. - It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question. + Returns: + True if the generation failed, False otherwise. + """ + return isinstance(self.filters, Exception) or isinstance(self.aggregation, Exception) + + +class IQLGenerator: + """ + Orchestrates all IQL operations for the given question. """ def __init__( self, - llm: LLM, - *, - decision_prompt: Optional[PromptTemplate[FilteringDecisionPromptFormat]] = None, - generation_prompt: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None, + filters_generation: Optional["IQLOperationGenerator"] = None, + aggregation_generation: Optional["IQLOperationGenerator"] = None, ) -> None: """ Constructs a new IQLGenerator instance. Args: - llm: LLM used to generate IQL. decision_prompt: Prompt template for filtering decision making. generation_prompt: Prompt template for IQL generation. """ - self._llm = llm - self._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE - self._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE + self._filters_generation = filters_generation or IQLOperationGenerator[IQLFiltersQuery]( + FILTERING_DECISION_TEMPLATE, + FILTERS_GENERATION_TEMPLATE, + ) + self._aggregation_generation = aggregation_generation or IQLOperationGenerator[IQLAggregationQuery]( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, + ) - async def generate( + # pylint: disable=too-many-arguments + async def __call__( self, + *, question: str, filters: List[ExposedFunction], - event_tracker: EventTracker, - examples: Optional[List[FewShotExample]] = None, + aggregations: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - ) -> Optional[IQLQuery]: + ) -> IQLGeneratorState: """ - Generates IQL in text form using LLM. + Generates IQL operations for the given question. Args: question: User question. filters: List of filters exposed by the view. + aggregations: List of aggregations exposed by the view. + examples: List of examples to be injected during filters and aggregation generation. + llm: LLM used to generate IQL. event_tracker: Event store used to audit the generation process. + llm_options: Options to use for the LLM client. + n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. + + Returns: + Generated IQL operations. + """ + filters, aggregation = await asyncio.gather( + self._filters_generation( + question=question, + methods=filters, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ), + self._aggregation_generation( + question=question, + methods=aggregations, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ), + return_exceptions=True, + ) + return IQLGeneratorState( + filters=filters, + aggregation=aggregation, + ) + + +class IQLOperationGenerator(Generic[IQLQueryT]): + """ + Generates IQL queries for the given question. + """ + + def __init__( + self, + assessor_prompt: PromptTemplate[DecisionPromptFormat], + generator_prompt: PromptTemplate[IQLGenerationPromptFormat], + ) -> None: + """ + Constructs a new IQLGenerator instance. + + Args: + assessor_prompt: Prompt template for filtering decision making. + generator_prompt: Prompt template for IQL generation. + """ + self.assessor = IQLQuestionAssessor(assessor_prompt) + self.generator = IQLQueryGenerator[IQLQueryT](generator_prompt) + + async def __call__( + self, + *, + question: str, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, + llm_options: Optional[LLMOptions] = None, + n_retries: int = 3, + ) -> Optional[IQLQueryT]: + """ + Generates IQL query for the given question. + + Args: + llm: LLM used to generate IQL. + question: User question. + methods: List of methods exposed by the view. examples: List of examples to be injected into the conversation. + event_tracker: Event store used to audit the generation process. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. @@ -77,38 +176,52 @@ async def generate( IQLError: If IQL parsing fails after all retries. UnsupportedQueryError: If the question is not supported by the view. """ - decision = await self._decide_on_generation( + decision = await self.assessor( question=question, - event_tracker=event_tracker, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) if not decision: return None - return await self._generate_iql( + return await self.generator( question=question, - filters=filters, - event_tracker=event_tracker, + methods=methods, examples=examples, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) - async def _decide_on_generation( + +class IQLQuestionAssessor: + """ + Assesses whether a question requires applying IQL operation or not. + """ + + def __init__(self, prompt: PromptTemplate[DecisionPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - event_tracker: EventTracker, + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, ) -> bool: """ - Decides whether the question requires filtering or not. + Decides whether the question requires generating IQL or not. Args: question: User question. - event_tracker: Event store used to audit the generation process. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to LLM API in case of errors. Returns: @@ -117,12 +230,14 @@ async def _decide_on_generation( Raises: LLMError: If LLM text generation fails after all retries. """ - prompt_format = FilteringDecisionPromptFormat(question=question) - formatted_prompt = self._decision_prompt.format_prompt(prompt_format) + prompt_format = DecisionPromptFormat( + question=question, + ) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, @@ -133,24 +248,39 @@ async def _decide_on_generation( if retry == n_retries: raise exc - async def _generate_iql( + +class IQLQueryGenerator(Generic[IQLQueryT]): + """ + Generates IQL queries for the given question. + """ + + ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ + generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" + + def __init__(self, prompt: PromptTemplate[IQLGenerationPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - filters: List[ExposedFunction], - event_tracker: Optional[EventTracker] = None, - examples: Optional[List[FewShotExample]] = None, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, - ) -> IQLQuery: + ) -> IQLQueryT: """ - Generates IQL in text form using LLM. + Generates IQL query for the given question. Args: question: User question. filters: List of filters exposed by the view. - event_tracker: Event store used to audit the generation process. examples: List of examples to be injected into the conversation. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. Returns: @@ -163,24 +293,22 @@ async def _generate_iql( """ prompt_format = IQLGenerationPromptFormat( question=question, - filters=filters, + methods=methods, examples=examples, ) - formatted_prompt = self._generation_prompt.format_prompt(prompt_format) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) # TODO: Move response parsing to llm generate_text method - iql = formatted_prompt.response_parser(response) - # TODO: Move IQL query parsing to prompt response parser - return await IQLQuery.parse( - source=iql, - allowed_functions=filters, + return await formatted_prompt.response_parser( + response=response, + allowed_functions=methods, event_tracker=event_tracker, ) except LLMError as exc: @@ -190,4 +318,4 @@ async def _generate_iql( if retry == n_retries: raise exc formatted_prompt = formatted_prompt.add_assistant_message(response) - formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) + formatted_prompt = formatted_prompt.add_user_message(self.ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 5dfc2028..8d8e7101 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -1,8 +1,10 @@ # pylint: disable=C0301 -from typing import List +from typing import List, Optional +from dbally.audit.event_tracker import EventTracker from dbally.exceptions import DbAllyError +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.prompt.elements import FewShotExample from dbally.prompt.template import PromptFormat, PromptTemplate from dbally.views.exposed_functions import ExposedFunction @@ -15,26 +17,65 @@ class UnsupportedQueryError(DbAllyError): """ -def _validate_iql_response(llm_response: str) -> str: +async def _iql_filters_parser( + response: str, + allowed_functions: List[ExposedFunction], + event_tracker: Optional[EventTracker] = None, +) -> IQLFiltersQuery: """ - Validates LLM response to IQL + Parses the response from the LLM to IQL. Args: - llm_response: LLM response + response: LLM response. + allowed_functions: List of functions that can be used in the IQL. + event_tracker: Event tracker to be used for auditing. Returns: - A string containing IQL for filters. + IQL query for filters. Raises: - UnsuppotedQueryError: When IQL generator is unable to construct a query - with given filters. + UnsuppotedQueryError: When IQL generator is unable to construct a query with given filters. """ - if "unsupported query" in llm_response.lower(): + if "unsupported query" in response.lower(): raise UnsupportedQueryError - return llm_response + return await IQLFiltersQuery.parse( + source=response, + allowed_functions=allowed_functions, + event_tracker=event_tracker, + ) -def _decision_iql_response_parser(response: str) -> bool: + +async def _iql_aggregation_parser( + response: str, + allowed_functions: List[ExposedFunction], + event_tracker: Optional[EventTracker] = None, +) -> IQLAggregationQuery: + """ + Parses the response from the LLM to IQL. + + Args: + response: LLM response. + allowed_functions: List of functions that can be used in the IQL. + event_tracker: Event tracker to be used for auditing. + + Returns: + IQL query for aggregations. + + Raises: + UnsuppotedQueryError: When IQL generator is unable to construct a query with given aggregations. + """ + if "unsupported query" in response.lower(): + raise UnsupportedQueryError + + return await IQLAggregationQuery.parse( + source=response, + allowed_functions=allowed_functions, + event_tracker=event_tracker, + ) + + +def _decision_parser(response: str) -> bool: """ Parses the response from the decision prompt. @@ -52,7 +93,7 @@ def _decision_iql_response_parser(response: str) -> bool: return "true" in decision -class FilteringDecisionPromptFormat(PromptFormat): +class DecisionPromptFormat(PromptFormat): """ IQL prompt format, providing a question and filters to be used in the conversation. """ @@ -71,69 +112,69 @@ def __init__(self, *, question: str, examples: List[FewShotExample] = None) -> N class IQLGenerationPromptFormat(PromptFormat): """ - IQL prompt format, providing a question and filters to be used in the conversation. + IQL prompt format, providing a question and methods to be used in the conversation. """ def __init__( self, *, question: str, - filters: List[ExposedFunction], - examples: List[FewShotExample] = None, + methods: List[ExposedFunction], + examples: Optional[List[FewShotExample]] = None, ) -> None: """ Constructs a new IQLGenerationPromptFormat instance. Args: question: Question to be asked. - filters: List of filters exposed by the view. + methods: List of methods exposed by the view. examples: List of examples to be injected into the conversation. + aggregations: List of aggregations exposed by the view. """ super().__init__(examples) self.question = question - self.filters = "\n".join([str(filter) for filter in filters]) + self.methods = "\n".join(str(method) for method in methods) -IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( +FILTERING_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( [ { "role": "system", "content": ( - "You have access to API that lets you query a database:\n" - "\n{filters}\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" - "Remember! Don't give any comments, just the function calls.\n" - "The output will look like this:\n" - 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' - "DO NOT INCLUDE arguments names in your response. Only the values.\n" - "You MUST use only these methods:\n" - "\n{filters}\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ - "This is CRUCIAL, otherwise the system will crash. " + "Given a question, determine whether the answer requires initial data filtering in order to compute it.\n" + "Initial data filtering is a process in which the result set is reduced to only include the rows " + "that meet certain criteria specified in the question.\n\n" + "---\n\n" + "Follow the following format.\n\n" + "Question: ${{question}}\n" + "Hint: ${{hint}}" + "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" + "Decision: indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)\n\n" ), }, { "role": "user", - "content": "{question}", + "content": ( + "Question: {question}\n" + "Hint: Look for words indicating data specific features.\n" + "Reasoning: Let's think step by step in order to " + ), }, ], - response_parser=_validate_iql_response, + response_parser=_decision_parser, ) - -FILTERING_DECISION_TEMPLATE = PromptTemplate[FilteringDecisionPromptFormat]( +AGGREGATION_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( [ { "role": "system", "content": ( - "Given a question, determine whether the answer requires initial data filtering in order to compute it.\n" - "Initial data filtering is a process in which the result set is reduced to only include the rows " - "that meet certain criteria specified in the question.\n\n" + "Given a question, determine whether the answer requires computing the aggregation in order to compute it.\n" + "Aggregation is a process in which the result set is reduced to a single value.\n\n" "---\n\n" "Follow the following format.\n\n" "Question: ${{question}}\n" - "Hint: ${{hint}}" "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" "Decision: indicates whether the answer to the question requires initial data filtering. " "(Respond with True or False)\n\n" @@ -141,12 +182,61 @@ def __init__( }, { "role": "user", + "content": ("Question: {question}\n" "Reasoning: Let's think step by step in order to "), + }, + ], + response_parser=_decision_parser, +) + +FILTERS_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( + [ + { + "role": "system", "content": ( - "Question: {question}\n" - "Hint: Look for words indicating data specific features.\n" - "Reasoning: Let's think step by step in order to " + "You have access to an API that lets you query a database:\n" + "\n{methods}\n" + "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Remember! Don't give any comments, just the function calls.\n" + "The output will look like this:\n" + 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{methods}\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" + "This is CRUCIAL, otherwise the system will crash. " ), }, + { + "role": "user", + "content": "{question}", + }, + ], + response_parser=_iql_filters_parser, +) + +AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( + [ + { + "role": "system", + "content": ( + "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" + "When prompted for an aggregation, use the following methods: \n" + "{methods}" + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{methods}\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" + "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " + "Structure output to resemble the following pattern:\n" + 'aggregation1("arg1", arg2)\n' + ), + }, + { + "role": "user", + "content": "{question}", + }, ], - response_parser=_decision_iql_response_parser, + response_parser=_iql_aggregation_parser, ) diff --git a/src/dbally/prompt/template.py b/src/dbally/prompt/template.py index 124a3e1c..b4ef650d 100644 --- a/src/dbally/prompt/template.py +++ b/src/dbally/prompt/template.py @@ -1,6 +1,6 @@ import copy import re -from typing import Callable, Dict, Generic, List, TypeVar +from typing import Callable, Dict, Generic, List, Optional, TypeVar from typing_extensions import Self @@ -55,7 +55,7 @@ class PromptFormat: Generic format for prompts allowing to inject few shot examples into the conversation. """ - def __init__(self, examples: List[FewShotExample] = None) -> None: + def __init__(self, examples: Optional[List[FewShotExample]] = None) -> None: """ Constructs a new PromptFormat instance. diff --git a/src/dbally/view_selection/prompt.py b/src/dbally/view_selection/prompt.py index cdbedf5a..2d49efa9 100644 --- a/src/dbally/view_selection/prompt.py +++ b/src/dbally/view_selection/prompt.py @@ -35,7 +35,7 @@ def __init__( "role": "system", "content": ( "You are a very smart database programmer. " - "You have access to API that lets you query a database:\n" + "You have access to an API that lets you query a database:\n" "First you need to select a class to query, based on its description and the user question. " "You have the following classes to choose from:\n" "{views}\n" diff --git a/src/dbally/views/decorators.py b/src/dbally/views/decorators.py index ac537f5f..d318cfc4 100644 --- a/src/dbally/views/decorators.py +++ b/src/dbally/views/decorators.py @@ -14,3 +14,18 @@ def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missin return func return wrapped + + +def view_aggregation() -> typing.Callable: + """ + Decorator for marking a method as an aggregation + + Returns: + Function that returns the decorated method + """ + + def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missing-return-doc + func._methodDecorator = view_aggregation # type:ignore # pylint: disable=protected-access + return func + + return wrapped diff --git a/src/dbally/views/exceptions.py b/src/dbally/views/exceptions.py index 277064a4..15770e9a 100644 --- a/src/dbally/views/exceptions.py +++ b/src/dbally/views/exceptions.py @@ -1,26 +1,22 @@ -from typing import Optional - from dbally.exceptions import DbAllyError +from dbally.iql_generator.iql_generator import IQLGeneratorState -class IQLGenerationError(DbAllyError): +class ViewExecutionError(DbAllyError): """ - Exception for when an error occurs while generating IQL for a view. + Exception for when an error occurs while executing a view. """ def __init__( self, view_name: str, - filters: Optional[str] = None, - aggregation: Optional[str] = None, + iql: IQLGeneratorState, ) -> None: """ Args: view_name: Name of the view that caused the error. - filters: Filters generated by the view. - aggregation: Aggregation generated by the view. + iql: View IQL generator state. """ - super().__init__(f"Error while generating IQL for view {view_name}") + super().__init__(f"Error while executing view {view_name}") self.view_name = view_name - self.filters = filters - self.aggregation = aggregation + self.iql = iql diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 8eeedfb0..8bf93363 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -1,6 +1,6 @@ -import abc import inspect import textwrap +from abc import ABC from typing import Any, Callable, List, Tuple from dbally.iql import syntax @@ -9,13 +9,13 @@ from dbally.views.structured import BaseStructuredView -class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta): +class MethodsBaseView(BaseStructuredView, ABC): """ Base class for views that use view methods to expose filters. """ # Method arguments that should be skipped when listing methods - HIDDEN_ARGUMENTS = ["self", "select", "return"] + HIDDEN_ARGUMENTS = ["cls", "self", "return"] @classmethod def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]: @@ -56,15 +56,25 @@ def list_filters(self) -> List[ExposedFunction]: """ return self.list_methods_by_decorator(decorators.view_filter) + def list_aggregations(self) -> List[ExposedFunction]: + """ + List aggregations in the given view + + Returns: + Aggregations defined inside the View and decorated with `decorators.view_aggregation`. + """ + return self.list_methods_by_decorator(decorators.view_aggregation) + def _method_with_args_from_call( self, func: syntax.FunctionCall, method_decorator: Callable - ) -> Tuple[Callable, list]: + ) -> Tuple[Callable, List]: """ Converts a IQL FunctionCall node to a method object and its arguments. Args: func: IQL FunctionCall node - method_decorator: The decorator that thhe method should have + method_decorator: The decorator that the method should have + (currently allows discrimination between filters and aggregations) Returns: Tuple with the method object and its arguments @@ -84,6 +94,21 @@ def _method_with_args_from_call( return method, func.arguments + async def _call_method(self, method: Callable, args: List) -> Any: + """ + Calls the method with the given arguments. If the method is a coroutine, it will be awaited. + + Args: + method: The method to call. + args: The arguments to pass to the method. + + Returns: + The result of the method call. + """ + if inspect.iscoroutinefunction(method): + return await method(*args) + return method(*args) + async def call_filter_method(self, func: syntax.FunctionCall) -> Any: """ Converts a IQL FunctonCall filter to a method call. If the method is a coroutine, it will be awaited. @@ -95,7 +120,17 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: The result of the method call """ method, args = self._method_with_args_from_call(func, decorators.view_filter) + return await self._call_method(method, args) - if inspect.iscoroutinefunction(method): - return await method(*args) - return method(*args) + async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: + """ + Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. + + Args: + func: IQL FunctionCall node + + Returns: + The result of the method call + """ + method, args = self._method_with_args_from_call(func, decorators.view_aggregation) + return await self._call_method(method, args) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index d7fa9446..e4da84c4 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -1,14 +1,36 @@ import asyncio +from dataclasses import dataclass from functools import reduce -from typing import Optional +from typing import List, Optional, Union import pandas as pd from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery, syntax +from dbally.iql import syntax +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.methods_base import MethodsBaseView +@dataclass(frozen=True) +class Aggregation: + """ + Represents an aggregation to be applied to a Pandas DataFrame. + """ + + column: str + function: str + + +@dataclass(frozen=True) +class AggregationGroup: + """ + Represents an aggregations and groupbys to be applied to a Pandas DataFrame. + """ + + aggregations: Optional[List[Aggregation]] = None + groupbys: Optional[Union[str, List[str]]] = None + + class DataFrameBaseView(MethodsBaseView): """ Base class for views that use Pandas DataFrames to store and filter data. @@ -19,48 +41,58 @@ class DataFrameBaseView(MethodsBaseView): def __init__(self, df: pd.DataFrame) -> None: """ + Creates a new instance of the DataFrame view. + Args: - df: Pandas DataFrame with the data to be filtered + df: Pandas DataFrame with the data to be filtered. """ super().__init__() self.df = df - - # The mask to be applied to the dataframe to filter the data self._filter_mask: Optional[pd.Series] = None + self._aggregation_group: AggregationGroup = AggregationGroup() - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: - filters: IQLQuery object representing the filters to apply + filters: IQLQuery object representing the filters to apply. """ - self._filter_mask = await self.build_filter_node(filters.root) + self._filter_mask = await self._build_filter_node(filters.root) - async def build_filter_node(self, node: syntax.Node) -> pd.Series: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: + """ + Applies the aggregation of choice to the view. + + Args: + aggregation: IQLQuery object representing the aggregation to apply. + """ + self._aggregation_group = await self.call_aggregation_method(aggregation.root) + + async def _build_filter_node(self, node: syntax.Node) -> pd.Series: """ Converts a filter node from the IQLQuery to a Pandas Series representing a boolean mask to be applied to the dataframe. Args: - node: IQLQuery node representing the filter or logical operator + node: IQLQuery node representing the filter or logical operator. Returns: - A boolean mask that can be used to filter the original DataFrame + A boolean mask that can be used to filter the original DataFrame. Raises: - ValueError: If the node type is not supported + ValueError: If the node type is not supported. """ if isinstance(node, syntax.FunctionCall): return await self.call_filter_method(node) if isinstance(node, syntax.And): # logical AND - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x & y, children) if isinstance(node, syntax.Or): # logical OR - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x | y, children) if isinstance(node, syntax.Not): - child = await self.build_filter_node(node.child) + child = await self._build_filter_node(node.child) return ~child raise ValueError(f"Unsupported grammar: {node}") @@ -70,21 +102,35 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: Args: dry_run: If True, the method will only add `context` field to the `ExecutionResult` with the\ - mask that would be applied to the dataframe + mask that would be applied to the dataframe. Returns: - ExecutionResult object with the results and the context information with the binary mask + ExecutionResult object with the results and the context information with the binary mask. """ - filtered_data = pd.DataFrame.empty + results = pd.DataFrame() if not dry_run: - filtered_data = self.df + results = self.df if self._filter_mask is not None: - filtered_data = filtered_data.loc[self._filter_mask] + results = results.loc[self._filter_mask] + + if self._aggregation_group.groupbys is not None: + results = results.groupby(self._aggregation_group.groupbys) + + if self._aggregation_group.aggregations is not None: + results = results.agg( + **{ + f"{agg.column}_{agg.function}": (agg.column, agg.function) + for agg in self._aggregation_group.aggregations + } + ) + results = results.reset_index() return ViewExecutionResult( - results=filtered_data.to_dict(orient="records"), + results=results.to_dict(orient="records"), context={ "filter_mask": self._filter_mask, + "groupbys": self._aggregation_group.groupbys, + "aggregations": self._aggregation_group.aggregations, }, ) diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index 2e15669c..3a7c7981 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -1,11 +1,11 @@ import abc import asyncio -from typing import Optional import sqlalchemy from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery, syntax +from dbally.iql import syntax +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.methods_base import MethodsBaseView @@ -15,28 +15,42 @@ class SqlAlchemyBaseView(MethodsBaseView): """ def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None: + """ + Creates a new instance of the SQL view. + + Args: + sqlalchemy_engine: SQLAlchemy engine to use for executing the queries. + """ super().__init__() + self.select = self.get_select() self._sqlalchemy_engine = sqlalchemy_engine - self._select = self.get_select() - self._where_clause: Optional[sqlalchemy.ColumnElement] = None @abc.abstractmethod def get_select(self) -> sqlalchemy.Select: - r""" + """ Creates the initial [SqlAlchemy select object ](https://docs.sqlalchemy.org/en/20/core/selectable.html#sqlalchemy.sql.expression.Select) which will be used to build the query. """ - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: - filters: IQLQuery object representing the filters to apply + filters: IQLQuery object representing the filters to apply. + """ + self.select = self.select.where(await self._build_filter_node(filters.root)) + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ - self._where_clause = await self._build_filter_node(filters.root) + Applies the chosen aggregation to the view. + + Args: + aggregation: IQLQuery object representing the aggregation to apply. + """ + self.select = await self.call_aggregation_method(aggregation.root) async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement: """ @@ -64,6 +78,7 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu return alchemy_op(*await nodes) if hasattr(bool_op, "child"): return alchemy_op(await self._build_filter_node(bool_op.child)) + raise ValueError(f"BoolOp {bool_op} has no children") def execute(self, dry_run: bool = False) -> ViewExecutionResult: @@ -78,17 +93,13 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored. """ results = [] - - if self._where_clause is not None: - self._select = self._select.where(self._where_clause) - - sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) + sql = str(self.select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) if not dry_run: with self._sqlalchemy_engine.connect() as connection: + rows = connection.execute(self.select).fetchall() # The underscore is used by sqlalchemy to avoid conflicts with column names # pylint: disable=protected-access - rows = connection.execute(self._select).fetchall() results = [dict(row._mapping) for row in rows] return ViewExecutionResult( diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 8d56c064..2e5cff85 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,13 +4,11 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery -from dbally.iql._exceptions import IQLError +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from dbally.views.exposed_functions import ExposedFunction from ..similarity import AbstractSimilarityIndex @@ -23,17 +21,14 @@ class BaseStructuredView(BaseView): to be able to list all available filters, apply them and execute queries. """ - def get_iql_generator(self, llm: LLM) -> IQLGenerator: + def get_iql_generator(self) -> IQLGenerator: """ Returns the IQL generator for the view. - Args: - llm: LLM used to generate the IQL queries. - Returns: IQL generator for the view. """ - return IQLGenerator(llm=llm) + return IQLGenerator() async def ask( self, @@ -60,42 +55,41 @@ async def ask( The result of the query. Raises: - LLMError: If LLM text generation API fails. - IQLGenerationError: If the IQL generation fails. + ViewExecutionError: When an error occurs while executing the view. """ - iql_generator = self.get_iql_generator(llm) - filters = self.list_filters() examples = self.list_few_shots() - - try: - iql = await iql_generator.generate( - question=query, - filters=filters, - examples=examples, - event_tracker=event_tracker, - llm_options=llm_options, - n_retries=n_retries, - ) - except UnsupportedQueryError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=None, - aggregation=None, - ) from exc - except IQLError as exc: - raise IQLGenerationError( + aggregations = self.list_aggregations() + + iql_generator = self.get_iql_generator() + iql = await iql_generator( + question=query, + filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, + event_tracker=event_tracker, + llm_options=llm_options, + n_retries=n_retries, + ) + + if iql.failed: + raise ViewExecutionError( view_name=self.__class__.__name__, - filters=exc.source, - aggregation=None, - ) from exc + iql=iql, + ) - if iql: - await self.apply_filters(iql) + if iql.filters: + await self.apply_filters(iql.filters) - result = self.execute(dry_run=dry_run) - result.context["iql"] = str(iql) if iql else None + if iql.aggregation: + await self.apply_aggregation(iql.aggregation) + result = self.execute(dry_run=dry_run) + result.context["iql"] = { + "filters": str(iql.filters) if iql.filters else None, + "aggregation": str(iql.aggregation) if iql.aggregation else None, + } return result @abc.abstractmethod @@ -108,12 +102,30 @@ def list_filters(self) -> List[ExposedFunction]: """ @abc.abstractmethod - async def apply_filters(self, filters: IQLQuery) -> None: + def list_aggregations(self) -> List[ExposedFunction]: + """ + Lists all available aggregations for the View. + + Returns: + Aggregations defined inside the View. + """ + + @abc.abstractmethod + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: - filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply + filters: IQLQuery object representing the filters to apply. + """ + + @abc.abstractmethod + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: + """ + Applies the chosen aggregation to the view. + + Args: + aggregation: IQLQuery object representing the aggregation to apply. """ @abc.abstractmethod @@ -122,7 +134,10 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: Executes the query and returns the result. Args: - dry_run: if True, should only generate the query without executing it + dry_run: if True, should only generate the query without executing it. + + Returns: + The view execution result. """ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLocation]]: diff --git a/tests/integration/test_llm_options.py b/tests/integration/test_llm_options.py index 892d979e..62a6766d 100644 --- a/tests/integration/test_llm_options.py +++ b/tests/integration/test_llm_options.py @@ -31,7 +31,7 @@ async def test_llm_options_propagation(): llm_options=custom_options, ) - assert llm.client.call.call_count == 3 + assert llm.client.call.call_count == 4 llm.client.call.assert_has_calls( [ diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index ae5d2269..bed83d0a 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -3,7 +3,7 @@ import pytest -from dbally.iql import IQLArgumentParsingError, IQLQuery, IQLUnsupportedSyntaxError, syntax +from dbally.iql import IQLArgumentParsingError, IQLUnsupportedSyntaxError, syntax from dbally.iql._exceptions import ( IQLArgumentValidationError, IQLFunctionNotExists, @@ -14,11 +14,12 @@ IQLSyntaxError, ) from dbally.iql._processor import IQLProcessor +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -async def test_iql_parser(): - parsed = await IQLQuery.parse( +async def test_iql_filter_parser(): + parsed = await IQLFiltersQuery.parse( "not (filter_by_name(['John', 'Anne']) and filter_by_city('cracow') and filter_by_company('deepsense.ai'))", allowed_functions=[ ExposedFunction( @@ -51,9 +52,9 @@ async def test_iql_parser(): assert company_filter.arguments[0] == "deepsense.ai" -async def test_iql_parser_arg_error(): +async def test_iql_filter_parser_arg_error(): with pytest.raises(IQLArgumentParsingError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_city('Cracow') and filter_by_name(lambda x: x + 1)", allowed_functions=[ ExposedFunction( @@ -76,9 +77,9 @@ async def test_iql_parser_arg_error(): assert exc_info.match(re.escape("Not a valid IQL argument: lambda x: x + 1")) -async def test_iql_parser_syntax_error(): +async def test_iql_filter_parser_syntax_error(): with pytest.raises(IQLSyntaxError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age(", allowed_functions=[ ExposedFunction( @@ -94,9 +95,9 @@ async def test_iql_parser_syntax_error(): assert exc_info.match(re.escape("Syntax error in: filter_by_age(")) -async def test_iql_parser_multiple_expression_error(): +async def test_iql_filter_parser_multiple_expression_error(): with pytest.raises(IQLMultipleStatementsError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age\nfilter_by_age", allowed_functions=[ ExposedFunction( @@ -112,9 +113,9 @@ async def test_iql_parser_multiple_expression_error(): assert exc_info.match(re.escape("Multiple statements in IQL are not supported")) -async def test_iql_parser_empty_expression_error(): +async def test_iql_filter_parser_empty_expression_error(): with pytest.raises(IQLNoStatementError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "", allowed_functions=[ ExposedFunction( @@ -130,9 +131,9 @@ async def test_iql_parser_empty_expression_error(): assert exc_info.match(re.escape("Empty IQL")) -async def test_iql_parser_no_expression_error(): +async def test_iql_filter_parser_no_expression_error(): with pytest.raises(IQLNoExpressionError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "import filter_by_age", allowed_functions=[ ExposedFunction( @@ -148,9 +149,9 @@ async def test_iql_parser_no_expression_error(): assert exc_info.match(re.escape("No expression found in IQL: import filter_by_age")) -async def test_iql_parser_unsupported_syntax_error(): +async def test_iql_filter_parser_unsupported_syntax_error(): with pytest.raises(IQLUnsupportedSyntaxError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age() >= 30", allowed_functions=[ ExposedFunction( @@ -166,9 +167,9 @@ async def test_iql_parser_unsupported_syntax_error(): assert exc_info.match(re.escape("Compare syntax is not supported in IQL: filter_by_age() >= 30")) -async def test_iql_parser_method_not_exists(): +async def test_iql_filter_parser_method_not_exists(): with pytest.raises(IQLFunctionNotExists) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_how_old_somebody_is(40)", allowed_functions=[ ExposedFunction( @@ -184,9 +185,9 @@ async def test_iql_parser_method_not_exists(): assert exc_info.match(re.escape("Function filter_by_how_old_somebody_is not exists: filter_by_how_old_somebody_is")) -async def test_iql_parser_incorrect_number_of_arguments_fail(): +async def test_iql_filter_parser_incorrect_number_of_arguments_fail(): with pytest.raises(IQLIncorrectNumberArgumentsError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age('too old', 40)", allowed_functions=[ ExposedFunction( @@ -204,9 +205,9 @@ async def test_iql_parser_incorrect_number_of_arguments_fail(): ) -async def test_iql_parser_argument_validation_fail(): +async def test_iql_filter_parser_argument_validation_fail(): with pytest.raises(IQLArgumentValidationError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age('too old')", allowed_functions=[ ExposedFunction( @@ -222,6 +223,189 @@ async def test_iql_parser_argument_validation_fail(): assert exc_info.match(re.escape("'too old' is not of type int: 'too old'")) +async def test_iql_aggregation_parser(): + parsed = await IQLAggregationQuery.parse( + "mean_age_by_city('Paris')", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert isinstance(parsed.root, syntax.FunctionCall) + assert parsed.root.name == "mean_age_by_city" + assert parsed.root.arguments == ["Paris"] + + +async def test_iql_aggregation_parser_arg_error(): + with pytest.raises(IQLArgumentParsingError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(lambda x: x + 1)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Not a valid IQL argument: lambda x: x + 1")) + + +async def test_iql_aggregation_parser_syntax_error(): + with pytest.raises(IQLSyntaxError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Syntax error in: mean_age_by_city(")) + + +async def test_iql_aggregation_parser_multiple_expression_error(): + with pytest.raises(IQLMultipleStatementsError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city\nmean_age_by_city", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Multiple statements in IQL are not supported")) + + +async def test_iql_aggregation_parser_empty_expression_error(): + with pytest.raises(IQLNoStatementError) as exc_info: + await IQLAggregationQuery.parse( + "", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Empty IQL")) + + +async def test_iql_aggregation_parser_no_expression_error(): + with pytest.raises(IQLNoExpressionError) as exc_info: + await IQLAggregationQuery.parse( + "import mean_age_by_city", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("No expression found in IQL: import mean_age_by_city")) + + +@pytest.mark.parametrize( + "iql, info", + [ + ("mean_age_by_city() >= 30", "Compare syntax is not supported in IQL: mean_age_by_city() >= 30"), + ( + "mean_age_by_city('Paris') and mean_age_by_city('London')", + "BoolOp syntax is not supported in IQL: mean_age_by_city('Paris') and mean_age_by_city('London')", + ), + ( + "mean_age_by_city('Paris') or mean_age_by_city('London')", + "BoolOp syntax is not supported in IQL: mean_age_by_city('Paris') or mean_age_by_city('London')", + ), + ("not mean_age_by_city('Paris')", "UnaryOp syntax is not supported in IQL: not mean_age_by_city('Paris')"), + ], +) +async def test_iql_aggregation_parser_unsupported_syntax_error(iql, info): + with pytest.raises(IQLUnsupportedSyntaxError) as exc_info: + await IQLAggregationQuery.parse( + iql, + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + assert exc_info.match(re.escape(info)) + + +async def test_iql_aggregation_parser_method_not_exists(): + with pytest.raises(IQLFunctionNotExists) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_town()", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Function mean_age_by_town not exists: mean_age_by_town")) + + +async def test_iql_aggregation_parser_incorrect_number_of_arguments_fail(): + with pytest.raises(IQLIncorrectNumberArgumentsError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city('too old')", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match( + re.escape("The method mean_age_by_city has incorrect number of arguments: mean_age_by_city('too old')") + ) + + +async def test_iql_aggregation_parser_argument_validation_fail(): + with pytest.raises(IQLArgumentValidationError): + await IQLAggregationQuery.parse( + "mean_age_by_city(12)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + def test_keywords_lowercase(): rv = IQLProcessor._to_lower_except_in_quotes( """NOT filter1(230) AND (NOT filter_2("NOT ADMIN") AND filter_('IS NOT ADMIN')) OR NOT filter_4()""", diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 00384ec3..69174389 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -9,8 +9,8 @@ from typing import List, Optional, Union from dbally import NOT_GIVEN, NotGiven -from dbally.iql import IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.llms.base import LLM from dbally.llms.clients.base import LLMClient, LLMOptions from dbally.similarity.index import AbstractSimilarityIndex @@ -26,20 +26,26 @@ class MockViewBase(BaseStructuredView): def list_filters(self) -> List[ExposedFunction]: return [] - async def apply_filters(self, filters: IQLQuery) -> None: + def list_aggregations(self) -> List[ExposedFunction]: + return [] + + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... - def execute(self, dry_run=False) -> ViewExecutionResult: + def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult(results=[], context={}) class MockIQLGenerator(IQLGenerator): - def __init__(self, iql: IQLQuery) -> None: - self.iql = iql - super().__init__(llm=MockLLM()) + def __init__(self, state: IQLGeneratorState) -> None: + self.state = state + super().__init__() - async def generate(self, *_, **__) -> IQLQuery: - return self.iql + async def __call__(self, *_, **__) -> IQLGeneratorState: + return self.state class MockViewSelector(ViewSelector): diff --git a/tests/unit/similarity/sample_module/submodule.py b/tests/unit/similarity/sample_module/submodule.py index 42e05c0a..ab4b6c7e 100644 --- a/tests/unit/similarity/sample_module/submodule.py +++ b/tests/unit/similarity/sample_module/submodule.py @@ -3,7 +3,7 @@ from typing_extensions import Annotated from dbally import MethodsBaseView, decorators -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.structured import ViewExecutionResult from tests.unit.mocks import MockSimilarityIndex @@ -20,7 +20,10 @@ def method_foo(self, idx: Annotated[str, index_foo]) -> str: def method_bar(self, city: Annotated[str, index_foo], year: Annotated[int, index_bar]) -> str: return f"hello {city} in {year}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: @@ -39,7 +42,10 @@ def method_qux(self, city: str, year: int) -> str: """ return f"hello {city} in {year}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 058da20b..1d675d84 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -10,8 +10,9 @@ from dbally.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql.syntax import FunctionCall +from dbally.iql_generator.iql_generator import IQLGeneratorState from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector @@ -59,8 +60,16 @@ def execute(self, dry_run=False) -> ViewExecutionResult: def list_filters(self) -> List[ExposedFunction]: return [ExposedFunction("test_filter", "", [])] - def get_iql_generator(self, *_, **__) -> MockIQLGenerator: - return MockIQLGenerator(IQLQuery(FunctionCall("test_filter", []), "test_filter()")) + def get_iql_generator(self) -> MockIQLGenerator: + return MockIQLGenerator( + IQLGeneratorState( + filters=IQLFiltersQuery(FunctionCall("test_filter", []), "test_filter()"), + aggregation=IQLAggregationQuery(FunctionCall("test_aggregation", []), "test_aggregation()"), + ), + ) + + def list_aggregations(self) -> List[ExposedFunction]: + return [ExposedFunction("test_aggregation", "", [])] @pytest.fixture(name="similarity_classes") @@ -291,7 +300,7 @@ async def test_ask_view_selection_single_view() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": "test_filter()"} + assert result.context == {"baz": "qux", "iql": {"aggregation": "test_aggregation()", "filters": "test_filter()"}} async def test_ask_view_selection_multiple_views() -> None: @@ -312,7 +321,7 @@ async def test_ask_view_selection_multiple_views() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": "test_filter()"} + assert result.context == {"baz": "qux", "iql": {"aggregation": "test_aggregation()", "filters": "test_filter()"}} async def test_ask_view_selection_no_views() -> None: diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index 8f583c4c..3a21a1fe 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -1,21 +1,21 @@ -from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat +from dbally.iql_generator.prompt import FILTERS_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.prompt.elements import FewShotExample async def test_iql_prompt_format_default() -> None: prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=[], ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { "role": "system", - "content": "You have access to API that lets you query a database:\n" + "content": "You have access to an API that lets you query a database:\n" "\n\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' @@ -23,7 +23,7 @@ async def test_iql_prompt_format_default() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, }, @@ -35,17 +35,17 @@ async def test_iql_prompt_format_few_shots_injected() -> None: examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { "role": "system", - "content": "You have access to API that lets you query a database:\n" + "content": "You have access to an API that lets you query a database:\n" "\n\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' @@ -53,7 +53,7 @@ async def test_iql_prompt_format_few_shots_injected() -> None: "You MUST use only these methods:\n" "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, }, @@ -67,12 +67,12 @@ async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) - assert len(formatted_prompt.chat) == len(IQL_GENERATION_TEMPLATE.chat) + (len(examples) * 2) + assert len(formatted_prompt.chat) == len(FILTERS_GENERATION_TEMPLATE.chat) + (len(examples) * 2) assert formatted_prompt.chat[1]["role"] == "user" assert formatted_prompt.chat[1]["content"] == examples[0].question assert formatted_prompt.chat[2]["role"] == "assistant" diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 90e17f0d..0defc8e1 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -1,29 +1,23 @@ # mypy: disable-error-code="empty-body" -from unittest.mock import AsyncMock, call, patch +from unittest.mock import AsyncMock, patch import pytest import sqlalchemy from dbally import decorators from dbally.audit.event_tracker import EventTracker -from dbally.iql import IQLError, IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import ( - FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, - IQLGenerationPromptFormat, -) +from dbally.iql import IQLAggregationQuery, IQLError, IQLFiltersQuery +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.views.methods_base import MethodsBaseView from tests.unit.mocks import MockLLM class MockView(MethodsBaseView): - def get_select(self) -> sqlalchemy.Select: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: ... - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False): @@ -56,125 +50,177 @@ def event_tracker() -> EventTracker: @pytest.fixture -def iql_generator(llm: MockLLM) -> IQLGenerator: - return IQLGenerator(llm) +def iql_generator() -> IQLGenerator: + return IQLGenerator() @pytest.mark.asyncio -async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventTracker, view: MockView) -> None: +async def test_iql_generation( + iql_generator: IQLGenerator, + llm: MockLLM, + event_tracker: EventTracker, + view: MockView, +) -> None: filters = view.list_filters() - - decision_format = FilteringDecisionPromptFormat( - question="Mock_question", - ) - generation_format = IQLGenerationPromptFormat( - question="Mock_question", - filters=filters, - ) - - decision_prompt = FILTERING_DECISION_TEMPLATE.format_prompt(decision_format) - generation_prompt = IQL_GENERATION_TEMPLATE.format_prompt(generation_format) + aggregations = view.list_aggregations() + examples = view.list_few_shots() llm_responses = [ "decision: true", "filter_by_id(1)", + "decision: true", + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: - iql = await iql_generator.generate( + iql_filter_parser_response = "filter_by_id(1)" + iql_aggregation_parser_response = "aggregate_by_id()" + + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch( + "dbally.iql.IQLFiltersQuery.parse", AsyncMock(return_value=iql_filter_parser_response) + ) as mock_filters_parse, patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(return_value=iql_aggregation_parser_response) + ) as mock_aggregation_parse: + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, ) - assert iql == "filter_by_id(1)" - iql_generator._llm.generate_text.assert_has_calls( - [ - call( - prompt=decision_prompt, - event_tracker=event_tracker, - options=None, - ), - call( - prompt=generation_prompt, - event_tracker=event_tracker, - options=None, - ), - ] + assert iql == IQLGeneratorState( + filters=iql_filter_parser_response, + aggregation=iql_aggregation_parser_response, ) - mock_parse.assert_called_once_with( - source="filter_by_id(1)", + assert llm.generate_text.call_count == 4 + mock_filters_parse.assert_called_once_with( + source=llm_responses[1], allowed_functions=filters, event_tracker=event_tracker, ) + mock_aggregation_parse.assert_called_once_with( + source=llm_responses[3], + allowed_functions=aggregations, + event_tracker=event_tracker, + ) @pytest.mark.asyncio async def test_iql_generation_error_escalation_after_max_retires( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ + aggregations = view.list_aggregations() + examples = view.list_few_shots() + + llm_responses = [ + "decision: true", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "decision: true", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + ] + iql_filter_parser_responses = [ IQLError("err1", "src1"), IQLError("err2", "src2"), IQLError("err3", "src3"), IQLError("err4", "src4"), ] - llm_responses = [ - "decision: true", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", + iql_aggregation_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + IQLError("err4", "src4"), ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)), pytest.raises(IQLError): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLFiltersQuery.parse", AsyncMock(side_effect=iql_filter_parser_responses)), patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(side_effect=iql_aggregation_parser_responses) + ): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) - assert iql is None - assert iql_generator._llm.generate_text.call_count == 4 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1): + assert iql == IQLGeneratorState( + filters=iql_filter_parser_responses[-1], + aggregation=iql_aggregation_parser_responses[-1], + ) + assert llm.generate_text.call_count == 10 + for i, arg in enumerate(llm.generate_text.call_args_list[2:5], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] + for i, arg in enumerate(llm.generate_text.call_args_list[7:10], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] @pytest.mark.asyncio async def test_iql_generation_response_after_max_retries( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ + aggregations = view.list_aggregations() + examples = view.list_few_shots() + + llm_responses = [ + "decision: true", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "filter_by_id(1)", + "decision: true", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + "aggregate_by_id()", + ] + iql_filter_parser_responses = [ IQLError("err1", "src1"), IQLError("err2", "src2"), IQLError("err3", "src3"), "filter_by_id(1)", ] - llm_responses = [ - "decision: true", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", + iql_aggregation_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLFiltersQuery.parse", AsyncMock(side_effect=iql_filter_parser_responses)), patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(side_effect=iql_aggregation_parser_responses) + ): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) - - assert iql == "filter_by_id(1)" - assert iql_generator._llm.generate_text.call_count == 5 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[2:], start=1): + assert iql == IQLGeneratorState( + filters=iql_filter_parser_responses[-1], + aggregation=iql_aggregation_parser_responses[-1], + ) + assert llm.generate_text.call_count == len(llm_responses) + for i, arg in enumerate(llm.generate_text.call_args_list[2:5], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] + for i, arg in enumerate(llm.generate_text.call_args_list[7:10], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 58959a64..57c0b68a 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -4,8 +4,8 @@ from typing import List, Literal, Tuple from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery -from dbally.views.decorators import view_filter +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery +from dbally.views.decorators import view_aggregation, view_filter from dbally.views.exposed_functions import MethodParamWithTyping from dbally.views.methods_base import MethodsBaseView @@ -25,7 +25,20 @@ def method_foo(self, idx: int) -> None: def method_bar(self, cities: List[str], year: Literal["2023", "2024"], pairs: List[Tuple[str, int]]) -> str: return f"hello {cities} in {year} of {pairs}" - async def apply_filters(self, filters: IQLQuery) -> None: + @view_aggregation() + def method_baz(self) -> None: + """ + Some documentation string + """ + + @view_aggregation() + def method_qux(self, ages: List[int], names: List[str]) -> str: + return f"hello {ages} and {names}" + + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: @@ -53,3 +66,23 @@ def test_list_filters() -> None: assert ( str(method_bar) == "method_bar(cities: List[str], year: Literal['2023', '2024'], pairs: List[Tuple[str, int]])" ) + + +def test_list_aggregations() -> None: + """ + Tests that the list_aggregations method works correctly + """ + mock_view = MockMethodsBase() + aggregations = mock_view.list_aggregations() + assert len(aggregations) == 2 + method_baz = [f for f in aggregations if f.name == "method_baz"][0] + assert method_baz.description == "Some documentation string" + assert method_baz.parameters == [] + assert str(method_baz) == "method_baz() - Some documentation string" + method_qux = [f for f in aggregations if f.name == "method_qux"][0] + assert method_qux.description == "" + assert method_qux.parameters == [ + MethodParamWithTyping("ages", List[int]), + MethodParamWithTyping("names", List[str]), + ] + assert str(method_qux) == "method_qux(ages: List[int], names: List[str])" diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 51eea791..029fe30f 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -1,10 +1,12 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name + import pandas as pd -from dbally.iql import IQLQuery -from dbally.views.decorators import view_filter -from dbally.views.pandas_base import DataFrameBaseView +from dbally.iql import IQLFiltersQuery +from dbally.iql._query import IQLAggregationQuery +from dbally.views.decorators import view_aggregation, view_filter +from dbally.views.pandas_base import Aggregation, AggregationGroup, DataFrameBaseView MOCK_DATA = [ {"name": "Alice", "city": "London", "year": 2020, "age": 30}, @@ -53,13 +55,30 @@ def filter_age(self, age: int) -> pd.Series: def filter_name(self, name: str) -> pd.Series: return self.df["name"] == name + @view_aggregation() + def mean_age_by_city(self) -> AggregationGroup: + return AggregationGroup( + aggregations=[ + Aggregation(column="age", function="mean"), + ], + groupbys="city", + ) + + @view_aggregation() + def count_records(self) -> AggregationGroup: + return AggregationGroup( + aggregations=[ + Aggregation(column="name", function="count"), + ], + ) + async def test_filter_or() -> None: """ Test that the filtering the DataFrame with logical OR works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'filter_city("Berlin") or filter_city("London")', allowed_functions=mock_view.list_filters(), ) @@ -67,6 +86,8 @@ async def test_filter_or() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_BERLIN_OR_LONDON assert result.context["filter_mask"].tolist() == [True, False, True, False, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_and() -> None: @@ -74,7 +95,7 @@ async def test_filter_and() -> None: Test that the filtering the DataFrame with logical AND works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'filter_city("Paris") and filter_year(2020)', allowed_functions=mock_view.list_filters(), ) @@ -82,6 +103,8 @@ async def test_filter_and() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_PARIS_2020 assert result.context["filter_mask"].tolist() == [False, True, False, False, False] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_not() -> None: @@ -89,7 +112,7 @@ async def test_filter_not() -> None: Test that the filtering the DataFrame with logical NOT works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'not (filter_city("Paris") and filter_year(2020))', allowed_functions=mock_view.list_filters(), ) @@ -97,3 +120,67 @@ async def test_filter_not() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_NOT_PARIS_2020 assert result.context["filter_mask"].tolist() == [True, False, True, True, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None + + +async def test_aggregation() -> None: + """ + Test that DataFrame aggregation works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLAggregationQuery.parse( + "count_records()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [ + {"index": "name_count", "name": 5}, + ] + assert result.context["filter_mask"] is None + assert result.context["groupbys"] is None + assert result.context["aggregations"] == [Aggregation(column="name", function="count")] + + +async def test_aggregtion_with_groupby() -> None: + """ + Test that DataFrame aggregation with groupby works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLAggregationQuery.parse( + "mean_age_by_city()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [ + {"city": "Berlin", "age_mean": 45.0}, + {"city": "London", "age_mean": 32.5}, + {"city": "Paris", "age_mean": 32.5}, + ] + assert result.context["filter_mask"] is None + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [Aggregation(column="age", function="mean")] + + +async def test_filters_and_aggregtion() -> None: + """ + Test that DataFrame filtering and aggregation works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLFiltersQuery.parse( + "filter_city('Paris')", + allowed_functions=mock_view.list_filters(), + ) + await mock_view.apply_filters(query) + query = await IQLAggregationQuery.parse( + "mean_age_by_city()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [{"city": "Paris", "age_mean": 32.5}] + assert result.context["filter_mask"].tolist() == [False, True, False, True, False] + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [Aggregation(column="age", function="mean")] diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 079a2135..571e6a70 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -4,8 +4,9 @@ import sqlalchemy -from dbally.iql import IQLQuery -from dbally.views.decorators import view_filter +from dbally.iql import IQLFiltersQuery +from dbally.iql._query import IQLAggregationQuery +from dbally.views.decorators import view_aggregation, view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -28,6 +29,13 @@ def method_foo(self, idx: int) -> sqlalchemy.ColumnElement: async def method_bar(self, city: str, year: int) -> sqlalchemy.ColumnElement: return sqlalchemy.literal(f"hello {city} in {year}") + @view_aggregation() + def method_baz(self) -> sqlalchemy.Select: + """ + Some documentation string + """ + return self.select.add_columns(sqlalchemy.literal("baz")).group_by(sqlalchemy.literal("baz")) + def normalize_whitespace(s: str) -> str: """ @@ -43,10 +51,47 @@ async def test_filter_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'method_foo(1) and method_bar("London", 2020)', allowed_functions=mock_view.list_filters(), ) await mock_view.apply_filters(query) sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) assert sql == "SELECT 'test' AS foo WHERE 1 AND 'hello London in 2020'" + + +async def test_aggregation_sql_generation() -> None: + """ + Tests that the SQL generation based on aggregations works correctly + """ + + mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) + mock_view = MockSqlAlchemyView(mock_connection.engine) + query = await IQLAggregationQuery.parse( + "method_baz()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) + assert sql == "SELECT 'test' AS foo, 'baz' AS anon_1 GROUP BY 'baz'" + + +async def test_filter_and_aggregation_sql_generation() -> None: + """ + Tests that the SQL generation based on filters and aggregations works correctly + """ + + mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) + mock_view = MockSqlAlchemyView(mock_connection.engine) + query = await IQLFiltersQuery.parse( + 'method_foo(1) and method_bar("London", 2020)', + allowed_functions=mock_view.list_filters() + mock_view.list_aggregations(), + ) + await mock_view.apply_filters(query) + query = await IQLAggregationQuery.parse( + "method_baz()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) + assert sql == "SELECT 'test' AS foo, 'baz' AS anon_1 WHERE 1 AND 'hello London in 2020' GROUP BY 'baz'"