Skip to content

Commit

Permalink
Adding aggregation handling for SqlAlchemyBaseView extending quicksta…
Browse files Browse the repository at this point in the history
…rt example.
  • Loading branch information
PatrykWyzgowski committed Jul 12, 2024
1 parent 221f6e1 commit 6d66b00
Show file tree
Hide file tree
Showing 16 changed files with 211 additions and 45 deletions.
58 changes: 31 additions & 27 deletions docs/how-to/views/custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,37 +106,41 @@ Let's implement the required `apply_filters` method in our `FilteredIterableBase

```python
def __init__(self) -> None:
super().__init__()
self._filter: Callable[[Any], bool] = lambda x: True

async def apply_filters(self, filters: IQLQuery) -> None:
"""
Applies the selected filters to the view.

Args:
filters: IQLQuery object representing the filters to apply
"""
self._filter = await self.build_filter_node(filters.root)
super().__init__()
self._filter: Callable[[Any], bool] = lambda x: True

async def build_filter_node(self, node: syntax.Node) -> Callable[[Any], bool]:
"""
Turns a filter node from the IQLQuery into a Python function.

Args:
node: IQLQuery node representing the filter or logical operator
"""
if isinstance(node, syntax.FunctionCall): # filter
return await self.call_filter_method(node)
if isinstance(node, syntax.And): # logical AND
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
return lambda x: all(child(x) for child in children) # combine children with logical AND
if isinstance(node, syntax.Or): # logical OR
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
return lambda x: any(child(x) for child in children) # combine children with logical OR
if isinstance(node, syntax.Not): # logical NOT
child = await self.build_filter_node(node.child)
return lambda x: not child(x)
raise ValueError(f"Unsupported grammar: {node}")
async def apply_filters(self, filters: IQLQuery) -> None:
"""
Applies the selected filters to the view.
Args:
filters: IQLQuery object representing the filters to apply
"""
self._filter = await self.build_filter_node(filters.root)


async def build_filter_node(self, node: syntax.Node) -> Callable[[Any], bool]:
"""
Turns a filter node from the IQLQuery into a Python function.
Args:
node: IQLQuery node representing the filter or logical operator
"""
if isinstance(node, syntax.FunctionCall): # filter
return await self.call_filter_method(node)
if isinstance(node, syntax.And): # logical AND
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
return lambda x: all(child(x) for child in children) # combine children with logical AND
if isinstance(node, syntax.Or): # logical OR
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
return lambda x: any(child(x) for child in children) # combine children with logical OR
if isinstance(node, syntax.Not): # logical NOT
child = await self.build_filter_node(node.child)
return lambda x: not child(x)
raise ValueError(f"Unsupported grammar: {node}")
```

In the `apply_filters` method, we're calling the `build_filter_node` method on the root of the IQL tree. The `build_filter_node` method uses recursion to create an object that represents the combined logic of the IQL expression and the returned filter methods. For `FilteredIterableBaseView`, this object is a function that takes a single argument (a candidate) and returns a boolean. We save this function in the `_filter` attribute.
Expand Down
1 change: 1 addition & 0 deletions docs/how-to/views/custom_views_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:

return ViewExecutionResult(results=filtered_data, context={})


class CandidateView(FilteredIterableBaseView):
def get_data(self) -> Iterable:
return [
Expand Down
10 changes: 8 additions & 2 deletions docs/quickstart/quickstart_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base

import dbally
from dbally import decorators, SqlAlchemyBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.llms.litellm import LiteLLM
Expand Down Expand Up @@ -55,6 +54,12 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement:
"""
return Candidate.country == country

@decorators.view_aggregation()
def count_by_column(self, subquery: sqlalchemy.Select, column_name: str) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011
select = sqlalchemy.select(getattr(subquery.c, column_name), sqlalchemy.func.count(subquery.c.name).label("count")) \
.group_by(getattr(subquery.c, column_name))
return select


async def main():
llm = LiteLLM(model_name="gpt-3.5-turbo")
Expand All @@ -63,7 +68,8 @@ async def main():
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(engine))

result = await collection.ask("Find me French candidates suitable for a senior data scientist position.")
result = await collection.ask("Could you find French candidates suitable for a senior data scientist position"
"and count the candidates university-wise and present the rows?")

print(f"The generated SQL query is: {result.context.get('sql')}")
print()
Expand Down
15 changes: 13 additions & 2 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from dbally.audit.event_tracker import EventTracker
from dbally.iql import IQLError, IQLQuery
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.iql_generator.prompt import (
IQL_GENERATION_TEMPLATE,
IQL_GENERATION_TEMPLATE_AGGREGATION,
IQLGenerationPromptFormat,
)
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.prompt.elements import FewShotExample
Expand Down Expand Up @@ -30,6 +34,7 @@ def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerat
Args:
llm: LLM used to generate IQL
prompt_template: If not provided by the users is set to `default_iql_template`
"""
self._llm = llm
self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE
Expand All @@ -39,6 +44,7 @@ async def generate_iql(
question: str,
filters: List[ExposedFunction],
event_tracker: EventTracker,
aggregations: List[ExposedFunction] = None,
examples: Optional[List[FewShotExample]] = None,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
Expand All @@ -50,6 +56,7 @@ async def generate_iql(
question: User question.
filters: List of filters exposed by the view.
event_tracker: Event store used to audit the generation process.
aggregations: List of aggregations to be applied on the view.
examples: List of examples to be injected into the conversation.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors.
Expand All @@ -60,8 +67,12 @@ async def generate_iql(
prompt_format = IQLGenerationPromptFormat(
question=question,
filters=filters,
aggregations=aggregations,
examples=examples,
)
if aggregations and not filters:
self._prompt_template = IQL_GENERATION_TEMPLATE_AGGREGATION

formatted_prompt = self._prompt_template.format_prompt(prompt_format)

for _ in range(n_retries + 1):
Expand All @@ -76,7 +87,7 @@ async def generate_iql(
# TODO: Move IQL query parsing to prompt response parser
return await IQLQuery.parse(
source=iql,
allowed_functions=filters,
allowed_functions=filters or [] + aggregations or [],
event_tracker=event_tracker,
)
except IQLError as exc:
Expand Down
Empty file.
30 changes: 28 additions & 2 deletions src/dbally/iql_generator/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
question: str,
filters: List[ExposedFunction],
examples: List[FewShotExample] = None,
aggregations: List[ExposedFunction] = None,
) -> None:
"""
Constructs a new IQLGenerationPromptFormat instance.
Expand All @@ -53,10 +54,12 @@ def __init__(
question: Question to be asked.
filters: List of filters exposed by the view.
examples: List of examples to be injected into the conversation.
aggregations: List of examples to be computed for the view (Currently 1 aggregation supported).
"""
super().__init__(examples)
self.question = question
self.filters = "\n".join([str(filter) for filter in filters])
self.filters = "\n".join([str(condition) for condition in filters]) if filters else []
self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else []


IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat](
Expand All @@ -66,7 +69,7 @@ def __init__(
"content": (
"You have access to API that lets you query a database:\n"
"\n{filters}\n"
"Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"Remember! Don't give any comments, just the function calls.\n"
"The output will look like this:\n"
'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n'
Expand All @@ -85,3 +88,26 @@ def __init__(
],
response_parser=_validate_iql_response,
)

IQL_GENERATION_TEMPLATE_AGGREGATION = PromptTemplate[IQLGenerationPromptFormat](
chat=(
{
"role": "system",
"content": "You have access to API that lets you query a database supporting SINGLE aggregation.\n"
"When prompted for just aggregation, use the following methods: \n"
"{aggregations}"
"DO NOT INCLUDE arguments names in your response. Only the values.\n"
"You MUST use only these methods:\n"
"\n{aggregations}\n"
"Structure output to resemble the following pattern:\n"
'aggregation1("Argument_in_lowercase")\n'
"It is VERY IMPORTANT not to use methods other than those listed above."
"""If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\" anything other than `UNSUPPORTED QUERY`"""
"This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. "
"Structure output to resemble the following pattern:\n"
'aggregation1("Argument_in_lowercase", another_base_python_datatype_argument)\n',
},
{"role": "user", "content": "{question}"},
),
response_parser=_validate_iql_response,
)
15 changes: 15 additions & 0 deletions src/dbally/views/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,18 @@ def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missin
return func

return wrapped


def view_aggregation() -> typing.Callable:
"""
Decorator for marking a method as a filter
Returns:
Function that returns the decorated method
"""

def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missing-return-doc
func._methodDecorator = view_aggregation # type:ignore # pylint: disable=protected-access
return func

return wrapped
33 changes: 31 additions & 2 deletions src/dbally/views/methods_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta):
"""

# Method arguments that should be skipped when listing methods
HIDDEN_ARGUMENTS = ["self", "select", "return"]
HIDDEN_ARGUMENTS = ["self", "select", "return", "subquery"]

def __init__(self):
self._subquery = None

@classmethod
def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]:
Expand Down Expand Up @@ -56,6 +59,15 @@ def list_filters(self) -> List[ExposedFunction]:
"""
return self.list_methods_by_decorator(decorators.view_filter)

def list_aggregations(self) -> List[ExposedFunction]:
"""
List aggregations in the given view
Returns:
Aggregations defined inside the View and decorated with `decorators.view_aggregation`.
"""
return self.list_methods_by_decorator(decorators.view_aggregation)

def _method_with_args_from_call(
self, func: syntax.FunctionCall, method_decorator: Callable
) -> Tuple[Callable, list]:
Expand All @@ -64,7 +76,8 @@ def _method_with_args_from_call(
Args:
func: IQL FunctionCall node
method_decorator: The decorator that thhe method should have
method_decorator: The decorator that the method should have
(currently allows discrimination between filters and aggregations)
Returns:
Tuple with the method object and its arguments
Expand Down Expand Up @@ -99,3 +112,19 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any:
if inspect.iscoroutinefunction(method):
return await method(*args)
return method(*args)

async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any:
"""
Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited.
Args:
func: IQL FunctionCall node
Returns:
The result of the method call
"""
method, args = self._method_with_args_from_call(func, decorators.view_aggregation)

if inspect.iscoroutinefunction(method):
return await method(self._subquery, *args)
return method(self._subquery, *args)
9 changes: 9 additions & 0 deletions src/dbally/views/pandas_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ async def apply_filters(self, filters: IQLQuery) -> None:
"""
self._filter_mask = await self.build_filter_node(filters.root)

async def apply_aggregation(self, aggregation: IQLQuery) -> None:
"""
Applies the aggregation of choice to the view.
Args:
aggregation: IQLQuery object representing the aggregation to apply
"""
pass # pylint: disable=unnecessary-pass

async def build_filter_node(self, node: syntax.Node) -> pd.Series:
"""
Converts a filter node from the IQLQuery to a Pandas Series representing
Expand Down
28 changes: 26 additions & 2 deletions src/dbally/views/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, sqlalchemy_engine: sqlalchemy.engine.Engine) -> None:
super().__init__()
self._select = self.get_select()
self._sqlalchemy_engine = sqlalchemy_engine
self._subquery = None

@abc.abstractmethod
def get_select(self) -> sqlalchemy.Select:
Expand All @@ -27,6 +28,15 @@ def get_select(self) -> sqlalchemy.Select:
which will be used to build the query.
"""

def get_subquery(self) -> sqlalchemy.Subquery:
"""
Creates the initial sqlalchemy.Subquery object, which will be used to build the query.
Returns:
The sqlalchemy.Subquery object based on private _select attribute.
"""
return self._select.subquery("subquery")

async def apply_filters(self, filters: IQLQuery) -> None:
"""
Applies the chosen filters to the view.
Expand Down Expand Up @@ -64,6 +74,16 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu
return alchemy_op(await self._build_filter_node(bool_op.child))
raise ValueError(f"BoolOp {bool_op} has no children")

async def apply_aggregation(self, aggregation: IQLQuery) -> None:
"""
Creates a subquery based on existing
Args:
aggregation: IQLQuery object representing the filters to apply
"""
self._subquery = self.get_subquery()
self._subquery = await self.call_aggregation_method(aggregation.root)

def execute(self, dry_run: bool = False) -> ViewExecutionResult:
"""
Executes the generated SQL query and returns the results.
Expand All @@ -77,13 +97,17 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
"""

results = []
sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))
statement = self._select
if self._subquery is not None:
statement = self._subquery

sql = str(statement.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))

if not dry_run:
with self._sqlalchemy_engine.connect() as connection:
# The underscore is used by sqlalchemy to avoid conflicts with column names
# pylint: disable=protected-access
rows = connection.execute(self._select).fetchall()
rows = connection.execute(statement).fetchall()
results = [dict(row._mapping) for row in rows]

return ViewExecutionResult(
Expand Down
Loading

0 comments on commit 6d66b00

Please sign in to comment.