Skip to content

Commit

Permalink
revert to prev approach
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 16, 2024
1 parent 63c3adc commit a2169f2
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 36 deletions.
13 changes: 7 additions & 6 deletions benchmarks/sql/bench/views/structured/superhero.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, singleton-comparison, consider-using-in, too-many-ancestors, too-many-public-methods
# flake8: noqa

from typing import Any, Literal
from typing import Literal

from sqlalchemy import ColumnElement, Engine, Select, func, select
from sqlalchemy.ext.declarative import DeferredReflection
Expand Down Expand Up @@ -285,12 +285,13 @@ class SuperheroColourFilterMixin:
Mixin for filtering the view by the superhero colour attributes.
"""

def __init__(self) -> None:
super().__init__()
def __init__(self, *args, **kwargs) -> None:
self.eye_colour = aliased(Colour)
self.hair_colour = aliased(Colour)
self.skin_colour = aliased(Colour)

super().__init__(*args, **kwargs)

@view_filter()
def filter_by_eye_colour(self, eye_colour: str) -> ColumnElement:
"""
Expand Down Expand Up @@ -433,26 +434,26 @@ class SuperheroAggregationMixin:
"""

@view_aggregation()
def count_superheroes(self) -> Any:
def count_superheroes(self, data_source: Select) -> Select:
"""
Counts the number of superheros.
Returns:
The superheros count.
"""
return func.count(Superhero.id).label("count_superheroes")
return data_source.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)


class SuperheroView(
DBInitMixin,
SqlAlchemyBaseView,
SuperheroFilterMixin,
SuperheroAggregationMixin,
SuperheroColourFilterMixin,
AlignmentFilterMixin,
GenderFilterMixin,
PublisherFilterMixin,
RaceFilterMixin,
SqlAlchemyBaseView,
):
"""
View for querying only superheros data. Contains the superhero id, superhero name, full name, height, weight,
Expand Down
10 changes: 1 addition & 9 deletions docs/quickstart/quickstart_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,6 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement:
"""
return Candidate.country == country

@decorators.view_aggregation()
def count_by_column(self, filtered_query: sqlalchemy.Select, column_name: str) -> sqlalchemy.Select: # pylint: disable=W0602, C0116, W9011
select = sqlalchemy.select(getattr(filtered_query.c, column_name),
sqlalchemy.func.count(filtered_query.c.name).label("count")) \
.group_by(getattr(filtered_query.c, column_name))
return select


async def main():
llm = LiteLLM(model_name="gpt-3.5-turbo")
Expand All @@ -69,8 +62,7 @@ async def main():
collection = dbally.create_collection("recruitment", llm)
collection.add(CandidateView, lambda: CandidateView(engine))

result = await collection.ask("Could you find French candidates suitable for a senior data scientist position"
"and count the candidates university-wise?")
result = await collection.ask("Find me French candidates suitable for a senior data scientist position.")

print(f"The generated SQL query is: {result.context.get('sql')}")
print()
Expand Down
16 changes: 8 additions & 8 deletions src/dbally/views/methods_base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import abc
import inspect
import textwrap
from typing import Any, Callable, List, Tuple
from abc import ABC
from typing import Any, Callable, Generic, List, Tuple

from dbally.iql import syntax
from dbally.views import decorators
from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping
from dbally.views.structured import BaseStructuredView
from dbally.views.structured import BaseStructuredView, DataSourceT


class MethodsBaseView(BaseStructuredView, metaclass=abc.ABCMeta):
class MethodsBaseView(Generic[DataSourceT], BaseStructuredView[DataSourceT], ABC):
"""
Base class for views that use view methods to expose filters.
"""

# Method arguments that should be skipped when listing methods
HIDDEN_ARGUMENTS = ["self", "select", "return"]
HIDDEN_ARGUMENTS = ["cls", "self", "return", "data_source"]

@classmethod
def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]:
Expand Down Expand Up @@ -110,7 +110,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any:
return await method(*args)
return method(*args)

async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any:
async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataSourceT:
"""
Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited.
Expand All @@ -123,5 +123,5 @@ async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any:
method, args = self._method_with_args_from_call(func, decorators.view_aggregation)

if inspect.iscoroutinefunction(method):
return await method(*args)
return method(*args)
return await method(self._data_source, *args)
return method(self._data_source, *args)
7 changes: 3 additions & 4 deletions src/dbally/views/pandas_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dbally.views.methods_base import MethodsBaseView


class DataFrameBaseView(MethodsBaseView):
class DataFrameBaseView(MethodsBaseView[pd.DataFrame]):
"""
Base class for views that use Pandas DataFrames to store and filter data.
Expand All @@ -22,8 +22,7 @@ def __init__(self, df: pd.DataFrame) -> None:
Args:
df: Pandas DataFrame with the data to be filtered
"""
super().__init__()
self.df = df
super().__init__(df)

# The mask to be applied to the dataframe to filter the data
self._filter_mask: Optional[pd.Series] = None
Expand Down Expand Up @@ -87,7 +86,7 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
filtered_data = pd.DataFrame.empty

if not dry_run:
filtered_data = self.df
filtered_data = self._data_source
if self._filter_mask is not None:
filtered_data = filtered_data.loc[self._filter_mask]

Expand Down
15 changes: 8 additions & 7 deletions src/dbally/views/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from dbally.views.methods_base import MethodsBaseView


class SqlAlchemyBaseView(MethodsBaseView):
class SqlAlchemyBaseView(MethodsBaseView[sqlalchemy.Select]):
"""
Base class for views that use SQLAlchemy to generate SQL queries.
"""

def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None:
super().__init__()
super().__init__(self.get_select())
self._sqlalchemy_engine = sqlalchemy_engine
self._select = self.get_select()

@abc.abstractmethod
def get_select(self) -> sqlalchemy.Select:
Expand All @@ -34,7 +33,8 @@ async def apply_filters(self, filters: IQLQuery) -> None:
Args:
filters: IQLQuery object representing the filters to apply.
"""
self._select = self._select.where(await self._build_filter_node(filters.root))
# pylint: disable=W0201
self._data_source = self._data_source.where(await self._build_filter_node(filters.root))

async def apply_aggregation(self, aggregation: IQLQuery) -> None:
"""
Expand All @@ -43,7 +43,8 @@ async def apply_aggregation(self, aggregation: IQLQuery) -> None:
Args:
aggregation: IQLQuery object representing the aggregation to apply.
"""
self._select = self._select.with_only_columns(await self.call_aggregation_method(aggregation.root))
# pylint: disable=W0201
self._data_source = await self.call_aggregation_method(aggregation.root)

async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement:
"""
Expand Down Expand Up @@ -86,13 +87,13 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored.
"""
results = []
sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))
sql = str(self._data_source.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))

if not dry_run:
with self._sqlalchemy_engine.connect() as connection:
# The underscore is used by sqlalchemy to avoid conflicts with column names
# pylint: disable=protected-access
rows = connection.execute(self._select).fetchall()
rows = connection.execute(self._data_source).fetchall()
results = [dict(row._mapping) for row in rows]

return ViewExecutionResult(
Expand Down
10 changes: 8 additions & 2 deletions src/dbally/views/structured.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from collections import defaultdict
from typing import Dict, List, Optional
from typing import Any, Dict, Generic, List, Optional, TypeVar

from dbally.audit.event_tracker import EventTracker
from dbally.collection.results import ViewExecutionResult
Expand All @@ -18,13 +18,19 @@
from ..similarity import AbstractSimilarityIndex
from .base import BaseView, IndexLocation

DataSourceT = TypeVar("DataSourceT", bound=Any)

class BaseStructuredView(BaseView):

class BaseStructuredView(Generic[DataSourceT], BaseView):
"""
Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\
to be able to list all available filters, apply them and execute queries.
"""

def __init__(self, data_source: DataSourceT) -> None:
super().__init__()
self._data_source = data_source

def get_iql_generator(self, llm: LLM) -> IQLGenerator:
"""
Returns the IQL generator for the view.
Expand Down

0 comments on commit a2169f2

Please sign in to comment.