Skip to content

Commit

Permalink
fix: 🐛 add safeguard if there is a single column of a specific dtype, c…
Browse files Browse the repository at this point in the history
…loses #31
  • Loading branch information
Thomas Bury committed Aug 22, 2023
1 parent de0d1de commit 156752c
Showing 1 changed file with 55 additions and 40 deletions.
95 changes: 55 additions & 40 deletions src/arfs/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def theils_u_matrix(X, sample_weight=None, n_jobs=-1, handle_na="drop"):
dtypes_dic = create_dtype_dict(X, dic_keys="dtypes")
cat_cols = dtypes_dic["cat"]

if cat_cols:
if cat_cols and (len(cat_cols) >= 2):
# explicitely store the unique 2-permutation of column names
# permutations and not combinations because U is asymmetric
comb_list = [comb for comb in permutations(cat_cols, 2)]
Expand Down Expand Up @@ -353,7 +353,7 @@ def cramer_v_matrix(X, sample_weight=None, n_jobs=-1, handle_na="drop"):
# in GLM supposed to be all the columns
cat_cols = dtypes_dic["cat"]

if cat_cols:
if cat_cols and (len(cat_cols) >= 2):
# explicitely store the unique 2-combinations of column names
comb_list = [comb for comb in combinations(cat_cols, 2)]
# define the number of cores
Expand Down Expand Up @@ -1212,51 +1212,66 @@ def association_matrix(
"""
# sanity checks
X, sample_weight = _check_association_input(X, sample_weight, handle_na)
dtypes_dic = create_dtype_dict(X, dic_keys="dtypes")

# Cramer's V only for categorical columns
# in GLM supposed to be all the columns
n_cat_cols = len(dtypes_dic["cat"])
n_num_cols = len(dtypes_dic["num"])

df_to_concat = []

# num-num, NaNs already checked above, not repeating the process
if callable(num_num_assoc):
w_num_num = _callable_association_matrix_fn(
assoc_fn=num_num_assoc,
cols_comb=num_num_comb,
kind="num-num",
X=X,
sample_weight=sample_weight,
n_jobs=n_jobs,
)
else:
w_num_num = wcorr_matrix(
X, sample_weight, n_jobs, handle_na=None, method=num_num_assoc
)
if n_num_cols >= 2:
if callable(num_num_assoc):
w_num_num = _callable_association_matrix_fn(
assoc_fn=num_num_assoc,
cols_comb=num_num_comb,
kind="num-num",
X=X,
sample_weight=sample_weight,
n_jobs=n_jobs,
)
else:
w_num_num = wcorr_matrix(
X, sample_weight, n_jobs, handle_na=None, method=num_num_assoc
)
df_to_concat.append(w_num_num)

# nom-num
if callable(nom_num_assoc):
w_nom_num = _callable_association_matrix_fn(
assoc_fn=nom_num_assoc,
cols_comb=nom_num_comb,
kind="nom-num",
X=X,
sample_weight=sample_weight,
n_jobs=n_jobs,
)
else:
w_nom_num = correlation_ratio_matrix(X, sample_weight, n_jobs, handle_na=None)
if (n_num_cols >= 1) and (n_cat_cols >= 1):
if callable(nom_num_assoc):
w_nom_num = _callable_association_matrix_fn(
assoc_fn=nom_num_assoc,
cols_comb=nom_num_comb,
kind="nom-num",
X=X,
sample_weight=sample_weight,
n_jobs=n_jobs,
)
else:
w_nom_num = correlation_ratio_matrix(X, sample_weight, n_jobs, handle_na=None)
df_to_concat.append(w_nom_num)

# nom-nom
if callable(nom_nom_assoc):
w_nom_nom = _callable_association_matrix_fn(
assoc_fn=nom_nom_assoc,
cols_comb=nom_nom_comb,
kind="nom-nom",
X=X,
sample_weight=sample_weight,
n_jobs=n_jobs,
)
elif nom_nom_assoc == "cramer":
w_nom_nom = cramer_v_matrix(X, sample_weight, n_jobs, handle_na=None)
else:
w_nom_nom = theils_u_matrix(X, sample_weight, n_jobs, handle_na=None)
if n_cat_cols >= 2:
if callable(nom_nom_assoc):
w_nom_nom = _callable_association_matrix_fn(
assoc_fn=nom_nom_assoc,
cols_comb=nom_nom_comb,
kind="nom-nom",
X=X,
sample_weight=sample_weight,
n_jobs=n_jobs,
)
elif nom_nom_assoc == "cramer":
w_nom_nom = cramer_v_matrix(X, sample_weight, n_jobs, handle_na=None)
else:
w_nom_nom = theils_u_matrix(X, sample_weight, n_jobs, handle_na=None)
df_to_concat.append(w_nom_nom)


return pd.concat([w_num_num, w_nom_num, w_nom_nom], ignore_index=True)
return pd.concat(df_to_concat, ignore_index=True)


def _callable_association_series_fn(
Expand Down

0 comments on commit 156752c

Please sign in to comment.