Skip to content

Commit

Permalink
review: aggregations in structured views (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst authored Aug 26, 2024
1 parent d21f4e1 commit 9a8b63e
Show file tree
Hide file tree
Showing 27 changed files with 959 additions and 629 deletions.
26 changes: 25 additions & 1 deletion benchmarks/sql/bench/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from dbally.iql._exceptions import IQLError
from dbally.iql._query import IQLQuery
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.llms.base import LLM
from dbally.llms.clients.exceptions import LLMError
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM

Expand All @@ -16,6 +20,25 @@ class IQL:
source: Optional[str] = None
unsupported: bool = False
valid: bool = True
generated: bool = True

@classmethod
def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL":
"""
Creates an IQL object from the query.
Args:
query: The IQL query or exception.
Returns:
The IQL object.
"""
return cls(
source=query.source if isinstance(query, (IQLQuery, IQLError)) else None,
unsupported=isinstance(query, UnsupportedQueryError),
valid=not isinstance(query, IQLError),
generated=not isinstance(query, LLMError),
)


@dataclass
Expand Down Expand Up @@ -47,6 +70,7 @@ class EvaluationResult:
"""

db_id: str
question_id: str
question: str
reference: ExecutionResult
prediction: ExecutionResult
Expand Down
40 changes: 9 additions & 31 deletions benchmarks/sql/bench/pipelines/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import dbally
from dbally.collection.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.view_selection.llm_view_selector import LLMViewSelector
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exceptions import ViewExecutionError

from ..views import VIEWS_REGISTRY
from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult
Expand Down Expand Up @@ -74,44 +72,23 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
return_natural_response=False,
)
except NoViewFoundError:
prediction = ExecutionResult(
view_name=None,
iql=None,
sql=None,
)
except IQLGenerationError as exc:
prediction = ExecutionResult()
except ViewExecutionError as exc:
prediction = ExecutionResult(
view_name=exc.view_name,
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
filters=IQL.from_query(exc.iql.filters),
aggregation=IQL.from_query(exc.iql.aggregation),
),
sql=None,
)
else:
prediction = ExecutionResult(
view_name=result.view_name,
iql=IQLResult(
filters=IQL(
source=result.context.get("iql"),
unsupported=False,
valid=True,
),
aggregation=IQL(
source=None,
unsupported=False,
valid=True,
),
filters=IQL(source=result.context["iql"]["filters"]),
aggregation=IQL(source=result.context["iql"]["aggregation"]),
),
sql=result.context.get("sql"),
sql=result.context["sql"],
)

reference = ExecutionResult(
Expand All @@ -134,6 +111,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:

return EvaluationResult(
db_id=data["db_id"],
question_id=data["question_id"],
question=data["question"],
reference=reference,
prediction=prediction,
Expand Down
35 changes: 8 additions & 27 deletions benchmarks/sql/bench/pipelines/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from sqlalchemy import create_engine

from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exceptions import ViewExecutionError
from dbally.views.freeform.text2sql.view import BaseText2SQLView
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

Expand Down Expand Up @@ -94,37 +92,20 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
dry_run=True,
n_retries=0,
)
except IQLGenerationError as exc:
except ViewExecutionError as exc:
prediction = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
filters=IQL.from_query(exc.iql.filters),
aggregation=IQL.from_query(exc.iql.aggregation),
),
sql=None,
)
else:
prediction = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=result.context["iql"],
unsupported=False,
valid=True,
),
aggregation=IQL(
source=None,
unsupported=False,
valid=True,
),
filters=IQL(source=result.context["iql"]["filters"]),
aggregation=IQL(source=result.context["iql"]["aggregation"]),
),
sql=result.context["sql"],
)
Expand All @@ -135,12 +116,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
filters=IQL(
source=data["iql_filters"],
unsupported=data["iql_filters_unsupported"],
valid=True,
),
aggregation=IQL(
source=data["iql_aggregation"],
unsupported=data["iql_aggregation_unsupported"],
valid=True,
),
context=data["iql_context"],
),
Expand All @@ -149,6 +128,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:

return EvaluationResult(
db_id=data["db_id"],
question_id=data["question_id"],
question=data["question"],
reference=reference,
prediction=prediction,
Expand Down Expand Up @@ -209,6 +189,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:

return EvaluationResult(
db_id=data["db_id"],
question_id=data["question_id"],
question=data["question"],
reference=reference,
prediction=prediction,
Expand Down
7 changes: 3 additions & 4 deletions benchmarks/sql/bench/views/structured/superhero.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,11 @@ class SuperheroColourFilterMixin:
"""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.eye_colour = aliased(Colour)
self.hair_colour = aliased(Colour)
self.skin_colour = aliased(Colour)

super().__init__(*args, **kwargs)

@view_filter()
def filter_by_eye_colour(self, eye_colour: str) -> ColumnElement:
"""
Expand Down Expand Up @@ -441,19 +440,19 @@ def count_superheroes(self) -> Select:
Returns:
The superheros count.
"""
return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)
return self.select.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)


class SuperheroView(
DBInitMixin,
SqlAlchemyBaseView,
SuperheroFilterMixin,
SuperheroAggregationMixin,
SuperheroColourFilterMixin,
AlignmentFilterMixin,
GenderFilterMixin,
PublisherFilterMixin,
RaceFilterMixin,
SqlAlchemyBaseView,
):
"""
View for querying only superheros data. Contains the superhero id, superhero name, full name, height, weight,
Expand Down
7 changes: 0 additions & 7 deletions src/dbally/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,3 @@ class DbAllyError(Exception):
"""
Base class for all exceptions raised by db-ally.
"""


class UnsupportedAggregationError(DbAllyError):
"""
Error raised when AggregationFormatter is unable to construct a query
with given aggregation.
"""
12 changes: 10 additions & 2 deletions src/dbally/iql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from . import syntax
from ._exceptions import IQLArgumentParsingError, IQLError, IQLUnsupportedSyntaxError
from ._query import IQLQuery
from ._query import IQLAggregationQuery, IQLFiltersQuery, IQLQuery

__all__ = ["IQLQuery", "syntax", "IQLError", "IQLArgumentParsingError", "IQLUnsupportedSyntaxError"]
__all__ = [
"IQLQuery",
"IQLFiltersQuery",
"IQLAggregationQuery",
"syntax",
"IQLError",
"IQLArgumentParsingError",
"IQLUnsupportedSyntaxError",
]
77 changes: 55 additions & 22 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from typing import TYPE_CHECKING, Any, List, Optional, Union
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
from dbally.iql import syntax
Expand All @@ -19,10 +20,12 @@
if TYPE_CHECKING:
from dbally.views.structured import ExposedFunction

RootT = TypeVar("RootT", bound=syntax.Node)

class IQLProcessor:

class IQLProcessor(Generic[RootT], ABC):
"""
Parses IQL string to tree structure.
Base class for IQL processors.
"""

def __init__(
Expand All @@ -32,9 +35,9 @@ def __init__(
self.allowed_functions = {func.name: func for func in allowed_functions}
self._event_tracker = event_tracker or EventTracker()

async def process(self) -> syntax.Node:
async def process(self) -> RootT:
"""
Process IQL string to root IQL.Node.
Process IQL string to IQL root node.
Returns:
IQL node which is root of the tree representing IQL query.
Expand All @@ -60,25 +63,17 @@ async def process(self) -> syntax.Node:

return await self._parse_node(ast_tree.body[0].value)

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
if isinstance(node, ast.BoolOp):
return await self._parse_bool_op(node)
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.operand))
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)
@abstractmethod
async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT:
"""
Parses AST node to IQL node.
async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
if isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.values[0]))
if isinstance(node.op, ast.And):
return syntax.And([await self._parse_node(x) for x in node.values])
if isinstance(node.op, ast.Or):
return syntax.Or([await self._parse_node(x) for x in node.values])
Args:
node: AST node to parse.
raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")
Returns:
IQL node.
"""

async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall:
func = node.func
Expand Down Expand Up @@ -153,3 +148,41 @@ def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str:
converted_text = converted_text[: len(converted_text) - len(keyword)] + keyword.lower()

return converted_text


class IQLFiltersProcessor(IQLProcessor[syntax.Node]):
"""
IQL processor for filters.
"""

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
if isinstance(node, ast.BoolOp):
return await self._parse_bool_op(node)
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.operand))
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)

async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
if isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.values[0]))
if isinstance(node.op, ast.And):
return syntax.And([await self._parse_node(x) for x in node.values])
if isinstance(node.op, ast.Or):
return syntax.Or([await self._parse_node(x) for x in node.values])

raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")


class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]):
"""
IQL processor for aggregation.
"""

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall:
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)
Loading

0 comments on commit 9a8b63e

Please sign in to comment.