Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: aggregations in structured views #62

Merged
merged 23 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6d66b00
Adding aggregation handling for SqlAlchemyBaseView extending quicksta…
PatrykWyzgowski Jul 12, 2024
be12e6e
Applying initial review feedback. Adding both filters and aggregation…
PatrykWyzgowski Jul 15, 2024
33d5b2e
Renaming subquery attribute and method argument to filtered_query
PatrykWyzgowski Jul 16, 2024
2e77fbc
Simplified question to the model.
PatrykWyzgowski Jul 17, 2024
41e88ed
Fixing unnecessary-pass.
PatrykWyzgowski Jul 17, 2024
dfe3e13
Continuation of review feedback application.
PatrykWyzgowski Jul 17, 2024
c09e68e
Adjusting filter prompt not to mix IQL with 'UNSUPPORTED QUERY'. Furt…
PatrykWyzgowski Jul 17, 2024
c6bbf90
Applied changes suggested in a comment related to Aggregations not ge…
PatrykWyzgowski Jul 18, 2024
4765dde
Applying pre-commit hooks.
PatrykWyzgowski Jul 18, 2024
2918ba5
Mocking AggregationFormat in tests.
PatrykWyzgowski Jul 18, 2024
0b7e50a
Merge branch 'main' into pw/add-single-aggregation
PatrykWyzgowski Jul 19, 2024
5511729
Mocking methods of the view related to aggregations to make them comp…
PatrykWyzgowski Jul 19, 2024
ae26c8b
Pre-commit fixes.
PatrykWyzgowski Jul 19, 2024
63c3adc
merge main
micpst Aug 12, 2024
a2169f2
revert to prev approach
micpst Aug 16, 2024
013cb69
fix tests
micpst Aug 16, 2024
f0a2f6e
add more tests
micpst Aug 16, 2024
aeb6295
trying to fix tests (localy working)
micpst Aug 16, 2024
d21f4e1
fix tests for python 3.8
micpst Aug 17, 2024
9a8b63e
review: aggregations in structured views (#85)
micpst Aug 26, 2024
e0e9da8
update docstrings
micpst Aug 26, 2024
ab43238
add pandas agg wrapper
micpst Aug 27, 2024
74cabe2
restore links
micpst Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/how-to/views/custom_views_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:

return ViewExecutionResult(results=filtered_data, context={})


class CandidateView(FilteredIterableBaseView):
def get_data(self) -> Iterable:
return [
Expand Down
11 changes: 9 additions & 2 deletions docs/quickstart/quickstart_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base

import dbally
from dbally import decorators, SqlAlchemyBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.llms.litellm import LiteLLM
Expand Down Expand Up @@ -55,6 +54,13 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement:
"""
return Candidate.country == country

@decorators.view_aggregation()
def count_by_column(self, filtered_query: sqlalchemy.Select, column_name: str) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011
micpst marked this conversation as resolved.
Show resolved Hide resolved
select = sqlalchemy.select(getattr(filtered_query.c, column_name),
sqlalchemy.func.count(filtered_query.c.name).label("count")) \
.group_by(getattr(filtered_query.c, column_name))
return select


async def main():
llm = LiteLLM(model_name="gpt-3.5-turbo")
Expand All @@ -63,7 +69,8 @@ async def main():
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(engine))

result = await collection.ask("Find me French candidates suitable for a senior data scientist position.")
result = await collection.ask("Could you find French candidates suitable for a senior data scientist position"
"and count the candidates university-wise?")

print(f"The generated SQL query is: {result.context.get('sql')}")
print()
Expand Down
7 changes: 7 additions & 0 deletions src/dbally/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,10 @@ class DbAllyError(Exception):
"""
Base class for all exceptions raised by db-ally.
"""


class UnsupportedAggregationError(DbAllyError):
"""
Error raised when AggregationFormatter is unable to construct a query
with given aggregation.
"""
4 changes: 3 additions & 1 deletion src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerat

Args:
llm: LLM used to generate IQL
prompt_template: If not provided by the users is set to `default_iql_template`
"""
self._llm = llm
self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE
Expand Down Expand Up @@ -62,6 +63,7 @@ async def generate_iql(
filters=filters,
examples=examples,
)

formatted_prompt = self._prompt_template.format_prompt(prompt_format)

for _ in range(n_retries + 1):
Expand All @@ -76,7 +78,7 @@ async def generate_iql(
# TODO: Move IQL query parsing to prompt response parser
return await IQLQuery.parse(
source=iql,
allowed_functions=filters,
allowed_functions=filters or [],
event_tracker=event_tracker,
)
except IQLError as exc:
Expand Down
Empty file.
9 changes: 5 additions & 4 deletions src/dbally/iql_generator/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,29 @@ def __init__(
question: Question to be asked.
filters: List of filters exposed by the view.
examples: List of examples to be injected into the conversation.
aggregations: List of aggregations exposed by the view.
"""
super().__init__(examples)
self.question = question
self.filters = "\n".join([str(filter) for filter in filters])
self.filters = "\n".join([str(condition) for condition in filters]) if filters else []


IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat](
[
{
"role": "system",
"content": (
"You have access to API that lets you query a database:\n"
"You have access to an API that lets you query a database:\n"
"\n{filters}\n"
"Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"Remember! Don't give any comments, just the function calls.\n"
"The output will look like this:\n"
'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n'
"DO NOT INCLUDE arguments names in your response. Only the values.\n"
"You MUST use only these methods:\n"
"\n{filters}\n"
"It is VERY IMPORTANT not to use methods other than those listed above."
"""If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """
"""If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`"""
"This is CRUCIAL, otherwise the system will crash. "
),
},
Expand Down
122 changes: 122 additions & 0 deletions src/dbally/prompt/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import List, Optional

from dbally.audit import EventTracker
from dbally.exceptions import UnsupportedAggregationError
from dbally.iql import IQLQuery
from dbally.llms.base import LLM
from dbally.llms.clients import LLMOptions
from dbally.prompt.template import PromptFormat, PromptTemplate
from dbally.views.exposed_functions import ExposedFunction


def _validate_agg_response(llm_response: str) -> str:
"""
Validates LLM response to IQL

Args:
llm_response: LLM response

Returns:
A string containing aggregations.

Raises:
UnsupportedAggregationError: When IQL generator is unable to construct a query
with given aggregation.
"""
if "unsupported query" in llm_response.lower():
raise UnsupportedAggregationError
return llm_response


class AggregationPromptFormat(PromptFormat):
"""
Aggregation prompt format, providing a question and aggregation to be used in the conversation.
"""

def __init__(
self,
question: str,
aggregations: List[ExposedFunction] = None,
) -> None:
super().__init__()
self.question = question
self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else []


class AggregationFormatter:
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Class used to manage choice and formatting of aggregation based on natural language question.
"""

def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[AggregationPromptFormat]] = None) -> None:
"""
Constructs a new AggregationFormatter instance.

Args:
llm: LLM used to generate IQL
prompt_template: If not provided by the users is set to `AGGREGATION_GENERATION_TEMPLATE`
"""
self._llm = llm
self._prompt_template = prompt_template or AGGREGATION_GENERATION_TEMPLATE

async def format_to_query_object(
self,
question: str,
event_tracker: EventTracker,
aggregations: List[ExposedFunction] = None,
llm_options: Optional[LLMOptions] = None,
) -> IQLQuery:
"""
Generates IQL in text form using LLM.

Args:
question: User question.
event_tracker: Event store used to audit the generation process.
aggregations: List of aggregations exposed by the view.
llm_options: Options to use for the LLM client.

Returns:
Generated aggregation query.
"""
prompt_format = AggregationPromptFormat(
question=question,
aggregations=aggregations,
)

formatted_prompt = self._prompt_template.format_prompt(prompt_format)

response = await self._llm.generate_text(
prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
)
# TODO: Move response parsing to llm generate_text method
agg = formatted_prompt.response_parser(response)
# TODO: Move IQL query parsing to prompt response parser
return await IQLQuery.parse(
source=agg,
allowed_functions=aggregations or [],
event_tracker=event_tracker,
)


AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[AggregationPromptFormat](
micpst marked this conversation as resolved.
Show resolved Hide resolved
[
{
"role": "system",
"content": "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n"
"When prompted for an aggregation, use the following methods: \n"
"{aggregations}"
"DO NOT INCLUDE arguments names in your response. Only the values.\n"
"You MUST use only these methods:\n"
"\n{aggregations}\n"
"It is VERY IMPORTANT not to use methods other than those listed above."
"""If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`"""
"This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. "
"Structure output to resemble the following pattern:\n"
'aggregation1("arg1", arg2)\n',
},
{"role": "user", "content": "{question}"},
],
response_parser=_validate_agg_response,
)
2 changes: 1 addition & 1 deletion src/dbally/view_selection/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
"role": "system",
"content": (
"You are a very smart database programmer. "
"You have access to API that lets you query a database:\n"
"You have access to an API that lets you query a database:\n"
"First you need to select a class to query, based on its description and the user question. "
"You have the following classes to choose from:\n"
"{views}\n"
Expand Down
15 changes: 15 additions & 0 deletions src/dbally/views/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,18 @@ def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missin
return func

return wrapped


def view_aggregation() -> typing.Callable:
"""
Decorator for marking a method as an aggregation
Returns:
Function that returns the decorated method
"""

def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missing-return-doc
func._methodDecorator = view_aggregation # type:ignore # pylint: disable=protected-access
return func

return wrapped
33 changes: 31 additions & 2 deletions src/dbally/views/methods_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta):
"""

# Method arguments that should be skipped when listing methods
HIDDEN_ARGUMENTS = ["self", "select", "return"]
HIDDEN_ARGUMENTS = ["self", "select", "return", "filtered_query"]

def __init__(self):
self._filtered_query = None

@classmethod
def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]:
Expand Down Expand Up @@ -56,6 +59,15 @@ def list_filters(self) -> List[ExposedFunction]:
"""
return self.list_methods_by_decorator(decorators.view_filter)

def list_aggregations(self) -> List[ExposedFunction]:
"""
List aggregations in the given view

Returns:
Aggregations defined inside the View and decorated with `decorators.view_aggregation`.
"""
return self.list_methods_by_decorator(decorators.view_aggregation)

def _method_with_args_from_call(
self, func: syntax.FunctionCall, method_decorator: Callable
) -> Tuple[Callable, list]:
Expand All @@ -64,7 +76,8 @@ def _method_with_args_from_call(

Args:
func: IQL FunctionCall node
method_decorator: The decorator that thhe method should have
method_decorator: The decorator that the method should have
(currently allows discrimination between filters and aggregations)

Returns:
Tuple with the method object and its arguments
Expand Down Expand Up @@ -99,3 +112,19 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any:
if inspect.iscoroutinefunction(method):
return await method(*args)
return method(*args)

async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any:
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited.

Args:
func: IQL FunctionCall node

Returns:
The result of the method call
"""
method, args = self._method_with_args_from_call(func, decorators.view_aggregation)

if inspect.iscoroutinefunction(method):
return await method(self._filtered_query, *args)
return method(self._filtered_query, *args)
9 changes: 9 additions & 0 deletions src/dbally/views/pandas_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ async def apply_filters(self, filters: IQLQuery) -> None:
"""
self._filter_mask = await self.build_filter_node(filters.root)

async def apply_aggregation(self, aggregation: IQLQuery) -> None:
"""
Applies the aggregation of choice to the view.

Args:
aggregation: IQLQuery object representing the aggregation to apply
"""
# TODO - to be covered in a separate ticket.

async def build_filter_node(self, node: syntax.Node) -> pd.Series:
"""
Converts a filter node from the IQLQuery to a Pandas Series representing
Expand Down
21 changes: 21 additions & 0 deletions src/dbally/views/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, sqlalchemy_engine: sqlalchemy.engine.Engine) -> None:
super().__init__()
self._select = self.get_select()
self._sqlalchemy_engine = sqlalchemy_engine
self._filtered_query = None

@abc.abstractmethod
def get_select(self) -> sqlalchemy.Select:
Expand All @@ -27,6 +28,15 @@ def get_select(self) -> sqlalchemy.Select:
which will be used to build the query.
"""

def _get_filtered_query(self) -> sqlalchemy.Subquery:
"""
Creates the initial sqlalchemy.Subquery object, which will be used to build the query.

Returns:
The sqlalchemy.Subquery object based on private _select attribute.
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._select.subquery("subquery")

async def apply_filters(self, filters: IQLQuery) -> None:
"""
Applies the chosen filters to the view.
Expand Down Expand Up @@ -64,6 +74,16 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu
return alchemy_op(await self._build_filter_node(bool_op.child))
raise ValueError(f"BoolOp {bool_op} has no children")

async def apply_aggregation(self, aggregation: syntax.FunctionCall) -> None:
"""
Creates a subquery based on existing and calls the aggregation method.

Args:
aggregation: IQLQuery object representing the filters to apply
"""
self._filtered_query = self._get_filtered_query()
self._select = await self.call_aggregation_method(aggregation)

def execute(self, dry_run: bool = False) -> ViewExecutionResult:
"""
Executes the generated SQL query and returns the results.
Expand All @@ -77,6 +97,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
"""

results = []

sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))

if not dry_run:
Expand Down
Loading
Loading