Skip to content

Commit

Permalink
Removes occurrences of DataFrame.values (ndarray)
Browse files Browse the repository at this point in the history
Uses the DataFrame everywhere it's possible.
  • Loading branch information
fberanizo committed Aug 29, 2022
1 parent e8a16e3 commit 03e84e0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/fklearn/training/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,21 +571,21 @@ def lgbm_classification_learner(df: pd.DataFrame,
params = assoc(params, "eta", learning_rate)
params = params if "objective" in params else assoc(params, "objective", 'binary')

weights = df[weight_column].values if weight_column else None
weights = df[weight_column] if weight_column else None

features = features if not encode_extra_cols else expand_features_encoded(df, features)

dtrain = lgbm.Dataset(df[features], label=df[target], feature_name=list(map(str, features)), weight=weights,
silent=True, categorical_feature=categorical_features)

bst = lgbm.train(params, dtrain, num_estimators)
bst = lgbm.train(params, dtrain, num_estimators, categorical_feature=categorical_features)

def p(new_df: pd.DataFrame, apply_shap: bool = False) -> pd.DataFrame:
if params["objective"] == "multiclass":
col_dict = {prediction_column + "_" + str(key): value
for (key, value) in enumerate(bst.predict(new_df[features].values).T)}
for (key, value) in enumerate(bst.predict(new_df[features]).T)}
else:
col_dict = {prediction_column: bst.predict(new_df[features].values)}
col_dict = {prediction_column: bst.predict(new_df[features])}

if apply_shap:
import shap
Expand Down

0 comments on commit 03e84e0

Please sign in to comment.