diff --git a/optuna/visualization/_rank.py b/optuna/visualization/_rank.py index 0ec4028a10..ead722555c 100644 --- a/optuna/visualization/_rank.py +++ b/optuna/visualization/_rank.py @@ -211,22 +211,20 @@ def _get_rank_subplot_info( yaxis = _get_axis_info(trials, y_param) infeasible_trial_ids = [] - for i in range(len(trials)): - constraints = trials[i].system_attrs.get(_CONSTRAINTS_KEY) + filtered_ids = [] + for idx, trial in enumerate(trials): + constraints = trial.system_attrs.get(_CONSTRAINTS_KEY) if constraints is not None and any([x > 0.0 for x in constraints]): - infeasible_trial_ids.append(i) + infeasible_trial_ids.append(idx) + if x_param in trial.params and y_param in trial.params: + filtered_ids.append(idx) - colors[infeasible_trial_ids] = (204, 204, 204) # equal to "#CCCCCC" - - filtered_ids = [ - i - for i in range(len(trials)) - if x_param in trials[i].params and y_param in trials[i].params - ] filtered_trials = [trials[i] for i in filtered_ids] xs = [trial.params[x_param] for trial in filtered_trials] ys = [trial.params[y_param] for trial in filtered_trials] zs = target_values[filtered_ids] + + colors[infeasible_trial_ids] = (204, 204, 204) colors = colors[filtered_ids] return _RankSubplotInfo( xaxis=xaxis,