Skip to content

Commit

Permalink
Switch if/else conditions order in fit
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudioSalvatoreArcidiacono committed Dec 18, 2024
1 parent 38edc42 commit 9a32376
Showing 1 changed file with 50 additions and 50 deletions.
100 changes: 50 additions & 50 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,56 @@ def fit(
params=params,
)

if self.early_stopping is True and eval_set is None:
valid_sets: List[Dataset] = []
if eval_set is not None:
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _extract_evaluation_meta_data(
collection=eval_sample_weight,
name="eval_sample_weight",
i=i,
)
valid_class_weight = _extract_evaluation_meta_data(
collection=eval_class_weight,
name="eval_class_weight",
i=i,
)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _extract_evaluation_meta_data(
collection=eval_init_score,
name="eval_init_score",
i=i,
)
valid_group = _extract_evaluation_meta_data(
collection=eval_group,
name="eval_group",
i=i,
)
valid_set = Dataset(
data=valid_data[0],
label=valid_data[1],
weight=valid_weight,
group=valid_group,
init_score=valid_init_score,
categorical_feature="auto",
params=params,
)

valid_sets.append(valid_set)

elif self.early_stopping is True:
if self.validation_fraction is not None:
n_splits = max(int(np.ceil(1 / self.validation_fraction)), 2)
stratified = isinstance(self, LGBMClassifier)
Expand All @@ -1001,55 +1050,6 @@ def fit(
valid_set = train_set
valid_set = valid_set.construct()
valid_sets = [valid_set]
else:
valid_sets: List[Dataset] = []
if eval_set is not None:
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _extract_evaluation_meta_data(
collection=eval_sample_weight,
name="eval_sample_weight",
i=i,
)
valid_class_weight = _extract_evaluation_meta_data(
collection=eval_class_weight,
name="eval_class_weight",
i=i,
)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _extract_evaluation_meta_data(
collection=eval_init_score,
name="eval_init_score",
i=i,
)
valid_group = _extract_evaluation_meta_data(
collection=eval_group,
name="eval_group",
i=i,
)
valid_set = Dataset(
data=valid_data[0],
label=valid_data[1],
weight=valid_weight,
group=valid_group,
init_score=valid_init_score,
categorical_feature="auto",
params=params,
)

valid_sets.append(valid_set)

if isinstance(init_model, LGBMModel):
init_model = init_model.booster_
Expand Down

0 comments on commit 9a32376

Please sign in to comment.