Skip to content

Commit

Permalink
add smoothing to target encoder + improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudioSalvatoreArcidiacono committed Dec 20, 2024
1 parent 8d34db3 commit 39ed08c
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
"Operating System :: OS Independent",
]

dependencies = ["narwhals", "pydantic"]
dependencies = ["narwhals", "pydantic", "scikit-learn"]

[project.optional-dependencies]
dev = [
Expand Down
93 changes: 87 additions & 6 deletions sklearo/encoding/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,104 @@ def __init__(
nw.Categorical,
nw.String,
),
unseen: Literal["raise", "ignore"] = "raise",
fill_value_unseen: int | float | None | Literal["mean"] = "mean",
unseen: Literal["raise", "ignore", "fill"] = "raise",
fill_value_unseen: float | None | Literal["mean"] = "mean",
missing_values: Literal["encode", "ignore", "raise"] = "encode",
underrepresented_categories: Literal["raise", "fill"] = "raise",
fill_values_underrepresented: float | None | Literal["mean"] = "mean",
target_type: Literal["auto", "binary", "multiclass", "continuous"] = "auto",
smooth: Literal["auto"] | float = "auto",
) -> None:

self.columns = columns
self.missing_values = missing_values
self.unseen = unseen
self.fill_value_unseen = fill_value_unseen
self.target_type = target_type
self.smooth = smooth
self.underrepresented_categories = underrepresented_categories
self.fill_values_underrepresented = fill_values_underrepresented

def _calculate_target_statistic(
self, x_y: IntoFrameT, target_col: str, column: str
) -> dict:
mean_target_all_categories = (
x_y.group_by(column).agg(nw.col(target_col).mean()).rows()

if column in (
"category_count",
"sum_target",
"std_target",
"smoothing",
"shrinkage",
"smoothed_target",
):
# rename the column to avoid conflict
original_column_name = column
x_y = x_y.rename(columns={column: f"{column}_original"})
column = f"{column}_original"
else:
original_column_name = column

x_y_grouped = x_y.group_by(column, drop_null_keys=True).agg(
category_count=nw.col(target_col).count(),
sum_target=nw.col(target_col).sum(),
**(
{"std_target": nw.col(target_col).std()}
if self.smooth == "auto"
else {}
),
)
underrepresented_categories = x_y_grouped.filter(nw.col("category_count") == 1)[
column
].to_list()
if underrepresented_categories:
if self.underrepresented_categories == "raise":
raise ValueError(
f"Found underrepresented categories for the column {original_column_name}: "
f"{underrepresented_categories}. Please consider handling underrepresented "
"categories by using a RareLabelEncoder. Alternatively, set "
"underrepresented_categories to 'fill'."
)
else:
if self.fill_values_underrepresented == "mean":
fill_values_underrepresented = x_y[target_col].mean()
else:
fill_values_underrepresented = self.fill_values_underrepresented

x_y_grouped = x_y_grouped.filter(
~nw.col(column).is_in(underrepresented_categories)
)
encoding_dict = {
category: fill_values_underrepresented
for category in underrepresented_categories
}
else:
encoding_dict = {}

if self.smooth == "auto":
target_std = x_y[target_col].std()
x_y_grouped = x_y_grouped.with_columns(
smoothing=nw.col("std_target") / target_std
)
else:
x_y_grouped = x_y_grouped.with_columns(smoothing=nw.lit(self.smooth))

categories_encoding_as_list = (
x_y_grouped.with_columns(
shrinkage=nw.col("category_count")
/ (nw.col("category_count") + nw.col("smoothing"))
)
.with_columns(
smoothed_target=nw.col("shrinkage")
* nw.col("sum_target")
/ nw.col("category_count")
+ (1 - nw.col("shrinkage"))
* nw.col("sum_target")
/ nw.col("category_count")
)
.select(column, "smoothed_target")
.rows()
)
mean_target = dict(mean_target_all_categories)
return mean_target

encoding_dict.update(dict(categories_encoding_as_list))

return encoding_dict
Loading

0 comments on commit 39ed08c

Please sign in to comment.