Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the official implementation of c-TPE #177

Merged
merged 11 commits into from
Dec 4, 2024
Merged

Conversation

nabenabe0928
Copy link
Contributor

@nabenabe0928 nabenabe0928 commented Nov 13, 2024

Contributor Agreements

Please read the contributor agreements and if you agree, please click the checkbox below.

  • I agree to the contributor agreements.

Tip

Please follow the Quick TODO list to smoothly merge your PR.

Motivation

This PR is to migrate the official implementation of c-TPE into OptunaHub.

TODO List towards PR Merge

Please remove this section if this PR is not an addition of a new package.
Otherwise, please check the following TODO list:

  • Copy ./template/ to create your package
  • Replace <COPYRIGHT HOLDER> in LICENSE of your package with your name
  • Fill out README.md in your package
  • Add import statements of your function or class names to be used in __init__.py
  • (Optional) Add from __future__ import annotations at the head of any Python files that include typing to support older Python versions
  • Apply the formatter based on the tips in README.md
  • Check whether your module works as intended based on the tips in README.md

@nabenabe0928 nabenabe0928 marked this pull request as draft November 13, 2024 08:33
@y0z y0z added the new-package New packages label Nov 18, 2024
@nabenabe0928 nabenabe0928 marked this pull request as ready for review December 3, 2024 20:37
@nabenabe0928 nabenabe0928 changed the title Add a draft of c-TPE Add the official implementation of c-TPE Dec 4, 2024
@nabenabe0928
Copy link
Contributor Author

I verified the performance of c-TPE by using the experiment setup of Fig. 3 (Top Row) in the original paper.
slide-localization

@y0z y0z self-assigned this Dec 4, 2024
Copy link
Member

@y0z y0z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed the example in README.md works.
I leave some comments.

package/samplers/ctpe/README.md Show resolved Hide resolved
package/samplers/ctpe/README.md Show resolved Hide resolved
@nabenabe0928
Copy link
Contributor Author

nabenabe0928 commented Dec 4, 2024

optuna==4.1.0

To do this experiment, I extracted HPOLib using hpolib-extractor.

Please execute the following command:

$ cd ~/hpo_benchmarks/hpolib
$ wget http://ml4aad.org/wp-content/uploads/2019/01/fcnet_tabular_benchmarks.tar.gz
$ tar xf fcnet_tabular_benchmarks.tar.gz
$ mv fcnet_tabular_benchmarks/*.hdf5 .
$ rm -r fcnet_tabular_benchmarks/
$ pip install hpolib-extractor
$ python -c "from hpolib_extractor import extract_hpolib; extract_hpolib(data_dir='./', epochs=[100])"
Verification Code
from __future__ import annotations

import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import optuna
import optunahub


plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["font.size"] = 18
plt.rcParams["mathtext.fontset"] = "stix"  # The setting of math font
plt.rcParams["text.usetex"] = True

# Slice Localization with constraint levels of 0.1, 0.5, and 0.9.
# This information is to reproduce the results of Fig. 3 (Top Row).
n_params_key = "n_params"
runtime_key = "runtime"
benchmark_info = {
    0.1: {runtime_key: 119.38172, n_params_key: 8401.0, "oracle": 0.0010303777},
    0.5: {runtime_key: 366.13934, n_params_key: 52929.0, "oracle": 0.00028238542},
    0.9: {runtime_key: 1113.1345, n_params_key: 229633.0, "oracle": 0.00018452932},
}
data_path = f"{os.environ['HOME']}/hpo_benchmarks/hpolib/slice_localization.pkl"
DATASET = pickle.load(open(data_path, mode="rb"))


class HPOLib:
    def __init__(self, quantile: float, seed: int | None = None) -> None:
        assert quantile in [0.1, 0.5, 0.9]
        self._rng = np.random.RandomState(seed)
        self._thresholds = {
            runtime_key: benchmark_info[quantile][runtime_key],
            n_params_key: benchmark_info[quantile][n_params_key],
        }
        self._oracle = benchmark_info[quantile]["oracle"]

    def reseed(self, seed: int | None = None) -> None:
        self._rng = np.random.RandomState(seed)

    def __call__(self, trial: optuna.Trial) -> float:
        param_indices = [
            trial.suggest_categorical("activation_fn_1", list(range(2))),
            trial.suggest_categorical("activation_fn_2", list(range(2))),
            trial.suggest_int("batch_size", 0, 3),
            trial.suggest_int("dropout_1", 0, 2),
            trial.suggest_int("dropout_2", 0, 2),
            trial.suggest_int("init_lr", 0, 5),
            trial.suggest_categorical("lr_schedule", list(range(2))),
            trial.suggest_int("n_units_1", 0, 5),
            trial.suggest_int("n_units_2", 0, 5),
        ]
        config_id = "".join([str(i) for i in param_indices])
        seed = self._rng.randint(4)
        result = DATASET[config_id]
        loss = result["valid_mse"][seed][100]
        trial.set_user_attr(runtime_key, result[runtime_key][seed])
        trial.set_user_attr(n_params_key, result[n_params_key])
        is_feasible = all(
            trial.user_attrs[key] <= self._thresholds[key] for key in [n_params_key, runtime_key]
        )
        trial.set_user_attr("feasible", is_feasible)
        return loss

    def constraints_func(self, trial: optuna.trial.FrozenTrial) -> tuple[float, float]:
        return [
            trial.user_attrs[key] - self._thresholds[key] for key in [n_params_key, runtime_key]
        ]

    def compute_absolute_percentage_loss(self, study: optuna.Study) -> np.ndarray:
        trials = study.trials
        is_feasible = np.array([t.user_attrs["feasible"] for t in trials])
        loss_vals = np.array([t.value for t in trials])
        loss_vals[~is_feasible] = np.inf
        return (np.minimum.accumulate(loss_vals) - self._oracle) / self._oracle


def collect_results(n_seeds: int, n_trials: int, sampler_name: str) -> dict[float, np.ndarray]:
    if sampler_name == "ctpe":
        package_name = "samplers/ctpe"
        repo_owner = "nabenabe0928"
        mod = optunahub.load_local_module(package=package_name, registry_root="./package/")
        sampler_cls = mod.cTPESampler
    elif sampler_name == "random":
        sampler_cls = lambda constraints_func, seed: optuna.samplers.RandomSampler(seed)
    elif sampler_name == "nsgaii":
        sampler_cls = lambda constraints_func, seed: optuna.samplers.NSGAIISampler(
            seed=seed, constraints_func=constraints_func, population_size=8
        )
    elif sampler_name == "tpe":
        sampler_cls = lambda constraints_func, seed: optuna.samplers.TPESampler(
            multivariate=True, constraints_func=constraints_func, seed=seed
        )
    else:
        assert False, sampler_name

    results = {0.1: [], 0.5: [], 0.9: []}
    for q in results:
        hpolib = HPOLib(quantile=q)
        for seed in range(n_seeds):
            print(f"Start with {sampler_name=}, {q=}, and {seed=}.")
            hpolib.reseed(seed)
            sampler = sampler_cls(constraints_func=hpolib.constraints_func, seed=seed)
            study = optuna.create_study(sampler=sampler)
            study.optimize(hpolib, n_trials=n_trials)
            results[q].append(hpolib.compute_absolute_percentage_loss(study))

    return {k: np.asarray(v) for k, v in results.items()}


def visualize(axes: plt.Axes, n_seeds: int, n_trials: int, sampler_name: str) -> plt.Line2D:
    results = collect_results(n_seeds, n_trials, sampler_name)
    dx = np.arange(n_trials) + 1
    color = {
        "ctpe": "red",
        "nsgaii": "magenta",
        "random": "olive",
        "tpe": "blue",
    }[sampler_name]
    for ax, (q, res) in zip(axes, results.items()):
        m = np.mean(res, axis=0)
        s = np.std(res, axis=0) / np.sqrt(n_seeds)
        line, = ax.plot(dx, m, color=color)
        ax.fill_between(dx, m - s, m + s, color=color, alpha=0.2)
        ax.set_title(f"Quantile: {q:.1f}")

    return line


def main(n_seeds: int, n_trials: int) -> None:
    fig, axes = plt.subplots(
        ncols=3, figsize=(27, 4.5), sharex=True, sharey=True, gridspec_kw={"wspace": 0.05}
    )
    for ax in axes:
        ax.grid(which="minor", color="gray", linestyle=":")
        ax.grid(which="major", color="black")
        ax.set_yscale("log")
        ax.set_xlim(1, n_trials)
        ax.set_ylim(0.1, 2000)

    # sampler_names = ["ctpe", "tpe", "nsgaii", "random"]
    sampler_names = ["ctpe", "tpe", "nsgaii", "random"]
    labels = [
        {"ctpe": "c-TPE", "tpe": "Optuna TPE", "nsgaii": "CNSGA-II", "random": "Random"}[name]
        for name in sampler_names
    ]
    lines = [visualize(axes, n_seeds, n_trials, sampler_name) for sampler_name in sampler_names]
    fig.supxlabel("\# of Trials", y=-0.04)
    fig.supylabel("Absolute Percentage Loss", x=0.09)
    fig.legend(
        handles=lines,
        labels=labels,
        loc="lower center",
        ncol=len(labels),
        fontsize=24,
        bbox_to_anchor=(0.5, -0.25),
    )
    plt.savefig("slice-localization.png", bbox_inches="tight")


optuna.logging.set_verbosity(optuna.logging.CRITICAL)
main(n_seeds=50, n_trials=200)

Copy link
Member

@y0z y0z left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@y0z y0z merged commit a1e8463 into optuna:main Dec 4, 2024
4 checks passed
@y0z
Copy link
Member

y0z commented Dec 4, 2024

Published!
https://hub.optuna.org/samplers/ctpe/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new-package New packages
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants