diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index e5b8f890..56932947 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -434,14 +434,14 @@ class SuperheroAggregationMixin: """ @view_aggregation() - def count_superheroes(self, data_source: Select) -> Select: + def count_superheroes(self) -> Select: """ Counts the number of superheros. Returns: The superheros count. """ - return data_source.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) + return self.data.with_only_columns(func.count(Superhero.id).label("count_superheroes")).group_by(Superhero.id) class SuperheroView( diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 1be1aebc..f90d6d7d 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -6,16 +6,16 @@ from dbally.iql import syntax from dbally.views import decorators from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from dbally.views.structured import BaseStructuredView, DataSourceT +from dbally.views.structured import BaseStructuredView, DataT -class MethodsBaseView(Generic[DataSourceT], BaseStructuredView[DataSourceT], ABC): +class MethodsBaseView(Generic[DataT], BaseStructuredView[DataT], ABC): """ Base class for views that use view methods to expose filters. """ # Method arguments that should be skipped when listing methods - HIDDEN_ARGUMENTS = ["cls", "self", "return", "data_source"] + HIDDEN_ARGUMENTS = ["cls", "self", "return"] @classmethod def list_methods_by_decorator(cls, decorator: Callable) -> List[ExposedFunction]: @@ -110,7 +110,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: return await method(*args) return method(*args) - async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataSourceT: + async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataT: """ Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. @@ -123,5 +123,5 @@ async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataSource method, args = self._method_with_args_from_call(func, decorators.view_aggregation) if inspect.iscoroutinefunction(method): - return await method(self._data_source, *args) - return method(self._data_source, *args) + return await method(*args) + return method(*args) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index 2d628303..5f7bc8ce 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -19,8 +19,10 @@ class DataFrameBaseView(MethodsBaseView[pd.DataFrame]): def __init__(self, df: pd.DataFrame) -> None: """ + Creates a new instance of the DataFrame view. + Args: - df: Pandas DataFrame with the data to be filtered + df: Pandas DataFrame with the data to be filtered. """ super().__init__(df) @@ -32,18 +34,23 @@ async def apply_filters(self, filters: IQLQuery) -> None: Applies the chosen filters to the view. Args: - filters: IQLQuery object representing the filters to apply + filters: IQLQuery object representing the filters to apply. """ + # data is defined in the parent class + # pylint: disable=attribute-defined-outside-init self._filter_mask = await self.build_filter_node(filters.root) + self.data = self.data.loc[self._filter_mask] async def apply_aggregation(self, aggregation: IQLQuery) -> None: """ Applies the aggregation of choice to the view. Args: - aggregation: IQLQuery object representing the aggregation to apply + aggregation: IQLQuery object representing the aggregation to apply. """ - # TODO - to be covered in a separate ticket. + # data is defined in the parent class + # pylint: disable=attribute-defined-outside-init + self.data = await self.call_aggregation_method(aggregation.root) async def build_filter_node(self, node: syntax.Node) -> pd.Series: """ @@ -51,13 +58,13 @@ async def build_filter_node(self, node: syntax.Node) -> pd.Series: a boolean mask to be applied to the dataframe. Args: - node: IQLQuery node representing the filter or logical operator + node: IQLQuery node representing the filter or logical operator. Returns: - A boolean mask that can be used to filter the original DataFrame + A boolean mask that can be used to filter the original DataFrame. Raises: - ValueError: If the node type is not supported + ValueError: If the node type is not supported. """ if isinstance(node, syntax.FunctionCall): return await self.call_filter_method(node) @@ -78,20 +85,15 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: Args: dry_run: If True, the method will only add `context` field to the `ExecutionResult` with the\ - mask that would be applied to the dataframe + mask that would be applied to the dataframe. Returns: - ExecutionResult object with the results and the context information with the binary mask + ExecutionResult object with the results and the context information with the binary mask. """ - filtered_data = pd.DataFrame.empty - - if not dry_run: - filtered_data = self._data_source - if self._filter_mask is not None: - filtered_data = filtered_data.loc[self._filter_mask] + results = pd.DataFrame.empty if dry_run else self.data return ViewExecutionResult( - results=filtered_data.to_dict(orient="records"), + results=results.to_dict(orient="records"), context={ "filter_mask": self._filter_mask, }, diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index c8cdcad1..4863aa6f 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -14,6 +14,12 @@ class SqlAlchemyBaseView(MethodsBaseView[sqlalchemy.Select]): """ def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None: + """ + Creates a new instance of the SQL view. + + Args: + sqlalchemy_engine: SQLAlchemy engine to use for executing the queries. + """ super().__init__(self.get_select()) self._sqlalchemy_engine = sqlalchemy_engine @@ -33,8 +39,9 @@ async def apply_filters(self, filters: IQLQuery) -> None: Args: filters: IQLQuery object representing the filters to apply. """ - # pylint: disable=W0201 - self._data_source = self._data_source.where(await self._build_filter_node(filters.root)) + # data is defined in the parent class + # pylint: disable=attribute-defined-outside-init + self.data = self.data.where(await self._build_filter_node(filters.root)) async def apply_aggregation(self, aggregation: IQLQuery) -> None: """ @@ -43,8 +50,9 @@ async def apply_aggregation(self, aggregation: IQLQuery) -> None: Args: aggregation: IQLQuery object representing the aggregation to apply. """ - # pylint: disable=W0201 - self._data_source = await self.call_aggregation_method(aggregation.root) + # data is defined in the parent class + # pylint: disable=attribute-defined-outside-init + self.data = await self.call_aggregation_method(aggregation.root) async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement: """ @@ -87,13 +95,13 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored. """ results = [] - sql = str(self._data_source.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) + sql = str(self.data.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) if not dry_run: with self._sqlalchemy_engine.connect() as connection: + rows = connection.execute(self.data).fetchall() # The underscore is used by sqlalchemy to avoid conflicts with column names # pylint: disable=protected-access - rows = connection.execute(self._data_source).fetchall() results = [dict(row._mapping) for row in rows] return ViewExecutionResult( diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c695b684..c61ee0dd 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -18,18 +18,18 @@ from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation -DataSourceT = TypeVar("DataSourceT", bound=Any) +DataT = TypeVar("DataT", bound=Any) -class BaseStructuredView(Generic[DataSourceT], BaseView): +class BaseStructuredView(Generic[DataT], BaseView): """ Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\ to be able to list all available filters, apply them and execute queries. """ - def __init__(self, data_source: DataSourceT) -> None: + def __init__(self, data: DataT) -> None: super().__init__() - self._data_source = data_source + self.data = data def get_iql_generator(self, llm: LLM) -> IQLGenerator: """ diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 0d66df3b..29a5cc83 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -24,6 +24,9 @@ class MockViewBase(BaseStructuredView): Mock view base class """ + def __init__(self) -> None: + super().__init__(None) + def list_filters(self) -> List[ExposedFunction]: return [] diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 401c09d1..b95fe585 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -20,6 +20,9 @@ class MockView(MethodsBaseView): + def __init__(self) -> None: + super().__init__(None) + def get_select(self) -> sqlalchemy.Select: ... diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 841af1a7..e8a2bb56 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -15,6 +15,9 @@ class MockMethodsBase(MethodsBaseView): Mock class for testing the MethodsBaseView """ + def __init__(self) -> None: + super().__init__(None) + @view_filter() def method_foo(self, idx: int) -> None: """ diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 51eea791..52a8f405 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -3,7 +3,7 @@ import pandas as pd from dbally.iql import IQLQuery -from dbally.views.decorators import view_filter +from dbally.views.decorators import view_aggregation, view_filter from dbally.views.pandas_base import DataFrameBaseView MOCK_DATA = [ @@ -39,19 +39,23 @@ class MockDataFrameView(DataFrameBaseView): @view_filter() def filter_city(self, city: str) -> pd.Series: - return self.df["city"] == city + return self.data["city"] == city @view_filter() def filter_year(self, year: int) -> pd.Series: - return self.df["year"] == year + return self.data["year"] == year @view_filter() def filter_age(self, age: int) -> pd.Series: - return self.df["age"] == age + return self.data["age"] == age @view_filter() def filter_name(self, name: str) -> pd.Series: - return self.df["name"] == name + return self.data["name"] == name + + @view_aggregation() + def mean_age_by_city(self) -> pd.DataFrame: + return self.data.groupby(["city"]).agg({"age": "mean"}).reset_index() async def test_filter_or() -> None: @@ -97,3 +101,42 @@ async def test_filter_not() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_NOT_PARIS_2020 assert result.context["filter_mask"].tolist() == [True, False, True, True, True] + + +async def test_aggregtion() -> None: + """ + Test that DataFrame aggregation works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLQuery.parse( + "mean_age_by_city()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [ + {"city": "Berlin", "age": 45.0}, + {"city": "London", "age": 32.5}, + {"city": "Paris", "age": 32.5}, + ] + assert result.context["filter_mask"] is None + + +async def test_filters_and_aggregtion() -> None: + """ + Test that DataFrame filtering and aggregation works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLQuery.parse( + "filter_city('Paris')", + allowed_functions=mock_view.list_filters(), + ) + await mock_view.apply_filters(query) + query = await IQLQuery.parse( + "mean_age_by_city()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [{"city": "Paris", "age": 32.5}] + assert result.context["filter_mask"].tolist() == [False, True, False, True, False]