Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: aggregations in structured views #62

Merged
merged 23 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6d66b00
Adding aggregation handling for SqlAlchemyBaseView extending quicksta…
PatrykWyzgowski Jul 12, 2024
be12e6e
Applying initial review feedback. Adding both filters and aggregation…
PatrykWyzgowski Jul 15, 2024
33d5b2e
Renaming subquery attribute and method argument to filtered_query
PatrykWyzgowski Jul 16, 2024
2e77fbc
Simplified question to the model.
PatrykWyzgowski Jul 17, 2024
41e88ed
Fixing unnecessary-pass.
PatrykWyzgowski Jul 17, 2024
dfe3e13
Continuation of review feedback application.
PatrykWyzgowski Jul 17, 2024
c09e68e
Adjusting filter prompt not to mix IQL with 'UNSUPPORTED QUERY'. Furt…
PatrykWyzgowski Jul 17, 2024
c6bbf90
Applied changes suggested in a comment related to Aggregations not ge…
PatrykWyzgowski Jul 18, 2024
4765dde
Applying pre-commit hooks.
PatrykWyzgowski Jul 18, 2024
2918ba5
Mocking AggregationFormat in tests.
PatrykWyzgowski Jul 18, 2024
0b7e50a
Merge branch 'main' into pw/add-single-aggregation
PatrykWyzgowski Jul 19, 2024
5511729
Mocking methods of the view related to aggregations to make them comp…
PatrykWyzgowski Jul 19, 2024
ae26c8b
Pre-commit fixes.
PatrykWyzgowski Jul 19, 2024
63c3adc
merge main
micpst Aug 12, 2024
a2169f2
revert to prev approach
micpst Aug 16, 2024
013cb69
fix tests
micpst Aug 16, 2024
f0a2f6e
add more tests
micpst Aug 16, 2024
aeb6295
trying to fix tests (localy working)
micpst Aug 16, 2024
d21f4e1
fix tests for python 3.8
micpst Aug 17, 2024
9a8b63e
review: aggregations in structured views (#85)
micpst Aug 26, 2024
e0e9da8
update docstrings
micpst Aug 26, 2024
ab43238
add pandas agg wrapper
micpst Aug 27, 2024
74cabe2
restore links
micpst Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion benchmarks/sql/bench/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from dbally.iql._exceptions import IQLError
from dbally.iql._query import IQLQuery
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.llms.base import LLM
from dbally.llms.clients.exceptions import LLMError
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM

Expand All @@ -16,6 +20,25 @@ class IQL:
source: Optional[str] = None
unsupported: bool = False
valid: bool = True
generated: bool = True

@classmethod
def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL":
"""
Creates an IQL object from the query.

Args:
query: The IQL query or exception.

Returns:
The IQL object.
"""
return cls(
source=query.source if isinstance(query, (IQLQuery, IQLError)) else None,
unsupported=isinstance(query, UnsupportedQueryError),
valid=not isinstance(query, IQLError),
generated=not isinstance(query, LLMError),
)


@dataclass
Expand Down Expand Up @@ -47,6 +70,7 @@ class EvaluationResult:
"""

db_id: str
question_id: str
question: str
reference: ExecutionResult
prediction: ExecutionResult
Expand Down
40 changes: 9 additions & 31 deletions benchmarks/sql/bench/pipelines/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import dbally
from dbally.collection.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.view_selection.llm_view_selector import LLMViewSelector
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exceptions import ViewExecutionError

from ..views import VIEWS_REGISTRY
from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult
Expand Down Expand Up @@ -74,44 +72,23 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
return_natural_response=False,
)
except NoViewFoundError:
prediction = ExecutionResult(
view_name=None,
iql=None,
sql=None,
)
except IQLGenerationError as exc:
prediction = ExecutionResult()
except ViewExecutionError as exc:
prediction = ExecutionResult(
view_name=exc.view_name,
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
filters=IQL.from_query(exc.iql.filters),
aggregation=IQL.from_query(exc.iql.aggregation),
),
sql=None,
)
else:
prediction = ExecutionResult(
view_name=result.view_name,
iql=IQLResult(
filters=IQL(
source=result.context.get("iql"),
unsupported=False,
valid=True,
),
aggregation=IQL(
source=None,
unsupported=False,
valid=True,
),
filters=IQL(source=result.context["iql"]["filters"]),
aggregation=IQL(source=result.context["iql"]["aggregation"]),
),
sql=result.context.get("sql"),
sql=result.context["sql"],
)

reference = ExecutionResult(
Expand All @@ -134,6 +111,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:

return EvaluationResult(
db_id=data["db_id"],
question_id=data["question_id"],
question=data["question"],
reference=reference,
prediction=prediction,
Expand Down
35 changes: 8 additions & 27 deletions benchmarks/sql/bench/pipelines/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from sqlalchemy import create_engine

from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exceptions import ViewExecutionError
from dbally.views.freeform.text2sql.view import BaseText2SQLView
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

Expand Down Expand Up @@ -94,37 +92,20 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
dry_run=True,
n_retries=0,
)
except IQLGenerationError as exc:
except ViewExecutionError as exc:
prediction = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
filters=IQL.from_query(exc.iql.filters),
aggregation=IQL.from_query(exc.iql.aggregation),
),
sql=None,
)
else:
prediction = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=result.context["iql"],
unsupported=False,
valid=True,
),
aggregation=IQL(
source=None,
unsupported=False,
valid=True,
),
filters=IQL(source=result.context["iql"]["filters"]),
aggregation=IQL(source=result.context["iql"]["aggregation"]),
),
sql=result.context["sql"],
)
Expand All @@ -135,12 +116,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
filters=IQL(
source=data["iql_filters"],
unsupported=data["iql_filters_unsupported"],
valid=True,
),
aggregation=IQL(
source=data["iql_aggregation"],
unsupported=data["iql_aggregation_unsupported"],
valid=True,
),
context=data["iql_context"],
),
Expand All @@ -149,6 +128,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:

return EvaluationResult(
db_id=data["db_id"],
question_id=data["question_id"],
question=data["question"],
reference=reference,
prediction=prediction,
Expand Down Expand Up @@ -209,6 +189,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:

return EvaluationResult(
db_id=data["db_id"],
question_id=data["question_id"],
question=data["question"],
reference=reference,
prediction=prediction,
Expand Down
23 changes: 20 additions & 3 deletions benchmarks/sql/bench/views/structured/superhero.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.ext.declarative import DeferredReflection
from sqlalchemy.orm import aliased, declarative_base

from dbally.views.decorators import view_filter
from dbally.views.decorators import view_aggregation, view_filter
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

Base = declarative_base(cls=DeferredReflection)
Expand Down Expand Up @@ -285,8 +285,8 @@ class SuperheroColourFilterMixin:
Mixin for filtering the view by the superhero colour attributes.
"""

def __init__(self) -> None:
super().__init__()
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.eye_colour = aliased(Colour)
self.hair_colour = aliased(Colour)
self.skin_colour = aliased(Colour)
Expand Down Expand Up @@ -427,10 +427,27 @@ def filter_by_race(self, race: str) -> ColumnElement:
return Race.race == race


class SuperheroAggregationMixin:
"""
Mixin for aggregating the view by the superhero attributes.
"""

@view_aggregation()
def count_superheroes(self) -> Select:
"""
Counts the number of superheros.

Returns:
The superheros count.
"""
return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)


class SuperheroView(
DBInitMixin,
SqlAlchemyBaseView,
SuperheroFilterMixin,
SuperheroAggregationMixin,
SuperheroColourFilterMixin,
AlignmentFilterMixin,
GenderFilterMixin,
Expand Down
1 change: 1 addition & 0 deletions docs/how-to/views/custom_views_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:

return ViewExecutionResult(results=filtered_data, context={})


class CandidateView(FilteredIterableBaseView):
def get_data(self) -> Iterable:
return [
Expand Down
1 change: 0 additions & 1 deletion docs/quickstart/quickstart_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base

import dbally
from dbally import decorators, SqlAlchemyBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.llms.litellm import LiteLLM
Expand Down
12 changes: 10 additions & 2 deletions src/dbally/iql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from . import syntax
from ._exceptions import IQLArgumentParsingError, IQLError, IQLUnsupportedSyntaxError
from ._query import IQLQuery
from ._query import IQLAggregationQuery, IQLFiltersQuery, IQLQuery

__all__ = ["IQLQuery", "syntax", "IQLError", "IQLArgumentParsingError", "IQLUnsupportedSyntaxError"]
__all__ = [
"IQLQuery",
"IQLFiltersQuery",
"IQLAggregationQuery",
"syntax",
"IQLError",
"IQLArgumentParsingError",
"IQLUnsupportedSyntaxError",
]
77 changes: 55 additions & 22 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from typing import TYPE_CHECKING, Any, List, Optional, Union
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
from dbally.iql import syntax
Expand All @@ -19,10 +20,12 @@
if TYPE_CHECKING:
from dbally.views.structured import ExposedFunction

RootT = TypeVar("RootT", bound=syntax.Node)

class IQLProcessor:

class IQLProcessor(Generic[RootT], ABC):
"""
Parses IQL string to tree structure.
Base class for IQL processors.
"""

def __init__(
Expand All @@ -32,9 +35,9 @@ def __init__(
self.allowed_functions = {func.name: func for func in allowed_functions}
self._event_tracker = event_tracker or EventTracker()

async def process(self) -> syntax.Node:
async def process(self) -> RootT:
"""
Process IQL string to root IQL.Node.
Process IQL string to IQL root node.

Returns:
IQL node which is root of the tree representing IQL query.
Expand All @@ -60,25 +63,17 @@ async def process(self) -> syntax.Node:

return await self._parse_node(ast_tree.body[0].value)

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
if isinstance(node, ast.BoolOp):
return await self._parse_bool_op(node)
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.operand))
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)
@abstractmethod
async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT:
"""
Parses AST node to IQL node.

async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
if isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.values[0]))
if isinstance(node.op, ast.And):
return syntax.And([await self._parse_node(x) for x in node.values])
if isinstance(node.op, ast.Or):
return syntax.Or([await self._parse_node(x) for x in node.values])
Args:
node: AST node to parse.

raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")
Returns:
IQL node.
"""

async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall:
func = node.func
Expand Down Expand Up @@ -153,3 +148,41 @@ def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str:
converted_text = converted_text[: len(converted_text) - len(keyword)] + keyword.lower()

return converted_text


class IQLFiltersProcessor(IQLProcessor[syntax.Node]):
"""
IQL processor for filters.
"""

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
if isinstance(node, ast.BoolOp):
return await self._parse_bool_op(node)
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.operand))
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)

async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
if isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.values[0]))
if isinstance(node.op, ast.And):
return syntax.And([await self._parse_node(x) for x in node.values])
if isinstance(node.op, ast.Or):
return syntax.Or([await self._parse_node(x) for x in node.values])

raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")


class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]):
"""
IQL processor for aggregation.
"""

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall:
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)
Loading
Loading