From 6d66b007938513f73cb99366124f2b8847054fac Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Fri, 12 Jul 2024 16:40:06 +0200 Subject: [PATCH 01/21] Adding aggregation handling for SqlAlchemyBaseView extending quickstart example. --- docs/how-to/views/custom.md | 58 ++++++++++--------- docs/how-to/views/custom_views_code.py | 1 + docs/quickstart/quickstart_code.py | 10 +++- src/dbally/iql_generator/iql_generator.py | 15 ++++- .../iql_generator/iql_prompt_template.py | 0 src/dbally/iql_generator/prompt.py | 30 +++++++++- src/dbally/views/decorators.py | 15 +++++ src/dbally/views/methods_base.py | 33 ++++++++++- src/dbally/views/pandas_base.py | 9 +++ src/dbally/views/sqlalchemy_base.py | 28 ++++++++- src/dbally/views/structured.py | 31 +++++++++- tests/integration/test_llm_options.py | 2 +- tests/unit/mocks.py | 6 ++ tests/unit/test_iql_format.py | 12 ++-- tests/unit/test_iql_generator.py | 3 + tests/unit/views/test_methods_base.py | 3 + 16 files changed, 211 insertions(+), 45 deletions(-) create mode 100644 src/dbally/iql_generator/iql_prompt_template.py diff --git a/docs/how-to/views/custom.md b/docs/how-to/views/custom.md index ad8acb38..d9a50f74 100644 --- a/docs/how-to/views/custom.md +++ b/docs/how-to/views/custom.md @@ -106,37 +106,41 @@ Let's implement the required `apply_filters` method in our `FilteredIterableBase ```python def __init__(self) -> None: - super().__init__() - self._filter: Callable[[Any], bool] = lambda x: True - async def apply_filters(self, filters: IQLQuery) -> None: - """ - Applies the selected filters to the view. - Args: - filters: IQLQuery object representing the filters to apply - """ - self._filter = await self.build_filter_node(filters.root) +super().__init__() +self._filter: Callable[[Any], bool] = lambda x: True - async def build_filter_node(self, node: syntax.Node) -> Callable[[Any], bool]: - """ - Turns a filter node from the IQLQuery into a Python function. - Args: - node: IQLQuery node representing the filter or logical operator - """ - if isinstance(node, syntax.FunctionCall): # filter - return await self.call_filter_method(node) - if isinstance(node, syntax.And): # logical AND - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) - return lambda x: all(child(x) for child in children) # combine children with logical AND - if isinstance(node, syntax.Or): # logical OR - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) - return lambda x: any(child(x) for child in children) # combine children with logical OR - if isinstance(node, syntax.Not): # logical NOT - child = await self.build_filter_node(node.child) - return lambda x: not child(x) - raise ValueError(f"Unsupported grammar: {node}") +async def apply_filters(self, filters: IQLQuery) -> None: + """ + Applies the selected filters to the view. + + Args: + filters: IQLQuery object representing the filters to apply + """ + self._filter = await self.build_filter_node(filters.root) + + +async def build_filter_node(self, node: syntax.Node) -> Callable[[Any], bool]: + """ + Turns a filter node from the IQLQuery into a Python function. + + Args: + node: IQLQuery node representing the filter or logical operator + """ + if isinstance(node, syntax.FunctionCall): # filter + return await self.call_filter_method(node) + if isinstance(node, syntax.And): # logical AND + children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + return lambda x: all(child(x) for child in children) # combine children with logical AND + if isinstance(node, syntax.Or): # logical OR + children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + return lambda x: any(child(x) for child in children) # combine children with logical OR + if isinstance(node, syntax.Not): # logical NOT + child = await self.build_filter_node(node.child) + return lambda x: not child(x) + raise ValueError(f"Unsupported grammar: {node}") ``` In the `apply_filters` method, we're calling the `build_filter_node` method on the root of the IQL tree. The `build_filter_node` method uses recursion to create an object that represents the combined logic of the IQL expression and the returned filter methods. For `FilteredIterableBaseView`, this object is a function that takes a single argument (a candidate) and returns a boolean. We save this function in the `_filter` attribute. diff --git a/docs/how-to/views/custom_views_code.py b/docs/how-to/views/custom_views_code.py index 33c954c7..c64a2ffb 100644 --- a/docs/how-to/views/custom_views_code.py +++ b/docs/how-to/views/custom_views_code.py @@ -66,6 +66,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult(results=filtered_data, context={}) + class CandidateView(FilteredIterableBaseView): def get_data(self) -> Iterable: return [ diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index be2aab37..378c866b 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -6,7 +6,6 @@ from sqlalchemy import create_engine from sqlalchemy.ext.automap import automap_base -import dbally from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.llms.litellm import LiteLLM @@ -55,6 +54,12 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: """ return Candidate.country == country + @decorators.view_aggregation() + def count_by_column(self, subquery: sqlalchemy.Select, column_name: str) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011 + select = sqlalchemy.select(getattr(subquery.c, column_name), sqlalchemy.func.count(subquery.c.name).label("count")) \ + .group_by(getattr(subquery.c, column_name)) + return select + async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") @@ -63,7 +68,8 @@ async def main(): collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) - result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") + result = await collection.ask("Could you find French candidates suitable for a senior data scientist position" + "and count the candidates university-wise and present the rows?") print(f"The generated SQL query is: {result.context.get('sql')}") print() diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 7eeb9154..a9656457 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -2,7 +2,11 @@ from dbally.audit.event_tracker import EventTracker from dbally.iql import IQLError, IQLQuery -from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat +from dbally.iql_generator.prompt import ( + IQL_GENERATION_TEMPLATE, + IQL_GENERATION_TEMPLATE_AGGREGATION, + IQLGenerationPromptFormat, +) from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.elements import FewShotExample @@ -30,6 +34,7 @@ def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerat Args: llm: LLM used to generate IQL + prompt_template: If not provided by the users is set to `default_iql_template` """ self._llm = llm self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE @@ -39,6 +44,7 @@ async def generate_iql( question: str, filters: List[ExposedFunction], event_tracker: EventTracker, + aggregations: List[ExposedFunction] = None, examples: Optional[List[FewShotExample]] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, @@ -50,6 +56,7 @@ async def generate_iql( question: User question. filters: List of filters exposed by the view. event_tracker: Event store used to audit the generation process. + aggregations: List of aggregations to be applied on the view. examples: List of examples to be injected into the conversation. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors. @@ -60,8 +67,12 @@ async def generate_iql( prompt_format = IQLGenerationPromptFormat( question=question, filters=filters, + aggregations=aggregations, examples=examples, ) + if aggregations and not filters: + self._prompt_template = IQL_GENERATION_TEMPLATE_AGGREGATION + formatted_prompt = self._prompt_template.format_prompt(prompt_format) for _ in range(n_retries + 1): @@ -76,7 +87,7 @@ async def generate_iql( # TODO: Move IQL query parsing to prompt response parser return await IQLQuery.parse( source=iql, - allowed_functions=filters, + allowed_functions=filters or [] + aggregations or [], event_tracker=event_tracker, ) except IQLError as exc: diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 44bb2cd4..2d1093fe 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -45,6 +45,7 @@ def __init__( question: str, filters: List[ExposedFunction], examples: List[FewShotExample] = None, + aggregations: List[ExposedFunction] = None, ) -> None: """ Constructs a new IQLGenerationPromptFormat instance. @@ -53,10 +54,12 @@ def __init__( question: Question to be asked. filters: List of filters exposed by the view. examples: List of examples to be injected into the conversation. + aggregations: List of examples to be computed for the view (Currently 1 aggregation supported). """ super().__init__(examples) self.question = question - self.filters = "\n".join([str(filter) for filter in filters]) + self.filters = "\n".join([str(condition) for condition in filters]) if filters else [] + self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else [] IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( @@ -66,7 +69,7 @@ def __init__( "content": ( "You have access to API that lets you query a database:\n" "\n{filters}\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' @@ -85,3 +88,26 @@ def __init__( ], response_parser=_validate_iql_response, ) + +IQL_GENERATION_TEMPLATE_AGGREGATION = PromptTemplate[IQLGenerationPromptFormat]( + chat=( + { + "role": "system", + "content": "You have access to API that lets you query a database supporting SINGLE aggregation.\n" + "When prompted for just aggregation, use the following methods: \n" + "{aggregations}" + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{aggregations}\n" + "Structure output to resemble the following pattern:\n" + 'aggregation1("Argument_in_lowercase")\n' + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\" anything other than `UNSUPPORTED QUERY`""" + "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " + "Structure output to resemble the following pattern:\n" + 'aggregation1("Argument_in_lowercase", another_base_python_datatype_argument)\n', + }, + {"role": "user", "content": "{question}"}, + ), + response_parser=_validate_iql_response, +) diff --git a/src/dbally/views/decorators.py b/src/dbally/views/decorators.py index ac537f5f..e9257d94 100644 --- a/src/dbally/views/decorators.py +++ b/src/dbally/views/decorators.py @@ -14,3 +14,18 @@ def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missin return func return wrapped + + +def view_aggregation() -> typing.Callable: + """ + Decorator for marking a method as a filter + + Returns: + Function that returns the decorated method + """ + + def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missing-return-doc + func._methodDecorator = view_aggregation # type:ignore # pylint: disable=protected-access + return func + + return wrapped diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 8eeedfb0..4306a6c0 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -15,7 +15,10 @@ class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta): """ # Method arguments that should be skipped when listing methods - HIDDEN_ARGUMENTS = ["self", "select", "return"] + HIDDEN_ARGUMENTS = ["self", "select", "return", "subquery"] + + def __init__(self): + self._subquery = None @classmethod def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]: @@ -56,6 +59,15 @@ def list_filters(self) -> List[ExposedFunction]: """ return self.list_methods_by_decorator(decorators.view_filter) + def list_aggregations(self) -> List[ExposedFunction]: + """ + List aggregations in the given view + + Returns: + Aggregations defined inside the View and decorated with `decorators.view_aggregation`. + """ + return self.list_methods_by_decorator(decorators.view_aggregation) + def _method_with_args_from_call( self, func: syntax.FunctionCall, method_decorator: Callable ) -> Tuple[Callable, list]: @@ -64,7 +76,8 @@ def _method_with_args_from_call( Args: func: IQL FunctionCall node - method_decorator: The decorator that thhe method should have + method_decorator: The decorator that the method should have + (currently allows discrimination between filters and aggregations) Returns: Tuple with the method object and its arguments @@ -99,3 +112,19 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: if inspect.iscoroutinefunction(method): return await method(*args) return method(*args) + + async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: + """ + Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. + + Args: + func: IQL FunctionCall node + + Returns: + The result of the method call + """ + method, args = self._method_with_args_from_call(func, decorators.view_aggregation) + + if inspect.iscoroutinefunction(method): + return await method(self._subquery, *args) + return method(self._subquery, *args) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index 3d3831f7..f3a86cde 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -36,6 +36,15 @@ async def apply_filters(self, filters: IQLQuery) -> None: """ self._filter_mask = await self.build_filter_node(filters.root) + async def apply_aggregation(self, aggregation: IQLQuery) -> None: + """ + Applies the aggregation of choice to the view. + + Args: + aggregation: IQLQuery object representing the aggregation to apply + """ + pass # pylint: disable=unnecessary-pass + async def build_filter_node(self, node: syntax.Node) -> pd.Series: """ Converts a filter node from the IQLQuery to a Pandas Series representing diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index b1783558..e2d62984 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -17,6 +17,7 @@ def __init__(self, sqlalchemy_engine: sqlalchemy.engine.Engine) -> None: super().__init__() self._select = self.get_select() self._sqlalchemy_engine = sqlalchemy_engine + self._subquery = None @abc.abstractmethod def get_select(self) -> sqlalchemy.Select: @@ -27,6 +28,15 @@ def get_select(self) -> sqlalchemy.Select: which will be used to build the query. """ + def get_subquery(self) -> sqlalchemy.Subquery: + """ + Creates the initial sqlalchemy.Subquery object, which will be used to build the query. + + Returns: + The sqlalchemy.Subquery object based on private _select attribute. + """ + return self._select.subquery("subquery") + async def apply_filters(self, filters: IQLQuery) -> None: """ Applies the chosen filters to the view. @@ -64,6 +74,16 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu return alchemy_op(await self._build_filter_node(bool_op.child)) raise ValueError(f"BoolOp {bool_op} has no children") + async def apply_aggregation(self, aggregation: IQLQuery) -> None: + """ + Creates a subquery based on existing + + Args: + aggregation: IQLQuery object representing the filters to apply + """ + self._subquery = self.get_subquery() + self._subquery = await self.call_aggregation_method(aggregation.root) + def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ Executes the generated SQL query and returns the results. @@ -77,13 +97,17 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ results = [] - sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) + statement = self._select + if self._subquery is not None: + statement = self._subquery + + sql = str(statement.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) if not dry_run: with self._sqlalchemy_engine.connect() as connection: # The underscore is used by sqlalchemy to avoid conflicts with column names # pylint: disable=protected-access - rows = connection.execute(self._select).fetchall() + rows = connection.execute(statement).fetchall() results = [dict(row._mapping) for row in rows] return ViewExecutionResult( diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index b5863075..cec6936c 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -57,20 +57,32 @@ async def ask( The result of the query. """ iql_generator = self.get_iql_generator(llm) - filters = self.list_filters() examples = self.list_few_shots() + aggregations = self.list_aggregations() iql = await iql_generator.generate_iql( question=query, filters=filters, examples=examples, + aggregations=[], event_tracker=event_tracker, llm_options=llm_options, n_retries=n_retries, ) await self.apply_filters(iql) + iql = await iql_generator.generate_iql( + question=query, + filters=[], + examples=[], + aggregations=aggregations, + event_tracker=event_tracker, + llm_options=llm_options, + n_retries=n_retries, + ) + await self.apply_aggregation(iql) + result = self.execute(dry_run=dry_run) result.context["iql"] = f"{iql}" @@ -94,6 +106,23 @@ async def apply_filters(self, filters: IQLQuery) -> None: filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply """ + @abc.abstractmethod + def list_aggregations(self) -> List[ExposedFunction]: + """ + + Returns: + Aggregations defined inside the View. + """ + + @abc.abstractmethod + async def apply_aggregation(self, aggregation: IQLQuery) -> None: + """ + Applies the chosen aggregation to the view. + + Args: + aggregation: [IQLQuery](../../concepts/iql.md) object representing the filters to apply + """ + @abc.abstractmethod def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ diff --git a/tests/integration/test_llm_options.py b/tests/integration/test_llm_options.py index fb8cfba4..411ffe14 100644 --- a/tests/integration/test_llm_options.py +++ b/tests/integration/test_llm_options.py @@ -30,7 +30,7 @@ async def test_llm_options_propagation(): llm_options=custom_options, ) - assert llm.client.call.call_count == 3 + assert llm.client.call.call_count == 4 llm.client.call.assert_has_calls( [ diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 75cc914b..ba77a8ce 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -29,6 +29,12 @@ def list_filters(self) -> List[ExposedFunction]: async def apply_filters(self, filters: IQLQuery) -> None: ... + def list_aggregations(self) -> List[ExposedFunction]: + return [] + + async def apply_aggregation(self, filters: IQLQuery) -> None: + ... + def execute(self, dry_run=False) -> ViewExecutionResult: return ViewExecutionResult(results=[], context={}) diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index 8f583c4c..b0bae183 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -14,14 +14,14 @@ async def test_iql_prompt_format_default() -> None: { "role": "system", "content": "You have access to API that lets you query a database:\n" - "\n\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "\n[]\n" + "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n\n" + "\n[]\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", @@ -44,14 +44,14 @@ async def test_iql_prompt_format_few_shots_injected() -> None: { "role": "system", "content": "You have access to API that lets you query a database:\n" - "\n\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "\n[]\n" + "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n\n" + "\n[]\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index ce3f593d..4a79c394 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -21,6 +21,9 @@ def get_select(self) -> sqlalchemy.Select: async def apply_filters(self, filters: IQLQuery) -> None: ... + async def apply_aggregation(self, filters: IQLQuery) -> None: + ... + def execute(self, dry_run: bool = False): ... diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 58959a64..841af1a7 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -28,6 +28,9 @@ def method_bar(self, cities: List[str], year: Literal["2023", "2024"], pairs: Li async def apply_filters(self, filters: IQLQuery) -> None: ... + async def apply_aggregation(self, filters: IQLQuery) -> None: + ... + def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult(results=[], context={}) From be12e6ea67ed5720e73ffdbe2704e7f9e13f2128 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Mon, 15 Jul 2024 11:39:49 +0200 Subject: [PATCH 02/21] Applying initial review feedback. Adding both filters and aggregation to result context. --- docs/how-to/views/custom.md | 58 ++++++++++++++---------------- src/dbally/iql_generator/prompt.py | 4 +-- src/dbally/views/structured.py | 7 ++-- 3 files changed, 32 insertions(+), 37 deletions(-) diff --git a/docs/how-to/views/custom.md b/docs/how-to/views/custom.md index d9a50f74..ad8acb38 100644 --- a/docs/how-to/views/custom.md +++ b/docs/how-to/views/custom.md @@ -106,41 +106,37 @@ Let's implement the required `apply_filters` method in our `FilteredIterableBase ```python def __init__(self) -> None: + super().__init__() + self._filter: Callable[[Any], bool] = lambda x: True + async def apply_filters(self, filters: IQLQuery) -> None: + """ + Applies the selected filters to the view. -super().__init__() -self._filter: Callable[[Any], bool] = lambda x: True - - -async def apply_filters(self, filters: IQLQuery) -> None: - """ - Applies the selected filters to the view. - - Args: - filters: IQLQuery object representing the filters to apply - """ - self._filter = await self.build_filter_node(filters.root) - + Args: + filters: IQLQuery object representing the filters to apply + """ + self._filter = await self.build_filter_node(filters.root) -async def build_filter_node(self, node: syntax.Node) -> Callable[[Any], bool]: - """ - Turns a filter node from the IQLQuery into a Python function. + async def build_filter_node(self, node: syntax.Node) -> Callable[[Any], bool]: + """ + Turns a filter node from the IQLQuery into a Python function. - Args: - node: IQLQuery node representing the filter or logical operator - """ - if isinstance(node, syntax.FunctionCall): # filter - return await self.call_filter_method(node) - if isinstance(node, syntax.And): # logical AND - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) - return lambda x: all(child(x) for child in children) # combine children with logical AND - if isinstance(node, syntax.Or): # logical OR - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) - return lambda x: any(child(x) for child in children) # combine children with logical OR - if isinstance(node, syntax.Not): # logical NOT - child = await self.build_filter_node(node.child) - return lambda x: not child(x) - raise ValueError(f"Unsupported grammar: {node}") + Args: + node: IQLQuery node representing the filter or logical operator + """ + if isinstance(node, syntax.FunctionCall): # filter + return await self.call_filter_method(node) + if isinstance(node, syntax.And): # logical AND + children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + return lambda x: all(child(x) for child in children) # combine children with logical AND + if isinstance(node, syntax.Or): # logical OR + children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + return lambda x: any(child(x) for child in children) # combine children with logical OR + if isinstance(node, syntax.Not): # logical NOT + child = await self.build_filter_node(node.child) + return lambda x: not child(x) + raise ValueError(f"Unsupported grammar: {node}") ``` In the `apply_filters` method, we're calling the `build_filter_node` method on the root of the IQL tree. The `build_filter_node` method uses recursion to create an object that represents the combined logic of the IQL expression and the returned filter methods. For `FilteredIterableBaseView`, this object is a function that takes a single argument (a candidate) and returns a boolean. We save this function in the `_filter` attribute. diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 2d1093fe..3d234930 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -99,13 +99,11 @@ def __init__( "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" "\n{aggregations}\n" - "Structure output to resemble the following pattern:\n" - 'aggregation1("Argument_in_lowercase")\n' "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\" anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " "Structure output to resemble the following pattern:\n" - 'aggregation1("Argument_in_lowercase", another_base_python_datatype_argument)\n', + 'aggregation1("arg1", arg2)\n', }, {"role": "user", "content": "{question}"}, ), diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index cec6936c..be2cb578 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -72,7 +72,7 @@ async def ask( ) await self.apply_filters(iql) - iql = await iql_generator.generate_iql( + iql_agg = await iql_generator.generate_iql( question=query, filters=[], examples=[], @@ -81,10 +81,11 @@ async def ask( llm_options=llm_options, n_retries=n_retries, ) - await self.apply_aggregation(iql) + await self.apply_aggregation(iql_agg) result = self.execute(dry_run=dry_run) - result.context["iql"] = f"{iql}" + result.context["iql"] = {"filters": f"{iql}", "aggregation": f"{iql_agg}"} + return result From 33d5b2ee92417f1b90e19d47d089307551af6341 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Tue, 16 Jul 2024 12:56:39 +0200 Subject: [PATCH 03/21] Renaming subquery attribute and method argument to filtered_query --- docs/quickstart/quickstart_code.py | 10 +++++----- src/dbally/views/methods_base.py | 8 ++++---- src/dbally/views/sqlalchemy_base.py | 14 +++++++------- src/dbally/views/structured.py | 1 - 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 378c866b..fce7bdac 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -55,9 +55,9 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: return Candidate.country == country @decorators.view_aggregation() - def count_by_column(self, subquery: sqlalchemy.Select, column_name: str) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011 - select = sqlalchemy.select(getattr(subquery.c, column_name), sqlalchemy.func.count(subquery.c.name).label("count")) \ - .group_by(getattr(subquery.c, column_name)) + def count_by_column(self, filtered_query: sqlalchemy.Select, column_name: str) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011 + select = sqlalchemy.select(getattr(filtered_query.c, column_name), sqlalchemy.func.count(filtered_query.c.name).label("count")) \ + .group_by(getattr(filtered_query.c, column_name)) return select @@ -68,8 +68,8 @@ async def main(): collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) - result = await collection.ask("Could you find French candidates suitable for a senior data scientist position" - "and count the candidates university-wise and present the rows?") + result = await collection.ask("Give me the number of French candidates suitable" + "for a senior data scientist position for each university") print(f"The generated SQL query is: {result.context.get('sql')}") print() diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 4306a6c0..7a4e1f79 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -15,10 +15,10 @@ class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta): """ # Method arguments that should be skipped when listing methods - HIDDEN_ARGUMENTS = ["self", "select", "return", "subquery"] + HIDDEN_ARGUMENTS = ["self", "select", "return", "filtered_query"] def __init__(self): - self._subquery = None + self._filtered_query = None @classmethod def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]: @@ -126,5 +126,5 @@ async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: method, args = self._method_with_args_from_call(func, decorators.view_aggregation) if inspect.iscoroutinefunction(method): - return await method(self._subquery, *args) - return method(self._subquery, *args) + return await method(self._filtered_query, *args) + return method(self._filtered_query, *args) diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index e2d62984..c0e0389b 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -17,7 +17,7 @@ def __init__(self, sqlalchemy_engine: sqlalchemy.engine.Engine) -> None: super().__init__() self._select = self.get_select() self._sqlalchemy_engine = sqlalchemy_engine - self._subquery = None + self._filtered_query = None @abc.abstractmethod def get_select(self) -> sqlalchemy.Select: @@ -28,7 +28,7 @@ def get_select(self) -> sqlalchemy.Select: which will be used to build the query. """ - def get_subquery(self) -> sqlalchemy.Subquery: + def get_filtered_query(self) -> sqlalchemy.Subquery: """ Creates the initial sqlalchemy.Subquery object, which will be used to build the query. @@ -76,13 +76,13 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu async def apply_aggregation(self, aggregation: IQLQuery) -> None: """ - Creates a subquery based on existing + Creates a subquery based on existing and calls the aggregation method. Args: aggregation: IQLQuery object representing the filters to apply """ - self._subquery = self.get_subquery() - self._subquery = await self.call_aggregation_method(aggregation.root) + self._filtered_query = self.get_filtered_query() + self._filtered_query = await self.call_aggregation_method(aggregation.root) def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ @@ -98,8 +98,8 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: results = [] statement = self._select - if self._subquery is not None: - statement = self._subquery + if self._filtered_query is not None: + statement = self._filtered_query sql = str(statement.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index be2cb578..145a9826 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -86,7 +86,6 @@ async def ask( result = self.execute(dry_run=dry_run) result.context["iql"] = {"filters": f"{iql}", "aggregation": f"{iql_agg}"} - return result @abc.abstractmethod From 2e77fbc4745deeb3d9181711e466d11009e094bb Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Wed, 17 Jul 2024 09:36:55 +0200 Subject: [PATCH 04/21] Simplified question to the model. --- docs/quickstart/quickstart_code.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index fce7bdac..6c5ddb11 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -68,8 +68,8 @@ async def main(): collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) - result = await collection.ask("Give me the number of French candidates suitable" - "for a senior data scientist position for each university") + result = await collection.ask("Could you find French candidates suitable for a senior data scientist position" + "and count them university-wise?") print(f"The generated SQL query is: {result.context.get('sql')}") print() From 41e88ed5114e1ea2506192d745a23897410b7006 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Wed, 17 Jul 2024 09:48:42 +0200 Subject: [PATCH 05/21] Fixing unnecessary-pass. --- src/dbally/views/pandas_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index f3a86cde..bb9d0516 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -43,7 +43,7 @@ async def apply_aggregation(self, aggregation: IQLQuery) -> None: Args: aggregation: IQLQuery object representing the aggregation to apply """ - pass # pylint: disable=unnecessary-pass + # TODO - to be covered in a separate ticket. async def build_filter_node(self, node: syntax.Node) -> pd.Series: """ From dfe3e13002c11a3574d161fe921b666d959965c2 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Wed, 17 Jul 2024 10:10:02 +0200 Subject: [PATCH 06/21] Continuation of review feedback application. --- src/dbally/iql_generator/iql_generator.py | 2 +- src/dbally/iql_generator/prompt.py | 10 +++++----- src/dbally/view_selection/prompt.py | 2 +- src/dbally/views/decorators.py | 2 +- src/dbally/views/sqlalchemy_base.py | 8 ++++---- tests/unit/test_iql_format.py | 4 ++-- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index a9656457..89ff0113 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -56,7 +56,7 @@ async def generate_iql( question: User question. filters: List of filters exposed by the view. event_tracker: Event store used to audit the generation process. - aggregations: List of aggregations to be applied on the view. + aggregations: List of aggregations exposed by the view. examples: List of examples to be injected into the conversation. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors. diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 3d234930..eb225575 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -54,7 +54,7 @@ def __init__( question: Question to be asked. filters: List of filters exposed by the view. examples: List of examples to be injected into the conversation. - aggregations: List of examples to be computed for the view (Currently 1 aggregation supported). + aggregations: List of aggregations exposed by the view. """ super().__init__(examples) self.question = question @@ -67,7 +67,7 @@ def __init__( { "role": "system", "content": ( - "You have access to API that lets you query a database:\n" + "You have access to an API that lets you query a database:\n" "\n{filters}\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" @@ -93,14 +93,14 @@ def __init__( chat=( { "role": "system", - "content": "You have access to API that lets you query a database supporting SINGLE aggregation.\n" - "When prompted for just aggregation, use the following methods: \n" + "content": "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" + "When prompted for an aggregation, use the following methods: \n" "{aggregations}" "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" "\n{aggregations}\n" "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\" anything other than `UNSUPPORTED QUERY`""" + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " "Structure output to resemble the following pattern:\n" 'aggregation1("arg1", arg2)\n', diff --git a/src/dbally/view_selection/prompt.py b/src/dbally/view_selection/prompt.py index cdbedf5a..2d49efa9 100644 --- a/src/dbally/view_selection/prompt.py +++ b/src/dbally/view_selection/prompt.py @@ -35,7 +35,7 @@ def __init__( "role": "system", "content": ( "You are a very smart database programmer. " - "You have access to API that lets you query a database:\n" + "You have access to an API that lets you query a database:\n" "First you need to select a class to query, based on its description and the user question. " "You have the following classes to choose from:\n" "{views}\n" diff --git a/src/dbally/views/decorators.py b/src/dbally/views/decorators.py index e9257d94..d318cfc4 100644 --- a/src/dbally/views/decorators.py +++ b/src/dbally/views/decorators.py @@ -18,7 +18,7 @@ def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missin def view_aggregation() -> typing.Callable: """ - Decorator for marking a method as a filter + Decorator for marking a method as an aggregation Returns: Function that returns the decorated method diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index c0e0389b..bc395ac2 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -28,7 +28,7 @@ def get_select(self) -> sqlalchemy.Select: which will be used to build the query. """ - def get_filtered_query(self) -> sqlalchemy.Subquery: + def _get_filtered_query(self) -> sqlalchemy.Subquery: """ Creates the initial sqlalchemy.Subquery object, which will be used to build the query. @@ -74,15 +74,15 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu return alchemy_op(await self._build_filter_node(bool_op.child)) raise ValueError(f"BoolOp {bool_op} has no children") - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: syntax.FunctionCall) -> None: """ Creates a subquery based on existing and calls the aggregation method. Args: aggregation: IQLQuery object representing the filters to apply """ - self._filtered_query = self.get_filtered_query() - self._filtered_query = await self.call_aggregation_method(aggregation.root) + self._filtered_query = self._get_filtered_query() + self._filtered_query = await self.call_aggregation_method(aggregation) def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index b0bae183..7650c12f 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -13,7 +13,7 @@ async def test_iql_prompt_format_default() -> None: assert formatted_prompt.chat == [ { "role": "system", - "content": "You have access to API that lets you query a database:\n" + "content": "You have access to an API that lets you query a database:\n" "\n[]\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" @@ -43,7 +43,7 @@ async def test_iql_prompt_format_few_shots_injected() -> None: assert formatted_prompt.chat == [ { "role": "system", - "content": "You have access to API that lets you query a database:\n" + "content": "You have access to an API that lets you query a database:\n" "\n[]\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" From c09e68e1b7ad41a47ce6b39eb72ac695017a4764 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Wed, 17 Jul 2024 10:51:40 +0200 Subject: [PATCH 07/21] Adjusting filter prompt not to mix IQL with 'UNSUPPORTED QUERY'. Further review feedback application. --- docs/quickstart/quickstart_code.py | 9 ++++++++- src/dbally/iql_generator/prompt.py | 2 +- src/dbally/views/methods_base.py | 2 +- src/dbally/views/sqlalchemy_base.py | 2 +- tests/unit/test_iql_format.py | 4 ++-- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 6c5ddb11..a08f11dc 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -56,10 +56,17 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: @decorators.view_aggregation() def count_by_column(self, filtered_query: sqlalchemy.Select, column_name: str) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011 - select = sqlalchemy.select(getattr(filtered_query.c, column_name), sqlalchemy.func.count(filtered_query.c.name).label("count")) \ + select = sqlalchemy.select(getattr(filtered_query.c, column_name), + sqlalchemy.func.count(filtered_query.c.name).label("count")) \ .group_by(getattr(filtered_query.c, column_name)) return select + # @decorators.view_aggregation() + # def count_by_university(self, filtered_query: sqlalchemy.Select) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011 + # select = sqlalchemy.select(filtered_query.c.university, sqlalchemy.func.count(filtered_query.c.name).label("count")) \ + # .group_by(filtered_query.c.university) + # return select + async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index eb225575..48bea4f6 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -77,7 +77,7 @@ def __init__( "You MUST use only these methods:\n" "\n{filters}\n" "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. " ), }, diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 7a4e1f79..5672b453 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -113,7 +113,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: return await method(*args) return method(*args) - async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: + async def call_aggregation_method(self, func: syntax.Node) -> Any: """ Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index bc395ac2..9a1301c6 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -82,7 +82,7 @@ async def apply_aggregation(self, aggregation: syntax.FunctionCall) -> None: aggregation: IQLQuery object representing the filters to apply """ self._filtered_query = self._get_filtered_query() - self._filtered_query = await self.call_aggregation_method(aggregation) + self._filtered_query = await self.call_aggregation_method(aggregation.root) def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index 7650c12f..b798e533 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -23,7 +23,7 @@ async def test_iql_prompt_format_default() -> None: "You MUST use only these methods:\n" "\n[]\n" "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, }, @@ -53,7 +53,7 @@ async def test_iql_prompt_format_few_shots_injected() -> None: "You MUST use only these methods:\n" "\n[]\n" "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. ", "is_example": False, }, From c6bbf9079959b954036c25337109c66a5b8d7aed Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Thu, 18 Jul 2024 12:07:38 +0200 Subject: [PATCH 08/21] Applied changes suggested in a comment related to Aggregations not generated by IQLGenerator (separate formatter implemented). --- docs/quickstart/quickstart_code.py | 8 +- src/dbally/exceptions.py | 7 ++ src/dbally/iql_generator/iql_generator.py | 7 +- src/dbally/iql_generator/prompt.py | 23 ----- src/dbally/prompt/aggregation.py | 111 ++++++++++++++++++++++ src/dbally/views/methods_base.py | 2 +- src/dbally/views/sqlalchemy_base.py | 9 +- src/dbally/views/structured.py | 27 ++++-- 8 files changed, 143 insertions(+), 51 deletions(-) create mode 100644 src/dbally/prompt/aggregation.py diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index a08f11dc..a7bc700d 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -61,12 +61,6 @@ def count_by_column(self, filtered_query: sqlalchemy.Select, column_name: str) - .group_by(getattr(filtered_query.c, column_name)) return select - # @decorators.view_aggregation() - # def count_by_university(self, filtered_query: sqlalchemy.Select) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011 - # select = sqlalchemy.select(filtered_query.c.university, sqlalchemy.func.count(filtered_query.c.name).label("count")) \ - # .group_by(filtered_query.c.university) - # return select - async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") @@ -76,7 +70,7 @@ async def main(): collection.add(CandidateView, lambda: CandidateView(engine)) result = await collection.ask("Could you find French candidates suitable for a senior data scientist position" - "and count them university-wise?") + "and count the candidates university-wise?") print(f"The generated SQL query is: {result.context.get('sql')}") print() diff --git a/src/dbally/exceptions.py b/src/dbally/exceptions.py index 6b095cd7..62faac37 100644 --- a/src/dbally/exceptions.py +++ b/src/dbally/exceptions.py @@ -2,3 +2,10 @@ class DbAllyError(Exception): """ Base class for all exceptions raised by db-ally. """ + + +class UnsupportedAggregationError(DbAllyError): + """ + Error raised when AggregationFormatter is unable to construct a query + with given aggregation. + """ diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 89ff0113..5a375497 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -4,7 +4,6 @@ from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.prompt import ( IQL_GENERATION_TEMPLATE, - IQL_GENERATION_TEMPLATE_AGGREGATION, IQLGenerationPromptFormat, ) from dbally.llms.base import LLM @@ -44,7 +43,6 @@ async def generate_iql( question: str, filters: List[ExposedFunction], event_tracker: EventTracker, - aggregations: List[ExposedFunction] = None, examples: Optional[List[FewShotExample]] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, @@ -67,11 +65,8 @@ async def generate_iql( prompt_format = IQLGenerationPromptFormat( question=question, filters=filters, - aggregations=aggregations, examples=examples, ) - if aggregations and not filters: - self._prompt_template = IQL_GENERATION_TEMPLATE_AGGREGATION formatted_prompt = self._prompt_template.format_prompt(prompt_format) @@ -87,7 +82,7 @@ async def generate_iql( # TODO: Move IQL query parsing to prompt response parser return await IQLQuery.parse( source=iql, - allowed_functions=filters or [] + aggregations or [], + allowed_functions=filters or [], event_tracker=event_tracker, ) except IQLError as exc: diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 48bea4f6..8099e8d6 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -45,7 +45,6 @@ def __init__( question: str, filters: List[ExposedFunction], examples: List[FewShotExample] = None, - aggregations: List[ExposedFunction] = None, ) -> None: """ Constructs a new IQLGenerationPromptFormat instance. @@ -59,7 +58,6 @@ def __init__( super().__init__(examples) self.question = question self.filters = "\n".join([str(condition) for condition in filters]) if filters else [] - self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else [] IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( @@ -88,24 +86,3 @@ def __init__( ], response_parser=_validate_iql_response, ) - -IQL_GENERATION_TEMPLATE_AGGREGATION = PromptTemplate[IQLGenerationPromptFormat]( - chat=( - { - "role": "system", - "content": "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" - "When prompted for an aggregation, use the following methods: \n" - "{aggregations}" - "DO NOT INCLUDE arguments names in your response. Only the values.\n" - "You MUST use only these methods:\n" - "\n{aggregations}\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" - "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " - "Structure output to resemble the following pattern:\n" - 'aggregation1("arg1", arg2)\n', - }, - {"role": "user", "content": "{question}"}, - ), - response_parser=_validate_iql_response, -) diff --git a/src/dbally/prompt/aggregation.py b/src/dbally/prompt/aggregation.py new file mode 100644 index 00000000..71b3a81b --- /dev/null +++ b/src/dbally/prompt/aggregation.py @@ -0,0 +1,111 @@ +from typing import List, Optional + +from dbally.audit import EventTracker +from dbally.exceptions import UnsupportedAggregationError + +from dbally.iql import IQLQuery +from dbally.llms.base import LLM +from dbally.llms.clients import LLMOptions +from dbally.prompt.template import PromptFormat, PromptTemplate +from dbally.views.exposed_functions import ExposedFunction + + +def _validate_agg_response(llm_response: str) -> str: + """ + Validates LLM response to IQL + + Args: + llm_response: LLM response + + Returns: + A string containing aggregations. + + Raises: + UnsupportedAggregationError: When IQL generator is unable to construct a query + with given aggregation. + """ + if "unsupported query" in llm_response.lower(): + raise UnsupportedAggregationError + return llm_response + + +class AggregationPromptFormat(PromptFormat): + """ + Aggregation prompt format, providing a question and aggregation to be used in the conversation. + """ + + def __init__( + self, + question: str, + aggregations: List[ExposedFunction] = None, + ) -> None: + super().__init__() + self.question = question + self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else [] + + +class AggregationFormatter: + def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[AggregationPromptFormat]] = None) -> None: + """ + Constructs a new AggregationFormatter instance. + + Args: + llm: LLM used to generate IQL + prompt_template: If not provided by the users is set to `AGGREGATION_GENERATION_TEMPLATE` + """ + self._llm = llm + self._prompt_template = prompt_template or AGGREGATION_GENERATION_TEMPLATE + + async def format_to_query_object( + self, + question: str, + event_tracker: EventTracker, + aggregations: List[ExposedFunction] = None, + llm_options: Optional[LLMOptions] = None, + ) -> IQLQuery: + + prompt_format = AggregationPromptFormat( + question=question, + aggregations=aggregations, + ) + + formatted_prompt = self._prompt_template.format_prompt(prompt_format) + + response = await self._llm.generate_text( + prompt=formatted_prompt, + event_tracker=event_tracker, + options=llm_options, + ) + # TODO: Move response parsing to llm generate_text method + agg = formatted_prompt.response_parser(response) + # TODO: Move IQL query parsing to prompt response parser + return await IQLQuery.parse( + source=agg, + allowed_functions=aggregations or [], + event_tracker=event_tracker, + ) + + +AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[AggregationPromptFormat]( + [ + { + "role": "system", + "content": "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" + "When prompted for an aggregation, use the following methods: \n" + "{aggregations}" + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{aggregations}\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" + "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " + "Structure output to resemble the following pattern:\n" + 'aggregation1("arg1", arg2)\n', + }, + { + "role": "user", + "content": "{question}" + }, + ], + response_parser=_validate_agg_response, +) diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 5672b453..7a4e1f79 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -113,7 +113,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: return await method(*args) return method(*args) - async def call_aggregation_method(self, func: syntax.Node) -> Any: + async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: """ Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index 9a1301c6..04d1e80c 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -82,7 +82,7 @@ async def apply_aggregation(self, aggregation: syntax.FunctionCall) -> None: aggregation: IQLQuery object representing the filters to apply """ self._filtered_query = self._get_filtered_query() - self._filtered_query = await self.call_aggregation_method(aggregation.root) + self._select = await self.call_aggregation_method(aggregation) def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ @@ -97,17 +97,14 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ results = [] - statement = self._select - if self._filtered_query is not None: - statement = self._filtered_query - sql = str(statement.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) + sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) if not dry_run: with self._sqlalchemy_engine.connect() as connection: # The underscore is used by sqlalchemy to avoid conflicts with column names # pylint: disable=protected-access - rows = connection.execute(statement).fetchall() + rows = connection.execute(self._select).fetchall() results = [dict(row._mapping) for row in rows] return ViewExecutionResult( diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 145a9826..df5a4d69 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -9,6 +9,8 @@ from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.views.exposed_functions import ExposedFunction +from ..iql.syntax import FunctionCall +from ..prompt.aggregation import AggregationFormatter from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation @@ -32,6 +34,18 @@ def get_iql_generator(self, llm: LLM) -> IQLGenerator: """ return IQLGenerator(llm=llm) + def get_agg_formatter(self, llm: LLM) -> AggregationFormatter: + """ + Returns the AggregtionFormatter for the view. + + Args: + llm: LLM used to generate the queries. + + Returns: + AggregtionFormatter for the view. + """ + return AggregationFormatter(llm=llm) + async def ask( self, query: str, @@ -57,6 +71,7 @@ async def ask( The result of the query. """ iql_generator = self.get_iql_generator(llm) + agg_formatter = self.get_agg_formatter(llm) filters = self.list_filters() examples = self.list_few_shots() aggregations = self.list_aggregations() @@ -65,26 +80,22 @@ async def ask( question=query, filters=filters, examples=examples, - aggregations=[], event_tracker=event_tracker, llm_options=llm_options, n_retries=n_retries, ) await self.apply_filters(iql) - iql_agg = await iql_generator.generate_iql( + agg_node = await agg_formatter.format_to_query_object( question=query, - filters=[], - examples=[], aggregations=aggregations, event_tracker=event_tracker, llm_options=llm_options, - n_retries=n_retries, ) - await self.apply_aggregation(iql_agg) + await self.apply_aggregation(agg_node.root) result = self.execute(dry_run=dry_run) - result.context["iql"] = {"filters": f"{iql}", "aggregation": f"{iql_agg}"} + result.context["iql"] = {"filters": f"{iql}", "aggregation": f"{agg_node}"} return result @@ -115,7 +126,7 @@ def list_aggregations(self) -> List[ExposedFunction]: """ @abc.abstractmethod - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: FunctionCall) -> None: """ Applies the chosen aggregation to the view. From 4765dded2c4e8028fcb7f1e610cefd7542c6c63c Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Thu, 18 Jul 2024 13:47:36 +0200 Subject: [PATCH 09/21] Applying pre-commit hooks. --- src/dbally/iql_generator/iql_generator.py | 6 +----- src/dbally/prompt/aggregation.py | 21 ++++++++++++++++----- src/dbally/views/structured.py | 2 +- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 5a375497..4d028fa9 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -2,10 +2,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.iql import IQLError, IQLQuery -from dbally.iql_generator.prompt import ( - IQL_GENERATION_TEMPLATE, - IQLGenerationPromptFormat, -) +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.prompt.elements import FewShotExample @@ -54,7 +51,6 @@ async def generate_iql( question: User question. filters: List of filters exposed by the view. event_tracker: Event store used to audit the generation process. - aggregations: List of aggregations exposed by the view. examples: List of examples to be injected into the conversation. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors. diff --git a/src/dbally/prompt/aggregation.py b/src/dbally/prompt/aggregation.py index 71b3a81b..8dedd95c 100644 --- a/src/dbally/prompt/aggregation.py +++ b/src/dbally/prompt/aggregation.py @@ -2,7 +2,6 @@ from dbally.audit import EventTracker from dbally.exceptions import UnsupportedAggregationError - from dbally.iql import IQLQuery from dbally.llms.base import LLM from dbally.llms.clients import LLMOptions @@ -45,6 +44,10 @@ def __init__( class AggregationFormatter: + """ + Class used to manage choice and formatting of aggregation based on natural language question. + """ + def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[AggregationPromptFormat]] = None) -> None: """ Constructs a new AggregationFormatter instance. @@ -63,7 +66,18 @@ async def format_to_query_object( aggregations: List[ExposedFunction] = None, llm_options: Optional[LLMOptions] = None, ) -> IQLQuery: + """ + Generates IQL in text form using LLM. + Args: + question: User question. + event_tracker: Event store used to audit the generation process. + aggregations: List of aggregations exposed by the view. + llm_options: Options to use for the LLM client. + + Returns: + Generated aggregation query. + """ prompt_format = AggregationPromptFormat( question=question, aggregations=aggregations, @@ -102,10 +116,7 @@ async def format_to_query_object( "Structure output to resemble the following pattern:\n" 'aggregation1("arg1", arg2)\n', }, - { - "role": "user", - "content": "{question}" - }, + {"role": "user", "content": "{question}"}, ], response_parser=_validate_agg_response, ) diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index df5a4d69..db81894b 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -9,9 +9,9 @@ from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.views.exposed_functions import ExposedFunction + from ..iql.syntax import FunctionCall from ..prompt.aggregation import AggregationFormatter - from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation From 2918ba5a81d0da7e3247964b5d792f3a7c1ee080 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Thu, 18 Jul 2024 14:21:46 +0200 Subject: [PATCH 10/21] Mocking AggregationFormat in tests. --- tests/unit/mocks.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index ba77a8ce..748632a5 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -13,6 +13,7 @@ from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM from dbally.llms.clients.base import LLMClient, LLMOptions +from dbally.prompt.aggregation import AggregationFormatter from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.structured import BaseStructuredView, ExposedFunction, ViewExecutionResult @@ -48,6 +49,15 @@ async def generate_iql(self, *_, **__) -> IQLQuery: return self.iql +class MockAggregationFormatter(AggregationFormatter): + def __init__(self, iql_query: IQLQuery) -> None: + self.iql_query = iql_query + super().__init__(llm=MockLLM()) + + async def format_to_query_object(self, *_, **__) -> IQLQuery: + return self.iql_query + + class MockViewSelector(ViewSelector): def __init__(self, name: str) -> None: self.name = name From 551172915f92ac810c7ba551c07e717dceaff9e8 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Fri, 19 Jul 2024 10:18:10 +0200 Subject: [PATCH 11/21] Mocking methods of the view related to aggregations to make them compliant to tests. --- tests/unit/test_collection.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 058da20b..38adf5f8 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -13,7 +13,8 @@ from dbally.iql import IQLQuery from dbally.iql.syntax import FunctionCall from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector +from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector, \ + MockAggregationFormatter class MockView1(MockViewBase): @@ -62,6 +63,12 @@ def list_filters(self) -> List[ExposedFunction]: def get_iql_generator(self, *_, **__) -> MockIQLGenerator: return MockIQLGenerator(IQLQuery(FunctionCall("test_filter", []), "test_filter()")) + def list_aggregations(self) -> List[ExposedFunction]: + return [ExposedFunction("test_aggregation", "", [])] + + def get_agg_formatter(self, *_, **__) -> MockAggregationFormatter: + return MockAggregationFormatter(IQLQuery(FunctionCall("test_aggregation", []), "test_aggregation()")) + @pytest.fixture(name="similarity_classes") def mock_similarity_classes() -> ( @@ -291,7 +298,7 @@ async def test_ask_view_selection_single_view() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": "test_filter()"} + assert result.context == {"baz": "qux", "iql": {'aggregation': 'test_aggregation()', 'filters': 'test_filter()'}} async def test_ask_view_selection_multiple_views() -> None: @@ -312,7 +319,7 @@ async def test_ask_view_selection_multiple_views() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": "test_filter()"} + assert result.context == {"baz": "qux", "iql": {'aggregation': 'test_aggregation()', 'filters': 'test_filter()'}} async def test_ask_view_selection_no_views() -> None: From ae26c8b50dc1ad6ab6e92f9c333b5131b53caa02 Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Fri, 19 Jul 2024 10:20:56 +0200 Subject: [PATCH 12/21] Pre-commit fixes. --- tests/unit/test_collection.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 38adf5f8..a077286d 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -13,8 +13,14 @@ from dbally.iql import IQLQuery from dbally.iql.syntax import FunctionCall from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector, \ - MockAggregationFormatter +from tests.unit.mocks import ( + MockAggregationFormatter, + MockIQLGenerator, + MockLLM, + MockSimilarityIndex, + MockViewBase, + MockViewSelector, +) class MockView1(MockViewBase): @@ -298,7 +304,7 @@ async def test_ask_view_selection_single_view() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": {'aggregation': 'test_aggregation()', 'filters': 'test_filter()'}} + assert result.context == {"baz": "qux", "iql": {"aggregation": "test_aggregation()", "filters": "test_filter()"}} async def test_ask_view_selection_multiple_views() -> None: @@ -319,7 +325,7 @@ async def test_ask_view_selection_multiple_views() -> None: result = await collection.ask("Mock question") assert result.view_name == "MockViewWithResults" assert result.results == [{"foo": "bar"}] - assert result.context == {"baz": "qux", "iql": {'aggregation': 'test_aggregation()', 'filters': 'test_filter()'}} + assert result.context == {"baz": "qux", "iql": {"aggregation": "test_aggregation()", "filters": "test_filter()"}} async def test_ask_view_selection_no_views() -> None: From a2169f2bb83630ac8d847e505ef41fffd253f34f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 16 Aug 2024 04:03:44 +0200 Subject: [PATCH 13/21] revert to prev approach --- .../sql/bench/views/structured/superhero.py | 13 +++++++------ docs/quickstart/quickstart_code.py | 10 +--------- src/dbally/views/methods_base.py | 16 ++++++++-------- src/dbally/views/pandas_base.py | 7 +++---- src/dbally/views/sqlalchemy_base.py | 15 ++++++++------- src/dbally/views/structured.py | 10 ++++++++-- 6 files changed, 35 insertions(+), 36 deletions(-) diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index f391887d..e5b8f890 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -1,7 +1,7 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, singleton-comparison, consider-using-in, too-many-ancestors, too-many-public-methods # flake8: noqa -from typing import Any, Literal +from typing import Literal from sqlalchemy import ColumnElement, Engine, Select, func, select from sqlalchemy.ext.declarative import DeferredReflection @@ -285,12 +285,13 @@ class SuperheroColourFilterMixin: Mixin for filtering the view by the superhero colour attributes. """ - def __init__(self) -> None: - super().__init__() + def __init__(self, *args, **kwargs) -> None: self.eye_colour = aliased(Colour) self.hair_colour = aliased(Colour) self.skin_colour = aliased(Colour) + super().__init__(*args, **kwargs) + @view_filter() def filter_by_eye_colour(self, eye_colour: str) -> ColumnElement: """ @@ -433,19 +434,18 @@ class SuperheroAggregationMixin: """ @view_aggregation() - def count_superheroes(self) -> Any: + def count_superheroes(self, data_source: Select) -> Select: """ Counts the number of superheros. Returns: The superheros count. """ - return func.count(Superhero.id).label("count_superheroes") + return data_source.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) class SuperheroView( DBInitMixin, - SqlAlchemyBaseView, SuperheroFilterMixin, SuperheroAggregationMixin, SuperheroColourFilterMixin, @@ -453,6 +453,7 @@ class SuperheroView( GenderFilterMixin, PublisherFilterMixin, RaceFilterMixin, + SqlAlchemyBaseView, ): """ View for querying only superheros data. Contains the superhero id, superhero name, full name, height, weight, diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index a7bc700d..ef73cad0 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -54,13 +54,6 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: """ return Candidate.country == country - @decorators.view_aggregation() - def count_by_column(self, filtered_query: sqlalchemy.Select, column_name: str) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011 - select = sqlalchemy.select(getattr(filtered_query.c, column_name), - sqlalchemy.func.count(filtered_query.c.name).label("count")) \ - .group_by(getattr(filtered_query.c, column_name)) - return select - async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") @@ -69,8 +62,7 @@ async def main(): collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) - result = await collection.ask("Could you find French candidates suitable for a senior data scientist position" - "and count the candidates university-wise?") + result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") print(f"The generated SQL query is: {result.context.get('sql')}") print() diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index c1d15451..1be1aebc 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -1,21 +1,21 @@ -import abc import inspect import textwrap -from typing import Any, Callable, List, Tuple +from abc import ABC +from typing import Any, Callable, Generic, List, Tuple from dbally.iql import syntax from dbally.views import decorators from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from dbally.views.structured import BaseStructuredView +from dbally.views.structured import BaseStructuredView, DataSourceT -class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta): +class MethodsBaseView(Generic[DataSourceT], BaseStructuredView[DataSourceT], ABC): """ Base class for views that use view methods to expose filters. """ # Method arguments that should be skipped when listing methods - HIDDEN_ARGUMENTS = ["self", "select", "return"] + HIDDEN_ARGUMENTS = ["cls", "self", "return", "data_source"] @classmethod def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]: @@ -110,7 +110,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: return await method(*args) return method(*args) - async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: + async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataSourceT: """ Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. @@ -123,5 +123,5 @@ async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: method, args = self._method_with_args_from_call(func, decorators.view_aggregation) if inspect.iscoroutinefunction(method): - return await method(*args) - return method(*args) + return await method(self._data_source, *args) + return method(self._data_source, *args) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index a3bd8d0a..2d628303 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -9,7 +9,7 @@ from dbally.views.methods_base import MethodsBaseView -class DataFrameBaseView(MethodsBaseView): +class DataFrameBaseView(MethodsBaseView[pd.DataFrame]): """ Base class for views that use Pandas DataFrames to store and filter data. @@ -22,8 +22,7 @@ def __init__(self, df: pd.DataFrame) -> None: Args: df: Pandas DataFrame with the data to be filtered """ - super().__init__() - self.df = df + super().__init__(df) # The mask to be applied to the dataframe to filter the data self._filter_mask: Optional[pd.Series] = None @@ -87,7 +86,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: filtered_data = pd.DataFrame.empty if not dry_run: - filtered_data = self.df + filtered_data = self._data_source if self._filter_mask is not None: filtered_data = filtered_data.loc[self._filter_mask] diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index 9e708b5c..c8cdcad1 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -8,15 +8,14 @@ from dbally.views.methods_base import MethodsBaseView -class SqlAlchemyBaseView(MethodsBaseView): +class SqlAlchemyBaseView(MethodsBaseView[sqlalchemy.Select]): """ Base class for views that use SQLAlchemy to generate SQL queries. """ def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None: - super().__init__() + super().__init__(self.get_select()) self._sqlalchemy_engine = sqlalchemy_engine - self._select = self.get_select() @abc.abstractmethod def get_select(self) -> sqlalchemy.Select: @@ -34,7 +33,8 @@ async def apply_filters(self, filters: IQLQuery) -> None: Args: filters: IQLQuery object representing the filters to apply. """ - self._select = self._select.where(await self._build_filter_node(filters.root)) + # pylint: disable=W0201 + self._data_source = self._data_source.where(await self._build_filter_node(filters.root)) async def apply_aggregation(self, aggregation: IQLQuery) -> None: """ @@ -43,7 +43,8 @@ async def apply_aggregation(self, aggregation: IQLQuery) -> None: Args: aggregation: IQLQuery object representing the aggregation to apply. """ - self._select = self._select.with_only_columns(await self.call_aggregation_method(aggregation.root)) + # pylint: disable=W0201 + self._data_source = await self.call_aggregation_method(aggregation.root) async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement: """ @@ -86,13 +87,13 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored. """ results = [] - sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) + sql = str(self._data_source.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) if not dry_run: with self._sqlalchemy_engine.connect() as connection: # The underscore is used by sqlalchemy to avoid conflicts with column names # pylint: disable=protected-access - rows = connection.execute(self._select).fetchall() + rows = connection.execute(self._data_source).fetchall() results = [dict(row._mapping) for row in rows] return ViewExecutionResult( diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 9c1203ef..c695b684 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -1,6 +1,6 @@ import abc from collections import defaultdict -from typing import Dict, List, Optional +from typing import Any, Dict, Generic, List, Optional, TypeVar from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult @@ -18,13 +18,19 @@ from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation +DataSourceT = TypeVar("DataSourceT", bound=Any) -class BaseStructuredView(BaseView): + +class BaseStructuredView(Generic[DataSourceT], BaseView): """ Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\ to be able to list all available filters, apply them and execute queries. """ + def __init__(self, data_source: DataSourceT) -> None: + super().__init__() + self._data_source = data_source + def get_iql_generator(self, llm: LLM) -> IQLGenerator: """ Returns the IQL generator for the view. From 013cb695fa701db9037233b0a9f2743dca46b4af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 16 Aug 2024 23:02:38 +0200 Subject: [PATCH 14/21] fix tests --- .../sql/bench/views/structured/superhero.py | 4 +- src/dbally/views/methods_base.py | 12 ++--- src/dbally/views/pandas_base.py | 34 ++++++------ src/dbally/views/sqlalchemy_base.py | 20 ++++--- src/dbally/views/structured.py | 8 +-- tests/unit/mocks.py | 3 ++ tests/unit/test_iql_generator.py | 3 ++ tests/unit/views/test_methods_base.py | 3 ++ tests/unit/views/test_pandas_base.py | 53 +++++++++++++++++-- 9 files changed, 101 insertions(+), 39 deletions(-) diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index e5b8f890..56932947 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -434,14 +434,14 @@ class SuperheroAggregationMixin: """ @view_aggregation() - def count_superheroes(self, data_source: Select) -> Select: + def count_superheroes(self) -> Select: """ Counts the number of superheros. Returns: The superheros count. """ - return data_source.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) + return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) class SuperheroView( diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 1be1aebc..f90d6d7d 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -6,16 +6,16 @@ from dbally.iql import syntax from dbally.views import decorators from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from dbally.views.structured import BaseStructuredView, DataSourceT +from dbally.views.structured import BaseStructuredView, DataT -class MethodsBaseView(Generic[DataSourceT], BaseStructuredView[DataSourceT], ABC): +class MethodsBaseView(Generic[DataT], BaseStructuredView[DataT], ABC): """ Base class for views that use view methods to expose filters. """ # Method arguments that should be skipped when listing methods - HIDDEN_ARGUMENTS = ["cls", "self", "return", "data_source"] + HIDDEN_ARGUMENTS = ["cls", "self", "return"] @classmethod def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]: @@ -110,7 +110,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: return await method(*args) return method(*args) - async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataSourceT: + async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataT: """ Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. @@ -123,5 +123,5 @@ async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataSource method, args = self._method_with_args_from_call(func, decorators.view_aggregation) if inspect.iscoroutinefunction(method): - return await method(self._data_source, *args) - return method(self._data_source, *args) + return await method(*args) + return method(*args) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index 2d628303..5f7bc8ce 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -19,8 +19,10 @@ class DataFrameBaseView(MethodsBaseView[pd.DataFrame]): def __init__(self, df: pd.DataFrame) -> None: """ + Creates a new instance of the DataFrame view. + Args: - df: Pandas DataFrame with the data to be filtered + df: Pandas DataFrame with the data to be filtered. """ super().__init__(df) @@ -32,18 +34,23 @@ async def apply_filters(self, filters: IQLQuery) -> None: Applies the chosen filters to the view. Args: - filters: IQLQuery object representing the filters to apply + filters: IQLQuery object representing the filters to apply. """ + # data is defined in the parent class + # pylint: disable=attribute-defined-outside-init self._filter_mask = await self.build_filter_node(filters.root) + self.data = self.data.loc[self._filter_mask] async def apply_aggregation(self, aggregation: IQLQuery) -> None: """ Applies the aggregation of choice to the view. Args: - aggregation: IQLQuery object representing the aggregation to apply + aggregation: IQLQuery object representing the aggregation to apply. """ - # TODO - to be covered in a separate ticket. + # data is defined in the parent class + # pylint: disable=attribute-defined-outside-init + self.data = await self.call_aggregation_method(aggregation.root) async def build_filter_node(self, node: syntax.Node) -> pd.Series: """ @@ -51,13 +58,13 @@ async def build_filter_node(self, node: syntax.Node) -> pd.Series: a boolean mask to be applied to the dataframe. Args: - node: IQLQuery node representing the filter or logical operator + node: IQLQuery node representing the filter or logical operator. Returns: - A boolean mask that can be used to filter the original DataFrame + A boolean mask that can be used to filter the original DataFrame. Raises: - ValueError: If the node type is not supported + ValueError: If the node type is not supported. """ if isinstance(node, syntax.FunctionCall): return await self.call_filter_method(node) @@ -78,20 +85,15 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: Args: dry_run: If True, the method will only add `context` field to the `ExecutionResult` with the\ - mask that would be applied to the dataframe + mask that would be applied to the dataframe. Returns: - ExecutionResult object with the results and the context information with the binary mask + ExecutionResult object with the results and the context information with the binary mask. """ - filtered_data = pd.DataFrame.empty - - if not dry_run: - filtered_data = self._data_source - if self._filter_mask is not None: - filtered_data = filtered_data.loc[self._filter_mask] + results = pd.DataFrame.empty if dry_run else self.data return ViewExecutionResult( - results=filtered_data.to_dict(orient="records"), + results=results.to_dict(orient="records"), context={ "filter_mask": self._filter_mask, }, diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index c8cdcad1..4863aa6f 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -14,6 +14,12 @@ class SqlAlchemyBaseView(MethodsBaseView[sqlalchemy.Select]): """ def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None: + """ + Creates a new instance of the SQL view. + + Args: + sqlalchemy_engine: SQLAlchemy engine to use for executing the queries. + """ super().__init__(self.get_select()) self._sqlalchemy_engine = sqlalchemy_engine @@ -33,8 +39,9 @@ async def apply_filters(self, filters: IQLQuery) -> None: Args: filters: IQLQuery object representing the filters to apply. """ - # pylint: disable=W0201 - self._data_source = self._data_source.where(await self._build_filter_node(filters.root)) + # data is defined in the parent class + # pylint: disable=attribute-defined-outside-init + self.data = self.data.where(await self._build_filter_node(filters.root)) async def apply_aggregation(self, aggregation: IQLQuery) -> None: """ @@ -43,8 +50,9 @@ async def apply_aggregation(self, aggregation: IQLQuery) -> None: Args: aggregation: IQLQuery object representing the aggregation to apply. """ - # pylint: disable=W0201 - self._data_source = await self.call_aggregation_method(aggregation.root) + # data is defined in the parent class + # pylint: disable=attribute-defined-outside-init + self.data = await self.call_aggregation_method(aggregation.root) async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement: """ @@ -87,13 +95,13 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored. """ results = [] - sql = str(self._data_source.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) + sql = str(self.data.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) if not dry_run: with self._sqlalchemy_engine.connect() as connection: + rows = connection.execute(self.data).fetchall() # The underscore is used by sqlalchemy to avoid conflicts with column names # pylint: disable=protected-access - rows = connection.execute(self._data_source).fetchall() results = [dict(row._mapping) for row in rows] return ViewExecutionResult( diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c695b684..c61ee0dd 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -18,18 +18,18 @@ from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation -DataSourceT = TypeVar("DataSourceT", bound=Any) +DataT = TypeVar("DataT", bound=Any) -class BaseStructuredView(Generic[DataSourceT], BaseView): +class BaseStructuredView(Generic[DataT], BaseView): """ Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\ to be able to list all available filters, apply them and execute queries. """ - def __init__(self, data_source: DataSourceT) -> None: + def __init__(self, data: DataT) -> None: super().__init__() - self._data_source = data_source + self.data = data def get_iql_generator(self, llm: LLM) -> IQLGenerator: """ diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 0d66df3b..29a5cc83 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -24,6 +24,9 @@ class MockViewBase(BaseStructuredView): Mock view base class """ + def __init__(self) -> None: + super().__init__(None) + def list_filters(self) -> List[ExposedFunction]: return [] diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 401c09d1..b95fe585 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -20,6 +20,9 @@ class MockView(MethodsBaseView): + def __init__(self) -> None: + super().__init__(None) + def get_select(self) -> sqlalchemy.Select: ... diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 841af1a7..e8a2bb56 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -15,6 +15,9 @@ class MockMethodsBase(MethodsBaseView): Mock class for testing the MethodsBaseView """ + def __init__(self) -> None: + super().__init__(None) + @view_filter() def method_foo(self, idx: int) -> None: """ diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 51eea791..52a8f405 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -3,7 +3,7 @@ import pandas as pd from dbally.iql import IQLQuery -from dbally.views.decorators import view_filter +from dbally.views.decorators import view_aggregation, view_filter from dbally.views.pandas_base import DataFrameBaseView MOCK_DATA = [ @@ -39,19 +39,23 @@ class MockDataFrameView(DataFrameBaseView): @view_filter() def filter_city(self, city: str) -> pd.Series: - return self.df["city"] == city + return self.data["city"] == city @view_filter() def filter_year(self, year: int) -> pd.Series: - return self.df["year"] == year + return self.data["year"] == year @view_filter() def filter_age(self, age: int) -> pd.Series: - return self.df["age"] == age + return self.data["age"] == age @view_filter() def filter_name(self, name: str) -> pd.Series: - return self.df["name"] == name + return self.data["name"] == name + + @view_aggregation() + def mean_age_by_city(self) -> pd.DataFrame: + return self.data.groupby(["city"]).agg({"age": "mean"}).reset_index() async def test_filter_or() -> None: @@ -97,3 +101,42 @@ async def test_filter_not() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_NOT_PARIS_2020 assert result.context["filter_mask"].tolist() == [True, False, True, True, True] + + +async def test_aggregtion() -> None: + """ + Test that DataFrame aggregation works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLQuery.parse( + "mean_age_by_city()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [ + {"city": "Berlin", "age": 45.0}, + {"city": "London", "age": 32.5}, + {"city": "Paris", "age": 32.5}, + ] + assert result.context["filter_mask"] is None + + +async def test_filters_and_aggregtion() -> None: + """ + Test that DataFrame filtering and aggregation works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLQuery.parse( + "filter_city('Paris')", + allowed_functions=mock_view.list_filters(), + ) + await mock_view.apply_filters(query) + query = await IQLQuery.parse( + "mean_age_by_city()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [{"city": "Paris", "age": 32.5}] + assert result.context["filter_mask"].tolist() == [False, True, False, True, False] From f0a2f6eaa2903c724a0fffe274ad83f2b394557a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Sat, 17 Aug 2024 00:33:49 +0200 Subject: [PATCH 15/21] add more tests --- tests/unit/views/test_methods_base.py | 32 ++++++++++++++++- tests/unit/views/test_sqlalchemy_base.py | 46 +++++++++++++++++++++++- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index e8a2bb56..8d90ffc3 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -5,7 +5,7 @@ from dbally.collection.results import ViewExecutionResult from dbally.iql import IQLQuery -from dbally.views.decorators import view_filter +from dbally.views.decorators import view_aggregation, view_filter from dbally.views.exposed_functions import MethodParamWithTyping from dbally.views.methods_base import MethodsBaseView @@ -28,6 +28,16 @@ def method_foo(self, idx: int) -> None: def method_bar(self, cities: List[str], year: Literal["2023", "2024"], pairs: List[Tuple[str, int]]) -> str: return f"hello {cities} in {year} of {pairs}" + @view_aggregation() + def method_baz(self) -> None: + """ + Some documentation string + """ + + @view_aggregation() + def method_qux(self, ages: List[int], names: List[str]) -> None: + return f"hello {ages} and {names}" + async def apply_filters(self, filters: IQLQuery) -> None: ... @@ -59,3 +69,23 @@ def test_list_filters() -> None: assert ( str(method_bar) == "method_bar(cities: List[str], year: Literal['2023', '2024'], pairs: List[Tuple[str, int]])" ) + + +def test_list_aggregations() -> None: + """ + Tests that the list_aggregations method works correctly + """ + mock_view = MockMethodsBase() + aggregations = mock_view.list_aggregations() + assert len(aggregations) == 2 + method_baz = [f for f in aggregations if f.name == "method_baz"][0] + assert method_baz.description == "Some documentation string" + assert method_baz.parameters == [] + assert str(method_baz) == "method_baz() - Some documentation string" + method_qux = [f for f in aggregations if f.name == "method_qux"][0] + assert method_qux.description == "" + assert method_qux.parameters == [ + MethodParamWithTyping("ages", List[int]), + MethodParamWithTyping("names", List[str]), + ] + assert str(method_qux) == "method_qux(ages: List[int], names: List[str])" diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 079a2135..435c8f8e 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -5,7 +5,7 @@ import sqlalchemy from dbally.iql import IQLQuery -from dbally.views.decorators import view_filter +from dbally.views.decorators import view_aggregation, view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -28,6 +28,13 @@ def method_foo(self, idx: int) -> sqlalchemy.ColumnElement: async def method_bar(self, city: str, year: int) -> sqlalchemy.ColumnElement: return sqlalchemy.literal(f"hello {city} in {year}") + @view_aggregation() + def method_baz(self) -> sqlalchemy.Select: + """ + Some documentation string + """ + return self.data.add_columns(sqlalchemy.literal("baz")).group_by(sqlalchemy.literal("baz")) + def normalize_whitespace(s: str) -> str: """ @@ -50,3 +57,40 @@ async def test_filter_sql_generation() -> None: await mock_view.apply_filters(query) sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) assert sql == "SELECT 'test' AS foo WHERE 1 AND 'hello London in 2020'" + + +async def test_aggregation_sql_generation() -> None: + """ + Tests that the SQL generation based on aggregations works correctly + """ + + mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) + mock_view = MockSqlAlchemyView(mock_connection.engine) + query = await IQLQuery.parse( + "method_baz()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) + assert sql == "SELECT 'test' AS foo, 'baz' AS anon_1 GROUP BY 'baz'" + + +async def test_filter_and_aggregation_sql_generation() -> None: + """ + Tests that the SQL generation based on filters and aggregations works correctly + """ + + mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) + mock_view = MockSqlAlchemyView(mock_connection.engine) + query = await IQLQuery.parse( + 'method_foo(1) and method_bar("London", 2020)', + allowed_functions=mock_view.list_filters() + mock_view.list_aggregations(), + ) + await mock_view.apply_filters(query) + query = await IQLQuery.parse( + "method_baz()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + sql = normalize_whitespace(mock_view.execute(dry_run=True).context["sql"]) + assert sql == "SELECT 'test' AS foo, 'baz' AS anon_1 WHERE 1 AND 'hello London in 2020' GROUP BY 'baz'" From aeb6295fdc5452fdab05fa95060ad3ac906f677c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Sat, 17 Aug 2024 00:50:04 +0200 Subject: [PATCH 16/21] trying to fix tests (localy working) --- tests/unit/mocks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 29a5cc83..8b01c69f 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -19,13 +19,13 @@ from dbally.views.structured import BaseStructuredView, ExposedFunction, ViewExecutionResult -class MockViewBase(BaseStructuredView): +class MockViewBase(BaseStructuredView[List]): """ Mock view base class """ def __init__(self) -> None: - super().__init__(None) + super().__init__([]) def list_filters(self) -> List[ExposedFunction]: return [] @@ -39,7 +39,7 @@ def list_aggregations(self) -> List[ExposedFunction]: async def apply_aggregation(self, filters: IQLQuery) -> None: ... - def execute(self, dry_run=False) -> ViewExecutionResult: + def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult(results=[], context={}) From d21f4e1608905ecfbad1bd5038e9d5761e6884d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Sat, 17 Aug 2024 02:44:22 +0200 Subject: [PATCH 17/21] fix tests for python 3.8 --- src/dbally/views/methods_base.py | 2 +- src/dbally/views/structured.py | 5 +++-- tests/unit/mocks.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index f90d6d7d..977a2fa1 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -9,7 +9,7 @@ from dbally.views.structured import BaseStructuredView, DataT -class MethodsBaseView(Generic[DataT], BaseStructuredView[DataT], ABC): +class MethodsBaseView(Generic[DataT], BaseStructuredView, ABC): """ Base class for views that use view methods to expose filters. """ diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c61ee0dd..c3ac91e0 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -1,6 +1,6 @@ import abc from collections import defaultdict -from typing import Any, Dict, Generic, List, Optional, TypeVar +from typing import Any, Dict, List, Optional, TypeVar from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult @@ -21,7 +21,8 @@ DataT = TypeVar("DataT", bound=Any) -class BaseStructuredView(Generic[DataT], BaseView): +# TODO(Python 3.9+): Make BaseStructuredView a generic class +class BaseStructuredView(BaseView): """ Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\ to be able to list all available filters, apply them and execute queries. diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 8b01c69f..992fd03d 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -19,7 +19,7 @@ from dbally.views.structured import BaseStructuredView, ExposedFunction, ViewExecutionResult -class MockViewBase(BaseStructuredView[List]): +class MockViewBase(BaseStructuredView): """ Mock view base class """ From 9a8b63e6cd74cb999ad48eb423ea6e58c6b18467 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 26 Aug 2024 09:12:04 +0000 Subject: [PATCH 18/21] review: aggregations in structured views (#85) --- benchmarks/sql/bench/pipelines/base.py | 26 +- benchmarks/sql/bench/pipelines/collection.py | 40 +-- benchmarks/sql/bench/pipelines/view.py | 35 +-- .../sql/bench/views/structured/superhero.py | 7 +- src/dbally/exceptions.py | 7 - src/dbally/iql/__init__.py | 12 +- src/dbally/iql/_processor.py | 77 ++++-- src/dbally/iql/_query.py | 37 ++- src/dbally/iql_generator/iql_generator.py | 232 ++++++++++++++---- .../iql_generator/iql_prompt_template.py | 0 src/dbally/iql_generator/prompt.py | 165 ++++++++++--- src/dbally/prompt/aggregation.py | 122 --------- src/dbally/prompt/template.py | 4 +- src/dbally/views/exceptions.py | 18 +- src/dbally/views/methods_base.py | 8 +- src/dbally/views/pandas_base.py | 52 ++-- src/dbally/views/sqlalchemy_base.py | 24 +- src/dbally/views/structured.py | 119 +++------ tests/unit/iql/test_iql_parser.py | 226 +++++++++++++++-- tests/unit/mocks.py | 35 +-- .../similarity/sample_module/submodule.py | 12 +- tests/unit/test_collection.py | 24 +- tests/unit/test_iql_format.py | 16 +- tests/unit/test_iql_generator.py | 194 +++++++++------ tests/unit/views/test_methods_base.py | 11 +- tests/unit/views/test_pandas_base.py | 72 ++++-- tests/unit/views/test_sqlalchemy_base.py | 13 +- 27 files changed, 959 insertions(+), 629 deletions(-) delete mode 100644 src/dbally/iql_generator/iql_prompt_template.py delete mode 100644 src/dbally/prompt/aggregation.py diff --git a/benchmarks/sql/bench/pipelines/base.py b/benchmarks/sql/bench/pipelines/base.py index 38bcb304..dc8d83ea 100644 --- a/benchmarks/sql/bench/pipelines/base.py +++ b/benchmarks/sql/bench/pipelines/base.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union +from dbally.iql._exceptions import IQLError +from dbally.iql._query import IQLQuery +from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM +from dbally.llms.clients.exceptions import LLMError from dbally.llms.litellm import LiteLLM from dbally.llms.local import LocalLLM @@ -16,6 +20,25 @@ class IQL: source: Optional[str] = None unsupported: bool = False valid: bool = True + generated: bool = True + + @classmethod + def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL": + """ + Creates an IQL object from the query. + + Args: + query: The IQL query or exception. + + Returns: + The IQL object. + """ + return cls( + source=query.source if isinstance(query, (IQLQuery, IQLError)) else None, + unsupported=isinstance(query, UnsupportedQueryError), + valid=not isinstance(query, IQLError), + generated=not isinstance(query, LLMError), + ) @dataclass @@ -47,6 +70,7 @@ class EvaluationResult: """ db_id: str + question_id: str question: str reference: ExecutionResult prediction: ExecutionResult diff --git a/benchmarks/sql/bench/pipelines/collection.py b/benchmarks/sql/bench/pipelines/collection.py index dfc127cf..19831b0d 100644 --- a/benchmarks/sql/bench/pipelines/collection.py +++ b/benchmarks/sql/bench/pipelines/collection.py @@ -5,10 +5,8 @@ import dbally from dbally.collection.collection import Collection from dbally.collection.exceptions import NoViewFoundError -from dbally.iql._exceptions import IQLError -from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.view_selection.llm_view_selector import LLMViewSelector -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from ..views import VIEWS_REGISTRY from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult @@ -74,44 +72,23 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return_natural_response=False, ) except NoViewFoundError: - prediction = ExecutionResult( - view_name=None, - iql=None, - sql=None, - ) - except IQLGenerationError as exc: + prediction = ExecutionResult() + except ViewExecutionError as exc: prediction = ExecutionResult( view_name=exc.view_name, iql=IQLResult( - filters=IQL( - source=exc.filters, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), - aggregation=IQL( - source=exc.aggregation, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), + filters=IQL.from_query(exc.iql.filters), + aggregation=IQL.from_query(exc.iql.aggregation), ), - sql=None, ) else: prediction = ExecutionResult( view_name=result.view_name, iql=IQLResult( - filters=IQL( - source=result.context.get("iql"), - unsupported=False, - valid=True, - ), - aggregation=IQL( - source=None, - unsupported=False, - valid=True, - ), + filters=IQL(source=result.context["iql"]["filters"]), + aggregation=IQL(source=result.context["iql"]["aggregation"]), ), - sql=result.context.get("sql"), + sql=result.context["sql"], ) reference = ExecutionResult( @@ -134,6 +111,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/pipelines/view.py b/benchmarks/sql/bench/pipelines/view.py index d4ae8515..be9d8263 100644 --- a/benchmarks/sql/bench/pipelines/view.py +++ b/benchmarks/sql/bench/pipelines/view.py @@ -5,9 +5,7 @@ from sqlalchemy import create_engine -from dbally.iql._exceptions import IQLError -from dbally.iql_generator.prompt import UnsupportedQueryError -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from dbally.views.freeform.text2sql.view import BaseText2SQLView from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -94,37 +92,20 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: dry_run=True, n_retries=0, ) - except IQLGenerationError as exc: + except ViewExecutionError as exc: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL( - source=exc.filters, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), - aggregation=IQL( - source=exc.aggregation, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), + filters=IQL.from_query(exc.iql.filters), + aggregation=IQL.from_query(exc.iql.aggregation), ), - sql=None, ) else: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL( - source=result.context["iql"], - unsupported=False, - valid=True, - ), - aggregation=IQL( - source=None, - unsupported=False, - valid=True, - ), + filters=IQL(source=result.context["iql"]["filters"]), + aggregation=IQL(source=result.context["iql"]["aggregation"]), ), sql=result.context["sql"], ) @@ -135,12 +116,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: filters=IQL( source=data["iql_filters"], unsupported=data["iql_filters_unsupported"], - valid=True, ), aggregation=IQL( source=data["iql_aggregation"], unsupported=data["iql_aggregation_unsupported"], - valid=True, ), context=data["iql_context"], ), @@ -149,6 +128,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, @@ -209,6 +189,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index 56932947..2a6a75a0 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -286,12 +286,11 @@ class SuperheroColourFilterMixin: """ def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) self.eye_colour = aliased(Colour) self.hair_colour = aliased(Colour) self.skin_colour = aliased(Colour) - super().__init__(*args, **kwargs) - @view_filter() def filter_by_eye_colour(self, eye_colour: str) -> ColumnElement: """ @@ -441,11 +440,12 @@ def count_superheroes(self) -> Select: Returns: The superheros count. """ - return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) + return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) class SuperheroView( DBInitMixin, + SqlAlchemyBaseView, SuperheroFilterMixin, SuperheroAggregationMixin, SuperheroColourFilterMixin, @@ -453,7 +453,6 @@ class SuperheroView( GenderFilterMixin, PublisherFilterMixin, RaceFilterMixin, - SqlAlchemyBaseView, ): """ View for querying only superheros data. Contains the superhero id, superhero name, full name, height, weight, diff --git a/src/dbally/exceptions.py b/src/dbally/exceptions.py index 62faac37..6b095cd7 100644 --- a/src/dbally/exceptions.py +++ b/src/dbally/exceptions.py @@ -2,10 +2,3 @@ class DbAllyError(Exception): """ Base class for all exceptions raised by db-ally. """ - - -class UnsupportedAggregationError(DbAllyError): - """ - Error raised when AggregationFormatter is unable to construct a query - with given aggregation. - """ diff --git a/src/dbally/iql/__init__.py b/src/dbally/iql/__init__.py index 0df0a766..20bde9eb 100644 --- a/src/dbally/iql/__init__.py +++ b/src/dbally/iql/__init__.py @@ -1,5 +1,13 @@ from . import syntax from ._exceptions import IQLArgumentParsingError, IQLError, IQLUnsupportedSyntaxError -from ._query import IQLQuery +from ._query import IQLAggregationQuery, IQLFiltersQuery, IQLQuery -__all__ = ["IQLQuery", "syntax", "IQLError", "IQLArgumentParsingError", "IQLUnsupportedSyntaxError"] +__all__ = [ + "IQLQuery", + "IQLFiltersQuery", + "IQLAggregationQuery", + "syntax", + "IQLError", + "IQLArgumentParsingError", + "IQLUnsupportedSyntaxError", +] diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index f1adf64c..1bd72bcc 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,5 +1,6 @@ import ast -from typing import TYPE_CHECKING, Any, List, Optional, Union +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker from dbally.iql import syntax @@ -19,10 +20,12 @@ if TYPE_CHECKING: from dbally.views.structured import ExposedFunction +RootT = TypeVar("RootT", bound=syntax.Node) -class IQLProcessor: + +class IQLProcessor(Generic[RootT], ABC): """ - Parses IQL string to tree structure. + Base class for IQL processors. """ def __init__( @@ -32,9 +35,9 @@ def __init__( self.allowed_functions = {func.name: func for func in allowed_functions} self._event_tracker = event_tracker or EventTracker() - async def process(self) -> syntax.Node: + async def process(self) -> RootT: """ - Process IQL string to root IQL.Node. + Process IQL string to IQL root node. Returns: IQL node which is root of the tree representing IQL query. @@ -60,25 +63,17 @@ async def process(self) -> syntax.Node: return await self._parse_node(ast_tree.body[0].value) - async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: - if isinstance(node, ast.BoolOp): - return await self._parse_bool_op(node) - if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): - return syntax.Not(await self._parse_node(node.operand)) - if isinstance(node, ast.Call): - return await self._parse_call(node) - - raise IQLUnsupportedSyntaxError(node, self.source) + @abstractmethod + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT: + """ + Parses AST node to IQL node. - async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp: - if isinstance(node.op, ast.Not): - return syntax.Not(await self._parse_node(node.values[0])) - if isinstance(node.op, ast.And): - return syntax.And([await self._parse_node(x) for x in node.values]) - if isinstance(node.op, ast.Or): - return syntax.Or([await self._parse_node(x) for x in node.values]) + Args: + node: AST node to parse. - raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp") + Returns: + IQL node. + """ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: func = node.func @@ -153,3 +148,41 @@ def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str: converted_text = converted_text[: len(converted_text) - len(keyword)] + keyword.lower() return converted_text + + +class IQLFiltersProcessor(IQLProcessor[syntax.Node]): + """ + IQL processor for filters. + """ + + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: + if isinstance(node, ast.BoolOp): + return await self._parse_bool_op(node) + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return syntax.Not(await self._parse_node(node.operand)) + if isinstance(node, ast.Call): + return await self._parse_call(node) + + raise IQLUnsupportedSyntaxError(node, self.source) + + async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp: + if isinstance(node.op, ast.Not): + return syntax.Not(await self._parse_node(node.values[0])) + if isinstance(node.op, ast.And): + return syntax.And([await self._parse_node(x) for x in node.values]) + if isinstance(node.op, ast.Or): + return syntax.Or([await self._parse_node(x) for x in node.values]) + + raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp") + + +class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]): + """ + IQL processor for aggregation. + """ + + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall: + if isinstance(node, ast.Call): + return await self._parse_call(node) + + raise IQLUnsupportedSyntaxError(node, self.source) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index dd831a91..57b3b4ed 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,26 +1,29 @@ -from typing import TYPE_CHECKING, List, Optional +from abc import ABC +from typing import TYPE_CHECKING, Generic, List, Optional, Type from ..audit.event_tracker import EventTracker from . import syntax -from ._processor import IQLProcessor +from ._processor import IQLAggregationProcessor, IQLFiltersProcessor, IQLProcessor, RootT if TYPE_CHECKING: from dbally.views.structured import ExposedFunction -class IQLQuery: +class IQLQuery(Generic[RootT], ABC): """ IQLQuery container. It stores IQL as a syntax tree defined in `IQL` class. """ - root: syntax.Node + root: RootT + source: str + _processor: Type[IQLProcessor[RootT]] - def __init__(self, root: syntax.Node, source: str) -> None: + def __init__(self, root: RootT, source: str) -> None: self.root = root - self._source = source + self.source = source def __str__(self) -> str: - return self._source + return self.source @classmethod async def parse( @@ -28,7 +31,7 @@ async def parse( source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None, - ) -> "IQLQuery": + ) -> "IQLQuery[RootT]": """ Parse IQL string to IQLQuery object. @@ -43,5 +46,21 @@ async def parse( Raises: IQLError: If parsing fails. """ - root = await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process() + root = await cls._processor(source, allowed_functions, event_tracker=event_tracker).process() return cls(root=root, source=source) + + +class IQLFiltersQuery(IQLQuery[syntax.Node]): + """ + IQL filters query container. + """ + + _processor: Type[IQLFiltersProcessor] = IQLFiltersProcessor + + +class IQLAggregationQuery(IQLQuery[syntax.FunctionCall]): + """ + IQL aggregation query container. + """ + + _processor: Type[IQLAggregationProcessor] = IQLAggregationProcessor diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 27347734..2222e179 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,11 +1,16 @@ -from typing import List, Optional +import asyncio +from dataclasses import dataclass +from typing import Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker from dbally.iql import IQLError, IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.prompt import ( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, + FILTERS_GENERATION_TEMPLATE, + DecisionPromptFormat, IQLGenerationPromptFormat, ) from dbally.llms.base import LLM @@ -15,57 +20,151 @@ from dbally.prompt.template import PromptTemplate from dbally.views.exposed_functions import ExposedFunction -ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ - generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" +IQLQueryT = TypeVar("IQLQueryT", bound=IQLQuery) -class IQLGenerator: +@dataclass +class IQLGeneratorState: + """ + State of the IQL generator. """ - Class used to generate IQL from natural language question. - In db-ally, LLM uses IQL (Intermediate Query Language) to express complex queries in a simplified way. - The class used to generate IQL from natural language query is `IQLGenerator`. + filters: Optional[Union[IQLFiltersQuery, Exception]] = None + aggregation: Optional[Union[IQLAggregationQuery, Exception]] = None + + @property + def failed(self) -> bool: + """ + Checks if the generation failed. - IQL generation is done using the method `self.generate_iql`. - It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question. + Returns: + True if the generation failed, False otherwise. + """ + return isinstance(self.filters, Exception) or isinstance(self.aggregation, Exception) + + +class IQLGenerator: + """ + Program that orchestrates all IQL operations for the given question. """ def __init__( self, - llm: LLM, - *, - decision_prompt: Optional[PromptTemplate[FilteringDecisionPromptFormat]] = None, - generation_prompt: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None, + filters_generation: Optional["IQLOperationGenerator"] = None, + aggregation_generation: Optional["IQLOperationGenerator"] = None, ) -> None: """ Constructs a new IQLGenerator instance. Args: - llm: LLM used to generate IQL. decision_prompt: Prompt template for filtering decision making. generation_prompt: Prompt template for IQL generation. """ - self._llm = llm - self._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE - self._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE + self._filters_generation = filters_generation or IQLOperationGenerator[IQLFiltersQuery]( + FILTERING_DECISION_TEMPLATE, + FILTERS_GENERATION_TEMPLATE, + ) + self._aggregation_generation = aggregation_generation or IQLOperationGenerator[IQLAggregationQuery]( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, + ) - async def generate( + # pylint: disable=too-many-arguments + async def __call__( self, + *, question: str, filters: List[ExposedFunction], - event_tracker: EventTracker, - examples: Optional[List[FewShotExample]] = None, + aggregations: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - ) -> Optional[IQLQuery]: + ) -> IQLGeneratorState: """ - Generates IQL in text form using LLM. + Generates IQL operations for the given question. Args: question: User question. filters: List of filters exposed by the view. + aggregations: List of aggregations exposed by the view. + examples: List of examples to be injected during filters and aggregation generation. + llm: LLM used to generate IQL. event_tracker: Event store used to audit the generation process. + llm_options: Options to use for the LLM client. + n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. + + Returns: + Generated IQL operations. + """ + filters, aggregation = await asyncio.gather( + self._filters_generation( + question=question, + methods=filters, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ), + self._aggregation_generation( + question=question, + methods=aggregations, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ), + return_exceptions=True, + ) + return IQLGeneratorState( + filters=filters, + aggregation=aggregation, + ) + + +class IQLOperationGenerator(Generic[IQLQueryT]): + """ + Program that generates IQL queries for the given question. + """ + + def __init__( + self, + assessor_prompt: PromptTemplate[DecisionPromptFormat], + generator_prompt: PromptTemplate[IQLGenerationPromptFormat], + ) -> None: + """ + Constructs a new IQLGenerator instance. + + Args: + assessor_prompt: Prompt template for filtering decision making. + generator_prompt: Prompt template for IQL generation. + """ + self.assessor = IQLQuestionAssessor(assessor_prompt) + self.generator = IQLQueryGenerator[IQLQueryT](generator_prompt) + + async def __call__( + self, + *, + question: str, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, + llm_options: Optional[LLMOptions] = None, + n_retries: int = 3, + ) -> Optional[IQLQueryT]: + """ + Generates IQL query for the given question. + + Args: + llm: LLM used to generate IQL. + question: User question. + methods: List of methods exposed by the view. examples: List of examples to be injected into the conversation. + event_tracker: Event store used to audit the generation process. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. @@ -77,38 +176,52 @@ async def generate( IQLError: If IQL parsing fails after all retries. UnsupportedQueryError: If the question is not supported by the view. """ - decision = await self._decide_on_generation( + decision = await self.assessor( question=question, - event_tracker=event_tracker, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) if not decision: return None - return await self._generate_iql( + return await self.generator( question=question, - filters=filters, - event_tracker=event_tracker, + methods=methods, examples=examples, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) - async def _decide_on_generation( + +class IQLQuestionAssessor: + """ + Program that assesses whether a question requires applying IQL operation or not. + """ + + def __init__(self, prompt: PromptTemplate[DecisionPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - event_tracker: EventTracker, + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, ) -> bool: """ - Decides whether the question requires filtering or not. + Decides whether the question requires generating IQL or not. Args: question: User question. - event_tracker: Event store used to audit the generation process. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to LLM API in case of errors. Returns: @@ -117,12 +230,14 @@ async def _decide_on_generation( Raises: LLMError: If LLM text generation fails after all retries. """ - prompt_format = FilteringDecisionPromptFormat(question=question) - formatted_prompt = self._decision_prompt.format_prompt(prompt_format) + prompt_format = DecisionPromptFormat( + question=question, + ) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, @@ -133,24 +248,39 @@ async def _decide_on_generation( if retry == n_retries: raise exc - async def _generate_iql( + +class IQLQueryGenerator(Generic[IQLQueryT]): + """ + Program that generates IQL queries for the given question. + """ + + ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ + generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" + + def __init__(self, prompt: PromptTemplate[IQLGenerationPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - filters: List[ExposedFunction], - event_tracker: Optional[EventTracker] = None, - examples: Optional[List[FewShotExample]] = None, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, - ) -> IQLQuery: + ) -> IQLQueryT: """ - Generates IQL in text form using LLM. + Generates IQL query for the given question. Args: question: User question. filters: List of filters exposed by the view. - event_tracker: Event store used to audit the generation process. examples: List of examples to be injected into the conversation. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. Returns: @@ -163,24 +293,22 @@ async def _generate_iql( """ prompt_format = IQLGenerationPromptFormat( question=question, - filters=filters, + methods=methods, examples=examples, ) - formatted_prompt = self._generation_prompt.format_prompt(prompt_format) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) # TODO: Move response parsing to llm generate_text method - iql = formatted_prompt.response_parser(response) - # TODO: Move IQL query parsing to prompt response parser - return await IQLQuery.parse( - source=iql, - allowed_functions=filters, + return await formatted_prompt.response_parser( + response=response, + allowed_functions=methods, event_tracker=event_tracker, ) except LLMError as exc: @@ -190,4 +318,4 @@ async def _generate_iql( if retry == n_retries: raise exc formatted_prompt = formatted_prompt.add_assistant_message(response) - formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) + formatted_prompt = formatted_prompt.add_user_message(self.ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 4e5a45ec..bf33fbe0 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -1,8 +1,10 @@ # pylint: disable=C0301 -from typing import List +from typing import List, Optional +from dbally.audit.event_tracker import EventTracker from dbally.exceptions import DbAllyError +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.prompt.elements import FewShotExample from dbally.prompt.template import PromptFormat, PromptTemplate from dbally.views.exposed_functions import ExposedFunction @@ -15,26 +17,65 @@ class UnsupportedQueryError(DbAllyError): """ -def _validate_iql_response(llm_response: str) -> str: +async def _iql_filters_parser( + response: str, + allowed_functions: List[ExposedFunction], + event_tracker: Optional[EventTracker] = None, +) -> IQLFiltersQuery: """ - Validates LLM response to IQL + Parses the response from the LLM to IQL. Args: - llm_response: LLM response + response: LLM response. + allowed_functions: List of functions that can be used in the IQL. + event_tracker: Event tracker to be used for auditing. Returns: - A string containing IQL for filters. + IQL query for filters. Raises: - UnsuppotedQueryError: When IQL generator is unable to construct a query - with given filters. + UnsuppotedQueryError: When IQL generator is unable to construct a query with given filters. """ - if "unsupported query" in llm_response.lower(): + if "unsupported query" in response.lower(): raise UnsupportedQueryError - return llm_response + return await IQLFiltersQuery.parse( + source=response, + allowed_functions=allowed_functions, + event_tracker=event_tracker, + ) -def _decision_iql_response_parser(response: str) -> bool: + +async def _iql_aggregation_parser( + response: str, + allowed_functions: List[ExposedFunction], + event_tracker: Optional[EventTracker] = None, +) -> IQLAggregationQuery: + """ + Parses the response from the LLM to IQL. + + Args: + response: LLM response. + allowed_functions: List of functions that can be used in the IQL. + event_tracker: Event tracker to be used for auditing. + + Returns: + IQL query for aggregations. + + Raises: + UnsuppotedQueryError: When IQL generator is unable to construct a query with given aggregations. + """ + if "unsupported query" in response.lower(): + raise UnsupportedQueryError + + return await IQLAggregationQuery.parse( + source=response, + allowed_functions=allowed_functions, + event_tracker=event_tracker, + ) + + +def _decision_parser(response: str) -> bool: """ Parses the response from the decision prompt. @@ -52,7 +93,7 @@ def _decision_iql_response_parser(response: str) -> bool: return "true" in decision -class FilteringDecisionPromptFormat(PromptFormat): +class DecisionPromptFormat(PromptFormat): """ IQL prompt format, providing a question and filters to be used in the conversation. """ @@ -71,44 +112,96 @@ def __init__(self, *, question: str, examples: List[FewShotExample] = None) -> N class IQLGenerationPromptFormat(PromptFormat): """ - IQL prompt format, providing a question and filters to be used in the conversation. + IQL prompt format, providing a question and methods to be used in the conversation. """ def __init__( self, *, question: str, - filters: List[ExposedFunction], - examples: List[FewShotExample] = None, + methods: List[ExposedFunction], + examples: Optional[List[FewShotExample]] = None, ) -> None: """ Constructs a new IQLGenerationPromptFormat instance. Args: question: Question to be asked. - filters: List of filters exposed by the view. + methods: List of filters exposed by the view. examples: List of examples to be injected into the conversation. aggregations: List of aggregations exposed by the view. """ super().__init__(examples) self.question = question - self.filters = "\n".join([str(condition) for condition in filters]) if filters else [] + self.methods = "\n".join([str(condition) for condition in methods]) if methods else [] + + +FILTERING_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( + [ + { + "role": "system", + "content": ( + "Given a question, determine whether the answer requires initial data filtering in order to compute it.\n" + "Initial data filtering is a process in which the result set is reduced to only include the rows " + "that meet certain criteria specified in the question.\n\n" + "---\n\n" + "Follow the following format.\n\n" + "Question: ${{question}}\n" + "Hint: ${{hint}}" + "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" + "Decision: indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)\n\n" + ), + }, + { + "role": "user", + "content": ( + "Question: {question}\n" + "Hint: Look for words indicating data specific features.\n" + "Reasoning: Let's think step by step in order to " + ), + }, + ], + response_parser=_decision_parser, +) +AGGREGATION_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( + [ + { + "role": "system", + "content": ( + "Given a question, determine whether the answer requires computing the aggregation in order to compute it.\n" + "Aggregation is a process in which the result set is reduced to a single value.\n\n" + "---\n\n" + "Follow the following format.\n\n" + "Question: ${{question}}\n" + "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" + "Decision: indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)\n\n" + ), + }, + { + "role": "user", + "content": ("Question: {question}\n" "Reasoning: Let's think step by step in order to "), + }, + ], + response_parser=_decision_parser, +) -IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( +FILTERS_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( [ { "role": "system", "content": ( "You have access to an API that lets you query a database:\n" - "\n{filters}\n" + "\n{methods}\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n{filters}\n" + "\n{methods}\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. " @@ -119,35 +212,31 @@ def __init__( "content": "{question}", }, ], - response_parser=_validate_iql_response, + response_parser=_iql_filters_parser, ) - -FILTERING_DECISION_TEMPLATE = PromptTemplate[FilteringDecisionPromptFormat]( +AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( [ { "role": "system", "content": ( - "Given a question, determine whether the answer requires initial data filtering in order to compute it.\n" - "Initial data filtering is a process in which the result set is reduced to only include the rows " - "that meet certain criteria specified in the question.\n\n" - "---\n\n" - "Follow the following format.\n\n" - "Question: ${{question}}\n" - "Hint: ${{hint}}" - "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" - "Decision: indicates whether the answer to the question requires initial data filtering. " - "(Respond with True or False)\n\n" + "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" + "When prompted for an aggregation, use the following methods: \n" + "{methods}" + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{methods}\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" + "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " + "Structure output to resemble the following pattern:\n" + 'aggregation1("arg1", arg2)\n' ), }, { "role": "user", - "content": ( - "Question: {question}\n" - "Hint: Look for words indicating data specific features.\n" - "Reasoning: Let's think step by step in order to " - ), + "content": "{question}", }, ], - response_parser=_decision_iql_response_parser, + response_parser=_iql_aggregation_parser, ) diff --git a/src/dbally/prompt/aggregation.py b/src/dbally/prompt/aggregation.py deleted file mode 100644 index 8dedd95c..00000000 --- a/src/dbally/prompt/aggregation.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import List, Optional - -from dbally.audit import EventTracker -from dbally.exceptions import UnsupportedAggregationError -from dbally.iql import IQLQuery -from dbally.llms.base import LLM -from dbally.llms.clients import LLMOptions -from dbally.prompt.template import PromptFormat, PromptTemplate -from dbally.views.exposed_functions import ExposedFunction - - -def _validate_agg_response(llm_response: str) -> str: - """ - Validates LLM response to IQL - - Args: - llm_response: LLM response - - Returns: - A string containing aggregations. - - Raises: - UnsupportedAggregationError: When IQL generator is unable to construct a query - with given aggregation. - """ - if "unsupported query" in llm_response.lower(): - raise UnsupportedAggregationError - return llm_response - - -class AggregationPromptFormat(PromptFormat): - """ - Aggregation prompt format, providing a question and aggregation to be used in the conversation. - """ - - def __init__( - self, - question: str, - aggregations: List[ExposedFunction] = None, - ) -> None: - super().__init__() - self.question = question - self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else [] - - -class AggregationFormatter: - """ - Class used to manage choice and formatting of aggregation based on natural language question. - """ - - def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[AggregationPromptFormat]] = None) -> None: - """ - Constructs a new AggregationFormatter instance. - - Args: - llm: LLM used to generate IQL - prompt_template: If not provided by the users is set to `AGGREGATION_GENERATION_TEMPLATE` - """ - self._llm = llm - self._prompt_template = prompt_template or AGGREGATION_GENERATION_TEMPLATE - - async def format_to_query_object( - self, - question: str, - event_tracker: EventTracker, - aggregations: List[ExposedFunction] = None, - llm_options: Optional[LLMOptions] = None, - ) -> IQLQuery: - """ - Generates IQL in text form using LLM. - - Args: - question: User question. - event_tracker: Event store used to audit the generation process. - aggregations: List of aggregations exposed by the view. - llm_options: Options to use for the LLM client. - - Returns: - Generated aggregation query. - """ - prompt_format = AggregationPromptFormat( - question=question, - aggregations=aggregations, - ) - - formatted_prompt = self._prompt_template.format_prompt(prompt_format) - - response = await self._llm.generate_text( - prompt=formatted_prompt, - event_tracker=event_tracker, - options=llm_options, - ) - # TODO: Move response parsing to llm generate_text method - agg = formatted_prompt.response_parser(response) - # TODO: Move IQL query parsing to prompt response parser - return await IQLQuery.parse( - source=agg, - allowed_functions=aggregations or [], - event_tracker=event_tracker, - ) - - -AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[AggregationPromptFormat]( - [ - { - "role": "system", - "content": "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" - "When prompted for an aggregation, use the following methods: \n" - "{aggregations}" - "DO NOT INCLUDE arguments names in your response. Only the values.\n" - "You MUST use only these methods:\n" - "\n{aggregations}\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" - "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " - "Structure output to resemble the following pattern:\n" - 'aggregation1("arg1", arg2)\n', - }, - {"role": "user", "content": "{question}"}, - ], - response_parser=_validate_agg_response, -) diff --git a/src/dbally/prompt/template.py b/src/dbally/prompt/template.py index 124a3e1c..b4ef650d 100644 --- a/src/dbally/prompt/template.py +++ b/src/dbally/prompt/template.py @@ -1,6 +1,6 @@ import copy import re -from typing import Callable, Dict, Generic, List, TypeVar +from typing import Callable, Dict, Generic, List, Optional, TypeVar from typing_extensions import Self @@ -55,7 +55,7 @@ class PromptFormat: Generic format for prompts allowing to inject few shot examples into the conversation. """ - def __init__(self, examples: List[FewShotExample] = None) -> None: + def __init__(self, examples: Optional[List[FewShotExample]] = None) -> None: """ Constructs a new PromptFormat instance. diff --git a/src/dbally/views/exceptions.py b/src/dbally/views/exceptions.py index 277064a4..15770e9a 100644 --- a/src/dbally/views/exceptions.py +++ b/src/dbally/views/exceptions.py @@ -1,26 +1,22 @@ -from typing import Optional - from dbally.exceptions import DbAllyError +from dbally.iql_generator.iql_generator import IQLGeneratorState -class IQLGenerationError(DbAllyError): +class ViewExecutionError(DbAllyError): """ - Exception for when an error occurs while generating IQL for a view. + Exception for when an error occurs while executing a view. """ def __init__( self, view_name: str, - filters: Optional[str] = None, - aggregation: Optional[str] = None, + iql: IQLGeneratorState, ) -> None: """ Args: view_name: Name of the view that caused the error. - filters: Filters generated by the view. - aggregation: Aggregation generated by the view. + iql: View IQL generator state. """ - super().__init__(f"Error while generating IQL for view {view_name}") + super().__init__(f"Error while executing view {view_name}") self.view_name = view_name - self.filters = filters - self.aggregation = aggregation + self.iql = iql diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 977a2fa1..2a2c5d8e 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -1,15 +1,15 @@ import inspect import textwrap from abc import ABC -from typing import Any, Callable, Generic, List, Tuple +from typing import Any, Callable, List, Tuple from dbally.iql import syntax from dbally.views import decorators from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from dbally.views.structured import BaseStructuredView, DataT +from dbally.views.structured import BaseStructuredView -class MethodsBaseView(Generic[DataT], BaseStructuredView, ABC): +class MethodsBaseView(BaseStructuredView, ABC): """ Base class for views that use view methods to expose filters. """ @@ -110,7 +110,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: return await method(*args) return method(*args) - async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataT: + async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: """ Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index 5f7bc8ce..c35fd30f 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -1,15 +1,17 @@ import asyncio from functools import reduce -from typing import Optional +from typing import List, Optional, Union import pandas as pd +from sqlalchemy import Tuple from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery, syntax +from dbally.iql import syntax +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.methods_base import MethodsBaseView -class DataFrameBaseView(MethodsBaseView[pd.DataFrame]): +class DataFrameBaseView(MethodsBaseView): """ Base class for views that use Pandas DataFrames to store and filter data. @@ -24,35 +26,31 @@ def __init__(self, df: pd.DataFrame) -> None: Args: df: Pandas DataFrame with the data to be filtered. """ - super().__init__(df) - - # The mask to be applied to the dataframe to filter the data + super().__init__() + self.df = df self._filter_mask: Optional[pd.Series] = None + self._groupbys: Optional[Union[str, List[str]]] = None + self._aggregations: Optional[List[Tuple[str, str]]] = None - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: filters: IQLQuery object representing the filters to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self._filter_mask = await self.build_filter_node(filters.root) - self.data = self.data.loc[self._filter_mask] + self._filter_mask = await self._build_filter_node(filters.root) - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ Applies the aggregation of choice to the view. Args: aggregation: IQLQuery object representing the aggregation to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = await self.call_aggregation_method(aggregation.root) + self._groupbys, self._aggregations = await self.call_aggregation_method(aggregation.root) - async def build_filter_node(self, node: syntax.Node) -> pd.Series: + async def _build_filter_node(self, node: syntax.Node) -> pd.Series: """ Converts a filter node from the IQLQuery to a Pandas Series representing a boolean mask to be applied to the dataframe. @@ -69,13 +67,13 @@ async def build_filter_node(self, node: syntax.Node) -> pd.Series: if isinstance(node, syntax.FunctionCall): return await self.call_filter_method(node) if isinstance(node, syntax.And): # logical AND - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x & y, children) if isinstance(node, syntax.Or): # logical OR - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x | y, children) if isinstance(node, syntax.Not): - child = await self.build_filter_node(node.child) + child = await self._build_filter_node(node.child) return ~child raise ValueError(f"Unsupported grammar: {node}") @@ -90,11 +88,25 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: Returns: ExecutionResult object with the results and the context information with the binary mask. """ - results = pd.DataFrame.empty if dry_run else self.data + results = pd.DataFrame() + + if not dry_run: + results = self.df + if self._filter_mask is not None: + results = results.loc[self._filter_mask] + + if self._groupbys is not None: + results = results.groupby(self._groupbys) + + if self._aggregations is not None: + results = results.agg(**{"_".join(agg): agg for agg in self._aggregations}) + results = results.reset_index() return ViewExecutionResult( results=results.to_dict(orient="records"), context={ "filter_mask": self._filter_mask, + "groupbys": self._groupbys, + "aggregations": self._aggregations, }, ) diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index 4863aa6f..691797a1 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -4,11 +4,12 @@ import sqlalchemy from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery, syntax +from dbally.iql import syntax +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.methods_base import MethodsBaseView -class SqlAlchemyBaseView(MethodsBaseView[sqlalchemy.Select]): +class SqlAlchemyBaseView(MethodsBaseView): """ Base class for views that use SQLAlchemy to generate SQL queries. """ @@ -20,7 +21,8 @@ def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None: Args: sqlalchemy_engine: SQLAlchemy engine to use for executing the queries. """ - super().__init__(self.get_select()) + super().__init__() + self.select = self.get_select() self._sqlalchemy_engine = sqlalchemy_engine @abc.abstractmethod @@ -32,27 +34,23 @@ def get_select(self) -> sqlalchemy.Select: SQLAlchemy Select object for the view. """ - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: filters: IQLQuery object representing the filters to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = self.data.where(await self._build_filter_node(filters.root)) + self.select = self.select.where(await self._build_filter_node(filters.root)) - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ Applies the chosen aggregation to the view. Args: aggregation: IQLQuery object representing the aggregation to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = await self.call_aggregation_method(aggregation.root) + self.select = await self.call_aggregation_method(aggregation.root) async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement: """ @@ -95,11 +93,11 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored. """ results = [] - sql = str(self.data.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) + sql = str(self.select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) if not dry_run: with self._sqlalchemy_engine.connect() as connection: - rows = connection.execute(self.data).fetchall() + rows = connection.execute(self.select).fetchall() # The underscore is used by sqlalchemy to avoid conflicts with column names # pylint: disable=protected-access results = [dict(row._mapping) for row in rows] diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c3ac91e0..bab0f7b3 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -1,60 +1,34 @@ import abc from collections import defaultdict -from typing import Any, Dict, List, Optional, TypeVar +from typing import Dict, List, Optional from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.exceptions import UnsupportedAggregationError -from dbally.iql import IQLQuery -from dbally.iql._exceptions import IQLError +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from dbally.views.exposed_functions import ExposedFunction -from ..prompt.aggregation import AggregationFormatter from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation -DataT = TypeVar("DataT", bound=Any) - -# TODO(Python 3.9+): Make BaseStructuredView a generic class class BaseStructuredView(BaseView): """ - Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\ + Base class for all structured views. All classes implementing this interface has\ to be able to list all available filters, apply them and execute queries. """ - def __init__(self, data: DataT) -> None: - super().__init__() - self.data = data - - def get_iql_generator(self, llm: LLM) -> IQLGenerator: + def get_iql_generator(self) -> IQLGenerator: """ Returns the IQL generator for the view. - Args: - llm: LLM used to generate the IQL queries. - Returns: IQL generator for the view. """ - return IQLGenerator(llm=llm) - - def get_agg_formatter(self, llm: LLM) -> AggregationFormatter: - """ - Returns the AggregtionFormatter for the view. - - Args: - llm: LLM used to generate the queries. - - Returns: - AggregtionFormatter for the view. - """ - return AggregationFormatter(llm=llm) + return IQLGenerator() async def ask( self, @@ -81,68 +55,41 @@ async def ask( The result of the query. Raises: - LLMError: If LLM text generation API fails. - IQLGenerationError: If the IQL generation fails. + ViewExecutionError: When an error occurs while executing the view. """ - iql_generator = self.get_iql_generator(llm) - agg_formatter = self.get_agg_formatter(llm) filters = self.list_filters() examples = self.list_few_shots() aggregations = self.list_aggregations() - try: - iql = await iql_generator.generate( - question=query, - filters=filters, - examples=examples, - event_tracker=event_tracker, - llm_options=llm_options, - n_retries=n_retries, - ) - except UnsupportedQueryError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=None, - aggregation=None, - ) from exc - except IQLError as exc: - raise IQLGenerationError( + iql_generator = self.get_iql_generator() + iql = await iql_generator( + question=query, + filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, + event_tracker=event_tracker, + llm_options=llm_options, + n_retries=n_retries, + ) + + if iql.failed: + raise ViewExecutionError( view_name=self.__class__.__name__, - filters=exc.source, - aggregation=None, - ) from exc - - if iql: - await self.apply_filters(iql) - - try: - agg_node = await agg_formatter.format_to_query_object( - question=query, - aggregations=aggregations, - event_tracker=event_tracker, - llm_options=llm_options, + iql=iql, ) - except UnsupportedAggregationError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=str(iql) if iql else None, - aggregation=None, - ) from exc - except IQLError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=str(iql) if iql else None, - aggregation=exc.source, - ) from exc - await self.apply_aggregation(agg_node) + if iql.filters: + await self.apply_filters(iql.filters) + + if iql.aggregation: + await self.apply_aggregation(iql.aggregation) result = self.execute(dry_run=dry_run) result.context["iql"] = { - "filters": str(iql) if iql else None, - "aggregation": str(agg_node), + "filters": str(iql.filters) if iql.filters else None, + "aggregation": str(iql.aggregation) if iql.aggregation else None, } - return result @abc.abstractmethod @@ -164,21 +111,21 @@ def list_aggregations(self) -> List[ExposedFunction]: """ @abc.abstractmethod - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: - filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply. + filters: IQLQuery object representing the filters to apply. """ @abc.abstractmethod - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ Applies the chosen aggregation to the view. Args: - aggregation: [IQLQuery](../../concepts/iql.md) object representing the filters to apply. + aggregation: IQLQuery object representing the aggregation to apply. """ @abc.abstractmethod diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index ae5d2269..bed83d0a 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -3,7 +3,7 @@ import pytest -from dbally.iql import IQLArgumentParsingError, IQLQuery, IQLUnsupportedSyntaxError, syntax +from dbally.iql import IQLArgumentParsingError, IQLUnsupportedSyntaxError, syntax from dbally.iql._exceptions import ( IQLArgumentValidationError, IQLFunctionNotExists, @@ -14,11 +14,12 @@ IQLSyntaxError, ) from dbally.iql._processor import IQLProcessor +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -async def test_iql_parser(): - parsed = await IQLQuery.parse( +async def test_iql_filter_parser(): + parsed = await IQLFiltersQuery.parse( "not (filter_by_name(['John', 'Anne']) and filter_by_city('cracow') and filter_by_company('deepsense.ai'))", allowed_functions=[ ExposedFunction( @@ -51,9 +52,9 @@ async def test_iql_parser(): assert company_filter.arguments[0] == "deepsense.ai" -async def test_iql_parser_arg_error(): +async def test_iql_filter_parser_arg_error(): with pytest.raises(IQLArgumentParsingError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_city('Cracow') and filter_by_name(lambda x: x + 1)", allowed_functions=[ ExposedFunction( @@ -76,9 +77,9 @@ async def test_iql_parser_arg_error(): assert exc_info.match(re.escape("Not a valid IQL argument: lambda x: x + 1")) -async def test_iql_parser_syntax_error(): +async def test_iql_filter_parser_syntax_error(): with pytest.raises(IQLSyntaxError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age(", allowed_functions=[ ExposedFunction( @@ -94,9 +95,9 @@ async def test_iql_parser_syntax_error(): assert exc_info.match(re.escape("Syntax error in: filter_by_age(")) -async def test_iql_parser_multiple_expression_error(): +async def test_iql_filter_parser_multiple_expression_error(): with pytest.raises(IQLMultipleStatementsError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age\nfilter_by_age", allowed_functions=[ ExposedFunction( @@ -112,9 +113,9 @@ async def test_iql_parser_multiple_expression_error(): assert exc_info.match(re.escape("Multiple statements in IQL are not supported")) -async def test_iql_parser_empty_expression_error(): +async def test_iql_filter_parser_empty_expression_error(): with pytest.raises(IQLNoStatementError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "", allowed_functions=[ ExposedFunction( @@ -130,9 +131,9 @@ async def test_iql_parser_empty_expression_error(): assert exc_info.match(re.escape("Empty IQL")) -async def test_iql_parser_no_expression_error(): +async def test_iql_filter_parser_no_expression_error(): with pytest.raises(IQLNoExpressionError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "import filter_by_age", allowed_functions=[ ExposedFunction( @@ -148,9 +149,9 @@ async def test_iql_parser_no_expression_error(): assert exc_info.match(re.escape("No expression found in IQL: import filter_by_age")) -async def test_iql_parser_unsupported_syntax_error(): +async def test_iql_filter_parser_unsupported_syntax_error(): with pytest.raises(IQLUnsupportedSyntaxError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age() >= 30", allowed_functions=[ ExposedFunction( @@ -166,9 +167,9 @@ async def test_iql_parser_unsupported_syntax_error(): assert exc_info.match(re.escape("Compare syntax is not supported in IQL: filter_by_age() >= 30")) -async def test_iql_parser_method_not_exists(): +async def test_iql_filter_parser_method_not_exists(): with pytest.raises(IQLFunctionNotExists) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_how_old_somebody_is(40)", allowed_functions=[ ExposedFunction( @@ -184,9 +185,9 @@ async def test_iql_parser_method_not_exists(): assert exc_info.match(re.escape("Function filter_by_how_old_somebody_is not exists: filter_by_how_old_somebody_is")) -async def test_iql_parser_incorrect_number_of_arguments_fail(): +async def test_iql_filter_parser_incorrect_number_of_arguments_fail(): with pytest.raises(IQLIncorrectNumberArgumentsError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age('too old', 40)", allowed_functions=[ ExposedFunction( @@ -204,9 +205,9 @@ async def test_iql_parser_incorrect_number_of_arguments_fail(): ) -async def test_iql_parser_argument_validation_fail(): +async def test_iql_filter_parser_argument_validation_fail(): with pytest.raises(IQLArgumentValidationError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age('too old')", allowed_functions=[ ExposedFunction( @@ -222,6 +223,189 @@ async def test_iql_parser_argument_validation_fail(): assert exc_info.match(re.escape("'too old' is not of type int: 'too old'")) +async def test_iql_aggregation_parser(): + parsed = await IQLAggregationQuery.parse( + "mean_age_by_city('Paris')", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert isinstance(parsed.root, syntax.FunctionCall) + assert parsed.root.name == "mean_age_by_city" + assert parsed.root.arguments == ["Paris"] + + +async def test_iql_aggregation_parser_arg_error(): + with pytest.raises(IQLArgumentParsingError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(lambda x: x + 1)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Not a valid IQL argument: lambda x: x + 1")) + + +async def test_iql_aggregation_parser_syntax_error(): + with pytest.raises(IQLSyntaxError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Syntax error in: mean_age_by_city(")) + + +async def test_iql_aggregation_parser_multiple_expression_error(): + with pytest.raises(IQLMultipleStatementsError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city\nmean_age_by_city", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Multiple statements in IQL are not supported")) + + +async def test_iql_aggregation_parser_empty_expression_error(): + with pytest.raises(IQLNoStatementError) as exc_info: + await IQLAggregationQuery.parse( + "", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Empty IQL")) + + +async def test_iql_aggregation_parser_no_expression_error(): + with pytest.raises(IQLNoExpressionError) as exc_info: + await IQLAggregationQuery.parse( + "import mean_age_by_city", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("No expression found in IQL: import mean_age_by_city")) + + +@pytest.mark.parametrize( + "iql, info", + [ + ("mean_age_by_city() >= 30", "Compare syntax is not supported in IQL: mean_age_by_city() >= 30"), + ( + "mean_age_by_city('Paris') and mean_age_by_city('London')", + "BoolOp syntax is not supported in IQL: mean_age_by_city('Paris') and mean_age_by_city('London')", + ), + ( + "mean_age_by_city('Paris') or mean_age_by_city('London')", + "BoolOp syntax is not supported in IQL: mean_age_by_city('Paris') or mean_age_by_city('London')", + ), + ("not mean_age_by_city('Paris')", "UnaryOp syntax is not supported in IQL: not mean_age_by_city('Paris')"), + ], +) +async def test_iql_aggregation_parser_unsupported_syntax_error(iql, info): + with pytest.raises(IQLUnsupportedSyntaxError) as exc_info: + await IQLAggregationQuery.parse( + iql, + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + assert exc_info.match(re.escape(info)) + + +async def test_iql_aggregation_parser_method_not_exists(): + with pytest.raises(IQLFunctionNotExists) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_town()", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Function mean_age_by_town not exists: mean_age_by_town")) + + +async def test_iql_aggregation_parser_incorrect_number_of_arguments_fail(): + with pytest.raises(IQLIncorrectNumberArgumentsError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city('too old')", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match( + re.escape("The method mean_age_by_city has incorrect number of arguments: mean_age_by_city('too old')") + ) + + +async def test_iql_aggregation_parser_argument_validation_fail(): + with pytest.raises(IQLArgumentValidationError): + await IQLAggregationQuery.parse( + "mean_age_by_city(12)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + def test_keywords_lowercase(): rv = IQLProcessor._to_lower_except_in_quotes( """NOT filter1(230) AND (NOT filter_2("NOT ADMIN") AND filter_('IS NOT ADMIN')) OR NOT filter_4()""", diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 992fd03d..69174389 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -9,11 +9,10 @@ from typing import List, Optional, Union from dbally import NOT_GIVEN, NotGiven -from dbally.iql import IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.llms.base import LLM from dbally.llms.clients.base import LLMClient, LLMOptions -from dbally.prompt.aggregation import AggregationFormatter from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.structured import BaseStructuredView, ExposedFunction, ViewExecutionResult @@ -24,19 +23,16 @@ class MockViewBase(BaseStructuredView): Mock view base class """ - def __init__(self) -> None: - super().__init__([]) - def list_filters(self) -> List[ExposedFunction]: return [] - async def apply_filters(self, filters: IQLQuery) -> None: - ... - def list_aggregations(self) -> List[ExposedFunction]: return [] - async def apply_aggregation(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: @@ -44,21 +40,12 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: class MockIQLGenerator(IQLGenerator): - def __init__(self, iql: IQLQuery) -> None: - self.iql = iql - super().__init__(llm=MockLLM()) - - async def generate(self, *_, **__) -> IQLQuery: - return self.iql - - -class MockAggregationFormatter(AggregationFormatter): - def __init__(self, iql_query: IQLQuery) -> None: - self.iql_query = iql_query - super().__init__(llm=MockLLM()) + def __init__(self, state: IQLGeneratorState) -> None: + self.state = state + super().__init__() - async def format_to_query_object(self, *_, **__) -> IQLQuery: - return self.iql_query + async def __call__(self, *_, **__) -> IQLGeneratorState: + return self.state class MockViewSelector(ViewSelector): diff --git a/tests/unit/similarity/sample_module/submodule.py b/tests/unit/similarity/sample_module/submodule.py index 42e05c0a..ab4b6c7e 100644 --- a/tests/unit/similarity/sample_module/submodule.py +++ b/tests/unit/similarity/sample_module/submodule.py @@ -3,7 +3,7 @@ from typing_extensions import Annotated from dbally import MethodsBaseView, decorators -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.structured import ViewExecutionResult from tests.unit.mocks import MockSimilarityIndex @@ -20,7 +20,10 @@ def method_foo(self, idx: Annotated[str, index_foo]) -> str: def method_bar(self, city: Annotated[str, index_foo], year: Annotated[int, index_bar]) -> str: return f"hello {city} in {year}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: @@ -39,7 +42,10 @@ def method_qux(self, city: str, year: int) -> str: """ return f"hello {city} in {year}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index a077286d..1d675d84 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -10,17 +10,11 @@ from dbally.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql.syntax import FunctionCall +from dbally.iql_generator.iql_generator import IQLGeneratorState from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from tests.unit.mocks import ( - MockAggregationFormatter, - MockIQLGenerator, - MockLLM, - MockSimilarityIndex, - MockViewBase, - MockViewSelector, -) +from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector class MockView1(MockViewBase): @@ -66,15 +60,17 @@ def execute(self, dry_run=False) -> ViewExecutionResult: def list_filters(self) -> List[ExposedFunction]: return [ExposedFunction("test_filter", "", [])] - def get_iql_generator(self, *_, **__) -> MockIQLGenerator: - return MockIQLGenerator(IQLQuery(FunctionCall("test_filter", []), "test_filter()")) + def get_iql_generator(self) -> MockIQLGenerator: + return MockIQLGenerator( + IQLGeneratorState( + filters=IQLFiltersQuery(FunctionCall("test_filter", []), "test_filter()"), + aggregation=IQLAggregationQuery(FunctionCall("test_aggregation", []), "test_aggregation()"), + ), + ) def list_aggregations(self) -> List[ExposedFunction]: return [ExposedFunction("test_aggregation", "", [])] - def get_agg_formatter(self, *_, **__) -> MockAggregationFormatter: - return MockAggregationFormatter(IQLQuery(FunctionCall("test_aggregation", []), "test_aggregation()")) - @pytest.fixture(name="similarity_classes") def mock_similarity_classes() -> ( diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index b798e533..a2bf23c4 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -1,14 +1,14 @@ -from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat +from dbally.iql_generator.prompt import FILTERS_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.prompt.elements import FewShotExample async def test_iql_prompt_format_default() -> None: prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=[], ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { @@ -35,10 +35,10 @@ async def test_iql_prompt_format_few_shots_injected() -> None: examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { @@ -67,12 +67,12 @@ async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) - assert len(formatted_prompt.chat) == len(IQL_GENERATION_TEMPLATE.chat) + (len(examples) * 2) + assert len(formatted_prompt.chat) == len(FILTERS_GENERATION_TEMPLATE.chat) + (len(examples) * 2) assert formatted_prompt.chat[1]["role"] == "user" assert formatted_prompt.chat[1]["content"] == examples[0].question assert formatted_prompt.chat[2]["role"] == "assistant" diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index b95fe585..0defc8e1 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -1,35 +1,23 @@ # mypy: disable-error-code="empty-body" -from unittest.mock import AsyncMock, call, patch +from unittest.mock import AsyncMock, patch import pytest import sqlalchemy from dbally import decorators from dbally.audit.event_tracker import EventTracker -from dbally.iql import IQLError, IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import ( - FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, - IQLGenerationPromptFormat, -) +from dbally.iql import IQLAggregationQuery, IQLError, IQLFiltersQuery +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.views.methods_base import MethodsBaseView from tests.unit.mocks import MockLLM class MockView(MethodsBaseView): - def __init__(self) -> None: - super().__init__(None) - - def get_select(self) -> sqlalchemy.Select: - ... - - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: ... - async def apply_aggregation(self, filters: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False): @@ -62,125 +50,177 @@ def event_tracker() -> EventTracker: @pytest.fixture -def iql_generator(llm: MockLLM) -> IQLGenerator: - return IQLGenerator(llm) +def iql_generator() -> IQLGenerator: + return IQLGenerator() @pytest.mark.asyncio -async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventTracker, view: MockView) -> None: +async def test_iql_generation( + iql_generator: IQLGenerator, + llm: MockLLM, + event_tracker: EventTracker, + view: MockView, +) -> None: filters = view.list_filters() - - decision_format = FilteringDecisionPromptFormat( - question="Mock_question", - ) - generation_format = IQLGenerationPromptFormat( - question="Mock_question", - filters=filters, - ) - - decision_prompt = FILTERING_DECISION_TEMPLATE.format_prompt(decision_format) - generation_prompt = IQL_GENERATION_TEMPLATE.format_prompt(generation_format) + aggregations = view.list_aggregations() + examples = view.list_few_shots() llm_responses = [ "decision: true", "filter_by_id(1)", + "decision: true", + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: - iql = await iql_generator.generate( + iql_filter_parser_response = "filter_by_id(1)" + iql_aggregation_parser_response = "aggregate_by_id()" + + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch( + "dbally.iql.IQLFiltersQuery.parse", AsyncMock(return_value=iql_filter_parser_response) + ) as mock_filters_parse, patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(return_value=iql_aggregation_parser_response) + ) as mock_aggregation_parse: + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, ) - assert iql == "filter_by_id(1)" - iql_generator._llm.generate_text.assert_has_calls( - [ - call( - prompt=decision_prompt, - event_tracker=event_tracker, - options=None, - ), - call( - prompt=generation_prompt, - event_tracker=event_tracker, - options=None, - ), - ] + assert iql == IQLGeneratorState( + filters=iql_filter_parser_response, + aggregation=iql_aggregation_parser_response, ) - mock_parse.assert_called_once_with( - source="filter_by_id(1)", + assert llm.generate_text.call_count == 4 + mock_filters_parse.assert_called_once_with( + source=llm_responses[1], allowed_functions=filters, event_tracker=event_tracker, ) + mock_aggregation_parse.assert_called_once_with( + source=llm_responses[3], + allowed_functions=aggregations, + event_tracker=event_tracker, + ) @pytest.mark.asyncio async def test_iql_generation_error_escalation_after_max_retires( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ + aggregations = view.list_aggregations() + examples = view.list_few_shots() + + llm_responses = [ + "decision: true", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "decision: true", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + ] + iql_filter_parser_responses = [ IQLError("err1", "src1"), IQLError("err2", "src2"), IQLError("err3", "src3"), IQLError("err4", "src4"), ] - llm_responses = [ - "decision: true", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", + iql_aggregation_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + IQLError("err4", "src4"), ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)), pytest.raises(IQLError): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLFiltersQuery.parse", AsyncMock(side_effect=iql_filter_parser_responses)), patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(side_effect=iql_aggregation_parser_responses) + ): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) - assert iql is None - assert iql_generator._llm.generate_text.call_count == 4 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1): + assert iql == IQLGeneratorState( + filters=iql_filter_parser_responses[-1], + aggregation=iql_aggregation_parser_responses[-1], + ) + assert llm.generate_text.call_count == 10 + for i, arg in enumerate(llm.generate_text.call_args_list[2:5], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] + for i, arg in enumerate(llm.generate_text.call_args_list[7:10], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] @pytest.mark.asyncio async def test_iql_generation_response_after_max_retries( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ + aggregations = view.list_aggregations() + examples = view.list_few_shots() + + llm_responses = [ + "decision: true", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "filter_by_id(1)", + "decision: true", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + "aggregate_by_id()", + ] + iql_filter_parser_responses = [ IQLError("err1", "src1"), IQLError("err2", "src2"), IQLError("err3", "src3"), "filter_by_id(1)", ] - llm_responses = [ - "decision: true", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", + iql_aggregation_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLFiltersQuery.parse", AsyncMock(side_effect=iql_filter_parser_responses)), patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(side_effect=iql_aggregation_parser_responses) + ): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) - - assert iql == "filter_by_id(1)" - assert iql_generator._llm.generate_text.call_count == 5 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[2:], start=1): + assert iql == IQLGeneratorState( + filters=iql_filter_parser_responses[-1], + aggregation=iql_aggregation_parser_responses[-1], + ) + assert llm.generate_text.call_count == len(llm_responses) + for i, arg in enumerate(llm.generate_text.call_args_list[2:5], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] + for i, arg in enumerate(llm.generate_text.call_args_list[7:10], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 8d90ffc3..57c0b68a 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -4,7 +4,7 @@ from typing import List, Literal, Tuple from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.exposed_functions import MethodParamWithTyping from dbally.views.methods_base import MethodsBaseView @@ -15,9 +15,6 @@ class MockMethodsBase(MethodsBaseView): Mock class for testing the MethodsBaseView """ - def __init__(self) -> None: - super().__init__(None) - @view_filter() def method_foo(self, idx: int) -> None: """ @@ -35,13 +32,13 @@ def method_baz(self) -> None: """ @view_aggregation() - def method_qux(self, ages: List[int], names: List[str]) -> None: + def method_qux(self, ages: List[int], names: List[str]) -> str: return f"hello {ages} and {names}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: ... - async def apply_aggregation(self, filters: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 52a8f405..b24a0398 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -1,8 +1,11 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name +from typing import List, Tuple + import pandas as pd -from dbally.iql import IQLQuery +from dbally.iql import IQLFiltersQuery +from dbally.iql._query import IQLAggregationQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.pandas_base import DataFrameBaseView @@ -39,23 +42,27 @@ class MockDataFrameView(DataFrameBaseView): @view_filter() def filter_city(self, city: str) -> pd.Series: - return self.data["city"] == city + return self.df["city"] == city @view_filter() def filter_year(self, year: int) -> pd.Series: - return self.data["year"] == year + return self.df["year"] == year @view_filter() def filter_age(self, age: int) -> pd.Series: - return self.data["age"] == age + return self.df["age"] == age @view_filter() def filter_name(self, name: str) -> pd.Series: - return self.data["name"] == name + return self.df["name"] == name + + @view_aggregation() + def mean_age_by_city(self) -> Tuple[str, List[Tuple[str, str]]]: + return "city", [("age", "mean")] @view_aggregation() - def mean_age_by_city(self) -> pd.DataFrame: - return self.data.groupby(["city"]).agg({"age": "mean"}).reset_index() + def count_records(self) -> Tuple[str, List[Tuple[str, str]]]: + return None, [("name", "count")] async def test_filter_or() -> None: @@ -63,7 +70,7 @@ async def test_filter_or() -> None: Test that the filtering the DataFrame with logical OR works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'filter_city("Berlin") or filter_city("London")', allowed_functions=mock_view.list_filters(), ) @@ -71,6 +78,8 @@ async def test_filter_or() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_BERLIN_OR_LONDON assert result.context["filter_mask"].tolist() == [True, False, True, False, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_and() -> None: @@ -78,7 +87,7 @@ async def test_filter_and() -> None: Test that the filtering the DataFrame with logical AND works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'filter_city("Paris") and filter_year(2020)', allowed_functions=mock_view.list_filters(), ) @@ -86,6 +95,8 @@ async def test_filter_and() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_PARIS_2020 assert result.context["filter_mask"].tolist() == [False, True, False, False, False] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_not() -> None: @@ -93,7 +104,7 @@ async def test_filter_not() -> None: Test that the filtering the DataFrame with logical NOT works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'not (filter_city("Paris") and filter_year(2020))', allowed_functions=mock_view.list_filters(), ) @@ -101,25 +112,48 @@ async def test_filter_not() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_NOT_PARIS_2020 assert result.context["filter_mask"].tolist() == [True, False, True, True, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None -async def test_aggregtion() -> None: +async def test_aggregation() -> None: """ Test that DataFrame aggregation works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( + "count_records()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [ + {"index": "name_count", "name": 5}, + ] + assert result.context["filter_mask"] is None + assert result.context["groupbys"] is None + assert result.context["aggregations"] == [("name", "count")] + + +async def test_aggregtion_with_groupby() -> None: + """ + Test that DataFrame aggregation with groupby works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLAggregationQuery.parse( "mean_age_by_city()", allowed_functions=mock_view.list_aggregations(), ) await mock_view.apply_aggregation(query) result = mock_view.execute() assert result.results == [ - {"city": "Berlin", "age": 45.0}, - {"city": "London", "age": 32.5}, - {"city": "Paris", "age": 32.5}, + {"city": "Berlin", "age_mean": 45.0}, + {"city": "London", "age_mean": 32.5}, + {"city": "Paris", "age_mean": 32.5}, ] assert result.context["filter_mask"] is None + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [("age", "mean")] async def test_filters_and_aggregtion() -> None: @@ -127,16 +161,18 @@ async def test_filters_and_aggregtion() -> None: Test that DataFrame filtering and aggregation works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( "filter_city('Paris')", allowed_functions=mock_view.list_filters(), ) await mock_view.apply_filters(query) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( "mean_age_by_city()", allowed_functions=mock_view.list_aggregations(), ) await mock_view.apply_aggregation(query) result = mock_view.execute() - assert result.results == [{"city": "Paris", "age": 32.5}] + assert result.results == [{"city": "Paris", "age_mean": 32.5}] assert result.context["filter_mask"].tolist() == [False, True, False, True, False] + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [("age", "mean")] diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 435c8f8e..571e6a70 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -4,7 +4,8 @@ import sqlalchemy -from dbally.iql import IQLQuery +from dbally.iql import IQLFiltersQuery +from dbally.iql._query import IQLAggregationQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -33,7 +34,7 @@ def method_baz(self) -> sqlalchemy.Select: """ Some documentation string """ - return self.data.add_columns(sqlalchemy.literal("baz")).group_by(sqlalchemy.literal("baz")) + return self.select.add_columns(sqlalchemy.literal("baz")).group_by(sqlalchemy.literal("baz")) def normalize_whitespace(s: str) -> str: @@ -50,7 +51,7 @@ async def test_filter_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'method_foo(1) and method_bar("London", 2020)', allowed_functions=mock_view.list_filters(), ) @@ -66,7 +67,7 @@ async def test_aggregation_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( "method_baz()", allowed_functions=mock_view.list_aggregations(), ) @@ -82,12 +83,12 @@ async def test_filter_and_aggregation_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'method_foo(1) and method_bar("London", 2020)', allowed_functions=mock_view.list_filters() + mock_view.list_aggregations(), ) await mock_view.apply_filters(query) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( "method_baz()", allowed_functions=mock_view.list_aggregations(), ) From e0e9da863bd0ff5a6deee77172edf9d1e59fe4f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 26 Aug 2024 14:52:44 +0200 Subject: [PATCH 19/21] update docstrings --- src/dbally/iql_generator/iql_generator.py | 8 ++++---- src/dbally/iql_generator/prompt.py | 4 ++-- tests/unit/test_iql_format.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 2222e179..4ea65340 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -45,7 +45,7 @@ def failed(self) -> bool: class IQLGenerator: """ - Program that orchestrates all IQL operations for the given question. + Orchestrates all IQL operations for the given question. """ def __init__( @@ -127,7 +127,7 @@ async def __call__( class IQLOperationGenerator(Generic[IQLQueryT]): """ - Program that generates IQL queries for the given question. + Generates IQL queries for the given question. """ def __init__( @@ -199,7 +199,7 @@ async def __call__( class IQLQuestionAssessor: """ - Program that assesses whether a question requires applying IQL operation or not. + Assesses whether a question requires applying IQL operation or not. """ def __init__(self, prompt: PromptTemplate[DecisionPromptFormat]) -> None: @@ -251,7 +251,7 @@ async def __call__( class IQLQueryGenerator(Generic[IQLQueryT]): """ - Program that generates IQL queries for the given question. + Generates IQL queries for the given question. """ ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index bf33fbe0..8d8e7101 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -127,13 +127,13 @@ def __init__( Args: question: Question to be asked. - methods: List of filters exposed by the view. + methods: List of methods exposed by the view. examples: List of examples to be injected into the conversation. aggregations: List of aggregations exposed by the view. """ super().__init__(examples) self.question = question - self.methods = "\n".join([str(condition) for condition in methods]) if methods else [] + self.methods = "\n".join(str(method) for method in methods) FILTERING_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index a2bf23c4..3a21a1fe 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -14,14 +14,14 @@ async def test_iql_prompt_format_default() -> None: { "role": "system", "content": "You have access to an API that lets you query a database:\n" - "\n[]\n" + "\n\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n[]\n" + "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. ", @@ -44,14 +44,14 @@ async def test_iql_prompt_format_few_shots_injected() -> None: { "role": "system", "content": "You have access to an API that lets you query a database:\n" - "\n[]\n" + "\n\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n[]\n" + "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. ", From ab432382c0d85b04364533b854011978e114339e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 27 Aug 2024 15:14:44 +0200 Subject: [PATCH 20/21] add pandas agg wrapper --- src/dbally/views/methods_base.py | 29 +++++++++++------- src/dbally/views/pandas_base.py | 44 +++++++++++++++++++++------- tests/unit/views/test_pandas_base.py | 26 ++++++++++------ 3 files changed, 70 insertions(+), 29 deletions(-) diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 2a2c5d8e..8bf93363 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -67,14 +67,14 @@ def list_aggregations(self) -> List[ExposedFunction]: def _method_with_args_from_call( self, func: syntax.FunctionCall, method_decorator: Callable - ) -> Tuple[Callable, list]: + ) -> Tuple[Callable, List]: """ Converts a IQL FunctionCall node to a method object and its arguments. Args: func: IQL FunctionCall node method_decorator: The decorator that the method should have - (currently allows discrimination between filters and aggregations) + (currently allows discrimination between filters and aggregations) Returns: Tuple with the method object and its arguments @@ -94,6 +94,21 @@ def _method_with_args_from_call( return method, func.arguments + async def _call_method(self, method: Callable, args: List) -> Any: + """ + Calls the method with the given arguments. If the method is a coroutine, it will be awaited. + + Args: + method: The method to call. + args: The arguments to pass to the method. + + Returns: + The result of the method call. + """ + if inspect.iscoroutinefunction(method): + return await method(*args) + return method(*args) + async def call_filter_method(self, func: syntax.FunctionCall) -> Any: """ Converts a IQL FunctonCall filter to a method call. If the method is a coroutine, it will be awaited. @@ -105,10 +120,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: The result of the method call """ method, args = self._method_with_args_from_call(func, decorators.view_filter) - - if inspect.iscoroutinefunction(method): - return await method(*args) - return method(*args) + return await self._call_method(method, args) async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: """ @@ -121,7 +133,4 @@ async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: The result of the method call """ method, args = self._method_with_args_from_call(func, decorators.view_aggregation) - - if inspect.iscoroutinefunction(method): - return await method(*args) - return method(*args) + return await self._call_method(method, args) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index c35fd30f..e4da84c4 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -1,9 +1,9 @@ import asyncio +from dataclasses import dataclass from functools import reduce from typing import List, Optional, Union import pandas as pd -from sqlalchemy import Tuple from dbally.collection.results import ViewExecutionResult from dbally.iql import syntax @@ -11,6 +11,26 @@ from dbally.views.methods_base import MethodsBaseView +@dataclass(frozen=True) +class Aggregation: + """ + Represents an aggregation to be applied to a Pandas DataFrame. + """ + + column: str + function: str + + +@dataclass(frozen=True) +class AggregationGroup: + """ + Represents an aggregations and groupbys to be applied to a Pandas DataFrame. + """ + + aggregations: Optional[List[Aggregation]] = None + groupbys: Optional[Union[str, List[str]]] = None + + class DataFrameBaseView(MethodsBaseView): """ Base class for views that use Pandas DataFrames to store and filter data. @@ -29,8 +49,7 @@ def __init__(self, df: pd.DataFrame) -> None: super().__init__() self.df = df self._filter_mask: Optional[pd.Series] = None - self._groupbys: Optional[Union[str, List[str]]] = None - self._aggregations: Optional[List[Tuple[str, str]]] = None + self._aggregation_group: AggregationGroup = AggregationGroup() async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ @@ -48,7 +67,7 @@ async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: Args: aggregation: IQLQuery object representing the aggregation to apply. """ - self._groupbys, self._aggregations = await self.call_aggregation_method(aggregation.root) + self._aggregation_group = await self.call_aggregation_method(aggregation.root) async def _build_filter_node(self, node: syntax.Node) -> pd.Series: """ @@ -95,18 +114,23 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: if self._filter_mask is not None: results = results.loc[self._filter_mask] - if self._groupbys is not None: - results = results.groupby(self._groupbys) + if self._aggregation_group.groupbys is not None: + results = results.groupby(self._aggregation_group.groupbys) - if self._aggregations is not None: - results = results.agg(**{"_".join(agg): agg for agg in self._aggregations}) + if self._aggregation_group.aggregations is not None: + results = results.agg( + **{ + f"{agg.column}_{agg.function}": (agg.column, agg.function) + for agg in self._aggregation_group.aggregations + } + ) results = results.reset_index() return ViewExecutionResult( results=results.to_dict(orient="records"), context={ "filter_mask": self._filter_mask, - "groupbys": self._groupbys, - "aggregations": self._aggregations, + "groupbys": self._aggregation_group.groupbys, + "aggregations": self._aggregation_group.aggregations, }, ) diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index b24a0398..029fe30f 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -1,13 +1,12 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name -from typing import List, Tuple import pandas as pd from dbally.iql import IQLFiltersQuery from dbally.iql._query import IQLAggregationQuery from dbally.views.decorators import view_aggregation, view_filter -from dbally.views.pandas_base import DataFrameBaseView +from dbally.views.pandas_base import Aggregation, AggregationGroup, DataFrameBaseView MOCK_DATA = [ {"name": "Alice", "city": "London", "year": 2020, "age": 30}, @@ -57,12 +56,21 @@ def filter_name(self, name: str) -> pd.Series: return self.df["name"] == name @view_aggregation() - def mean_age_by_city(self) -> Tuple[str, List[Tuple[str, str]]]: - return "city", [("age", "mean")] + def mean_age_by_city(self) -> AggregationGroup: + return AggregationGroup( + aggregations=[ + Aggregation(column="age", function="mean"), + ], + groupbys="city", + ) @view_aggregation() - def count_records(self) -> Tuple[str, List[Tuple[str, str]]]: - return None, [("name", "count")] + def count_records(self) -> AggregationGroup: + return AggregationGroup( + aggregations=[ + Aggregation(column="name", function="count"), + ], + ) async def test_filter_or() -> None: @@ -132,7 +140,7 @@ async def test_aggregation() -> None: ] assert result.context["filter_mask"] is None assert result.context["groupbys"] is None - assert result.context["aggregations"] == [("name", "count")] + assert result.context["aggregations"] == [Aggregation(column="name", function="count")] async def test_aggregtion_with_groupby() -> None: @@ -153,7 +161,7 @@ async def test_aggregtion_with_groupby() -> None: ] assert result.context["filter_mask"] is None assert result.context["groupbys"] == "city" - assert result.context["aggregations"] == [("age", "mean")] + assert result.context["aggregations"] == [Aggregation(column="age", function="mean")] async def test_filters_and_aggregtion() -> None: @@ -175,4 +183,4 @@ async def test_filters_and_aggregtion() -> None: assert result.results == [{"city": "Paris", "age_mean": 32.5}] assert result.context["filter_mask"].tolist() == [False, True, False, True, False] assert result.context["groupbys"] == "city" - assert result.context["aggregations"] == [("age", "mean")] + assert result.context["aggregations"] == [Aggregation(column="age", function="mean")] From 74cabe2a6f57bb6694405502ff32719f6cd28daf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 28 Aug 2024 15:07:01 +0200 Subject: [PATCH 21/21] restore links --- src/dbally/views/sqlalchemy_base.py | 8 ++++---- src/dbally/views/structured.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index 691797a1..3a7c7981 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -28,10 +28,10 @@ def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None: @abc.abstractmethod def get_select(self) -> sqlalchemy.Select: """ - Creates initial SELECT statement for the view. - - Returns: - SQLAlchemy Select object for the view. + Creates the initial + [SqlAlchemy select object + ](https://docs.sqlalchemy.org/en/20/core/selectable.html#sqlalchemy.sql.expression.Select) + which will be used to build the query. """ async def apply_filters(self, filters: IQLFiltersQuery) -> None: diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index bab0f7b3..2e5cff85 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -17,7 +17,7 @@ class BaseStructuredView(BaseView): """ - Base class for all structured views. All classes implementing this interface has\ + Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\ to be able to list all available filters, apply them and execute queries. """