Skip to content

Commit

Permalink
add lazy pandas aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 25, 2024
1 parent db82bba commit 5f79a67
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/dbally/views/methods_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any:
return await method(*args)
return method(*args)

async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataT:
async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any:
"""
Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited.
Expand Down
40 changes: 25 additions & 15 deletions src/dbally/views/pandas_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
from functools import reduce
from typing import Optional
from typing import List, Optional, Union

import pandas as pd
from sqlalchemy import Tuple

from dbally.collection.results import ViewExecutionResult
from dbally.iql import syntax
Expand All @@ -26,9 +27,9 @@ def __init__(self, df: pd.DataFrame) -> None:
df: Pandas DataFrame with the data to be filtered.
"""
super().__init__(df)

# The mask to be applied to the dataframe to filter the data
self._filter_mask: Optional[pd.Series] = None
self._groupbys: Optional[Union[str, List[str]]] = None
self._aggregations: Optional[List[Tuple[str, str]]] = None

async def apply_filters(self, filters: IQLFiltersQuery) -> None:
"""
Expand All @@ -37,10 +38,7 @@ async def apply_filters(self, filters: IQLFiltersQuery) -> None:
Args:
filters: IQLQuery object representing the filters to apply.
"""
# data is defined in the parent class
# pylint: disable=attribute-defined-outside-init
self._filter_mask = await self.build_filter_node(filters.root)
self.data = self.data.loc[self._filter_mask]
self._filter_mask = await self._build_filter_node(filters.root)

async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None:
"""
Expand All @@ -49,11 +47,9 @@ async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None:
Args:
aggregation: IQLQuery object representing the aggregation to apply.
"""
# data is defined in the parent class
# pylint: disable=attribute-defined-outside-init
self.data = await self.call_aggregation_method(aggregation.root)
self._groupbys, self._aggregations = await self.call_aggregation_method(aggregation.root)

async def build_filter_node(self, node: syntax.Node) -> pd.Series:
async def _build_filter_node(self, node: syntax.Node) -> pd.Series:
"""
Converts a filter node from the IQLQuery to a Pandas Series representing
a boolean mask to be applied to the dataframe.
Expand All @@ -70,13 +66,13 @@ async def build_filter_node(self, node: syntax.Node) -> pd.Series:
if isinstance(node, syntax.FunctionCall):
return await self.call_filter_method(node)
if isinstance(node, syntax.And): # logical AND
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children])
return reduce(lambda x, y: x & y, children)
if isinstance(node, syntax.Or): # logical OR
children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children])
children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children])
return reduce(lambda x, y: x | y, children)
if isinstance(node, syntax.Not):
child = await self.build_filter_node(node.child)
child = await self._build_filter_node(node.child)
return ~child
raise ValueError(f"Unsupported grammar: {node}")

Expand All @@ -91,11 +87,25 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
Returns:
ExecutionResult object with the results and the context information with the binary mask.
"""
results = pd.DataFrame.empty if dry_run else self.data
results = pd.DataFrame()

if not dry_run:
results = self.data
if self._filter_mask is not None:
results = results.loc[self._filter_mask]

if self._groupbys is not None:
results = results.groupby(self._groupbys)

if self._aggregations is not None:
results = results.agg(**{"_".join(agg): agg for agg in self._aggregations})
results = results.reset_index()

return ViewExecutionResult(
results=results.to_dict(orient="records"),
context={
"filter_mask": self._filter_mask,
"groupbys": self._groupbys,
"aggregations": self._aggregations,
},
)
49 changes: 42 additions & 7 deletions tests/unit/views/test_pandas_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name

from typing import List, Tuple

import pandas as pd

from dbally.iql import IQLFiltersQuery
Expand Down Expand Up @@ -55,8 +57,12 @@ def filter_name(self, name: str) -> pd.Series:
return self.data["name"] == name

@view_aggregation()
def mean_age_by_city(self) -> pd.DataFrame:
return self.data.groupby(["city"]).agg({"age": "mean"}).reset_index()
def mean_age_by_city(self) -> Tuple[str, List[Tuple[str, str]]]:
return "city", [("age", "mean")]

@view_aggregation()
def count_records(self) -> Tuple[str, List[Tuple[str, str]]]:
return None, [("name", "count")]


async def test_filter_or() -> None:
Expand All @@ -72,6 +78,8 @@ async def test_filter_or() -> None:
result = mock_view.execute()
assert result.results == MOCK_DATA_BERLIN_OR_LONDON
assert result.context["filter_mask"].tolist() == [True, False, True, False, True]
assert result.context["groupbys"] is None
assert result.context["aggregations"] is None


async def test_filter_and() -> None:
Expand All @@ -87,6 +95,8 @@ async def test_filter_and() -> None:
result = mock_view.execute()
assert result.results == MOCK_DATA_PARIS_2020
assert result.context["filter_mask"].tolist() == [False, True, False, False, False]
assert result.context["groupbys"] is None
assert result.context["aggregations"] is None


async def test_filter_not() -> None:
Expand All @@ -102,25 +112,48 @@ async def test_filter_not() -> None:
result = mock_view.execute()
assert result.results == MOCK_DATA_NOT_PARIS_2020
assert result.context["filter_mask"].tolist() == [True, False, True, True, True]
assert result.context["groupbys"] is None
assert result.context["aggregations"] is None


async def test_aggregtion() -> None:
async def test_aggregation() -> None:
"""
Test that DataFrame aggregation works correctly
"""
mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA))
query = await IQLAggregationQuery.parse(
"count_records()",
allowed_functions=mock_view.list_aggregations(),
)
await mock_view.apply_aggregation(query)
result = mock_view.execute()
assert result.results == [
{"index": "name_count", "name": 5},
]
assert result.context["filter_mask"] is None
assert result.context["groupbys"] is None
assert result.context["aggregations"] == [("name", "count")]


async def test_aggregtion_with_groupby() -> None:
"""
Test that DataFrame aggregation with groupby works correctly
"""
mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA))
query = await IQLAggregationQuery.parse(
"mean_age_by_city()",
allowed_functions=mock_view.list_aggregations(),
)
await mock_view.apply_aggregation(query)
result = mock_view.execute()
assert result.results == [
{"city": "Berlin", "age": 45.0},
{"city": "London", "age": 32.5},
{"city": "Paris", "age": 32.5},
{"city": "Berlin", "age_mean": 45.0},
{"city": "London", "age_mean": 32.5},
{"city": "Paris", "age_mean": 32.5},
]
assert result.context["filter_mask"] is None
assert result.context["groupbys"] == "city"
assert result.context["aggregations"] == [("age", "mean")]


async def test_filters_and_aggregtion() -> None:
Expand All @@ -139,5 +172,7 @@ async def test_filters_and_aggregtion() -> None:
)
await mock_view.apply_aggregation(query)
result = mock_view.execute()
assert result.results == [{"city": "Paris", "age": 32.5}]
assert result.results == [{"city": "Paris", "age_mean": 32.5}]
assert result.context["filter_mask"].tolist() == [False, True, False, True, False]
assert result.context["groupbys"] == "city"
assert result.context["aggregations"] == [("age", "mean")]

0 comments on commit 5f79a67

Please sign in to comment.