Skip to content

Commit

Permalink
add iql aggregation parser
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 25, 2024
1 parent ff55482 commit db82bba
Show file tree
Hide file tree
Showing 18 changed files with 534 additions and 361 deletions.
23 changes: 22 additions & 1 deletion benchmarks/sql/bench/pipelines/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from dbally.iql._exceptions import IQLError
from dbally.iql._query import IQLQuery
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.llms.base import LLM
from dbally.llms.litellm import LiteLLM
from dbally.llms.local import LocalLLM
Expand All @@ -17,6 +20,24 @@ class IQL:
unsupported: bool = False
valid: bool = True

@classmethod
def from_generator_state(cls, state: Optional[Union[IQLQuery, Exception]]) -> "IQL":
"""
Creates an IQL object from a view execution exception.
Args:
state: The IQL generator state.
Returns:
The IQL object.
"""
source = state.source if isinstance(state, IQLError) else str(state) if isinstance(state, IQLQuery) else None
return cls(
source=source,
unsupported=isinstance(state, UnsupportedQueryError),
valid=not isinstance(state, IQLError),
)


@dataclass
class IQLResult:
Expand Down
39 changes: 8 additions & 31 deletions benchmarks/sql/bench/pipelines/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import dbally
from dbally.collection.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.view_selection.llm_view_selector import LLMViewSelector
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exceptions import ViewExecutionError

from ..views import VIEWS_REGISTRY
from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult
Expand Down Expand Up @@ -74,44 +72,23 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
return_natural_response=False,
)
except NoViewFoundError:
prediction = ExecutionResult(
view_name=None,
iql=None,
sql=None,
)
except IQLGenerationError as exc:
prediction = ExecutionResult()
except ViewExecutionError as exc:
prediction = ExecutionResult(
view_name=exc.view_name,
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
filters=IQL.from_generator_state(exc.iql.filters),
aggregation=IQL.from_generator_state(exc.iql.aggregation),
),
sql=None,
)
else:
prediction = ExecutionResult(
view_name=result.view_name,
iql=IQLResult(
filters=IQL(
source=result.context.get("iql"),
unsupported=False,
valid=True,
),
aggregation=IQL(
source=None,
unsupported=False,
valid=True,
),
filters=IQL(source=result.context["iql"]["filters"]),
aggregation=IQL(source=result.context["iql"]["aggregation"]),
),
sql=result.context.get("sql"),
sql=result.context["sql"],
)

reference = ExecutionResult(
Expand Down
19 changes: 4 additions & 15 deletions benchmarks/sql/bench/pipelines/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

from sqlalchemy import create_engine

from dbally.iql._exceptions import IQLError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exceptions import ViewExecutionError
from dbally.views.freeform.text2sql.view import BaseText2SQLView
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

Expand Down Expand Up @@ -94,22 +92,13 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult:
dry_run=True,
n_retries=0,
)
except IQLGenerationError as exc:
except ViewExecutionError as exc:
prediction = ExecutionResult(
view_name=data["view_name"],
iql=IQLResult(
filters=IQL(
source=exc.filters,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
aggregation=IQL(
source=exc.aggregation,
unsupported=isinstance(exc.__cause__, UnsupportedQueryError),
valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)),
),
filters=IQL.from_generator_state(exc.iql.filters),
aggregation=IQL.from_generator_state(exc.iql.aggregation),
),
sql=None,
)
else:
prediction = ExecutionResult(
Expand Down
7 changes: 0 additions & 7 deletions src/dbally/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,3 @@ class DbAllyError(Exception):
"""
Base class for all exceptions raised by db-ally.
"""


class UnsupportedAggregationError(DbAllyError):
"""
Error raised when AggregationFormatter is unable to construct a query
with given aggregation.
"""
12 changes: 10 additions & 2 deletions src/dbally/iql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from . import syntax
from ._exceptions import IQLArgumentParsingError, IQLError, IQLUnsupportedSyntaxError
from ._query import IQLQuery
from ._query import IQLAggregationQuery, IQLFiltersQuery, IQLQuery

__all__ = ["IQLQuery", "syntax", "IQLError", "IQLArgumentParsingError", "IQLUnsupportedSyntaxError"]
__all__ = [
"IQLQuery",
"IQLFiltersQuery",
"IQLAggregationQuery",
"syntax",
"IQLError",
"IQLArgumentParsingError",
"IQLUnsupportedSyntaxError",
]
77 changes: 55 additions & 22 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from typing import TYPE_CHECKING, Any, List, Optional, Union
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
from dbally.iql import syntax
Expand All @@ -19,10 +20,12 @@
if TYPE_CHECKING:
from dbally.views.structured import ExposedFunction

RootT = TypeVar("RootT", bound=syntax.Node)

class IQLProcessor:

class IQLProcessor(Generic[RootT], ABC):
"""
Parses IQL string to tree structure.
Base class for IQL processors.
"""

def __init__(
Expand All @@ -32,9 +35,9 @@ def __init__(
self.allowed_functions = {func.name: func for func in allowed_functions}
self._event_tracker = event_tracker or EventTracker()

async def process(self) -> syntax.Node:
async def process(self) -> RootT:
"""
Process IQL string to root IQL.Node.
Process IQL string to IQL root node.
Returns:
IQL node which is root of the tree representing IQL query.
Expand All @@ -60,25 +63,17 @@ async def process(self) -> syntax.Node:

return await self._parse_node(ast_tree.body[0].value)

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
if isinstance(node, ast.BoolOp):
return await self._parse_bool_op(node)
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.operand))
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)
@abstractmethod
async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT:
"""
Parses AST node to IQL node.
async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
if isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.values[0]))
if isinstance(node.op, ast.And):
return syntax.And([await self._parse_node(x) for x in node.values])
if isinstance(node.op, ast.Or):
return syntax.Or([await self._parse_node(x) for x in node.values])
Args:
node: AST node to parse.
raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")
Returns:
IQL node.
"""

async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall:
func = node.func
Expand Down Expand Up @@ -153,3 +148,41 @@ def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str:
converted_text = converted_text[: len(converted_text) - len(keyword)] + keyword.lower()

return converted_text


class IQLFiltersProcessor(IQLProcessor[syntax.Node]):
"""
IQL processor for filters.
"""

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node:
if isinstance(node, ast.BoolOp):
return await self._parse_bool_op(node)
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.operand))
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)

async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp:
if isinstance(node.op, ast.Not):
return syntax.Not(await self._parse_node(node.values[0]))
if isinstance(node.op, ast.And):
return syntax.And([await self._parse_node(x) for x in node.values])
if isinstance(node.op, ast.Or):
return syntax.Or([await self._parse_node(x) for x in node.values])

raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp")


class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]):
"""
IQL processor for aggregation.
"""

async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall:
if isinstance(node, ast.Call):
return await self._parse_call(node)

raise IQLUnsupportedSyntaxError(node, self.source)
37 changes: 28 additions & 9 deletions src/dbally/iql/_query.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,37 @@
from typing import TYPE_CHECKING, List, Optional
from abc import ABC
from typing import TYPE_CHECKING, Generic, List, Optional, Type

from ..audit.event_tracker import EventTracker
from . import syntax
from ._processor import IQLProcessor
from ._processor import IQLAggregationProcessor, IQLFiltersProcessor, IQLProcessor, RootT

if TYPE_CHECKING:
from dbally.views.structured import ExposedFunction


class IQLQuery:
class IQLQuery(Generic[RootT], ABC):
"""
IQLQuery container. It stores IQL as a syntax tree defined in `IQL` class.
"""

root: syntax.Node
root: RootT
source: str
_processor: Type[IQLProcessor[RootT]]

def __init__(self, root: syntax.Node, source: str) -> None:
def __init__(self, root: RootT, source: str) -> None:
self.root = root
self._source = source
self.source = source

def __str__(self) -> str:
return self._source
return self.source

@classmethod
async def parse(
cls,
source: str,
allowed_functions: List["ExposedFunction"],
event_tracker: Optional[EventTracker] = None,
) -> "IQLQuery":
) -> "IQLQuery[RootT]":
"""
Parse IQL string to IQLQuery object.
Expand All @@ -43,5 +46,21 @@ async def parse(
Raises:
IQLError: If parsing fails.
"""
root = await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process()
root = await cls._processor(source, allowed_functions, event_tracker=event_tracker).process()
return cls(root=root, source=source)


class IQLFiltersQuery(IQLQuery[syntax.Node]):
"""
IQL filters query container.
"""

_processor: Type[IQLFiltersProcessor] = IQLFiltersProcessor


class IQLAggregationQuery(IQLQuery[syntax.FunctionCall]):
"""
IQL aggregation query container.
"""

_processor: Type[IQLAggregationProcessor] = IQLAggregationProcessor
Loading

0 comments on commit db82bba

Please sign in to comment.