From e4447b9de59d15a2d0387ba99a537f1d6cecd7b5 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Sun, 14 Jul 2024 18:38:43 +0900 Subject: [PATCH 1/6] refactorying _get_rank_subplot_info --- optuna/visualization/_rank.py | 36 +++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/optuna/visualization/_rank.py b/optuna/visualization/_rank.py index 0ec4028a10..2e165660eb 100644 --- a/optuna/visualization/_rank.py +++ b/optuna/visualization/_rank.py @@ -210,24 +210,24 @@ def _get_rank_subplot_info( xaxis = _get_axis_info(trials, x_param) yaxis = _get_axis_info(trials, y_param) - infeasible_trial_ids = [] - for i in range(len(trials)): - constraints = trials[i].system_attrs.get(_CONSTRAINTS_KEY) + xs: list[Any] = [] + ys: list[Any] = [] + zs: list[np.ndarray] = [] + filtered_trials: list[FrozenTrial] = [] + filtered_colors: list[np.ndarray] = [] + + for idx, trial in enumerate(trials): + constraints = trial.system_attrs.get(_CONSTRAINTS_KEY) if constraints is not None and any([x > 0.0 for x in constraints]): - infeasible_trial_ids.append(i) - - colors[infeasible_trial_ids] = (204, 204, 204) # equal to "#CCCCCC" - - filtered_ids = [ - i - for i in range(len(trials)) - if x_param in trials[i].params and y_param in trials[i].params - ] - filtered_trials = [trials[i] for i in filtered_ids] - xs = [trial.params[x_param] for trial in filtered_trials] - ys = [trial.params[y_param] for trial in filtered_trials] - zs = target_values[filtered_ids] - colors = colors[filtered_ids] + colors[idx] = (204, 204, 204) # equal to "#CCCCCC" + + if x_param in trial.params and y_param in trial.params: + xs.append(trial.params[x_param]) + ys.append(trial.params[y_param]) + zs.append(target_values[idx]) + filtered_trials.append(trial) + filtered_colors.append(colors[idx]) + return _RankSubplotInfo( xaxis=xaxis, yaxis=yaxis, @@ -235,7 +235,7 @@ def _get_rank_subplot_info( ys=ys, trials=filtered_trials, zs=np.array(zs), - colors=colors, + colors=filtered_colors, ) From 8d06bba7f20e50e77c33df823f63d4bff5e5a75e Mon Sep 17 00:00:00 2001 From: RektPunk Date: Sun, 14 Jul 2024 20:38:21 +0900 Subject: [PATCH 2/6] mypy _rank --- optuna/visualization/_rank.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/optuna/visualization/_rank.py b/optuna/visualization/_rank.py index 2e165660eb..bbbf2c3c65 100644 --- a/optuna/visualization/_rank.py +++ b/optuna/visualization/_rank.py @@ -212,9 +212,9 @@ def _get_rank_subplot_info( xs: list[Any] = [] ys: list[Any] = [] - zs: list[np.ndarray] = [] + zs = [] filtered_trials: list[FrozenTrial] = [] - filtered_colors: list[np.ndarray] = [] + filtered_colors = [] for idx, trial in enumerate(trials): constraints = trial.system_attrs.get(_CONSTRAINTS_KEY) @@ -228,6 +228,10 @@ def _get_rank_subplot_info( filtered_trials.append(trial) filtered_colors.append(colors[idx]) + filtered_colors = np.array(filtered_colors) + if filtered_colors.ndim == 1: + filtered_colors = filtered_colors.reshape(-1,1) + return _RankSubplotInfo( xaxis=xaxis, yaxis=yaxis, From e4250be729680495bf5062fd2ec52c4d829ca039 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Sun, 14 Jul 2024 20:40:01 +0900 Subject: [PATCH 3/6] black _rank --- optuna/visualization/_rank.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optuna/visualization/_rank.py b/optuna/visualization/_rank.py index bbbf2c3c65..85e6a8c15f 100644 --- a/optuna/visualization/_rank.py +++ b/optuna/visualization/_rank.py @@ -212,7 +212,7 @@ def _get_rank_subplot_info( xs: list[Any] = [] ys: list[Any] = [] - zs = [] + zs = [] filtered_trials: list[FrozenTrial] = [] filtered_colors = [] @@ -230,7 +230,7 @@ def _get_rank_subplot_info( filtered_colors = np.array(filtered_colors) if filtered_colors.ndim == 1: - filtered_colors = filtered_colors.reshape(-1,1) + filtered_colors = filtered_colors.reshape(-1, 1) return _RankSubplotInfo( xaxis=xaxis, From 33c5b37f7f221d3aea1406bf8fcdf5c1318ac8e6 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Sun, 14 Jul 2024 20:45:47 +0900 Subject: [PATCH 4/6] mypy _rank --- optuna/visualization/_rank.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optuna/visualization/_rank.py b/optuna/visualization/_rank.py index 85e6a8c15f..126b936a64 100644 --- a/optuna/visualization/_rank.py +++ b/optuna/visualization/_rank.py @@ -214,7 +214,7 @@ def _get_rank_subplot_info( ys: list[Any] = [] zs = [] filtered_trials: list[FrozenTrial] = [] - filtered_colors = [] + _filtered_colors = [] for idx, trial in enumerate(trials): constraints = trial.system_attrs.get(_CONSTRAINTS_KEY) @@ -226,9 +226,9 @@ def _get_rank_subplot_info( ys.append(trial.params[y_param]) zs.append(target_values[idx]) filtered_trials.append(trial) - filtered_colors.append(colors[idx]) + _filtered_colors.append(colors[idx]) - filtered_colors = np.array(filtered_colors) + filtered_colors: np.ndarray = np.array(_filtered_colors) if filtered_colors.ndim == 1: filtered_colors = filtered_colors.reshape(-1, 1) From 70662d61eb08d56f4358fa3807abb49662f5c7fd Mon Sep 17 00:00:00 2001 From: RektPunk Date: Wed, 17 Jul 2024 17:36:14 +0900 Subject: [PATCH 5/6] reconstruct code --- optuna/visualization/_rank.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/optuna/visualization/_rank.py b/optuna/visualization/_rank.py index 126b936a64..98ab8bc9a3 100644 --- a/optuna/visualization/_rank.py +++ b/optuna/visualization/_rank.py @@ -210,27 +210,22 @@ def _get_rank_subplot_info( xaxis = _get_axis_info(trials, x_param) yaxis = _get_axis_info(trials, y_param) - xs: list[Any] = [] - ys: list[Any] = [] - zs = [] - filtered_trials: list[FrozenTrial] = [] - _filtered_colors = [] - + infeasible_trial_ids = [] + filtered_ids = [] for idx, trial in enumerate(trials): - constraints = trial.system_attrs.get(_CONSTRAINTS_KEY) + constraints = trial.system_attrs.get("constraints") if constraints is not None and any([x > 0.0 for x in constraints]): - colors[idx] = (204, 204, 204) # equal to "#CCCCCC" - + infeasible_trial_ids.append(idx) if x_param in trial.params and y_param in trial.params: - xs.append(trial.params[x_param]) - ys.append(trial.params[y_param]) - zs.append(target_values[idx]) - filtered_trials.append(trial) - _filtered_colors.append(colors[idx]) + filtered_ids.append(idx) - filtered_colors: np.ndarray = np.array(_filtered_colors) - if filtered_colors.ndim == 1: - filtered_colors = filtered_colors.reshape(-1, 1) + filtered_trials = [trials[i] for i in filtered_ids] + xs = [trial.params[x_param] for trial in filtered_trials] + ys = [trial.params[y_param] for trial in filtered_trials] + zs = target_values[filtered_ids] + + colors[infeasible_trial_ids] = (204, 204, 204) + colors = colors[filtered_ids] return _RankSubplotInfo( xaxis=xaxis, @@ -239,7 +234,7 @@ def _get_rank_subplot_info( ys=ys, trials=filtered_trials, zs=np.array(zs), - colors=filtered_colors, + colors=colors, ) From 172a0f5529fda0309a51e0c05d3fd33820260ca0 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Wed, 17 Jul 2024 17:38:51 +0900 Subject: [PATCH 6/6] fix typo --- optuna/visualization/_rank.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optuna/visualization/_rank.py b/optuna/visualization/_rank.py index 98ab8bc9a3..ead722555c 100644 --- a/optuna/visualization/_rank.py +++ b/optuna/visualization/_rank.py @@ -213,7 +213,7 @@ def _get_rank_subplot_info( infeasible_trial_ids = [] filtered_ids = [] for idx, trial in enumerate(trials): - constraints = trial.system_attrs.get("constraints") + constraints = trial.system_attrs.get(_CONSTRAINTS_KEY) if constraints is not None and any([x > 0.0 for x in constraints]): infeasible_trial_ids.append(idx) if x_param in trial.params and y_param in trial.params: @@ -226,7 +226,6 @@ def _get_rank_subplot_info( colors[infeasible_trial_ids] = (204, 204, 204) colors = colors[filtered_ids] - return _RankSubplotInfo( xaxis=xaxis, yaxis=yaxis,