Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multilevel groupby #69

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 45 additions & 42 deletions tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class TableOne(object):
List of columns in the dataset to be included in the final table.
categorical : list, optional
List of columns that contain categorical variables.
groupby : str, optional
Optional column for stratifying the final table (default: None).
groupby : list, optional
Optional columns for stratifying the final table (default: None).
nonnormal : list, optional
List of columns that contain non-normal variables (default: None).
pval : bool, optional
Expand Down Expand Up @@ -83,9 +83,9 @@ def __init__(self, data, columns=None, categorical=None, groupby=None,

# check input arguments
if not groupby:
groupby = ''
elif groupby and type(groupby) == list:
groupby = groupby[0]
groupby = []
elif groupby and isinstance(groupby, str):
groupby = [groupby]

if not nonnormal:
nonnormal=[]
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(self, data, columns=None, categorical=None, groupby=None,

self._columns = list(columns)
self._isnull = isnull
self._continuous = [c for c in columns if c not in categorical + [groupby]]
self._continuous = [c for c in columns if c not in categorical + groupby]
self._categorical = categorical
self._nonnormal = nonnormal
self._pval = pval
Expand All @@ -131,13 +131,15 @@ def __init__(self, data, columns=None, categorical=None, groupby=None,
# output column names that cannot be contained in a groupby
self._reserved_columns = ['isnull', 'pval', 'ptest', 'pval (adjusted)']
if self._groupby:
self._groupbylvls = sorted(data.groupby(groupby).groups.keys())
for groupbyvar in groupby:
data[groupbyvar] = data[groupbyvar].astype(str) # Treat groupby variables as string to avoid problems with categorical groupby
self._groups = data.groupby(groupby).groups
# check that the group levels do not include reserved words
for level in self._groupbylvls:
for level in data.groupby(self._groupby[0]).groups:
if level in self._reserved_columns:
raise InputError('Group level contained "{}", a reserved keyword for tableone.'.format(level))
else:
self._groupbylvls = ['overall']
self._groups = {'overall': data.index}

# forgive me jraffa
if self._pval:
Expand Down Expand Up @@ -380,13 +382,15 @@ def _create_cont_describe(self,data):

if self._groupby:
# add the groupby column back
cont_data = cont_data.merge(data[[self._groupby]],
cont_data = cont_data.merge(data[self._groupby],
left_index=True, right_index=True)

# group and aggregate data
df_cont = pd.pivot_table(cont_data,
columns=[self._groupby],
columns=self._groupby,
aggfunc=aggfuncs)
if len(self._groupby) > 1:
df_cont = df_cont.unstack([i+1 for i in range(len(self._groupby))])
else:
# if no groupby, just add single group column
df_cont = cont_data.apply(aggfuncs).T
Expand Down Expand Up @@ -419,11 +423,8 @@ def _create_cat_describe(self,data):
"""
group_dict = {}

for g in self._groupbylvls:
if self._groupby:
d_slice = data.loc[data[self._groupby] == g, self._categorical]
else:
d_slice = data[self._categorical].copy()
for g, gdata in self._groups.items():
d_slice = data.loc[gdata, self._categorical].copy()

# create a dataframe with freq, proportion
df = d_slice.copy()
Expand Down Expand Up @@ -456,9 +457,9 @@ def _create_cat_describe(self,data):
group_dict[g] = df

df_cat = pd.concat(group_dict,axis=1)
# ensure the groups are the 2nd level of the column index
# ensure the groups are the final levels of the column index
if df_cat.columns.nlevels>1:
df_cat = df_cat.swaplevel(0, 1, axis=1).sort_index(axis=1,level=0)
df_cat = df_cat.reorder_levels([df_cat.columns.nlevels-1]+[i for i in range(df_cat.columns.nlevels-1)], axis=1).sort_index(axis=1,level=0)

return df_cat

Expand Down Expand Up @@ -495,8 +496,8 @@ def _create_significance_table(self,data):
if is_continuous:
catlevels = None
grouped_data = []
for s in self._groupbylvls:
lvl_data = data.loc[data[self._groupby]==s, v]
for g, gdata in self._groups.items():
lvl_data = data.loc[gdata, v]
# coerce to numeric and drop non-numeric data
lvl_data = lvl_data.apply(pd.to_numeric, errors='coerce').dropna()
# append to overall group data
Expand All @@ -505,7 +506,7 @@ def _create_significance_table(self,data):
# if categorical, create contingency table
elif is_categorical:
catlevels = sorted(data[v].astype('category').cat.categories)
grouped_data = pd.crosstab(data[self._groupby].rename('_groupby_var_'),data[v])
grouped_data = pd.crosstab([data[g] for g in self._groupby],data[v])
min_observed = grouped_data.sum(axis=1).min()

# minimum number of observations across all levels
Expand All @@ -516,6 +517,7 @@ def _create_significance_table(self,data):
grouped_data,is_continuous,is_categorical,
is_normal,min_observed,catlevels)

if len(self._groupby) > 1: df.columns = pd.MultiIndex.from_product([df.columns if i == 0 else [''] for i in range(len(self._groupby))])
return df

def _p_test(self,v,grouped_data,is_continuous,is_categorical,
Expand Down Expand Up @@ -600,12 +602,10 @@ def _create_cont_table(self,data):
table.columns = table.columns.droplevel(level=0)

# add a column of null counts as 1-count() from previous function
# isnull needs to be its own column
nulltable = data[self._continuous].isnull().sum().to_frame(name='isnull')
try:
table = table.join(nulltable)
except TypeError: # if columns form a CategoricalIndex, need to convert to string first
table.columns = table.columns.astype(str)
table = table.join(nulltable)
if len(self._groupby) > 1: nulltable.columns = pd.MultiIndex.from_product([['isnull'] if i == 0 else [''] for i in range(len(self._groupby))])
table = table.join(nulltable)

# add an empty level column, for joining with cat table
table['level'] = ''
Expand All @@ -632,11 +632,8 @@ def _create_cat_table(self,data):
# add the total count of null values across all levels
isnull = data[self._categorical].isnull().sum().to_frame(name='isnull')
isnull.index.rename('variable', inplace=True)
try:
table = table.join(isnull)
except TypeError: # if columns form a CategoricalIndex, need to convert to string first
table.columns = table.columns.astype(str)
table = table.join(isnull)
if len(self._groupby) > 1: isnull.columns = pd.MultiIndex.from_product([['isnull'] if i == 0 else [''] for i in range(len(self._groupby))])
table = table.join(isnull)

# add pval column
if self._pval and self._pval_adjust:
Expand Down Expand Up @@ -699,13 +696,15 @@ def _create_tableone(self,data):
n_row = pd.DataFrame(columns = ['variable','level','isnull'])
n_row.set_index(['variable','level'], inplace=True)
n_row.loc['n', ''] = None
if len(self._groupby) > 1: n_row.columns = pd.MultiIndex.from_tuples(
[tuple('isnull' if i == 0 else '' for i in range(len(self._groupby))), tuple('' for i in range(len(self._groupby)))], names=table.columns.names)
table = pd.concat([n_row,table],sort=False)

if self._groupbylvls == ['overall']:
if not self._groupby:
table.loc['n','overall'] = len(data.index)
else:
for g in self._groupbylvls:
ct = data[self._groupby][data[self._groupby]==g].count()
for g, gdata in self._groups.items():
ct = len(gdata)
table.loc['n',g] = ct

# only display data in first level row
Expand All @@ -716,28 +715,32 @@ def _create_tableone(self,data):
if col in table.columns.values:
dupe_columns.append(col)

if len(self._groupby) > 1: dupe_columns = [tuple(c if i == 0 else '' for i in range(len(self._groupby))) for c in dupe_columns]
table[dupe_columns] = table[dupe_columns].mask(dupe_mask).fillna('')

# remove empty column added above
table.drop([''], axis=1, inplace=True)
if len(self._groupby) > 1: table.drop(tuple('' for i in range(len(self._groupby))), axis=1, inplace=True)
else: table.drop('', axis=1, inplace=True)

# remove isnull column if not needed
if not self._isnull:
table.drop('isnull',axis=1,inplace=True)
if len(self._groupby) > 1: table.drop(tuple('isnull' if i == 0 else '' for i in range(len(self._groupby))),axis=1,inplace=True)
else: table.drop('isnull',axis=1,inplace=True)

# replace nans with empty strings
table.fillna('',inplace=True)

# add column index
if not self._groupbylvls == ['overall']:
if self._groupby:
# rename groupby variable if requested
c = self._groupby
if self._alt_labels:
if self._groupby in self._alt_labels:
c = self._alt_labels[self._groupby]
if self._alt_labels: c = ', '.join([self._alt_labels.get(g, g) for g in self._groupby])
else: c = ', '.join(self._groupby)

c = 'Grouped by {}'.format(c)
table.columns = pd.MultiIndex.from_product([[c], table.columns])
if len(self._groupby) > 1:
table.columns = pd.MultiIndex.from_tuples(tuple([c])+table.columns[i] for i in range(len(table.columns)))
else:
table.columns = pd.MultiIndex.from_product([[c], table.columns])

# display alternative labels if assigned
table.rename(index=self._create_row_labels(), inplace=True, level=0)
Expand Down
91 changes: 91 additions & 0 deletions test_tableone.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,94 @@ def test_check_null_counts_are_correct_pn(self):
# check each null count is correct
col = isnull.index[i][0]
assert self.data_pn[col].isnull().sum() == v

@with_setup(setup, teardown)
def test_multilevel_groupby(self):
"""
Test multilevel groupby produces expected results
"""
columns = ['Age', 'Height', 'Weight', 'ICU']
categorical = ['ICU']

table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'])
assert table.tableone.columns[0][0] == 'Grouped by death, MechVent'
table.tableone.columns = table.tableone.columns.droplevel(0)
assert len(table.tableone.columns) == 5
for i, correct_col in enumerate([('isnull', ''), ('0', '0'), ('0', '1'), ('1', '0'), ('1', '1')]):
assert table.tableone.columns[i] == correct_col
assert len(table.tableone.index) == 8
rows = [('n', ''), ('Age', ''), ('Height', ''), ('Weight', ''), ('ICU', 'CCU'), ('ICU', 'CSRU'), ('ICU', 'MICU'), ('ICU', 'SICU')]
for i, correct_row in enumerate(rows):
assert table.tableone.index[i] == correct_row
correct_value = {
('n', ''): ['', 468, 396, 72, 64],
('Age', ''): [0, '65.29 (17.94)', '62.47 (16.65)', '71.06 (13.90)', '72.42 (14.21)'],
('Height', ''): [475, '171.55 (31.78)', '169.24 (11.06)', '167.36 (11.32)', '169.86 (11.34)'],
('Weight', ''): [302, '81.03 (22.28)', '85.02 (24.67)', '83.89 (28.35)', '80.44 (21.66)'],
('ICU', 'CCU'): [0, '110 (23.5)', '27 (6.82)', '11 (15.28)', '14 (21.88)'],
('ICU', 'CSRU'): ['', '50 (10.68)', '144 (36.36)', '3 (4.17)', '5 (7.81)'],
('ICU', 'MICU'): ['', '205 (43.8)', '113 (28.54)', '47 (65.28)', '15 (23.44)'],
('ICU', 'SICU'): ['', '103 (22.01)', '112 (28.28)', '11 (15.28)', '30 (46.88)']
}
for row in rows:
assert list(table.tableone.loc[row]) == correct_value[row]

@with_setup(setup, teardown)
def test_multilevel_groupby_pval(self):
"""
Test multilevel groupby works when p-values are requested
"""
columns = ['Age', 'Height', 'Weight', 'ICU']
categorical = ['ICU']

table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], pval=True)
table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], pval=True, pval_adjust='bonferroni')
table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], pval=True, nonnormal=['Age'])
assert table.tableone.loc['Weight', ('Grouped by death, MechVent', 'pval', '')][0] == '0.187'

@with_setup(setup, teardown)
def test_multilevel_groupby_noisnull(self):
"""
Test multilevel groupby runs without error when isnull option is False
"""
columns = ['Age', 'Height', 'Weight', 'ICU']
categorical = ['ICU']

table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], isnull=False)

@with_setup(setup, teardown)
def test_multilevel_groupby_sort(self):
"""
Test multilevel groupby runs without error when sort option is True
"""
columns = ['Age', 'Height', 'Weight', 'ICU']
categorical = ['ICU']

table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], sort=True)

@with_setup(setup, teardown)
def test_multilevel_groupby_limit(self):
"""
Test multilevel groupby runs correctly when limit option is set
"""
columns = ['Age', 'Height', 'Weight', 'ICU']
categorical = ['ICU']

table = TableOne(self.data_pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'], limit=2)
assert list(table.tableone.loc['ICU'].index) == ['MICU', 'SICU']

@with_setup(setup, teardown)
def test_groupby_categorical(self):
"""
Test groupby runs without error with categorical groupby variable
"""
columns = ['Age', 'Height', 'Weight', 'ICU']
categorical = ['ICU']

pn = self.data_pn.copy()
pn['death'] = pn['death'].astype('category')
table = TableOne(pn, columns=columns, categorical=categorical, groupby=['death'])
assert len(table.tableone.columns == 3)
table = TableOne(pn, columns=columns, categorical=categorical, groupby=['death', 'MechVent'])
assert len(table.tableone.columns) == 5