From 156752cfd33a0de793bcadc6b56a2cbd10faa9d1 Mon Sep 17 00:00:00 2001 From: Thomas Bury Date: Tue, 22 Aug 2023 09:56:13 +0200 Subject: [PATCH] fix: :bug: add safeguard if there is a single column of a specific dtype, closes #31 --- src/arfs/association.py | 95 ++++++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 40 deletions(-) diff --git a/src/arfs/association.py b/src/arfs/association.py index 084ea83..86d7b8b 100644 --- a/src/arfs/association.py +++ b/src/arfs/association.py @@ -194,7 +194,7 @@ def theils_u_matrix(X, sample_weight=None, n_jobs=-1, handle_na="drop"): dtypes_dic = create_dtype_dict(X, dic_keys="dtypes") cat_cols = dtypes_dic["cat"] - if cat_cols: + if cat_cols and (len(cat_cols) >= 2): # explicitely store the unique 2-permutation of column names # permutations and not combinations because U is asymmetric comb_list = [comb for comb in permutations(cat_cols, 2)] @@ -353,7 +353,7 @@ def cramer_v_matrix(X, sample_weight=None, n_jobs=-1, handle_na="drop"): # in GLM supposed to be all the columns cat_cols = dtypes_dic["cat"] - if cat_cols: + if cat_cols and (len(cat_cols) >= 2): # explicitely store the unique 2-combinations of column names comb_list = [comb for comb in combinations(cat_cols, 2)] # define the number of cores @@ -1212,51 +1212,66 @@ def association_matrix( """ # sanity checks X, sample_weight = _check_association_input(X, sample_weight, handle_na) + dtypes_dic = create_dtype_dict(X, dic_keys="dtypes") + + # Cramer's V only for categorical columns + # in GLM supposed to be all the columns + n_cat_cols = len(dtypes_dic["cat"]) + n_num_cols = len(dtypes_dic["num"]) + + df_to_concat = [] # num-num, NaNs already checked above, not repeating the process - if callable(num_num_assoc): - w_num_num = _callable_association_matrix_fn( - assoc_fn=num_num_assoc, - cols_comb=num_num_comb, - kind="num-num", - X=X, - sample_weight=sample_weight, - n_jobs=n_jobs, - ) - else: - w_num_num = wcorr_matrix( - X, sample_weight, n_jobs, handle_na=None, method=num_num_assoc - ) + if n_num_cols >= 2: + if callable(num_num_assoc): + w_num_num = _callable_association_matrix_fn( + assoc_fn=num_num_assoc, + cols_comb=num_num_comb, + kind="num-num", + X=X, + sample_weight=sample_weight, + n_jobs=n_jobs, + ) + else: + w_num_num = wcorr_matrix( + X, sample_weight, n_jobs, handle_na=None, method=num_num_assoc + ) + df_to_concat.append(w_num_num) # nom-num - if callable(nom_num_assoc): - w_nom_num = _callable_association_matrix_fn( - assoc_fn=nom_num_assoc, - cols_comb=nom_num_comb, - kind="nom-num", - X=X, - sample_weight=sample_weight, - n_jobs=n_jobs, - ) - else: - w_nom_num = correlation_ratio_matrix(X, sample_weight, n_jobs, handle_na=None) + if (n_num_cols >= 1) and (n_cat_cols >= 1): + if callable(nom_num_assoc): + w_nom_num = _callable_association_matrix_fn( + assoc_fn=nom_num_assoc, + cols_comb=nom_num_comb, + kind="nom-num", + X=X, + sample_weight=sample_weight, + n_jobs=n_jobs, + ) + else: + w_nom_num = correlation_ratio_matrix(X, sample_weight, n_jobs, handle_na=None) + df_to_concat.append(w_nom_num) # nom-nom - if callable(nom_nom_assoc): - w_nom_nom = _callable_association_matrix_fn( - assoc_fn=nom_nom_assoc, - cols_comb=nom_nom_comb, - kind="nom-nom", - X=X, - sample_weight=sample_weight, - n_jobs=n_jobs, - ) - elif nom_nom_assoc == "cramer": - w_nom_nom = cramer_v_matrix(X, sample_weight, n_jobs, handle_na=None) - else: - w_nom_nom = theils_u_matrix(X, sample_weight, n_jobs, handle_na=None) + if n_cat_cols >= 2: + if callable(nom_nom_assoc): + w_nom_nom = _callable_association_matrix_fn( + assoc_fn=nom_nom_assoc, + cols_comb=nom_nom_comb, + kind="nom-nom", + X=X, + sample_weight=sample_weight, + n_jobs=n_jobs, + ) + elif nom_nom_assoc == "cramer": + w_nom_nom = cramer_v_matrix(X, sample_weight, n_jobs, handle_na=None) + else: + w_nom_nom = theils_u_matrix(X, sample_weight, n_jobs, handle_na=None) + df_to_concat.append(w_nom_nom) + - return pd.concat([w_num_num, w_nom_num, w_nom_nom], ignore_index=True) + return pd.concat(df_to_concat, ignore_index=True) def _callable_association_series_fn(