Skip to content

Commit

Permalink
Support new multi-class objectives in lgbm_classification_learner (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
IgorKhomyanin authored Oct 15, 2023
1 parent 8fe0823 commit 054d319
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/fklearn/training/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,14 @@ def lgbm_classification_learner(df: pd.DataFrame,

import lightgbm as lgbm

LGBM_MULTICLASS_OBJECTIVES = {'multiclass', 'softmax', 'multiclassova', 'multiclass_ova', 'ova', 'ovr'}

params = extra_params if extra_params else {}
params = assoc(params, "eta", learning_rate)
params = params if "objective" in params else assoc(params, "objective", 'binary')

is_multiclass_classification = params["objective"] in LGBM_MULTICLASS_OBJECTIVES

weights = df[weight_column].values if weight_column else None

features = features if not encode_extra_cols else expand_features_encoded(df, features)
Expand All @@ -637,7 +641,7 @@ def lgbm_classification_learner(df: pd.DataFrame,
callbacks=callbacks)

def p(new_df: pd.DataFrame, apply_shap: bool = False) -> pd.DataFrame:
if params["objective"] == "multiclass":
if is_multiclass_classification:
col_dict = {prediction_column + "_" + str(key): value
for (key, value) in enumerate(bst.predict(new_df[features].values).T)}
else:
Expand All @@ -649,7 +653,7 @@ def p(new_df: pd.DataFrame, apply_shap: bool = False) -> pd.DataFrame:
shap_values = explainer.shap_values(new_df[features])
shap_expected_value = explainer.expected_value

if params["objective"] == "multiclass":
if is_multiclass_classification:
shap_values_multiclass = {f"shap_values_{class_index}": list(value)
for (class_index, value) in enumerate(shap_values)}
shap_expected_value_multiclass = {
Expand Down

0 comments on commit 054d319

Please sign in to comment.