diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 977a2fa1..0eb95315 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -110,7 +110,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: return await method(*args) return method(*args) - async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataT: + async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: """ Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index 1cca548c..02c503f8 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -1,8 +1,9 @@ import asyncio from functools import reduce -from typing import Optional +from typing import List, Optional, Union import pandas as pd +from sqlalchemy import Tuple from dbally.collection.results import ViewExecutionResult from dbally.iql import syntax @@ -26,9 +27,9 @@ def __init__(self, df: pd.DataFrame) -> None: df: Pandas DataFrame with the data to be filtered. """ super().__init__(df) - - # The mask to be applied to the dataframe to filter the data self._filter_mask: Optional[pd.Series] = None + self._groupbys: Optional[Union[str, List[str]]] = None + self._aggregations: Optional[List[Tuple[str, str]]] = None async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ @@ -37,10 +38,7 @@ async def apply_filters(self, filters: IQLFiltersQuery) -> None: Args: filters: IQLQuery object representing the filters to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self._filter_mask = await self.build_filter_node(filters.root) - self.data = self.data.loc[self._filter_mask] + self._filter_mask = await self._build_filter_node(filters.root) async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ @@ -49,11 +47,9 @@ async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: Args: aggregation: IQLQuery object representing the aggregation to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = await self.call_aggregation_method(aggregation.root) + self._groupbys, self._aggregations = await self.call_aggregation_method(aggregation.root) - async def build_filter_node(self, node: syntax.Node) -> pd.Series: + async def _build_filter_node(self, node: syntax.Node) -> pd.Series: """ Converts a filter node from the IQLQuery to a Pandas Series representing a boolean mask to be applied to the dataframe. @@ -70,13 +66,13 @@ async def build_filter_node(self, node: syntax.Node) -> pd.Series: if isinstance(node, syntax.FunctionCall): return await self.call_filter_method(node) if isinstance(node, syntax.And): # logical AND - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x & y, children) if isinstance(node, syntax.Or): # logical OR - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x | y, children) if isinstance(node, syntax.Not): - child = await self.build_filter_node(node.child) + child = await self._build_filter_node(node.child) return ~child raise ValueError(f"Unsupported grammar: {node}") @@ -91,11 +87,25 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: Returns: ExecutionResult object with the results and the context information with the binary mask. """ - results = pd.DataFrame.empty if dry_run else self.data + results = pd.DataFrame() + + if not dry_run: + results = self.data + if self._filter_mask is not None: + results = results.loc[self._filter_mask] + + if self._groupbys is not None: + results = results.groupby(self._groupbys) + + if self._aggregations is not None: + results = results.agg(**{"_".join(agg): agg for agg in self._aggregations}) + results = results.reset_index() return ViewExecutionResult( results=results.to_dict(orient="records"), context={ "filter_mask": self._filter_mask, + "groupbys": self._groupbys, + "aggregations": self._aggregations, }, ) diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 5da5db18..47b8ee8f 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -1,5 +1,7 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name +from typing import List, Tuple + import pandas as pd from dbally.iql import IQLFiltersQuery @@ -55,8 +57,12 @@ def filter_name(self, name: str) -> pd.Series: return self.data["name"] == name @view_aggregation() - def mean_age_by_city(self) -> pd.DataFrame: - return self.data.groupby(["city"]).agg({"age": "mean"}).reset_index() + def mean_age_by_city(self) -> Tuple[str, List[Tuple[str, str]]]: + return "city", [("age", "mean")] + + @view_aggregation() + def count_records(self) -> Tuple[str, List[Tuple[str, str]]]: + return None, [("name", "count")] async def test_filter_or() -> None: @@ -72,6 +78,8 @@ async def test_filter_or() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_BERLIN_OR_LONDON assert result.context["filter_mask"].tolist() == [True, False, True, False, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_and() -> None: @@ -87,6 +95,8 @@ async def test_filter_and() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_PARIS_2020 assert result.context["filter_mask"].tolist() == [False, True, False, False, False] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_not() -> None: @@ -102,13 +112,34 @@ async def test_filter_not() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_NOT_PARIS_2020 assert result.context["filter_mask"].tolist() == [True, False, True, True, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None -async def test_aggregtion() -> None: +async def test_aggregation() -> None: """ Test that DataFrame aggregation works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLAggregationQuery.parse( + "count_records()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [ + {"index": "name_count", "name": 5}, + ] + assert result.context["filter_mask"] is None + assert result.context["groupbys"] is None + assert result.context["aggregations"] == [("name", "count")] + + +async def test_aggregtion_with_groupby() -> None: + """ + Test that DataFrame aggregation with groupby works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) query = await IQLAggregationQuery.parse( "mean_age_by_city()", allowed_functions=mock_view.list_aggregations(), @@ -116,11 +147,13 @@ async def test_aggregtion() -> None: await mock_view.apply_aggregation(query) result = mock_view.execute() assert result.results == [ - {"city": "Berlin", "age": 45.0}, - {"city": "London", "age": 32.5}, - {"city": "Paris", "age": 32.5}, + {"city": "Berlin", "age_mean": 45.0}, + {"city": "London", "age_mean": 32.5}, + {"city": "Paris", "age_mean": 32.5}, ] assert result.context["filter_mask"] is None + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [("age", "mean")] async def test_filters_and_aggregtion() -> None: @@ -139,5 +172,7 @@ async def test_filters_and_aggregtion() -> None: ) await mock_view.apply_aggregation(query) result = mock_view.execute() - assert result.results == [{"city": "Paris", "age": 32.5}] + assert result.results == [{"city": "Paris", "age_mean": 32.5}] assert result.context["filter_mask"].tolist() == [False, True, False, True, False] + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [("age", "mean")]