Skip to content

Commit

Permalink
Merge pull request optuna#4887 from Alnusjaponica/fix-alias-handler
Browse files Browse the repository at this point in the history
Fix alias handler
  • Loading branch information
HideakiImamura authored Sep 14, 2023
2 parents d0d89a9 + a3446cc commit 0ff4164
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
42 changes: 28 additions & 14 deletions optuna/integration/_lightgbm_tuner/alias.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Any
from typing import Dict
from typing import List # NOQA


_ALIAS_GROUP_LIST: List[Dict[str, Any]] = [
_ALIAS_GROUP_LIST: list[dict[str, Any]] = [
{"param_name": "bagging_fraction", "alias_names": ["sub_row", "subsample", "bagging"]},
{"param_name": "learning_rate", "alias_names": ["shrinkage_rate", "eta"]},
{
Expand All @@ -27,7 +28,7 @@
]


def _handling_alias_parameters(lgbm_params: Dict[str, Any]) -> None:
def _handling_alias_parameters(lgbm_params: dict[str, Any]) -> None:
"""Handling alias parameters."""

for alias_group in _ALIAS_GROUP_LIST:
Expand All @@ -40,7 +41,7 @@ def _handling_alias_parameters(lgbm_params: Dict[str, Any]) -> None:
del lgbm_params[alias_name]


_ALIAS_METRIC_LIST: List[Dict[str, Any]] = [
_ALIAS_METRIC_LIST: list[dict[str, Any]] = [
# The list `alias_names` do not include the `metric_name` itself.
{
"metric_name": "ndcg",
Expand Down Expand Up @@ -103,18 +104,31 @@ def _handling_alias_parameters(lgbm_params: Dict[str, Any]) -> None:
},
]

_ALIAS_METRIC_MAP: dict[str, str] = {
alias_name: canonical_metric["metric_name"]
for canonical_metric in _ALIAS_METRIC_LIST
for alias_name in canonical_metric["alias_names"]
}

def _handling_alias_metrics(lgbm_params: Dict[str, Any]) -> None:
"""Handling alias metrics."""

def _handling_alias_metrics(lgbm_params: dict[str, Any]) -> None:
"""Handling alias metrics."""
if "metric" not in lgbm_params.keys():
return

for metric in _ALIAS_METRIC_LIST:
metric_name = metric["metric_name"]
alias_names = metric["alias_names"]
if not isinstance(lgbm_params["metric"], (str, Iterable)):
raise ValueError(
"The `metric` parameter is expected to be a string or an iterable object, but got "
f"{type(lgbm_params['metric'])}."
)

for alias_name in alias_names:
if lgbm_params["metric"] == alias_name:
lgbm_params["metric"] = metric_name
break
if isinstance(lgbm_params["metric"], str):
lgbm_params["metric"] = (
_ALIAS_METRIC_MAP.get(lgbm_params["metric"]) or lgbm_params["metric"]
)
return

canonical_metrics = []
for metric in lgbm_params["metric"]:
canonical_metrics.append(_ALIAS_METRIC_MAP.get(metric) or metric)
lgbm_params["metric"] = canonical_metrics
10 changes: 10 additions & 0 deletions tests/integration_tests/lightgbm_tuner_tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def test_handling_alias_parameter() -> None:
(["auc_mu"], "auc_mu"),
(["custom", "none", "null", "na"], "custom"),
([], None), # If "metric" not in lgbm_params.keys(): return None.
([["lambdarank"]], ["ndcg"]),
(
[["lambdarank", "mean_average_precision", "root_mean_squared_error"]],
["ndcg", "map", "rmse"],
),
],
)
def test_handling_alias_metrics(aliases: List[str], expect: str) -> None:
Expand All @@ -93,3 +98,8 @@ def test_handling_alias_metrics(aliases: List[str], expect: str) -> None:
lgbm_params = {}
_handling_alias_metrics(lgbm_params)
assert lgbm_params == {}


def test_handling_unexpected_alias_metrics() -> None:
with pytest.raises(ValueError):
_handling_alias_metrics({"metric": 1})

0 comments on commit 0ff4164

Please sign in to comment.