Skip to content

Commit

Permalink
Merge pull request optuna#5867 from not522/refactor-plot-contour
Browse files Browse the repository at this point in the history
Refactor plot contour
  • Loading branch information
HideakiImamura authored Dec 19, 2024
2 parents 945f856 + 2e143f2 commit e6e0cb4
Showing 1 changed file with 31 additions and 79 deletions.
110 changes: 31 additions & 79 deletions optuna/visualization/matplotlib/_contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,7 @@ def _calculate_axis_data(
return ci, cat_param_labels, cat_param_pos, list(returned_values)


def _calculate_griddata(
info: _SubContourInfo,
) -> tuple[
np.ndarray,
np.ndarray,
np.ndarray,
list[int],
list[str],
list[int],
list[str],
_PlotValues,
_PlotValues,
]:
def _calculate_griddata(info: _SubContourInfo) -> tuple[np.ndarray, _PlotValues, _PlotValues]:
xaxis = info.xaxis
yaxis = info.yaxis
z_values_dict = info.z_values
Expand All @@ -220,17 +208,7 @@ def _calculate_griddata(

# Return empty values when x or y has no value.
if len(x_values) == 0 or len(y_values) == 0:
return (
np.array([]),
np.array([]),
np.array([]),
[],
[],
[],
[],
_PlotValues([], []),
_PlotValues([], []),
)
return np.array([]), _PlotValues([], []), _PlotValues([], [])

xi, cat_param_labels_x, cat_param_pos_x, transformed_x_values = _calculate_axis_data(
xaxis,
Expand Down Expand Up @@ -261,90 +239,64 @@ def _calculate_griddata(
infeasible.x.append(x_value)
infeasible.y.append(y_value)

return (
xi,
yi,
zi,
cat_param_pos_x,
cat_param_labels_x,
cat_param_pos_y,
cat_param_labels_y,
feasible,
infeasible,
)
return zi, feasible, infeasible


def _generate_contour_subplot(
info: _SubContourInfo, ax: "Axes", cmap: "Colormap"
) -> "ContourSet" | None:
ax.label_outer()

if len(info.xaxis.indices) < 2 or len(info.yaxis.indices) < 2:
ax.label_outer()
return None

ax.set(xlabel=info.xaxis.name, ylabel=info.yaxis.name)
ax.set_xlim(info.xaxis.range[0], info.xaxis.range[1])
ax.set_ylim(info.yaxis.range[0], info.yaxis.range[1])
x_values, y_values = _filter_missing_values(info.xaxis, info.yaxis)
xi, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values)
yi, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values)
if info.xaxis.is_cat:
_, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values)
ax.set_xticks(x_cat_param_pos)
ax.set_xticklabels(x_cat_param_label)
else:
ax.set_xscale("log" if info.xaxis.is_log else "linear")
if info.yaxis.is_cat:
_, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values)
ax.set_yticks(y_cat_param_pos)
ax.set_yticklabels(y_cat_param_label)
else:
ax.set_yscale("log" if info.yaxis.is_log else "linear")

if info.xaxis.name == info.yaxis.name:
ax.label_outer()
return None

(
xi,
yi,
zi,
x_cat_param_pos,
x_cat_param_label,
y_cat_param_pos,
y_cat_param_label,
feasible_plot_values,
infeasible_plot_values,
) = _calculate_griddata(info)
zi, feasible_plot_values, infeasible_plot_values = _calculate_griddata(info)
cs = None
if len(zi) > 0:
if info.xaxis.is_log:
ax.set_xscale("log")
if info.yaxis.is_log:
ax.set_yscale("log")
if info.xaxis.name != info.yaxis.name:
# Contour the gridded data.
ax.contour(xi, yi, zi, 15, linewidths=0.5, colors="k")
cs = ax.contourf(xi, yi, zi, 15, cmap=cmap.reversed())
assert isinstance(cs, ContourSet)
# Plot data points.
ax.scatter(
feasible_plot_values.x,
feasible_plot_values.y,
marker="o",
c="black",
s=20,
edgecolors="grey",
linewidth=2.0,
)
ax.scatter(
infeasible_plot_values.x,
infeasible_plot_values.y,
marker="o",
c="#cccccc",
s=20,
edgecolors="grey",
linewidth=2.0,
)
# Contour the gridded data.
ax.contour(xi, yi, zi, 15, linewidths=0.5, colors="k")
cs = ax.contourf(xi, yi, zi, 15, cmap=cmap.reversed())
assert isinstance(cs, ContourSet)
# Plot data points.
ax.scatter(
feasible_plot_values.x,
feasible_plot_values.y,
marker="o",
c="black",
s=20,
edgecolors="grey",
linewidth=2.0,
)
ax.scatter(
infeasible_plot_values.x,
infeasible_plot_values.y,
marker="o",
c="#cccccc",
s=20,
edgecolors="grey",
linewidth=2.0,
)

ax.label_outer()
return cs


Expand Down

0 comments on commit e6e0cb4

Please sign in to comment.