Skip to content

Commit

Permalink
Merge pull request optuna#4896 from Alnusjaponica/use-isinstance
Browse files Browse the repository at this point in the history
Use `isinstance` instead of `if type() is ...`
  • Loading branch information
not522 authored Oct 5, 2023
2 parents cdcb645 + e0f6505 commit 760c774
Showing 1 changed file with 46 additions and 50 deletions.
96 changes: 46 additions & 50 deletions optuna/integration/_lightgbm_tuner/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from collections.abc import Callable
from collections.abc import Container
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Sequence
import copy
import json
import os
Expand Down Expand Up @@ -78,11 +80,11 @@ def _get_metric_for_objective(self) -> str:
metric = self.lgbm_params.get("metric", "binary_logloss")

# todo (smly): This implementation is different logic from the LightGBM's python bindings.
if type(metric) is str:
if isinstance(metric, str):
pass
elif type(metric) is list:
elif isinstance(metric, Sequence):
metric = metric[-1]
elif type(metric) is set:
elif isinstance(metric, Iterable):
metric = list(metric)[-1]
else:
raise NotImplementedError
Expand All @@ -95,19 +97,19 @@ def _get_booster_best_score(self, booster: "lgb.Booster") -> float:
valid_sets: VALID_SET_TYPE | None = self.lgbm_kwargs.get("valid_sets")

if self.lgbm_kwargs.get("valid_names") is not None:
if type(self.lgbm_kwargs["valid_names"]) is str:
if isinstance(self.lgbm_kwargs["valid_names"], str):
valid_name = self.lgbm_kwargs["valid_names"]
elif type(self.lgbm_kwargs["valid_names"]) in [list, tuple]:
elif isinstance(self.lgbm_kwargs["valid_names"], Sequence):
valid_name = self.lgbm_kwargs["valid_names"][-1]
else:
raise NotImplementedError

elif type(valid_sets) is lgb.Dataset:
elif isinstance(valid_sets, lgb.Dataset):
valid_name = "valid_0"

elif isinstance(valid_sets, (list, tuple)) and len(valid_sets) > 0:
elif isinstance(valid_sets, Sequence) and len(valid_sets) > 0:
valid_set_idx = len(valid_sets) - 1
valid_name = "valid_{}".format(valid_set_idx)
valid_name = f"valid_{valid_set_idx}"

else:
raise NotImplementedError
Expand All @@ -116,27 +118,27 @@ def _get_booster_best_score(self, booster: "lgb.Booster") -> float:
return val_score

def _metric_with_eval_at(self, metric: str) -> str:
if metric != "ndcg" and metric != "map":
# The parameter eval_at is only available when the metric is ndcg or map
if metric not in ["ndcg", "map"]:
return metric

eval_at = self.lgbm_params.get("eval_at")
if eval_at is None:
eval_at = self.lgbm_params.get("{}_at".format(metric))
if eval_at is None:
eval_at = self.lgbm_params.get("{}_eval_at".format(metric))
if eval_at is None:
# Set default value of LightGBM.
eval_at = (
self.lgbm_params.get("eval_at")
or self.lgbm_params.get(f"{metric}_at")
or self.lgbm_params.get(f"{metric}_eval_at")
# Set default value of LightGBM when no possible key is absent.
# See https://lightgbm.readthedocs.io/en/latest/Parameters.html#eval_at.
eval_at = [1, 2, 3, 4, 5]
or [1, 2, 3, 4, 5]
)

# Optuna can handle only a single metric. Choose first one.
if type(eval_at) in [list, tuple]:
return "{}@{}".format(metric, eval_at[0])
if type(eval_at) is int:
return "{}@{}".format(metric, eval_at)
if isinstance(eval_at, (list, tuple)):
return f"{metric}@{eval_at[0]}"
if isinstance(eval_at, int):
return f"{metric}@{eval_at}"
raise ValueError(
"The value of eval_at is expected to be int or a list/tuple of int."
"'{}' is specified.".format(eval_at)
f"The value of eval_at is expected to be int or a list/tuple of int. '{eval_at}' is "
"specified."
)

def higher_is_better(self) -> bool:
Expand Down Expand Up @@ -182,18 +184,12 @@ def __init__(
self.pbar_fmt = "{}, val_score: {:.6f}"

def _check_target_names_supported(self) -> None:
supported_param_names = [
"lambda_l1",
"lambda_l2",
"num_leaves",
"feature_fraction",
"bagging_fraction",
"bagging_freq",
"min_child_samples",
]
for target_param_name in self.target_param_names:
if target_param_name not in supported_param_names:
raise NotImplementedError("Parameter `{}` is not supported for tuning.")
if target_param_name in _DEFAULT_LIGHTGBM_PARAMETERS:
continue
raise NotImplementedError(
f"Parameter `{target_param_name}` is not supported for tuning."
)

def _preprocess(self, trial: optuna.trial.Trial) -> None:
if self.pbar is not None:
Expand Down Expand Up @@ -246,10 +242,10 @@ def __call__(self, trial: optuna.trial.Trial) -> float:
average_iteration_time = elapsed_secs / booster.current_iteration()

if self.model_dir is not None:
path = os.path.join(self.model_dir, "{}.pkl".format(trial.number))
path = os.path.join(self.model_dir, f"{trial.number}.pkl")
with open(path, "wb") as fout:
pickle.dump(booster, fout)
_logger.info("The booster of trial#{} was saved as {}.".format(trial.number, path))
_logger.info(f"The booster of trial#{trial.number} was saved as {path}.")

if self.compare_validation_metrics(val_score, self.best_score):
self.best_score = val_score
Expand Down Expand Up @@ -326,14 +322,14 @@ def __call__(self, trial: optuna.trial.Trial) -> float:
average_iteration_time = elapsed_secs / len(val_scores)

if self.model_dir is not None and self.lgbm_kwargs.get("return_cvbooster"):
path = os.path.join(self.model_dir, "{}.pkl".format(trial.number))
path = os.path.join(self.model_dir, f"{trial.number}.pkl")
with open(path, "wb") as fout:
# At version `lightgbm==3.0.0`, :class:`lightgbm.CVBooster` does not
# have `__getstate__` which is required for pickle serialization.
cvbooster = cv_results["cvbooster"]
assert isinstance(cvbooster, lgb.CVBooster)
pickle.dump((cvbooster.boosters, cvbooster.best_iteration), fout)
_logger.info("The booster of trial#{} was saved as {}.".format(trial.number, path))
_logger.info(f"The booster of trial#{trial.number} was saved as {path}.")

if self.compare_validation_metrics(val_score, self.best_score):
self.best_score = val_score
Expand Down Expand Up @@ -418,15 +414,15 @@ def __init__(
if self.study.direction != optuna.study.StudyDirection.MAXIMIZE:
metric_name = self.lgbm_params.get("metric", "binary_logloss")
raise ValueError(
"Study direction is inconsistent with the metric {}. "
"Please set 'maximize' as the direction.".format(metric_name)
f"Study direction is inconsistent with the metric {metric_name}. "
"Please set 'maximize' as the direction."
)
else:
if self.study.direction != optuna.study.StudyDirection.MINIMIZE:
metric_name = self.lgbm_params.get("metric", "binary_logloss")
raise ValueError(
"Study direction is inconsistent with the metric {}. "
"Please set 'minimize' as the direction.".format(metric_name)
f"Study direction is inconsistent with the metric {metric_name}. "
"Please set 'minimize' as the direction."
)

if verbosity is not None:
Expand Down Expand Up @@ -843,12 +839,12 @@ def get_best_booster(self) -> "lgb.Booster":
)

best_trial = self.study.best_trial
path = os.path.join(self._model_dir, "{}.pkl".format(best_trial.number))
path = os.path.join(self._model_dir, f"{best_trial.number}.pkl")
if not os.path.exists(path):
raise ValueError(
"The best booster cannot be found in {}. If you execute `LightGBMTuner` in "
"distributed environment, please use network file system (e.g., NFS) to share "
"models with multiple workers.".format(self._model_dir)
f"The best booster cannot be found in {self._model_dir}. If you execute "
"`LightGBMTuner` in distributed environment, please use network file system "
"(e.g., NFS) to share models with multiple workers."
)

with open(path, "rb") as fin:
Expand Down Expand Up @@ -1040,12 +1036,12 @@ def get_best_booster(self) -> "lgb.CVBooster":
)

best_trial = self.study.best_trial
path = os.path.join(self._model_dir, "{}.pkl".format(best_trial.number))
path = os.path.join(self._model_dir, f"{best_trial.number}.pkl")
if not os.path.exists(path):
raise ValueError(
"The best booster cannot be found in {}. If you execute `LightGBMTunerCV` in "
"distributed environment, please use network file system (e.g., NFS) to share "
"models with multiple workers.".format(self._model_dir)
f"The best booster cannot be found in {self._model_dir}. If you execute "
"`LightGBMTunerCV` in distributed environment, please use network file system "
"(e.g., NFS) to share models with multiple workers."
)

with open(path, "rb") as fin:
Expand Down

0 comments on commit 760c774

Please sign in to comment.