Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 16, 2024
1 parent a2169f2 commit 013cb69
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 39 deletions.
4 changes: 2 additions & 2 deletions benchmarks/sql/bench/views/structured/superhero.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,14 @@ class SuperheroAggregationMixin:
"""

@view_aggregation()
def count_superheroes(self, data_source: Select) -> Select:
def count_superheroes(self) -> Select:
"""
Counts the number of superheros.
Returns:
The superheros count.
"""
return data_source.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)
return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id)


class SuperheroView(
Expand Down
12 changes: 6 additions & 6 deletions src/dbally/views/methods_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
from dbally.iql import syntax
from dbally.views import decorators
from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping
from dbally.views.structured import BaseStructuredView, DataSourceT
from dbally.views.structured import BaseStructuredView, DataT


class MethodsBaseView(Generic[DataSourceT], BaseStructuredView[DataSourceT], ABC):
class MethodsBaseView(Generic[DataT], BaseStructuredView[DataT], ABC):
"""
Base class for views that use view methods to expose filters.
"""

# Method arguments that should be skipped when listing methods
HIDDEN_ARGUMENTS = ["cls", "self", "return", "data_source"]
HIDDEN_ARGUMENTS = ["cls", "self", "return"]

@classmethod
def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]:
Expand Down Expand Up @@ -110,7 +110,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any:
return await method(*args)
return method(*args)

async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataSourceT:
async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataT:
"""
Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited.
Expand All @@ -123,5 +123,5 @@ async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataSource
method, args = self._method_with_args_from_call(func, decorators.view_aggregation)

if inspect.iscoroutinefunction(method):
return await method(self._data_source, *args)
return method(self._data_source, *args)
return await method(*args)
return method(*args)
34 changes: 18 additions & 16 deletions src/dbally/views/pandas_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ class DataFrameBaseView(MethodsBaseView[pd.DataFrame]):

def __init__(self, df: pd.DataFrame) -> None:
"""
Creates a new instance of the DataFrame view.
Args:
df: Pandas DataFrame with the data to be filtered
df: Pandas DataFrame with the data to be filtered.
"""
super().__init__(df)

Expand All @@ -32,32 +34,37 @@ async def apply_filters(self, filters: IQLQuery) -> None:
Applies the chosen filters to the view.
Args:
filters: IQLQuery object representing the filters to apply
filters: IQLQuery object representing the filters to apply.
"""
# data is defined in the parent class
# pylint: disable=attribute-defined-outside-init
self._filter_mask = await self.build_filter_node(filters.root)
self.data = self.data.loc[self._filter_mask]

async def apply_aggregation(self, aggregation: IQLQuery) -> None:
"""
Applies the aggregation of choice to the view.
Args:
aggregation: IQLQuery object representing the aggregation to apply
aggregation: IQLQuery object representing the aggregation to apply.
"""
# TODO - to be covered in a separate ticket.
# data is defined in the parent class
# pylint: disable=attribute-defined-outside-init
self.data = await self.call_aggregation_method(aggregation.root)

async def build_filter_node(self, node: syntax.Node) -> pd.Series:
"""
Converts a filter node from the IQLQuery to a Pandas Series representing
a boolean mask to be applied to the dataframe.
Args:
node: IQLQuery node representing the filter or logical operator
node: IQLQuery node representing the filter or logical operator.
Returns:
A boolean mask that can be used to filter the original DataFrame
A boolean mask that can be used to filter the original DataFrame.
Raises:
ValueError: If the node type is not supported
ValueError: If the node type is not supported.
"""
if isinstance(node, syntax.FunctionCall):
return await self.call_filter_method(node)
Expand All @@ -78,20 +85,15 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
Args:
dry_run: If True, the method will only add `context` field to the `ExecutionResult` with the\
mask that would be applied to the dataframe
mask that would be applied to the dataframe.
Returns:
ExecutionResult object with the results and the context information with the binary mask
ExecutionResult object with the results and the context information with the binary mask.
"""
filtered_data = pd.DataFrame.empty

if not dry_run:
filtered_data = self._data_source
if self._filter_mask is not None:
filtered_data = filtered_data.loc[self._filter_mask]
results = pd.DataFrame.empty if dry_run else self.data

return ViewExecutionResult(
results=filtered_data.to_dict(orient="records"),
results=results.to_dict(orient="records"),
context={
"filter_mask": self._filter_mask,
},
Expand Down
20 changes: 14 additions & 6 deletions src/dbally/views/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ class SqlAlchemyBaseView(MethodsBaseView[sqlalchemy.Select]):
"""

def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None:
"""
Creates a new instance of the SQL view.
Args:
sqlalchemy_engine: SQLAlchemy engine to use for executing the queries.
"""
super().__init__(self.get_select())
self._sqlalchemy_engine = sqlalchemy_engine

Expand All @@ -33,8 +39,9 @@ async def apply_filters(self, filters: IQLQuery) -> None:
Args:
filters: IQLQuery object representing the filters to apply.
"""
# pylint: disable=W0201
self._data_source = self._data_source.where(await self._build_filter_node(filters.root))
# data is defined in the parent class
# pylint: disable=attribute-defined-outside-init
self.data = self.data.where(await self._build_filter_node(filters.root))

async def apply_aggregation(self, aggregation: IQLQuery) -> None:
"""
Expand All @@ -43,8 +50,9 @@ async def apply_aggregation(self, aggregation: IQLQuery) -> None:
Args:
aggregation: IQLQuery object representing the aggregation to apply.
"""
# pylint: disable=W0201
self._data_source = await self.call_aggregation_method(aggregation.root)
# data is defined in the parent class
# pylint: disable=attribute-defined-outside-init
self.data = await self.call_aggregation_method(aggregation.root)

async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement:
"""
Expand Down Expand Up @@ -87,13 +95,13 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored.
"""
results = []
sql = str(self._data_source.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))
sql = str(self.data.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))

if not dry_run:
with self._sqlalchemy_engine.connect() as connection:
rows = connection.execute(self.data).fetchall()
# The underscore is used by sqlalchemy to avoid conflicts with column names
# pylint: disable=protected-access
rows = connection.execute(self._data_source).fetchall()
results = [dict(row._mapping) for row in rows]

return ViewExecutionResult(
Expand Down
8 changes: 4 additions & 4 deletions src/dbally/views/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
from ..similarity import AbstractSimilarityIndex
from .base import BaseView, IndexLocation

DataSourceT = TypeVar("DataSourceT", bound=Any)
DataT = TypeVar("DataT", bound=Any)


class BaseStructuredView(Generic[DataSourceT], BaseView):
class BaseStructuredView(Generic[DataT], BaseView):
"""
Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\
to be able to list all available filters, apply them and execute queries.
"""

def __init__(self, data_source: DataSourceT) -> None:
def __init__(self, data: DataT) -> None:
super().__init__()
self._data_source = data_source
self.data = data

def get_iql_generator(self, llm: LLM) -> IQLGenerator:
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class MockViewBase(BaseStructuredView):
Mock view base class
"""

def __init__(self) -> None:
super().__init__(None)

def list_filters(self) -> List[ExposedFunction]:
return []

Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@


class MockView(MethodsBaseView):
def __init__(self) -> None:
super().__init__(None)

def get_select(self) -> sqlalchemy.Select:
...

Expand Down
3 changes: 3 additions & 0 deletions tests/unit/views/test_methods_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class MockMethodsBase(MethodsBaseView):
Mock class for testing the MethodsBaseView
"""

def __init__(self) -> None:
super().__init__(None)

@view_filter()
def method_foo(self, idx: int) -> None:
"""
Expand Down
53 changes: 48 additions & 5 deletions tests/unit/views/test_pandas_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd

from dbally.iql import IQLQuery
from dbally.views.decorators import view_filter
from dbally.views.decorators import view_aggregation, view_filter
from dbally.views.pandas_base import DataFrameBaseView

MOCK_DATA = [
Expand Down Expand Up @@ -39,19 +39,23 @@ class MockDataFrameView(DataFrameBaseView):

@view_filter()
def filter_city(self, city: str) -> pd.Series:
return self.df["city"] == city
return self.data["city"] == city

@view_filter()
def filter_year(self, year: int) -> pd.Series:
return self.df["year"] == year
return self.data["year"] == year

@view_filter()
def filter_age(self, age: int) -> pd.Series:
return self.df["age"] == age
return self.data["age"] == age

@view_filter()
def filter_name(self, name: str) -> pd.Series:
return self.df["name"] == name
return self.data["name"] == name

@view_aggregation()
def mean_age_by_city(self) -> pd.DataFrame:
return self.data.groupby(["city"]).agg({"age": "mean"}).reset_index()


async def test_filter_or() -> None:
Expand Down Expand Up @@ -97,3 +101,42 @@ async def test_filter_not() -> None:
result = mock_view.execute()
assert result.results == MOCK_DATA_NOT_PARIS_2020
assert result.context["filter_mask"].tolist() == [True, False, True, True, True]


async def test_aggregtion() -> None:
"""
Test that DataFrame aggregation works correctly
"""
mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA))
query = await IQLQuery.parse(
"mean_age_by_city()",
allowed_functions=mock_view.list_aggregations(),
)
await mock_view.apply_aggregation(query)
result = mock_view.execute()
assert result.results == [
{"city": "Berlin", "age": 45.0},
{"city": "London", "age": 32.5},
{"city": "Paris", "age": 32.5},
]
assert result.context["filter_mask"] is None


async def test_filters_and_aggregtion() -> None:
"""
Test that DataFrame filtering and aggregation works correctly
"""
mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA))
query = await IQLQuery.parse(
"filter_city('Paris')",
allowed_functions=mock_view.list_filters(),
)
await mock_view.apply_filters(query)
query = await IQLQuery.parse(
"mean_age_by_city()",
allowed_functions=mock_view.list_aggregations(),
)
await mock_view.apply_aggregation(query)
result = mock_view.execute()
assert result.results == [{"city": "Paris", "age": 32.5}]
assert result.context["filter_mask"].tolist() == [False, True, False, True, False]

0 comments on commit 013cb69

Please sign in to comment.