diff --git a/benchmarks/sql/bench/pipelines/base.py b/benchmarks/sql/bench/pipelines/base.py index 38bcb304..dc8d83ea 100644 --- a/benchmarks/sql/bench/pipelines/base.py +++ b/benchmarks/sql/bench/pipelines/base.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union +from dbally.iql._exceptions import IQLError +from dbally.iql._query import IQLQuery +from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM +from dbally.llms.clients.exceptions import LLMError from dbally.llms.litellm import LiteLLM from dbally.llms.local import LocalLLM @@ -16,6 +20,25 @@ class IQL: source: Optional[str] = None unsupported: bool = False valid: bool = True + generated: bool = True + + @classmethod + def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL": + """ + Creates an IQL object from the query. + + Args: + query: The IQL query or exception. + + Returns: + The IQL object. + """ + return cls( + source=query.source if isinstance(query, (IQLQuery, IQLError)) else None, + unsupported=isinstance(query, UnsupportedQueryError), + valid=not isinstance(query, IQLError), + generated=not isinstance(query, LLMError), + ) @dataclass @@ -47,6 +70,7 @@ class EvaluationResult: """ db_id: str + question_id: str question: str reference: ExecutionResult prediction: ExecutionResult diff --git a/benchmarks/sql/bench/pipelines/collection.py b/benchmarks/sql/bench/pipelines/collection.py index dfc127cf..19831b0d 100644 --- a/benchmarks/sql/bench/pipelines/collection.py +++ b/benchmarks/sql/bench/pipelines/collection.py @@ -5,10 +5,8 @@ import dbally from dbally.collection.collection import Collection from dbally.collection.exceptions import NoViewFoundError -from dbally.iql._exceptions import IQLError -from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.view_selection.llm_view_selector import LLMViewSelector -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from ..views import VIEWS_REGISTRY from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult @@ -74,44 +72,23 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return_natural_response=False, ) except NoViewFoundError: - prediction = ExecutionResult( - view_name=None, - iql=None, - sql=None, - ) - except IQLGenerationError as exc: + prediction = ExecutionResult() + except ViewExecutionError as exc: prediction = ExecutionResult( view_name=exc.view_name, iql=IQLResult( - filters=IQL( - source=exc.filters, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), - aggregation=IQL( - source=exc.aggregation, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), + filters=IQL.from_query(exc.iql.filters), + aggregation=IQL.from_query(exc.iql.aggregation), ), - sql=None, ) else: prediction = ExecutionResult( view_name=result.view_name, iql=IQLResult( - filters=IQL( - source=result.context.get("iql"), - unsupported=False, - valid=True, - ), - aggregation=IQL( - source=None, - unsupported=False, - valid=True, - ), + filters=IQL(source=result.context["iql"]["filters"]), + aggregation=IQL(source=result.context["iql"]["aggregation"]), ), - sql=result.context.get("sql"), + sql=result.context["sql"], ) reference = ExecutionResult( @@ -134,6 +111,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/pipelines/view.py b/benchmarks/sql/bench/pipelines/view.py index d4ae8515..be9d8263 100644 --- a/benchmarks/sql/bench/pipelines/view.py +++ b/benchmarks/sql/bench/pipelines/view.py @@ -5,9 +5,7 @@ from sqlalchemy import create_engine -from dbally.iql._exceptions import IQLError -from dbally.iql_generator.prompt import UnsupportedQueryError -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from dbally.views.freeform.text2sql.view import BaseText2SQLView from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -94,37 +92,20 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: dry_run=True, n_retries=0, ) - except IQLGenerationError as exc: + except ViewExecutionError as exc: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL( - source=exc.filters, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), - aggregation=IQL( - source=exc.aggregation, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), + filters=IQL.from_query(exc.iql.filters), + aggregation=IQL.from_query(exc.iql.aggregation), ), - sql=None, ) else: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL( - source=result.context["iql"], - unsupported=False, - valid=True, - ), - aggregation=IQL( - source=None, - unsupported=False, - valid=True, - ), + filters=IQL(source=result.context["iql"]["filters"]), + aggregation=IQL(source=result.context["iql"]["aggregation"]), ), sql=result.context["sql"], ) @@ -135,12 +116,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: filters=IQL( source=data["iql_filters"], unsupported=data["iql_filters_unsupported"], - valid=True, ), aggregation=IQL( source=data["iql_aggregation"], unsupported=data["iql_aggregation_unsupported"], - valid=True, ), context=data["iql_context"], ), @@ -149,6 +128,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, @@ -209,6 +189,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index 56932947..2a6a75a0 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -286,12 +286,11 @@ class SuperheroColourFilterMixin: """ def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.eye_colour = aliased(Colour) self.hair_colour = aliased(Colour) self.skin_colour = aliased(Colour) - super().__init__(*args, **kwargs) - @view_filter() def filter_by_eye_colour(self, eye_colour: str) -> ColumnElement: """ @@ -441,11 +440,12 @@ def count_superheroes(self) -> Select: Returns: The superheros count. """ - return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) + return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) class SuperheroView( DBInitMixin, + SqlAlchemyBaseView, SuperheroFilterMixin, SuperheroAggregationMixin, SuperheroColourFilterMixin, @@ -453,7 +453,6 @@ class SuperheroView( GenderFilterMixin, PublisherFilterMixin, RaceFilterMixin, - SqlAlchemyBaseView, ): """ View for querying only superheros data. Contains the superhero id, superhero name, full name, height, weight, diff --git a/src/dbally/exceptions.py b/src/dbally/exceptions.py index 62faac37..6b095cd7 100644 --- a/src/dbally/exceptions.py +++ b/src/dbally/exceptions.py @@ -2,10 +2,3 @@ class DbAllyError(Exception): """ Base class for all exceptions raised by db-ally. """ - - -class UnsupportedAggregationError(DbAllyError): - """ - Error raised when AggregationFormatter is unable to construct a query - with given aggregation. - """ diff --git a/src/dbally/iql/__init__.py b/src/dbally/iql/__init__.py index 0df0a766..20bde9eb 100644 --- a/src/dbally/iql/__init__.py +++ b/src/dbally/iql/__init__.py @@ -1,5 +1,13 @@ from . import syntax from ._exceptions import IQLArgumentParsingError, IQLError, IQLUnsupportedSyntaxError -from ._query import IQLQuery +from ._query import IQLAggregationQuery, IQLFiltersQuery, IQLQuery -__all__ = ["IQLQuery", "syntax", "IQLError", "IQLArgumentParsingError", "IQLUnsupportedSyntaxError"] +__all__ = [ + "IQLQuery", + "IQLFiltersQuery", + "IQLAggregationQuery", + "syntax", + "IQLError", + "IQLArgumentParsingError", + "IQLUnsupportedSyntaxError", +] diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index f1adf64c..1bd72bcc 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,5 +1,6 @@ import ast -from typing import TYPE_CHECKING, Any, List, Optional, Union +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker from dbally.iql import syntax @@ -19,10 +20,12 @@ if TYPE_CHECKING: from dbally.views.structured import ExposedFunction +RootT = TypeVar("RootT", bound=syntax.Node) -class IQLProcessor: + +class IQLProcessor(Generic[RootT], ABC): """ - Parses IQL string to tree structure. + Base class for IQL processors. """ def __init__( @@ -32,9 +35,9 @@ def __init__( self.allowed_functions = {func.name: func for func in allowed_functions} self._event_tracker = event_tracker or EventTracker() - async def process(self) -> syntax.Node: + async def process(self) -> RootT: """ - Process IQL string to root IQL.Node. + Process IQL string to IQL root node. Returns: IQL node which is root of the tree representing IQL query. @@ -60,25 +63,17 @@ async def process(self) -> syntax.Node: return await self._parse_node(ast_tree.body[0].value) - async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: - if isinstance(node, ast.BoolOp): - return await self._parse_bool_op(node) - if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): - return syntax.Not(await self._parse_node(node.operand)) - if isinstance(node, ast.Call): - return await self._parse_call(node) - - raise IQLUnsupportedSyntaxError(node, self.source) + @abstractmethod + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT: + """ + Parses AST node to IQL node. - async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp: - if isinstance(node.op, ast.Not): - return syntax.Not(await self._parse_node(node.values[0])) - if isinstance(node.op, ast.And): - return syntax.And([await self._parse_node(x) for x in node.values]) - if isinstance(node.op, ast.Or): - return syntax.Or([await self._parse_node(x) for x in node.values]) + Args: + node: AST node to parse. - raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp") + Returns: + IQL node. + """ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: func = node.func @@ -153,3 +148,41 @@ def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str: converted_text = converted_text[: len(converted_text) - len(keyword)] + keyword.lower() return converted_text + + +class IQLFiltersProcessor(IQLProcessor[syntax.Node]): + """ + IQL processor for filters. + """ + + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: + if isinstance(node, ast.BoolOp): + return await self._parse_bool_op(node) + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return syntax.Not(await self._parse_node(node.operand)) + if isinstance(node, ast.Call): + return await self._parse_call(node) + + raise IQLUnsupportedSyntaxError(node, self.source) + + async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp: + if isinstance(node.op, ast.Not): + return syntax.Not(await self._parse_node(node.values[0])) + if isinstance(node.op, ast.And): + return syntax.And([await self._parse_node(x) for x in node.values]) + if isinstance(node.op, ast.Or): + return syntax.Or([await self._parse_node(x) for x in node.values]) + + raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp") + + +class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]): + """ + IQL processor for aggregation. + """ + + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall: + if isinstance(node, ast.Call): + return await self._parse_call(node) + + raise IQLUnsupportedSyntaxError(node, self.source) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index dd831a91..57b3b4ed 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,26 +1,29 @@ -from typing import TYPE_CHECKING, List, Optional +from abc import ABC +from typing import TYPE_CHECKING, Generic, List, Optional, Type from ..audit.event_tracker import EventTracker from . import syntax -from ._processor import IQLProcessor +from ._processor import IQLAggregationProcessor, IQLFiltersProcessor, IQLProcessor, RootT if TYPE_CHECKING: from dbally.views.structured import ExposedFunction -class IQLQuery: +class IQLQuery(Generic[RootT], ABC): """ IQLQuery container. It stores IQL as a syntax tree defined in `IQL` class. """ - root: syntax.Node + root: RootT + source: str + _processor: Type[IQLProcessor[RootT]] - def __init__(self, root: syntax.Node, source: str) -> None: + def __init__(self, root: RootT, source: str) -> None: self.root = root - self._source = source + self.source = source def __str__(self) -> str: - return self._source + return self.source @classmethod async def parse( @@ -28,7 +31,7 @@ async def parse( source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None, - ) -> "IQLQuery": + ) -> "IQLQuery[RootT]": """ Parse IQL string to IQLQuery object. @@ -43,5 +46,21 @@ async def parse( Raises: IQLError: If parsing fails. """ - root = await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process() + root = await cls._processor(source, allowed_functions, event_tracker=event_tracker).process() return cls(root=root, source=source) + + +class IQLFiltersQuery(IQLQuery[syntax.Node]): + """ + IQL filters query container. + """ + + _processor: Type[IQLFiltersProcessor] = IQLFiltersProcessor + + +class IQLAggregationQuery(IQLQuery[syntax.FunctionCall]): + """ + IQL aggregation query container. + """ + + _processor: Type[IQLAggregationProcessor] = IQLAggregationProcessor diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 27347734..2222e179 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,11 +1,16 @@ -from typing import List, Optional +import asyncio +from dataclasses import dataclass +from typing import Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker from dbally.iql import IQLError, IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.prompt import ( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, + FILTERS_GENERATION_TEMPLATE, + DecisionPromptFormat, IQLGenerationPromptFormat, ) from dbally.llms.base import LLM @@ -15,57 +20,151 @@ from dbally.prompt.template import PromptTemplate from dbally.views.exposed_functions import ExposedFunction -ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ - generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" +IQLQueryT = TypeVar("IQLQueryT", bound=IQLQuery) -class IQLGenerator: +@dataclass +class IQLGeneratorState: + """ + State of the IQL generator. """ - Class used to generate IQL from natural language question. - In db-ally, LLM uses IQL (Intermediate Query Language) to express complex queries in a simplified way. - The class used to generate IQL from natural language query is `IQLGenerator`. + filters: Optional[Union[IQLFiltersQuery, Exception]] = None + aggregation: Optional[Union[IQLAggregationQuery, Exception]] = None + + @property + def failed(self) -> bool: + """ + Checks if the generation failed. - IQL generation is done using the method `self.generate_iql`. - It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question. + Returns: + True if the generation failed, False otherwise. + """ + return isinstance(self.filters, Exception) or isinstance(self.aggregation, Exception) + + +class IQLGenerator: + """ + Program that orchestrates all IQL operations for the given question. """ def __init__( self, - llm: LLM, - *, - decision_prompt: Optional[PromptTemplate[FilteringDecisionPromptFormat]] = None, - generation_prompt: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None, + filters_generation: Optional["IQLOperationGenerator"] = None, + aggregation_generation: Optional["IQLOperationGenerator"] = None, ) -> None: """ Constructs a new IQLGenerator instance. Args: - llm: LLM used to generate IQL. decision_prompt: Prompt template for filtering decision making. generation_prompt: Prompt template for IQL generation. """ - self._llm = llm - self._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE - self._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE + self._filters_generation = filters_generation or IQLOperationGenerator[IQLFiltersQuery]( + FILTERING_DECISION_TEMPLATE, + FILTERS_GENERATION_TEMPLATE, + ) + self._aggregation_generation = aggregation_generation or IQLOperationGenerator[IQLAggregationQuery]( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, + ) - async def generate( + # pylint: disable=too-many-arguments + async def __call__( self, + *, question: str, filters: List[ExposedFunction], - event_tracker: EventTracker, - examples: Optional[List[FewShotExample]] = None, + aggregations: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - ) -> Optional[IQLQuery]: + ) -> IQLGeneratorState: """ - Generates IQL in text form using LLM. + Generates IQL operations for the given question. Args: question: User question. filters: List of filters exposed by the view. + aggregations: List of aggregations exposed by the view. + examples: List of examples to be injected during filters and aggregation generation. + llm: LLM used to generate IQL. event_tracker: Event store used to audit the generation process. + llm_options: Options to use for the LLM client. + n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. + + Returns: + Generated IQL operations. + """ + filters, aggregation = await asyncio.gather( + self._filters_generation( + question=question, + methods=filters, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ), + self._aggregation_generation( + question=question, + methods=aggregations, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ), + return_exceptions=True, + ) + return IQLGeneratorState( + filters=filters, + aggregation=aggregation, + ) + + +class IQLOperationGenerator(Generic[IQLQueryT]): + """ + Program that generates IQL queries for the given question. + """ + + def __init__( + self, + assessor_prompt: PromptTemplate[DecisionPromptFormat], + generator_prompt: PromptTemplate[IQLGenerationPromptFormat], + ) -> None: + """ + Constructs a new IQLGenerator instance. + + Args: + assessor_prompt: Prompt template for filtering decision making. + generator_prompt: Prompt template for IQL generation. + """ + self.assessor = IQLQuestionAssessor(assessor_prompt) + self.generator = IQLQueryGenerator[IQLQueryT](generator_prompt) + + async def __call__( + self, + *, + question: str, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, + llm_options: Optional[LLMOptions] = None, + n_retries: int = 3, + ) -> Optional[IQLQueryT]: + """ + Generates IQL query for the given question. + + Args: + llm: LLM used to generate IQL. + question: User question. + methods: List of methods exposed by the view. examples: List of examples to be injected into the conversation. + event_tracker: Event store used to audit the generation process. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. @@ -77,38 +176,52 @@ async def generate( IQLError: If IQL parsing fails after all retries. UnsupportedQueryError: If the question is not supported by the view. """ - decision = await self._decide_on_generation( + decision = await self.assessor( question=question, - event_tracker=event_tracker, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) if not decision: return None - return await self._generate_iql( + return await self.generator( question=question, - filters=filters, - event_tracker=event_tracker, + methods=methods, examples=examples, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) - async def _decide_on_generation( + +class IQLQuestionAssessor: + """ + Program that assesses whether a question requires applying IQL operation or not. + """ + + def __init__(self, prompt: PromptTemplate[DecisionPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - event_tracker: EventTracker, + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, ) -> bool: """ - Decides whether the question requires filtering or not. + Decides whether the question requires generating IQL or not. Args: question: User question. - event_tracker: Event store used to audit the generation process. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to LLM API in case of errors. Returns: @@ -117,12 +230,14 @@ async def _decide_on_generation( Raises: LLMError: If LLM text generation fails after all retries. """ - prompt_format = FilteringDecisionPromptFormat(question=question) - formatted_prompt = self._decision_prompt.format_prompt(prompt_format) + prompt_format = DecisionPromptFormat( + question=question, + ) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, @@ -133,24 +248,39 @@ async def _decide_on_generation( if retry == n_retries: raise exc - async def _generate_iql( + +class IQLQueryGenerator(Generic[IQLQueryT]): + """ + Program that generates IQL queries for the given question. + """ + + ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ + generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" + + def __init__(self, prompt: PromptTemplate[IQLGenerationPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - filters: List[ExposedFunction], - event_tracker: Optional[EventTracker] = None, - examples: Optional[List[FewShotExample]] = None, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, - ) -> IQLQuery: + ) -> IQLQueryT: """ - Generates IQL in text form using LLM. + Generates IQL query for the given question. Args: question: User question. filters: List of filters exposed by the view. - event_tracker: Event store used to audit the generation process. examples: List of examples to be injected into the conversation. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. Returns: @@ -163,24 +293,22 @@ async def _generate_iql( """ prompt_format = IQLGenerationPromptFormat( question=question, - filters=filters, + methods=methods, examples=examples, ) - formatted_prompt = self._generation_prompt.format_prompt(prompt_format) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) # TODO: Move response parsing to llm generate_text method - iql = formatted_prompt.response_parser(response) - # TODO: Move IQL query parsing to prompt response parser - return await IQLQuery.parse( - source=iql, - allowed_functions=filters, + return await formatted_prompt.response_parser( + response=response, + allowed_functions=methods, event_tracker=event_tracker, ) except LLMError as exc: @@ -190,4 +318,4 @@ async def _generate_iql( if retry == n_retries: raise exc formatted_prompt = formatted_prompt.add_assistant_message(response) - formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) + formatted_prompt = formatted_prompt.add_user_message(self.ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 4e5a45ec..bf33fbe0 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -1,8 +1,10 @@ # pylint: disable=C0301 -from typing import List +from typing import List, Optional +from dbally.audit.event_tracker import EventTracker from dbally.exceptions import DbAllyError +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.prompt.elements import FewShotExample from dbally.prompt.template import PromptFormat, PromptTemplate from dbally.views.exposed_functions import ExposedFunction @@ -15,26 +17,65 @@ class UnsupportedQueryError(DbAllyError): """ -def _validate_iql_response(llm_response: str) -> str: +async def _iql_filters_parser( + response: str, + allowed_functions: List[ExposedFunction], + event_tracker: Optional[EventTracker] = None, +) -> IQLFiltersQuery: """ - Validates LLM response to IQL + Parses the response from the LLM to IQL. Args: - llm_response: LLM response + response: LLM response. + allowed_functions: List of functions that can be used in the IQL. + event_tracker: Event tracker to be used for auditing. Returns: - A string containing IQL for filters. + IQL query for filters. Raises: - UnsuppotedQueryError: When IQL generator is unable to construct a query - with given filters. + UnsuppotedQueryError: When IQL generator is unable to construct a query with given filters. """ - if "unsupported query" in llm_response.lower(): + if "unsupported query" in response.lower(): raise UnsupportedQueryError - return llm_response + return await IQLFiltersQuery.parse( + source=response, + allowed_functions=allowed_functions, + event_tracker=event_tracker, + ) -def _decision_iql_response_parser(response: str) -> bool: + +async def _iql_aggregation_parser( + response: str, + allowed_functions: List[ExposedFunction], + event_tracker: Optional[EventTracker] = None, +) -> IQLAggregationQuery: + """ + Parses the response from the LLM to IQL. + + Args: + response: LLM response. + allowed_functions: List of functions that can be used in the IQL. + event_tracker: Event tracker to be used for auditing. + + Returns: + IQL query for aggregations. + + Raises: + UnsuppotedQueryError: When IQL generator is unable to construct a query with given aggregations. + """ + if "unsupported query" in response.lower(): + raise UnsupportedQueryError + + return await IQLAggregationQuery.parse( + source=response, + allowed_functions=allowed_functions, + event_tracker=event_tracker, + ) + + +def _decision_parser(response: str) -> bool: """ Parses the response from the decision prompt. @@ -52,7 +93,7 @@ def _decision_iql_response_parser(response: str) -> bool: return "true" in decision -class FilteringDecisionPromptFormat(PromptFormat): +class DecisionPromptFormat(PromptFormat): """ IQL prompt format, providing a question and filters to be used in the conversation. """ @@ -71,44 +112,96 @@ def __init__(self, *, question: str, examples: List[FewShotExample] = None) -> N class IQLGenerationPromptFormat(PromptFormat): """ - IQL prompt format, providing a question and filters to be used in the conversation. + IQL prompt format, providing a question and methods to be used in the conversation. """ def __init__( self, *, question: str, - filters: List[ExposedFunction], - examples: List[FewShotExample] = None, + methods: List[ExposedFunction], + examples: Optional[List[FewShotExample]] = None, ) -> None: """ Constructs a new IQLGenerationPromptFormat instance. Args: question: Question to be asked. - filters: List of filters exposed by the view. + methods: List of filters exposed by the view. examples: List of examples to be injected into the conversation. aggregations: List of aggregations exposed by the view. """ super().__init__(examples) self.question = question - self.filters = "\n".join([str(condition) for condition in filters]) if filters else [] + self.methods = "\n".join([str(condition) for condition in methods]) if methods else [] + + +FILTERING_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( + [ + { + "role": "system", + "content": ( + "Given a question, determine whether the answer requires initial data filtering in order to compute it.\n" + "Initial data filtering is a process in which the result set is reduced to only include the rows " + "that meet certain criteria specified in the question.\n\n" + "---\n\n" + "Follow the following format.\n\n" + "Question: ${{question}}\n" + "Hint: ${{hint}}" + "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" + "Decision: indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)\n\n" + ), + }, + { + "role": "user", + "content": ( + "Question: {question}\n" + "Hint: Look for words indicating data specific features.\n" + "Reasoning: Let's think step by step in order to " + ), + }, + ], + response_parser=_decision_parser, +) +AGGREGATION_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( + [ + { + "role": "system", + "content": ( + "Given a question, determine whether the answer requires computing the aggregation in order to compute it.\n" + "Aggregation is a process in which the result set is reduced to a single value.\n\n" + "---\n\n" + "Follow the following format.\n\n" + "Question: ${{question}}\n" + "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" + "Decision: indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)\n\n" + ), + }, + { + "role": "user", + "content": ("Question: {question}\n" "Reasoning: Let's think step by step in order to "), + }, + ], + response_parser=_decision_parser, +) -IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( +FILTERS_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( [ { "role": "system", "content": ( "You have access to an API that lets you query a database:\n" - "\n{filters}\n" + "\n{methods}\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n{filters}\n" + "\n{methods}\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. " @@ -119,35 +212,31 @@ def __init__( "content": "{question}", }, ], - response_parser=_validate_iql_response, + response_parser=_iql_filters_parser, ) - -FILTERING_DECISION_TEMPLATE = PromptTemplate[FilteringDecisionPromptFormat]( +AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( [ { "role": "system", "content": ( - "Given a question, determine whether the answer requires initial data filtering in order to compute it.\n" - "Initial data filtering is a process in which the result set is reduced to only include the rows " - "that meet certain criteria specified in the question.\n\n" - "---\n\n" - "Follow the following format.\n\n" - "Question: ${{question}}\n" - "Hint: ${{hint}}" - "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" - "Decision: indicates whether the answer to the question requires initial data filtering. " - "(Respond with True or False)\n\n" + "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" + "When prompted for an aggregation, use the following methods: \n" + "{methods}" + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{methods}\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" + "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " + "Structure output to resemble the following pattern:\n" + 'aggregation1("arg1", arg2)\n' ), }, { "role": "user", - "content": ( - "Question: {question}\n" - "Hint: Look for words indicating data specific features.\n" - "Reasoning: Let's think step by step in order to " - ), + "content": "{question}", }, ], - response_parser=_decision_iql_response_parser, + response_parser=_iql_aggregation_parser, ) diff --git a/src/dbally/prompt/aggregation.py b/src/dbally/prompt/aggregation.py deleted file mode 100644 index 8dedd95c..00000000 --- a/src/dbally/prompt/aggregation.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import List, Optional - -from dbally.audit import EventTracker -from dbally.exceptions import UnsupportedAggregationError -from dbally.iql import IQLQuery -from dbally.llms.base import LLM -from dbally.llms.clients import LLMOptions -from dbally.prompt.template import PromptFormat, PromptTemplate -from dbally.views.exposed_functions import ExposedFunction - - -def _validate_agg_response(llm_response: str) -> str: - """ - Validates LLM response to IQL - - Args: - llm_response: LLM response - - Returns: - A string containing aggregations. - - Raises: - UnsupportedAggregationError: When IQL generator is unable to construct a query - with given aggregation. - """ - if "unsupported query" in llm_response.lower(): - raise UnsupportedAggregationError - return llm_response - - -class AggregationPromptFormat(PromptFormat): - """ - Aggregation prompt format, providing a question and aggregation to be used in the conversation. - """ - - def __init__( - self, - question: str, - aggregations: List[ExposedFunction] = None, - ) -> None: - super().__init__() - self.question = question - self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else [] - - -class AggregationFormatter: - """ - Class used to manage choice and formatting of aggregation based on natural language question. - """ - - def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[AggregationPromptFormat]] = None) -> None: - """ - Constructs a new AggregationFormatter instance. - - Args: - llm: LLM used to generate IQL - prompt_template: If not provided by the users is set to `AGGREGATION_GENERATION_TEMPLATE` - """ - self._llm = llm - self._prompt_template = prompt_template or AGGREGATION_GENERATION_TEMPLATE - - async def format_to_query_object( - self, - question: str, - event_tracker: EventTracker, - aggregations: List[ExposedFunction] = None, - llm_options: Optional[LLMOptions] = None, - ) -> IQLQuery: - """ - Generates IQL in text form using LLM. - - Args: - question: User question. - event_tracker: Event store used to audit the generation process. - aggregations: List of aggregations exposed by the view. - llm_options: Options to use for the LLM client. - - Returns: - Generated aggregation query. - """ - prompt_format = AggregationPromptFormat( - question=question, - aggregations=aggregations, - ) - - formatted_prompt = self._prompt_template.format_prompt(prompt_format) - - response = await self._llm.generate_text( - prompt=formatted_prompt, - event_tracker=event_tracker, - options=llm_options, - ) - # TODO: Move response parsing to llm generate_text method - agg = formatted_prompt.response_parser(response) - # TODO: Move IQL query parsing to prompt response parser - return await IQLQuery.parse( - source=agg, - allowed_functions=aggregations or [], - event_tracker=event_tracker, - ) - - -AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[AggregationPromptFormat]( - [ - { - "role": "system", - "content": "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" - "When prompted for an aggregation, use the following methods: \n" - "{aggregations}" - "DO NOT INCLUDE arguments names in your response. Only the values.\n" - "You MUST use only these methods:\n" - "\n{aggregations}\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" - "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " - "Structure output to resemble the following pattern:\n" - 'aggregation1("arg1", arg2)\n', - }, - {"role": "user", "content": "{question}"}, - ], - response_parser=_validate_agg_response, -) diff --git a/src/dbally/prompt/template.py b/src/dbally/prompt/template.py index 124a3e1c..b4ef650d 100644 --- a/src/dbally/prompt/template.py +++ b/src/dbally/prompt/template.py @@ -1,6 +1,6 @@ import copy import re -from typing import Callable, Dict, Generic, List, TypeVar +from typing import Callable, Dict, Generic, List, Optional, TypeVar from typing_extensions import Self @@ -55,7 +55,7 @@ class PromptFormat: Generic format for prompts allowing to inject few shot examples into the conversation. """ - def __init__(self, examples: List[FewShotExample] = None) -> None: + def __init__(self, examples: Optional[List[FewShotExample]] = None) -> None: """ Constructs a new PromptFormat instance. diff --git a/src/dbally/views/exceptions.py b/src/dbally/views/exceptions.py index 277064a4..15770e9a 100644 --- a/src/dbally/views/exceptions.py +++ b/src/dbally/views/exceptions.py @@ -1,26 +1,22 @@ -from typing import Optional - from dbally.exceptions import DbAllyError +from dbally.iql_generator.iql_generator import IQLGeneratorState -class IQLGenerationError(DbAllyError): +class ViewExecutionError(DbAllyError): """ - Exception for when an error occurs while generating IQL for a view. + Exception for when an error occurs while executing a view. """ def __init__( self, view_name: str, - filters: Optional[str] = None, - aggregation: Optional[str] = None, + iql: IQLGeneratorState, ) -> None: """ Args: view_name: Name of the view that caused the error. - filters: Filters generated by the view. - aggregation: Aggregation generated by the view. + iql: View IQL generator state. """ - super().__init__(f"Error while generating IQL for view {view_name}") + super().__init__(f"Error while executing view {view_name}") self.view_name = view_name - self.filters = filters - self.aggregation = aggregation + self.iql = iql diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 977a2fa1..2a2c5d8e 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -1,15 +1,15 @@ import inspect import textwrap from abc import ABC -from typing import Any, Callable, Generic, List, Tuple +from typing import Any, Callable, List, Tuple from dbally.iql import syntax from dbally.views import decorators from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from dbally.views.structured import BaseStructuredView, DataT +from dbally.views.structured import BaseStructuredView -class MethodsBaseView(Generic[DataT], BaseStructuredView, ABC): +class MethodsBaseView(BaseStructuredView, ABC): """ Base class for views that use view methods to expose filters. """ @@ -110,7 +110,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: return await method(*args) return method(*args) - async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataT: + async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: """ Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index 5f7bc8ce..c35fd30f 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -1,15 +1,17 @@ import asyncio from functools import reduce -from typing import Optional +from typing import List, Optional, Union import pandas as pd +from sqlalchemy import Tuple from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery, syntax +from dbally.iql import syntax +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.methods_base import MethodsBaseView -class DataFrameBaseView(MethodsBaseView[pd.DataFrame]): +class DataFrameBaseView(MethodsBaseView): """ Base class for views that use Pandas DataFrames to store and filter data. @@ -24,35 +26,31 @@ def __init__(self, df: pd.DataFrame) -> None: Args: df: Pandas DataFrame with the data to be filtered. """ - super().__init__(df) - - # The mask to be applied to the dataframe to filter the data + super().__init__() + self.df = df self._filter_mask: Optional[pd.Series] = None + self._groupbys: Optional[Union[str, List[str]]] = None + self._aggregations: Optional[List[Tuple[str, str]]] = None - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: filters: IQLQuery object representing the filters to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self._filter_mask = await self.build_filter_node(filters.root) - self.data = self.data.loc[self._filter_mask] + self._filter_mask = await self._build_filter_node(filters.root) - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ Applies the aggregation of choice to the view. Args: aggregation: IQLQuery object representing the aggregation to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = await self.call_aggregation_method(aggregation.root) + self._groupbys, self._aggregations = await self.call_aggregation_method(aggregation.root) - async def build_filter_node(self, node: syntax.Node) -> pd.Series: + async def _build_filter_node(self, node: syntax.Node) -> pd.Series: """ Converts a filter node from the IQLQuery to a Pandas Series representing a boolean mask to be applied to the dataframe. @@ -69,13 +67,13 @@ async def build_filter_node(self, node: syntax.Node) -> pd.Series: if isinstance(node, syntax.FunctionCall): return await self.call_filter_method(node) if isinstance(node, syntax.And): # logical AND - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x & y, children) if isinstance(node, syntax.Or): # logical OR - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x | y, children) if isinstance(node, syntax.Not): - child = await self.build_filter_node(node.child) + child = await self._build_filter_node(node.child) return ~child raise ValueError(f"Unsupported grammar: {node}") @@ -90,11 +88,25 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: Returns: ExecutionResult object with the results and the context information with the binary mask. """ - results = pd.DataFrame.empty if dry_run else self.data + results = pd.DataFrame() + + if not dry_run: + results = self.df + if self._filter_mask is not None: + results = results.loc[self._filter_mask] + + if self._groupbys is not None: + results = results.groupby(self._groupbys) + + if self._aggregations is not None: + results = results.agg(**{"_".join(agg): agg for agg in self._aggregations}) + results = results.reset_index() return ViewExecutionResult( results=results.to_dict(orient="records"), context={ "filter_mask": self._filter_mask, + "groupbys": self._groupbys, + "aggregations": self._aggregations, }, ) diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index 4863aa6f..691797a1 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -4,11 +4,12 @@ import sqlalchemy from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery, syntax +from dbally.iql import syntax +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.methods_base import MethodsBaseView -class SqlAlchemyBaseView(MethodsBaseView[sqlalchemy.Select]): +class SqlAlchemyBaseView(MethodsBaseView): """ Base class for views that use SQLAlchemy to generate SQL queries. """ @@ -20,7 +21,8 @@ def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None: Args: sqlalchemy_engine: SQLAlchemy engine to use for executing the queries. """ - super().__init__(self.get_select()) + super().__init__() + self.select = self.get_select() self._sqlalchemy_engine = sqlalchemy_engine @abc.abstractmethod @@ -32,27 +34,23 @@ def get_select(self) -> sqlalchemy.Select: SQLAlchemy Select object for the view. """ - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: filters: IQLQuery object representing the filters to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = self.data.where(await self._build_filter_node(filters.root)) + self.select = self.select.where(await self._build_filter_node(filters.root)) - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ Applies the chosen aggregation to the view. Args: aggregation: IQLQuery object representing the aggregation to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = await self.call_aggregation_method(aggregation.root) + self.select = await self.call_aggregation_method(aggregation.root) async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement: """ @@ -95,11 +93,11 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored. """ results = [] - sql = str(self.data.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) + sql = str(self.select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) if not dry_run: with self._sqlalchemy_engine.connect() as connection: - rows = connection.execute(self.data).fetchall() + rows = connection.execute(self.select).fetchall() # The underscore is used by sqlalchemy to avoid conflicts with column names # pylint: disable=protected-access results = [dict(row._mapping) for row in rows] diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c3ac91e0..bab0f7b3 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -1,60 +1,34 @@ import abc from collections import defaultdict -from typing import Any, Dict, List, Optional, TypeVar +from typing import Dict, List, Optional from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.exceptions import UnsupportedAggregationError -from dbally.iql import IQLQuery -from dbally.iql._exceptions import IQLError +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from dbally.views.exposed_functions import ExposedFunction -from ..prompt.aggregation import AggregationFormatter from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation -DataT = TypeVar("DataT", bound=Any) - -# TODO(Python 3.9+): Make BaseStructuredView a generic class class BaseStructuredView(BaseView): """ - Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\ + Base class for all structured views. All classes implementing this interface has\ to be able to list all available filters, apply them and execute queries. """ - def __init__(self, data: DataT) -> None: - super().__init__() - self.data = data - - def get_iql_generator(self, llm: LLM) -> IQLGenerator: + def get_iql_generator(self) -> IQLGenerator: """ Returns the IQL generator for the view. - Args: - llm: LLM used to generate the IQL queries. - Returns: IQL generator for the view. """ - return IQLGenerator(llm=llm) - - def get_agg_formatter(self, llm: LLM) -> AggregationFormatter: - """ - Returns the AggregtionFormatter for the view. - - Args: - llm: LLM used to generate the queries. - - Returns: - AggregtionFormatter for the view. - """ - return AggregationFormatter(llm=llm) + return IQLGenerator() async def ask( self, @@ -81,68 +55,41 @@ async def ask( The result of the query. Raises: - LLMError: If LLM text generation API fails. - IQLGenerationError: If the IQL generation fails. + ViewExecutionError: When an error occurs while executing the view. """ - iql_generator = self.get_iql_generator(llm) - agg_formatter = self.get_agg_formatter(llm) filters = self.list_filters() examples = self.list_few_shots() aggregations = self.list_aggregations() - try: - iql = await iql_generator.generate( - question=query, - filters=filters, - examples=examples, - event_tracker=event_tracker, - llm_options=llm_options, - n_retries=n_retries, - ) - except UnsupportedQueryError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=None, - aggregation=None, - ) from exc - except IQLError as exc: - raise IQLGenerationError( + iql_generator = self.get_iql_generator() + iql = await iql_generator( + question=query, + filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, + event_tracker=event_tracker, + llm_options=llm_options, + n_retries=n_retries, + ) + + if iql.failed: + raise ViewExecutionError( view_name=self.__class__.__name__, - filters=exc.source, - aggregation=None, - ) from exc - - if iql: - await self.apply_filters(iql) - - try: - agg_node = await agg_formatter.format_to_query_object( - question=query, - aggregations=aggregations, - event_tracker=event_tracker, - llm_options=llm_options, + iql=iql, ) - except UnsupportedAggregationError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=str(iql) if iql else None, - aggregation=None, - ) from exc - except IQLError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=str(iql) if iql else None, - aggregation=exc.source, - ) from exc - await self.apply_aggregation(agg_node) + if iql.filters: + await self.apply_filters(iql.filters) + + if iql.aggregation: + await self.apply_aggregation(iql.aggregation) result = self.execute(dry_run=dry_run) result.context["iql"] = { - "filters": str(iql) if iql else None, - "aggregation": str(agg_node), + "filters": str(iql.filters) if iql.filters else None, + "aggregation": str(iql.aggregation) if iql.aggregation else None, } - return result @abc.abstractmethod @@ -164,21 +111,21 @@ def list_aggregations(self) -> List[ExposedFunction]: """ @abc.abstractmethod - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: - filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply. + filters: IQLQuery object representing the filters to apply. """ @abc.abstractmethod - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ Applies the chosen aggregation to the view. Args: - aggregation: [IQLQuery](../../concepts/iql.md) object representing the filters to apply. + aggregation: IQLQuery object representing the aggregation to apply. """ @abc.abstractmethod diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index ae5d2269..bed83d0a 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -3,7 +3,7 @@ import pytest -from dbally.iql import IQLArgumentParsingError, IQLQuery, IQLUnsupportedSyntaxError, syntax +from dbally.iql import IQLArgumentParsingError, IQLUnsupportedSyntaxError, syntax from dbally.iql._exceptions import ( IQLArgumentValidationError, IQLFunctionNotExists, @@ -14,11 +14,12 @@ IQLSyntaxError, ) from dbally.iql._processor import IQLProcessor +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -async def test_iql_parser(): - parsed = await IQLQuery.parse( +async def test_iql_filter_parser(): + parsed = await IQLFiltersQuery.parse( "not (filter_by_name(['John', 'Anne']) and filter_by_city('cracow') and filter_by_company('deepsense.ai'))", allowed_functions=[ ExposedFunction( @@ -51,9 +52,9 @@ async def test_iql_parser(): assert company_filter.arguments[0] == "deepsense.ai" -async def test_iql_parser_arg_error(): +async def test_iql_filter_parser_arg_error(): with pytest.raises(IQLArgumentParsingError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_city('Cracow') and filter_by_name(lambda x: x + 1)", allowed_functions=[ ExposedFunction( @@ -76,9 +77,9 @@ async def test_iql_parser_arg_error(): assert exc_info.match(re.escape("Not a valid IQL argument: lambda x: x + 1")) -async def test_iql_parser_syntax_error(): +async def test_iql_filter_parser_syntax_error(): with pytest.raises(IQLSyntaxError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age(", allowed_functions=[ ExposedFunction( @@ -94,9 +95,9 @@ async def test_iql_parser_syntax_error(): assert exc_info.match(re.escape("Syntax error in: filter_by_age(")) -async def test_iql_parser_multiple_expression_error(): +async def test_iql_filter_parser_multiple_expression_error(): with pytest.raises(IQLMultipleStatementsError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age\nfilter_by_age", allowed_functions=[ ExposedFunction( @@ -112,9 +113,9 @@ async def test_iql_parser_multiple_expression_error(): assert exc_info.match(re.escape("Multiple statements in IQL are not supported")) -async def test_iql_parser_empty_expression_error(): +async def test_iql_filter_parser_empty_expression_error(): with pytest.raises(IQLNoStatementError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "", allowed_functions=[ ExposedFunction( @@ -130,9 +131,9 @@ async def test_iql_parser_empty_expression_error(): assert exc_info.match(re.escape("Empty IQL")) -async def test_iql_parser_no_expression_error(): +async def test_iql_filter_parser_no_expression_error(): with pytest.raises(IQLNoExpressionError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "import filter_by_age", allowed_functions=[ ExposedFunction( @@ -148,9 +149,9 @@ async def test_iql_parser_no_expression_error(): assert exc_info.match(re.escape("No expression found in IQL: import filter_by_age")) -async def test_iql_parser_unsupported_syntax_error(): +async def test_iql_filter_parser_unsupported_syntax_error(): with pytest.raises(IQLUnsupportedSyntaxError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age() >= 30", allowed_functions=[ ExposedFunction( @@ -166,9 +167,9 @@ async def test_iql_parser_unsupported_syntax_error(): assert exc_info.match(re.escape("Compare syntax is not supported in IQL: filter_by_age() >= 30")) -async def test_iql_parser_method_not_exists(): +async def test_iql_filter_parser_method_not_exists(): with pytest.raises(IQLFunctionNotExists) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_how_old_somebody_is(40)", allowed_functions=[ ExposedFunction( @@ -184,9 +185,9 @@ async def test_iql_parser_method_not_exists(): assert exc_info.match(re.escape("Function filter_by_how_old_somebody_is not exists: filter_by_how_old_somebody_is")) -async def test_iql_parser_incorrect_number_of_arguments_fail(): +async def test_iql_filter_parser_incorrect_number_of_arguments_fail(): with pytest.raises(IQLIncorrectNumberArgumentsError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age('too old', 40)", allowed_functions=[ ExposedFunction( @@ -204,9 +205,9 @@ async def test_iql_parser_incorrect_number_of_arguments_fail(): ) -async def test_iql_parser_argument_validation_fail(): +async def test_iql_filter_parser_argument_validation_fail(): with pytest.raises(IQLArgumentValidationError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age('too old')", allowed_functions=[ ExposedFunction( @@ -222,6 +223,189 @@ async def test_iql_parser_argument_validation_fail(): assert exc_info.match(re.escape("'too old' is not of type int: 'too old'")) +async def test_iql_aggregation_parser(): + parsed = await IQLAggregationQuery.parse( + "mean_age_by_city('Paris')", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert isinstance(parsed.root, syntax.FunctionCall) + assert parsed.root.name == "mean_age_by_city" + assert parsed.root.arguments == ["Paris"] + + +async def test_iql_aggregation_parser_arg_error(): + with pytest.raises(IQLArgumentParsingError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(lambda x: x + 1)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Not a valid IQL argument: lambda x: x + 1")) + + +async def test_iql_aggregation_parser_syntax_error(): + with pytest.raises(IQLSyntaxError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Syntax error in: mean_age_by_city(")) + + +async def test_iql_aggregation_parser_multiple_expression_error(): + with pytest.raises(IQLMultipleStatementsError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city\nmean_age_by_city", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Multiple statements in IQL are not supported")) + + +async def test_iql_aggregation_parser_empty_expression_error(): + with pytest.raises(IQLNoStatementError) as exc_info: + await IQLAggregationQuery.parse( + "", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Empty IQL")) + + +async def test_iql_aggregation_parser_no_expression_error(): + with pytest.raises(IQLNoExpressionError) as exc_info: + await IQLAggregationQuery.parse( + "import mean_age_by_city", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("No expression found in IQL: import mean_age_by_city")) + + +@pytest.mark.parametrize( + "iql, info", + [ + ("mean_age_by_city() >= 30", "Compare syntax is not supported in IQL: mean_age_by_city() >= 30"), + ( + "mean_age_by_city('Paris') and mean_age_by_city('London')", + "BoolOp syntax is not supported in IQL: mean_age_by_city('Paris') and mean_age_by_city('London')", + ), + ( + "mean_age_by_city('Paris') or mean_age_by_city('London')", + "BoolOp syntax is not supported in IQL: mean_age_by_city('Paris') or mean_age_by_city('London')", + ), + ("not mean_age_by_city('Paris')", "UnaryOp syntax is not supported in IQL: not mean_age_by_city('Paris')"), + ], +) +async def test_iql_aggregation_parser_unsupported_syntax_error(iql, info): + with pytest.raises(IQLUnsupportedSyntaxError) as exc_info: + await IQLAggregationQuery.parse( + iql, + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + assert exc_info.match(re.escape(info)) + + +async def test_iql_aggregation_parser_method_not_exists(): + with pytest.raises(IQLFunctionNotExists) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_town()", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Function mean_age_by_town not exists: mean_age_by_town")) + + +async def test_iql_aggregation_parser_incorrect_number_of_arguments_fail(): + with pytest.raises(IQLIncorrectNumberArgumentsError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city('too old')", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match( + re.escape("The method mean_age_by_city has incorrect number of arguments: mean_age_by_city('too old')") + ) + + +async def test_iql_aggregation_parser_argument_validation_fail(): + with pytest.raises(IQLArgumentValidationError): + await IQLAggregationQuery.parse( + "mean_age_by_city(12)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + def test_keywords_lowercase(): rv = IQLProcessor._to_lower_except_in_quotes( """NOT filter1(230) AND (NOT filter_2("NOT ADMIN") AND filter_('IS NOT ADMIN')) OR NOT filter_4()""", diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 992fd03d..69174389 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -9,11 +9,10 @@ from typing import List, Optional, Union from dbally import NOT_GIVEN, NotGiven -from dbally.iql import IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.llms.base import LLM from dbally.llms.clients.base import LLMClient, LLMOptions -from dbally.prompt.aggregation import AggregationFormatter from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.structured import BaseStructuredView, ExposedFunction, ViewExecutionResult @@ -24,19 +23,16 @@ class MockViewBase(BaseStructuredView): Mock view base class """ - def __init__(self) -> None: - super().__init__([]) - def list_filters(self) -> List[ExposedFunction]: return [] - async def apply_filters(self, filters: IQLQuery) -> None: - ... - def list_aggregations(self) -> List[ExposedFunction]: return [] - async def apply_aggregation(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: @@ -44,21 +40,12 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: class MockIQLGenerator(IQLGenerator): - def __init__(self, iql: IQLQuery) -> None: - self.iql = iql - super().__init__(llm=MockLLM()) - - async def generate(self, *_, **__) -> IQLQuery: - return self.iql - - -class MockAggregationFormatter(AggregationFormatter): - def __init__(self, iql_query: IQLQuery) -> None: - self.iql_query = iql_query - super().__init__(llm=MockLLM()) + def __init__(self, state: IQLGeneratorState) -> None: + self.state = state + super().__init__() - async def format_to_query_object(self, *_, **__) -> IQLQuery: - return self.iql_query + async def __call__(self, *_, **__) -> IQLGeneratorState: + return self.state class MockViewSelector(ViewSelector): diff --git a/tests/unit/similarity/sample_module/submodule.py b/tests/unit/similarity/sample_module/submodule.py index 42e05c0a..ab4b6c7e 100644 --- a/tests/unit/similarity/sample_module/submodule.py +++ b/tests/unit/similarity/sample_module/submodule.py @@ -3,7 +3,7 @@ from typing_extensions import Annotated from dbally import MethodsBaseView, decorators -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.structured import ViewExecutionResult from tests.unit.mocks import MockSimilarityIndex @@ -20,7 +20,10 @@ def method_foo(self, idx: Annotated[str, index_foo]) -> str: def method_bar(self, city: Annotated[str, index_foo], year: Annotated[int, index_bar]) -> str: return f"hello {city} in {year}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: @@ -39,7 +42,10 @@ def method_qux(self, city: str, year: int) -> str: """ return f"hello {city} in {year}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index a077286d..1d675d84 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -10,17 +10,11 @@ from dbally.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql.syntax import FunctionCall +from dbally.iql_generator.iql_generator import IQLGeneratorState from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from tests.unit.mocks import ( - MockAggregationFormatter, - MockIQLGenerator, - MockLLM, - MockSimilarityIndex, - MockViewBase, - MockViewSelector, -) +from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector class MockView1(MockViewBase): @@ -66,15 +60,17 @@ def execute(self, dry_run=False) -> ViewExecutionResult: def list_filters(self) -> List[ExposedFunction]: return [ExposedFunction("test_filter", "", [])] - def get_iql_generator(self, *_, **__) -> MockIQLGenerator: - return MockIQLGenerator(IQLQuery(FunctionCall("test_filter", []), "test_filter()")) + def get_iql_generator(self) -> MockIQLGenerator: + return MockIQLGenerator( + IQLGeneratorState( + filters=IQLFiltersQuery(FunctionCall("test_filter", []), "test_filter()"), + aggregation=IQLAggregationQuery(FunctionCall("test_aggregation", []), "test_aggregation()"), + ), + ) def list_aggregations(self) -> List[ExposedFunction]: return [ExposedFunction("test_aggregation", "", [])] - def get_agg_formatter(self, *_, **__) -> MockAggregationFormatter: - return MockAggregationFormatter(IQLQuery(FunctionCall("test_aggregation", []), "test_aggregation()")) - @pytest.fixture(name="similarity_classes") def mock_similarity_classes() -> ( diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index b798e533..a2bf23c4 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -1,14 +1,14 @@ -from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat +from dbally.iql_generator.prompt import FILTERS_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.prompt.elements import FewShotExample async def test_iql_prompt_format_default() -> None: prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=[], ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { @@ -35,10 +35,10 @@ async def test_iql_prompt_format_few_shots_injected() -> None: examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { @@ -67,12 +67,12 @@ async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) - assert len(formatted_prompt.chat) == len(IQL_GENERATION_TEMPLATE.chat) + (len(examples) * 2) + assert len(formatted_prompt.chat) == len(FILTERS_GENERATION_TEMPLATE.chat) + (len(examples) * 2) assert formatted_prompt.chat[1]["role"] == "user" assert formatted_prompt.chat[1]["content"] == examples[0].question assert formatted_prompt.chat[2]["role"] == "assistant" diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index b95fe585..0defc8e1 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -1,35 +1,23 @@ # mypy: disable-error-code="empty-body" -from unittest.mock import AsyncMock, call, patch +from unittest.mock import AsyncMock, patch import pytest import sqlalchemy from dbally import decorators from dbally.audit.event_tracker import EventTracker -from dbally.iql import IQLError, IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import ( - FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, - IQLGenerationPromptFormat, -) +from dbally.iql import IQLAggregationQuery, IQLError, IQLFiltersQuery +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.views.methods_base import MethodsBaseView from tests.unit.mocks import MockLLM class MockView(MethodsBaseView): - def __init__(self) -> None: - super().__init__(None) - - def get_select(self) -> sqlalchemy.Select: - ... - - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: ... - async def apply_aggregation(self, filters: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False): @@ -62,125 +50,177 @@ def event_tracker() -> EventTracker: @pytest.fixture -def iql_generator(llm: MockLLM) -> IQLGenerator: - return IQLGenerator(llm) +def iql_generator() -> IQLGenerator: + return IQLGenerator() @pytest.mark.asyncio -async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventTracker, view: MockView) -> None: +async def test_iql_generation( + iql_generator: IQLGenerator, + llm: MockLLM, + event_tracker: EventTracker, + view: MockView, +) -> None: filters = view.list_filters() - - decision_format = FilteringDecisionPromptFormat( - question="Mock_question", - ) - generation_format = IQLGenerationPromptFormat( - question="Mock_question", - filters=filters, - ) - - decision_prompt = FILTERING_DECISION_TEMPLATE.format_prompt(decision_format) - generation_prompt = IQL_GENERATION_TEMPLATE.format_prompt(generation_format) + aggregations = view.list_aggregations() + examples = view.list_few_shots() llm_responses = [ "decision: true", "filter_by_id(1)", + "decision: true", + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: - iql = await iql_generator.generate( + iql_filter_parser_response = "filter_by_id(1)" + iql_aggregation_parser_response = "aggregate_by_id()" + + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch( + "dbally.iql.IQLFiltersQuery.parse", AsyncMock(return_value=iql_filter_parser_response) + ) as mock_filters_parse, patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(return_value=iql_aggregation_parser_response) + ) as mock_aggregation_parse: + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, ) - assert iql == "filter_by_id(1)" - iql_generator._llm.generate_text.assert_has_calls( - [ - call( - prompt=decision_prompt, - event_tracker=event_tracker, - options=None, - ), - call( - prompt=generation_prompt, - event_tracker=event_tracker, - options=None, - ), - ] + assert iql == IQLGeneratorState( + filters=iql_filter_parser_response, + aggregation=iql_aggregation_parser_response, ) - mock_parse.assert_called_once_with( - source="filter_by_id(1)", + assert llm.generate_text.call_count == 4 + mock_filters_parse.assert_called_once_with( + source=llm_responses[1], allowed_functions=filters, event_tracker=event_tracker, ) + mock_aggregation_parse.assert_called_once_with( + source=llm_responses[3], + allowed_functions=aggregations, + event_tracker=event_tracker, + ) @pytest.mark.asyncio async def test_iql_generation_error_escalation_after_max_retires( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ + aggregations = view.list_aggregations() + examples = view.list_few_shots() + + llm_responses = [ + "decision: true", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "decision: true", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + ] + iql_filter_parser_responses = [ IQLError("err1", "src1"), IQLError("err2", "src2"), IQLError("err3", "src3"), IQLError("err4", "src4"), ] - llm_responses = [ - "decision: true", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", + iql_aggregation_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + IQLError("err4", "src4"), ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)), pytest.raises(IQLError): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLFiltersQuery.parse", AsyncMock(side_effect=iql_filter_parser_responses)), patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(side_effect=iql_aggregation_parser_responses) + ): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) - assert iql is None - assert iql_generator._llm.generate_text.call_count == 4 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1): + assert iql == IQLGeneratorState( + filters=iql_filter_parser_responses[-1], + aggregation=iql_aggregation_parser_responses[-1], + ) + assert llm.generate_text.call_count == 10 + for i, arg in enumerate(llm.generate_text.call_args_list[2:5], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] + for i, arg in enumerate(llm.generate_text.call_args_list[7:10], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] @pytest.mark.asyncio async def test_iql_generation_response_after_max_retries( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ + aggregations = view.list_aggregations() + examples = view.list_few_shots() + + llm_responses = [ + "decision: true", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "filter_by_id(1)", + "decision: true", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + "aggregate_by_id()", + ] + iql_filter_parser_responses = [ IQLError("err1", "src1"), IQLError("err2", "src2"), IQLError("err3", "src3"), "filter_by_id(1)", ] - llm_responses = [ - "decision: true", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", + iql_aggregation_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLFiltersQuery.parse", AsyncMock(side_effect=iql_filter_parser_responses)), patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(side_effect=iql_aggregation_parser_responses) + ): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) - - assert iql == "filter_by_id(1)" - assert iql_generator._llm.generate_text.call_count == 5 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[2:], start=1): + assert iql == IQLGeneratorState( + filters=iql_filter_parser_responses[-1], + aggregation=iql_aggregation_parser_responses[-1], + ) + assert llm.generate_text.call_count == len(llm_responses) + for i, arg in enumerate(llm.generate_text.call_args_list[2:5], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] + for i, arg in enumerate(llm.generate_text.call_args_list[7:10], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 8d90ffc3..57c0b68a 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -4,7 +4,7 @@ from typing import List, Literal, Tuple from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.exposed_functions import MethodParamWithTyping from dbally.views.methods_base import MethodsBaseView @@ -15,9 +15,6 @@ class MockMethodsBase(MethodsBaseView): Mock class for testing the MethodsBaseView """ - def __init__(self) -> None: - super().__init__(None) - @view_filter() def method_foo(self, idx: int) -> None: """ @@ -35,13 +32,13 @@ def method_baz(self) -> None: """ @view_aggregation() - def method_qux(self, ages: List[int], names: List[str]) -> None: + def method_qux(self, ages: List[int], names: List[str]) -> str: return f"hello {ages} and {names}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: ... - async def apply_aggregation(self, filters: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 52a8f405..b24a0398 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -1,8 +1,11 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name +from typing import List, Tuple + import pandas as pd -from dbally.iql import IQLQuery +from dbally.iql import IQLFiltersQuery +from dbally.iql._query import IQLAggregationQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.pandas_base import DataFrameBaseView @@ -39,23 +42,27 @@ class MockDataFrameView(DataFrameBaseView): @view_filter() def filter_city(self, city: str) -> pd.Series: - return self.data["city"] == city + return self.df["city"] == city @view_filter() def filter_year(self, year: int) -> pd.Series: - return self.data["year"] == year + return self.df["year"] == year @view_filter() def filter_age(self, age: int) -> pd.Series: - return self.data["age"] == age + return self.df["age"] == age @view_filter() def filter_name(self, name: str) -> pd.Series: - return self.data["name"] == name + return self.df["name"] == name + + @view_aggregation() + def mean_age_by_city(self) -> Tuple[str, List[Tuple[str, str]]]: + return "city", [("age", "mean")] @view_aggregation() - def mean_age_by_city(self) -> pd.DataFrame: - return self.data.groupby(["city"]).agg({"age": "mean"}).reset_index() + def count_records(self) -> Tuple[str, List[Tuple[str, str]]]: + return None, [("name", "count")] async def test_filter_or() -> None: @@ -63,7 +70,7 @@ async def test_filter_or() -> None: Test that the filtering the DataFrame with logical OR works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'filter_city("Berlin") or filter_city("London")', allowed_functions=mock_view.list_filters(), ) @@ -71,6 +78,8 @@ async def test_filter_or() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_BERLIN_OR_LONDON assert result.context["filter_mask"].tolist() == [True, False, True, False, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_and() -> None: @@ -78,7 +87,7 @@ async def test_filter_and() -> None: Test that the filtering the DataFrame with logical AND works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'filter_city("Paris") and filter_year(2020)', allowed_functions=mock_view.list_filters(), ) @@ -86,6 +95,8 @@ async def test_filter_and() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_PARIS_2020 assert result.context["filter_mask"].tolist() == [False, True, False, False, False] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_not() -> None: @@ -93,7 +104,7 @@ async def test_filter_not() -> None: Test that the filtering the DataFrame with logical NOT works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'not (filter_city("Paris") and filter_year(2020))', allowed_functions=mock_view.list_filters(), ) @@ -101,25 +112,48 @@ async def test_filter_not() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_NOT_PARIS_2020 assert result.context["filter_mask"].tolist() == [True, False, True, True, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None -async def test_aggregtion() -> None: +async def test_aggregation() -> None: """ Test that DataFrame aggregation works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( + "count_records()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [ + {"index": "name_count", "name": 5}, + ] + assert result.context["filter_mask"] is None + assert result.context["groupbys"] is None + assert result.context["aggregations"] == [("name", "count")] + + +async def test_aggregtion_with_groupby() -> None: + """ + Test that DataFrame aggregation with groupby works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLAggregationQuery.parse( "mean_age_by_city()", allowed_functions=mock_view.list_aggregations(), ) await mock_view.apply_aggregation(query) result = mock_view.execute() assert result.results == [ - {"city": "Berlin", "age": 45.0}, - {"city": "London", "age": 32.5}, - {"city": "Paris", "age": 32.5}, + {"city": "Berlin", "age_mean": 45.0}, + {"city": "London", "age_mean": 32.5}, + {"city": "Paris", "age_mean": 32.5}, ] assert result.context["filter_mask"] is None + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [("age", "mean")] async def test_filters_and_aggregtion() -> None: @@ -127,16 +161,18 @@ async def test_filters_and_aggregtion() -> None: Test that DataFrame filtering and aggregation works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( "filter_city('Paris')", allowed_functions=mock_view.list_filters(), ) await mock_view.apply_filters(query) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( "mean_age_by_city()", allowed_functions=mock_view.list_aggregations(), ) await mock_view.apply_aggregation(query) result = mock_view.execute() - assert result.results == [{"city": "Paris", "age": 32.5}] + assert result.results == [{"city": "Paris", "age_mean": 32.5}] assert result.context["filter_mask"].tolist() == [False, True, False, True, False] + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [("age", "mean")] diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 435c8f8e..571e6a70 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -4,7 +4,8 @@ import sqlalchemy -from dbally.iql import IQLQuery +from dbally.iql import IQLFiltersQuery +from dbally.iql._query import IQLAggregationQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -33,7 +34,7 @@ def method_baz(self) -> sqlalchemy.Select: """ Some documentation string """ - return self.data.add_columns(sqlalchemy.literal("baz")).group_by(sqlalchemy.literal("baz")) + return self.select.add_columns(sqlalchemy.literal("baz")).group_by(sqlalchemy.literal("baz")) def normalize_whitespace(s: str) -> str: @@ -50,7 +51,7 @@ async def test_filter_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'method_foo(1) and method_bar("London", 2020)', allowed_functions=mock_view.list_filters(), ) @@ -66,7 +67,7 @@ async def test_aggregation_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( "method_baz()", allowed_functions=mock_view.list_aggregations(), ) @@ -82,12 +83,12 @@ async def test_filter_and_aggregation_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'method_foo(1) and method_bar("London", 2020)', allowed_functions=mock_view.list_filters() + mock_view.list_aggregations(), ) await mock_view.apply_filters(query) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( "method_baz()", allowed_functions=mock_view.list_aggregations(), )