Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: aggregations in structured views #62

Merged
merged 23 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6d66b00
Adding aggregation handling for SqlAlchemyBaseView extending quicksta…
PatrykWyzgowski Jul 12, 2024
be12e6e
Applying initial review feedback. Adding both filters and aggregation…
PatrykWyzgowski Jul 15, 2024
33d5b2e
Renaming subquery attribute and method argument to filtered_query
PatrykWyzgowski Jul 16, 2024
2e77fbc
Simplified question to the model.
PatrykWyzgowski Jul 17, 2024
41e88ed
Fixing unnecessary-pass.
PatrykWyzgowski Jul 17, 2024
dfe3e13
Continuation of review feedback application.
PatrykWyzgowski Jul 17, 2024
c09e68e
Adjusting filter prompt not to mix IQL with 'UNSUPPORTED QUERY'. Furt…
PatrykWyzgowski Jul 17, 2024
c6bbf90
Applied changes suggested in a comment related to Aggregations not ge…
PatrykWyzgowski Jul 18, 2024
4765dde
Applying pre-commit hooks.
PatrykWyzgowski Jul 18, 2024
2918ba5
Mocking AggregationFormat in tests.
PatrykWyzgowski Jul 18, 2024
0b7e50a
Merge branch 'main' into pw/add-single-aggregation
PatrykWyzgowski Jul 19, 2024
5511729
Mocking methods of the view related to aggregations to make them comp…
PatrykWyzgowski Jul 19, 2024
ae26c8b
Pre-commit fixes.
PatrykWyzgowski Jul 19, 2024
63c3adc
merge main
micpst Aug 12, 2024
a2169f2
revert to prev approach
micpst Aug 16, 2024
013cb69
fix tests
micpst Aug 16, 2024
f0a2f6e
add more tests
micpst Aug 16, 2024
aeb6295
trying to fix tests (localy working)
micpst Aug 16, 2024
d21f4e1
fix tests for python 3.8
micpst Aug 17, 2024
9a8b63e
review: aggregations in structured views (#85)
micpst Aug 26, 2024
e0e9da8
update docstrings
micpst Aug 26, 2024
ab43238
add pandas agg wrapper
micpst Aug 27, 2024
74cabe2
restore links
micpst Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions docs/how-to/views/custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,37 +106,41 @@ Let's implement the required `apply_filters` method in our `FilteredIterableBase

```python
def __init__(self) -> None:
super().__init__()
self._filter: Callable[[Any], bool] = lambda x: True

async def apply_filters(self, filters: IQLQuery) -> None:
"""
Applies the selected filters to the view.

Args:
filters: IQLQuery object representing the filters to apply
"""
self._filter = await self.build_filter_node(filters.root)
super().__init__()
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
self._filter: Callable[[Any], bool] = lambda x: True

async def build_filter_node(self, node: syntax.Node) -> Callable[[Any], bool]:
"""
Turns a filter node from the IQLQuery into a Python function.

Args:
node: IQLQuery node representing the filter or logical operator
"""
if isinstance(node, syntax.FunctionCall): # filter
return await self.call_filter_method(node)
if isinstance(node, syntax.And): # logical AND
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
return lambda x: all(child(x) for child in children) # combine children with logical AND
if isinstance(node, syntax.Or): # logical OR
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
return lambda x: any(child(x) for child in children) # combine children with logical OR
if isinstance(node, syntax.Not): # logical NOT
child = await self.build_filter_node(node.child)
return lambda x: not child(x)
raise ValueError(f"Unsupported grammar: {node}")
async def apply_filters(self, filters: IQLQuery) -> None:
"""
Applies the selected filters to the view.

Args:
filters: IQLQuery object representing the filters to apply
"""
self._filter = await self.build_filter_node(filters.root)


async def build_filter_node(self, node: syntax.Node) -> Callable[[Any], bool]:
"""
Turns a filter node from the IQLQuery into a Python function.

Args:
node: IQLQuery node representing the filter or logical operator
"""
if isinstance(node, syntax.FunctionCall): # filter
return await self.call_filter_method(node)
if isinstance(node, syntax.And): # logical AND
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
return lambda x: all(child(x) for child in children) # combine children with logical AND
if isinstance(node, syntax.Or): # logical OR
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
return lambda x: any(child(x) for child in children) # combine children with logical OR
if isinstance(node, syntax.Not): # logical NOT
child = await self.build_filter_node(node.child)
return lambda x: not child(x)
raise ValueError(f"Unsupported grammar: {node}")
```

In the `apply_filters` method, we're calling the `build_filter_node` method on the root of the IQL tree. The `build_filter_node` method uses recursion to create an object that represents the combined logic of the IQL expression and the returned filter methods. For `FilteredIterableBaseView`, this object is a function that takes a single argument (a candidate) and returns a boolean. We save this function in the `_filter` attribute.
Expand Down
1 change: 1 addition & 0 deletions docs/how-to/views/custom_views_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:

return ViewExecutionResult(results=filtered_data, context={})


class CandidateView(FilteredIterableBaseView):
def get_data(self) -> Iterable:
return [
Expand Down
10 changes: 8 additions & 2 deletions docs/quickstart/quickstart_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base

import dbally
from dbally import decorators, SqlAlchemyBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.llms.litellm import LiteLLM
Expand Down Expand Up @@ -55,6 +54,12 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement:
"""
return Candidate.country == country

@decorators.view_aggregation()
def count_by_column(self, subquery: sqlalchemy.Select, column_name: str) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
select = sqlalchemy.select(getattr(subquery.c, column_name), sqlalchemy.func.count(subquery.c.name).label("count")) \
.group_by(getattr(subquery.c, column_name))
return select


async def main():
llm = LiteLLM(model_name="gpt-3.5-turbo")
Expand All @@ -63,7 +68,8 @@ async def main():
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(engine))

result = await collection.ask("Find me French candidates suitable for a senior data scientist position.")
result = await collection.ask("Could you find French candidates suitable for a senior data scientist position"
"and count the candidates university-wise and present the rows?")
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved

print(f"The generated SQL query is: {result.context.get('sql')}")
print()
Expand Down
15 changes: 13 additions & 2 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from dbally.audit.event_tracker import EventTracker
from dbally.iql import IQLError, IQLQuery
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.iql_generator.prompt import (
IQL_GENERATION_TEMPLATE,
IQL_GENERATION_TEMPLATE_AGGREGATION,
IQLGenerationPromptFormat,
)
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.prompt.elements import FewShotExample
Expand Down Expand Up @@ -30,6 +34,7 @@ def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerat

Args:
llm: LLM used to generate IQL
prompt_template: If not provided by the users is set to `default_iql_template`
"""
self._llm = llm
self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE
Expand All @@ -39,6 +44,7 @@ async def generate_iql(
question: str,
filters: List[ExposedFunction],
event_tracker: EventTracker,
aggregations: List[ExposedFunction] = None,
examples: Optional[List[FewShotExample]] = None,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
Expand All @@ -50,6 +56,7 @@ async def generate_iql(
question: User question.
filters: List of filters exposed by the view.
event_tracker: Event store used to audit the generation process.
aggregations: List of aggregations to be applied on the view.
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
examples: List of examples to be injected into the conversation.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors.
Expand All @@ -60,8 +67,12 @@ async def generate_iql(
prompt_format = IQLGenerationPromptFormat(
question=question,
filters=filters,
aggregations=aggregations,
examples=examples,
)
if aggregations and not filters:
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
self._prompt_template = IQL_GENERATION_TEMPLATE_AGGREGATION

formatted_prompt = self._prompt_template.format_prompt(prompt_format)

for _ in range(n_retries + 1):
Expand All @@ -76,7 +87,7 @@ async def generate_iql(
# TODO: Move IQL query parsing to prompt response parser
return await IQLQuery.parse(
source=iql,
allowed_functions=filters,
allowed_functions=filters or [] + aggregations or [],
event_tracker=event_tracker,
)
except IQLError as exc:
Expand Down
Empty file.
30 changes: 28 additions & 2 deletions src/dbally/iql_generator/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
question: str,
filters: List[ExposedFunction],
examples: List[FewShotExample] = None,
aggregations: List[ExposedFunction] = None,
) -> None:
"""
Constructs a new IQLGenerationPromptFormat instance.
Expand All @@ -53,10 +54,12 @@ def __init__(
question: Question to be asked.
filters: List of filters exposed by the view.
examples: List of examples to be injected into the conversation.
aggregations: List of examples to be computed for the view (Currently 1 aggregation supported).
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__(examples)
self.question = question
self.filters = "\n".join([str(filter) for filter in filters])
self.filters = "\n".join([str(condition) for condition in filters]) if filters else []
self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else []


IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat](
Expand All @@ -66,7 +69,7 @@ def __init__(
"content": (
"You have access to API that lets you query a database:\n"
"\n{filters}\n"
"Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"Remember! Don't give any comments, just the function calls.\n"
"The output will look like this:\n"
'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n'
Expand All @@ -85,3 +88,26 @@ def __init__(
],
response_parser=_validate_iql_response,
)

IQL_GENERATION_TEMPLATE_AGGREGATION = PromptTemplate[IQLGenerationPromptFormat](
chat=(
{
"role": "system",
"content": "You have access to API that lets you query a database supporting SINGLE aggregation.\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"an API", "a SINGLE aggregation"

I would probably replace "query a database supporting" with "query a database with" but of course when it comes to the prompt, the results of the benchamarks is what's important.

"When prompted for just aggregation, use the following methods: \n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what "When prompted for just aggregation" mean.

"{aggregations}"
"DO NOT INCLUDE arguments names in your response. Only the values.\n"
"You MUST use only these methods:\n"
"\n{aggregations}\n"
"Structure output to resemble the following pattern:\n"
'aggregation1("Argument_in_lowercase")\n'
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
"It is VERY IMPORTANT not to use methods other than those listed above."
"""If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\" anything other than `UNSUPPORTED QUERY`"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "" in the middle doesn't make sense in the flow of this sentence

"This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. "
"Structure output to resemble the following pattern:\n"
'aggregation1("Argument_in_lowercase", another_base_python_datatype_argument)\n',
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
},
{"role": "user", "content": "{question}"},
),
response_parser=_validate_iql_response,
)
15 changes: 15 additions & 0 deletions src/dbally/views/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,18 @@ def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missin
return func

return wrapped


def view_aggregation() -> typing.Callable:
"""
Decorator for marking a method as a filter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Decorator for marking a method as a filter
Decorator for marking a method as an aggregation


Returns:
Function that returns the decorated method
"""

def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missing-return-doc
func._methodDecorator = view_aggregation # type:ignore # pylint: disable=protected-access
return func

return wrapped
33 changes: 31 additions & 2 deletions src/dbally/views/methods_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta):
"""

# Method arguments that should be skipped when listing methods
HIDDEN_ARGUMENTS = ["self", "select", "return"]
HIDDEN_ARGUMENTS = ["self", "select", "return", "subquery"]

def __init__(self):
self._subquery = None

@classmethod
def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]:
Expand Down Expand Up @@ -56,6 +59,15 @@ def list_filters(self) -> List[ExposedFunction]:
"""
return self.list_methods_by_decorator(decorators.view_filter)

def list_aggregations(self) -> List[ExposedFunction]:
"""
List aggregations in the given view

Returns:
Aggregations defined inside the View and decorated with `decorators.view_aggregation`.
"""
return self.list_methods_by_decorator(decorators.view_aggregation)

def _method_with_args_from_call(
self, func: syntax.FunctionCall, method_decorator: Callable
) -> Tuple[Callable, list]:
Expand All @@ -64,7 +76,8 @@ def _method_with_args_from_call(

Args:
func: IQL FunctionCall node
method_decorator: The decorator that thhe method should have
method_decorator: The decorator that the method should have
(currently allows discrimination between filters and aggregations)

Returns:
Tuple with the method object and its arguments
Expand Down Expand Up @@ -99,3 +112,19 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any:
if inspect.iscoroutinefunction(method):
return await method(*args)
return method(*args)

async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any:
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited.

Args:
func: IQL FunctionCall node

Returns:
The result of the method call
"""
method, args = self._method_with_args_from_call(func, decorators.view_aggregation)

if inspect.iscoroutinefunction(method):
return await method(self._subquery, *args)
return method(self._subquery, *args)
9 changes: 9 additions & 0 deletions src/dbally/views/pandas_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ async def apply_filters(self, filters: IQLQuery) -> None:
"""
self._filter_mask = await self.build_filter_node(filters.root)

async def apply_aggregation(self, aggregation: IQLQuery) -> None:
"""
Applies the aggregation of choice to the view.

Args:
aggregation: IQLQuery object representing the aggregation to apply
"""
pass # pylint: disable=unnecessary-pass
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved

async def build_filter_node(self, node: syntax.Node) -> pd.Series:
"""
Converts a filter node from the IQLQuery to a Pandas Series representing
Expand Down
28 changes: 26 additions & 2 deletions src/dbally/views/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, sqlalchemy_engine: sqlalchemy.engine.Engine) -> None:
super().__init__()
self._select = self.get_select()
self._sqlalchemy_engine = sqlalchemy_engine
self._subquery = None

@abc.abstractmethod
def get_select(self) -> sqlalchemy.Select:
Expand All @@ -27,6 +28,15 @@ def get_select(self) -> sqlalchemy.Select:
which will be used to build the query.
"""

def get_subquery(self) -> sqlalchemy.Subquery:
"""
Creates the initial sqlalchemy.Subquery object, which will be used to build the query.

Returns:
The sqlalchemy.Subquery object based on private _select attribute.
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._select.subquery("subquery")

async def apply_filters(self, filters: IQLQuery) -> None:
"""
Applies the chosen filters to the view.
Expand Down Expand Up @@ -64,6 +74,16 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu
return alchemy_op(await self._build_filter_node(bool_op.child))
raise ValueError(f"BoolOp {bool_op} has no children")

async def apply_aggregation(self, aggregation: IQLQuery) -> None:
"""
Creates a subquery based on existing

Args:
aggregation: IQLQuery object representing the filters to apply
"""
self._subquery = self.get_subquery()
self._subquery = await self.call_aggregation_method(aggregation.root)
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved

def execute(self, dry_run: bool = False) -> ViewExecutionResult:
"""
Executes the generated SQL query and returns the results.
Expand All @@ -77,13 +97,17 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
"""

results = []
sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))
statement = self._select
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
if self._subquery is not None:
statement = self._subquery

sql = str(statement.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))

if not dry_run:
with self._sqlalchemy_engine.connect() as connection:
# The underscore is used by sqlalchemy to avoid conflicts with column names
# pylint: disable=protected-access
rows = connection.execute(self._select).fetchall()
rows = connection.execute(statement).fetchall()
results = [dict(row._mapping) for row in rows]

return ViewExecutionResult(
Expand Down
Loading
Loading