Skip to content

Commit

Permalink
feat: refactor cross-shop code to ibis
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurkmmt committed Feb 26, 2025
1 parent 2ce2eec commit eb6a16b
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 108 deletions.
121 changes: 71 additions & 50 deletions pyretailscience/cross_shop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains the CrossShop class that is used to create a cross-shop diagram."""

import ibis
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.axes import Axes, SubplotBase
Expand All @@ -16,24 +17,27 @@ class CrossShop:

def __init__(
self,
df: pd.DataFrame,
group_1_idx: list[bool] | pd.Series,
group_2_idx: list[bool] | pd.Series,
group_3_idx: list[bool] | pd.Series | None = None,
df: pd.DataFrame | ibis.Table,
group_1_col: str,
group_1_val: str,
group_2_col: str,
group_2_val: str,
group_3_col: str | None = None,
group_3_val: str | None = None,
labels: list[str] | None = None,
value_col: str = get_option("column.unit_spend"),
agg_func: str = "sum",
) -> None:
"""Creates a cross-shop diagram that is used to show the overlap of customers between different groups.
Args:
df (pd.DataFrame): The dataframe with transactional data.
group_1_idx (list[bool], pd.Series): A list of bool values determining whether the row is a part of the
first group.
group_2_idx (list[bool], pd.Series): A list of bool values determining whether the row is a part of the
second group.
group_3_idx (list[bool], pd.Series, optional): An optional list of bool values determining whether the
row is a part of the third group. Defaults to None. If not supplied, only two groups will be considered.
df (pd.DataFrame | ibis.Table): The input DataFrame or ibis Table containing transactional data.
group_1_col (str): The column name for the first group.
group_1_val (str): The value of the first group to match.
group_2_col (str): The column name for the second group.
group_2_val (str): The value of the second group to match.
group_3_col (str, optional): The column name for the third group. Defaults to None.
group_3_val (str, optional): The value of the third group to match. Defaults to None.
labels (list[str], optional): The labels for the groups. Defaults to None.
value_col (str, optional): The column to aggregate. Defaults to the option column.unit_spend.
agg_func (str, optional): The aggregation function. Defaults to "sum".
Expand All @@ -51,7 +55,7 @@ def __init__(
msg = f"The following columns are required but missing: {missing_cols}"
raise ValueError(msg)

self.group_count = 2 if group_3_idx is None else 3
self.group_count = 2 if group_3_col is None else 3

Check warning on line 58 in pyretailscience/cross_shop.py

View check run for this annotation

Codecov / codecov/patch

pyretailscience/cross_shop.py#L58

Added line #L58 was not covered by tests

if (labels is not None) and (len(labels) != self.group_count):
raise ValueError("The number of labels must be equal to the number of group indexes given")
Expand All @@ -60,9 +64,12 @@ def __init__(

self.cross_shop_df = self._calc_cross_shop(
df=df,
group_1_idx=group_1_idx,
group_2_idx=group_2_idx,
group_3_idx=group_3_idx,
group_1_col=group_1_col,
group_1_val=group_1_val,
group_2_col=group_2_col,
group_2_val=group_2_val,
group_3_col=group_3_col,
group_3_val=group_3_val,
value_col=value_col,
agg_func=agg_func,
)
Expand All @@ -73,21 +80,26 @@ def __init__(

@staticmethod
def _calc_cross_shop(
df: pd.DataFrame,
group_1_idx: list[bool],
group_2_idx: list[bool],
group_3_idx: list[bool] | None = None,
df: pd.DataFrame | ibis.Table,
group_1_col: str,
group_1_val: str,
group_2_col: str,
group_2_val: str,
group_3_col: str | None = None,
group_3_val: str | None = None,
value_col: str = get_option("column.unit_spend"),
agg_func: str = "sum",
) -> pd.DataFrame:
"""Calculate the cross-shop dataframe that will be used to plot the diagram.
Args:
df (pd.DataFrame): The dataframe with transactional data.
group_1_idx (list[bool]): A list of bool values determining whether the row is a part of the first group.
group_2_idx (list[bool]): A list of bool values determining whether the row is a part of the second group.
group_3_idx (list[bool], optional): An optional list of bool values determining whether the row is a part
of the third group. Defaults to None. If not supplied, only two groups will be considered.
df (pd.DataFrame | ibis.Table): The input DataFrame or ibis Table containing transactional data.
group_1_col (str): Column name for the first group.
group_1_val (str): Value to filter for the first group.
group_2_col (str): Column name for the second group.
group_2_val (str): Value to filter for the second group.
group_3_col (str, optional): Column name for the third group. Defaults to None.
group_3_val (str, optional): Value to filter for the third group. Defaults to None.
value_col (str, optional): The column to aggregate. Defaults to option column.unit_spend.
agg_func (str, optional): The aggregation function. Defaults to "sum".
Expand All @@ -98,35 +110,44 @@ def _calc_cross_shop(
ValueError: If the groups are not mutually exclusive.
"""
cols = ColumnHelper()
if isinstance(group_1_idx, list):
group_1_idx = pd.Series(group_1_idx)
if isinstance(group_2_idx, list):
group_2_idx = pd.Series(group_2_idx)
if group_3_idx is not None and isinstance(group_3_idx, list):
group_3_idx = pd.Series(group_3_idx)

cs_df = df[[cols.customer_id]].copy()

cs_df["group_1"] = group_1_idx.astype(int)
cs_df["group_2"] = group_2_idx.astype(int)
group_cols = ["group_1", "group_2"]

if group_3_idx is not None:
cs_df["group_3"] = group_3_idx.astype(int)
group_cols += ["group_3"]

if (cs_df[group_cols].sum(axis=1) > 1).any():
raise ValueError("The groups must be mutually exclusive.")
if isinstance(df, pd.DataFrame):
df: ibis.Table = ibis.memtable(df)
temp_value_col = "temp_value_col"
df = df.mutate(**{temp_value_col: df[value_col]})

if not any(group_1_idx) or not any(group_2_idx) or (group_3_idx is not None and not any(group_3_idx)):
raise ValueError("There must at least one row selected for group_1_idx, group_2_idx, and group_3_idx.")
group_1 = (df[group_1_col] == group_1_val).cast("int64").name("group_1")
group_2 = (df[group_2_col] == group_2_val).cast("int64").name("group_2")
group_3 = (df[group_3_col] == group_3_val).cast("int64").name("group_3") if group_3_col else None

cs_df = cs_df.groupby(cols.customer_id)[group_cols].max()
cs_df["groups"] = cs_df[group_cols].apply(lambda x: tuple(x), axis=1)

kpi_df = df.groupby(cols.customer_id)[value_col].agg(agg_func)

return cs_df.merge(kpi_df, left_index=True, right_index=True)
group_cols = ["group_1", "group_2"]
select_cols = [df[cols.customer_id], group_1, group_2]
if group_3 is not None:
group_cols.append("group_3")
select_cols.append(group_3)

cs_df = df.select([*select_cols, df[temp_value_col]]).order_by(cols.customer_id)
cs_df = (
cs_df.group_by(cols.customer_id)
.aggregate(
**{col: cs_df[col].max().name(col) for col in group_cols},
**{temp_value_col: getattr(cs_df[temp_value_col], agg_func)().name(temp_value_col)},
)
.order_by(cols.customer_id)
)
cs_df = cs_df.mutate(
groups=ibis.literal("(")
+ cs_df[group_cols[0]].cast("string")
+ ", "
+ cs_df[group_cols[1]].cast("string")
+ (", " + cs_df[group_cols[2]].cast("string") if group_3 is not None else "")
+ ibis.literal(")"),
).execute()
cs_df["groups"] = cs_df["groups"].apply(eval)
column_order = [cols.customer_id, *group_cols, "groups", temp_value_col]
cs_df = cs_df[column_order]
cs_df.set_index(cols.customer_id, inplace=True)
return cs_df.rename(columns={temp_value_col: value_col})

@staticmethod
def _calc_cross_shop_table(
Expand Down
100 changes: 42 additions & 58 deletions tests/test_cross_shop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,20 @@ def sample_data():
return pd.DataFrame(
{
cols.customer_id: [1, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10],
"group_1_idx": [True, False, False, False, False, True, True, False, False, True, False, True],
"group_2_idx": [False, True, False, False, True, False, False, True, False, False, True, False],
"group_3_idx": [False, False, True, False, False, False, False, False, True, False, False, False],
"category_1_name": [
"Jeans",
"Shoes",
"Dresses",
"Hats",
"Shoes",
"Jeans",
"Jeans",
"Shoes",
"Dresses",
"Jeans",
"Shoes",
"Jeans",
],
cols.unit_spend: [10, 20, 30, 40, 20, 50, 10, 20, 30, 15, 40, 50],
},
)
Expand All @@ -27,8 +38,10 @@ def test_calc_cross_shop_two_groups(sample_data):
"""Test the _calc_cross_shop method with two groups."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
)
ret_df = pd.DataFrame(
{
Expand All @@ -47,9 +60,12 @@ def test_calc_cross_shop_three_groups(sample_data):
"""Test the _calc_cross_shop method with three groups."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_3_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_col="category_1_name",
group_3_val="Dresses",
)
ret_df = pd.DataFrame(
{
Expand All @@ -76,36 +92,16 @@ def test_calc_cross_shop_three_groups(sample_data):
assert cross_shop_df.equals(ret_df)


def test_calc_cross_shop_two_groups_overlap_error(sample_data):
"""Test the _calc_cross_shop method with two groups and overlapping group indices."""
with pytest.raises(ValueError):
CrossShop._calc_cross_shop(
sample_data,
# Pass the same group index for both groups
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_1_idx"],
)


def test_calc_cross_shop_three_groups_overlap_error(sample_data):
"""Test the _calc_cross_shop method with three groups and overlapping group indices."""
with pytest.raises(ValueError):
CrossShop._calc_cross_shop(
sample_data,
# Pass the same group index for groups 1 and 3
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_1_idx"],
)


def test_calc_cross_shop_three_groups_customer_id_nunique(sample_data):
"""Test the _calc_cross_shop method with three groups and customer_id as the value column."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_3_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_col="category_1_name",
group_3_val="Dresses",
value_col=cols.customer_id,
agg_func="nunique",
)
Expand Down Expand Up @@ -139,9 +135,12 @@ def test_calc_cross_shop_table(sample_data):
"""Test the _calc_cross_shop_table method."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_3_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_col="category_1_name",
group_3_val="Dresses",
value_col=cols.unit_spend,
)
cross_shop_table = CrossShop._calc_cross_shop_table(
Expand Down Expand Up @@ -174,9 +173,12 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data):
"""Test the _calc_cross_shop_table method with customer_id as the value column."""
cross_shop_df = CrossShop._calc_cross_shop(
sample_data,
group_1_idx=sample_data["group_1_idx"],
group_2_idx=sample_data["group_2_idx"],
group_3_idx=sample_data["group_3_idx"],
group_1_col="category_1_name",
group_1_val="Jeans",
group_2_col="category_1_name",
group_2_val="Shoes",
group_3_col="category_1_name",
group_3_val="Dresses",
value_col=cols.customer_id,
agg_func="nunique",
)
Expand All @@ -193,21 +195,3 @@ def test_calc_cross_shop_table_customer_id_nunique(sample_data):
)

assert cross_shop_table.equals(ret_df)


def test_calc_cross_shop_all_groups_false(sample_data):
"""Test the _calc_cross_shop method with all group indices set to False."""
with pytest.raises(ValueError):
CrossShop._calc_cross_shop(
sample_data,
group_1_idx=[False] * len(sample_data),
group_2_idx=[False] * len(sample_data),
)

with pytest.raises(ValueError):
CrossShop._calc_cross_shop(
sample_data,
group_1_idx=[False] * len(sample_data),
group_2_idx=[False] * len(sample_data),
group_3_idx=[False] * len(sample_data),
)

0 comments on commit eb6a16b

Please sign in to comment.