Skip to content

Commit

Permalink
feat: ✨ reduce the length of concatenation to avoid long entries
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Bury committed Aug 11, 2023
1 parent 2269736 commit de7ed74
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/arfs/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

# ARFS
from .gbm import GradientBoosting
from .utils import create_dtype_dict
from .utils import create_dtype_dict, concat_or_group


# fix random seed for reproducibility
Expand Down Expand Up @@ -609,7 +609,7 @@ def fit(self, X, y, sample_weight=None):
self.cat_bin_dict[col] = (
X[[f"{col}_g", col]]
.groupby(f"{col}_g")
.apply(lambda x: " / ".join(map(str, x[col].unique())))
.apply(lambda x: concat_or_group(col, x, max_length=25)) #" / ".join(map(str, x[col].unique())))
.to_dict()
)
else:
Expand Down
45 changes: 45 additions & 0 deletions src/arfs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,51 @@
# #
#####################

def concat_or_group(col, x, max_length=25):
"""
Concatenate unique values from a column or return a group value.
Parameters
----------
col : str
The name of the column to process.
x : pd.DataFrame
The DataFrame containing the data.
max_length : int, optional
The maximum length for concatenated strings, beyond which grouping is performed,
by default 40.
Returns
-------
str
A concatenated string of unique values if the length is less than `max_length`,
otherwise, a unique group value from the specified column.
Notes
-----
If the concatenated string length is greater than or equal to `max_length`, this
function returns the unique group value from the column with a "_g" suffix.
Examples
--------
>>> data = {
>>> 'Category_g': [1, 1, 2, 2, 3],
>>> 'Category': ['AAAAAAAAAAAAAAA', 'Bovoh', 'Ccccccccccccccc', 'D', 'E']}
>>> cat_bin_dict = {}
>>> col = 'Category'
>>> cat_bin_dict[col] = (
>>> X[[f"{col}_g", col]]
>>> .groupby(f"{col}_g")
>>> .apply(lambda x: concat_or_group(col, x))
>>> .to_dict()
>>> )
>>> print(cat_bin_dict)
>>> {'Category': {1: 'gr_1', 2: 'gr_2', 3: 'E'}}
"""
unique_values = x[col].unique()
concat_str = " / ".join(map(str, unique_values))
return concat_str if len(concat_str) < max_length else concat_str[:7] + "/.../" + concat_str[-7:]


def reset_plot():
"""Reset plot style"""
Expand Down

0 comments on commit de7ed74

Please sign in to comment.