diff --git a/package/samplers/cmamae/README.md b/package/samplers/cmamae/README.md index 7801df9f..ac015ef5 100644 --- a/package/samplers/cmamae/README.md +++ b/package/samplers/cmamae/README.md @@ -27,6 +27,9 @@ with improvement ranking, all wrapped up in a However, it is possible to implement many variations of CMA-MAE and other quality diversity algorithms using pyribs. +For visualizing the results of the `CmaMaeSampler`, note that we use the +`plot_grid_archive_heatmap` function from the `plot_pyribs` plugin. + ## Class or Function Names - CmaMaeSampler @@ -46,12 +49,17 @@ $ pip install ribs ## Example ```python +import matplotlib.pyplot as plt import optuna import optunahub + module = optunahub.load_module("samplers/cmamae") CmaMaeSampler = module.CmaMaeSampler +plot_pyribs = optunahub.load_module(package="visualization/plot_pyribs") +plot_grid_archive_heatmap = plot_pyribs.plot_grid_archive_heatmap + def objective(trial: optuna.trial.Trial) -> float: """Returns an objective followed by two measures.""" @@ -80,6 +88,11 @@ if __name__ == "__main__": ) study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=10000) + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_grid_archive_heatmap(study, ax=ax) + plt.savefig("archive.png") + plt.show() ``` ## Others diff --git a/package/samplers/cmamae/example.py b/package/samplers/cmamae/example.py index 2903a81a..97b88743 100644 --- a/package/samplers/cmamae/example.py +++ b/package/samplers/cmamae/example.py @@ -1,3 +1,4 @@ +import matplotlib.pyplot as plt import optuna import optunahub @@ -5,6 +6,9 @@ module = optunahub.load_module("samplers/cmamae") CmaMaeSampler = module.CmaMaeSampler +plot_pyribs = optunahub.load_module(package="visualization/plot_pyribs") +plot_grid_archive_heatmap = plot_pyribs.plot_grid_archive_heatmap + def objective(trial: optuna.trial.Trial) -> float: """Returns an objective followed by two measures.""" @@ -33,3 +37,8 @@ def objective(trial: optuna.trial.Trial) -> float: ) study = optuna.create_study(sampler=sampler) study.optimize(objective, n_trials=10000) + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_grid_archive_heatmap(study, ax=ax) + plt.savefig("archive.png") + plt.show() diff --git a/package/visualization/plot_pyribs/LICENSE b/package/visualization/plot_pyribs/LICENSE new file mode 100644 index 00000000..51c2bdfd --- /dev/null +++ b/package/visualization/plot_pyribs/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Bryon Tjanaka + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/package/visualization/plot_pyribs/README.md b/package/visualization/plot_pyribs/README.md new file mode 100644 index 00000000..f7eb29ca --- /dev/null +++ b/package/visualization/plot_pyribs/README.md @@ -0,0 +1,74 @@ +--- +author: Bryon Tjanaka +title: Pyribs Visualization Wrappers +description: This visualizaton module provides wrappers around the visualization functions from pyribs, which is useful for plotting results from CmaMaeSampler. +tags: [visualization, quality diversity, pyribs] +optuna_versions: [4.0.0] +license: MIT License +--- + +## Class or Function Names + +- `plot_grid_archive_heatmap(study: optuna.Study, ax: plt.Axes, **kwargs)` + + - `study`: Optuna study with a sampler that uses pyribs. This function will plot the result archive from the sampler's scheduler. + - `ax`: Axes on which to plot the heatmap. If None, we retrieve the current axes. + - `**kwargs`: All remaining kwargs will be passed to [`grid_archive_heatmap`](https://docs.pyribs.org/en/stable/api/ribs.visualize.grid_archive_heatmap.html). + +## Installation + +```shell +$ pip install ribs[visualize] +``` + +## Example + +A minimal example would be the following: + +```python +import matplotlib.pyplot as plt +import optuna +import optunahub + +module = optunahub.load_module("samplers/cmamae") +CmaMaeSampler = module.CmaMaeSampler + +plot_pyribs = optunahub.load_module(package="visualization/plot_pyribs") +plot_grid_archive_heatmap = plot_pyribs.plot_grid_archive_heatmap + + +def objective(trial: optuna.trial.Trial) -> float: + """Returns an objective followed by two measures.""" + x = trial.suggest_float("x", -10, 10) + y = trial.suggest_float("y", -10, 10) + trial.set_user_attr("m0", 2 * x) + trial.set_user_attr("m1", x + y) + return x**2 + y**2 + + +if __name__ == "__main__": + sampler = CmaMaeSampler( + param_names=["x", "y"], + measure_names=["m0", "m1"], + archive_dims=[20, 20], + archive_ranges=[(-1, 1), (-1, 1)], + archive_learning_rate=0.1, + archive_threshold_min=-10, + n_emitters=1, + emitter_x0={ + "x": 0, + "y": 0, + }, + emitter_sigma0=0.1, + emitter_batch_size=20, + ) + study = optuna.create_study(sampler=sampler) + study.optimize(objective, n_trials=10000) + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_grid_archive_heatmap(study, ax=ax) + plt.savefig("archive.png") + plt.show() +``` + +![Example of this Plot](images/archive.png) diff --git a/package/visualization/plot_pyribs/__init__.py b/package/visualization/plot_pyribs/__init__.py new file mode 100644 index 00000000..6d6630eb --- /dev/null +++ b/package/visualization/plot_pyribs/__init__.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import matplotlib.pyplot as plt +import optuna +from ribs.visualize import grid_archive_heatmap + + +if TYPE_CHECKING: + from matplotlib.axes._axes import Axes + + +def plot_grid_archive_heatmap( # type: ignore + study: optuna.Study, + ax: Axes | None = None, + **kwargs, +) -> Axes: + """Wrapper around pyribs grid_archive_heatmap. + + Refer to the `grid_archive_heatmap + `_ + function from pyribs for information. + + Args: + study: Optuna study with a sampler that uses pyribs. This function will + plot the result archive from the sampler's scheduler. + ax: Axes on which to plot the heatmap. If None, we retrieve the current + axes. + kwargs: All remaining kwargs will be passed to `grid_archive_heatmap + `_. + Returns: + The axes on which the plot was created. + """ + if ax is None: + ax = plt.gca() + + archive = study.sampler.scheduler.result_archive + grid_archive_heatmap(archive, ax=ax, **kwargs) + + return ax + + +__all__ = ["plot_grid_archive_heatmap"] diff --git a/package/visualization/plot_pyribs/example.py b/package/visualization/plot_pyribs/example.py new file mode 100644 index 00000000..97b88743 --- /dev/null +++ b/package/visualization/plot_pyribs/example.py @@ -0,0 +1,44 @@ +import matplotlib.pyplot as plt +import optuna +import optunahub + + +module = optunahub.load_module("samplers/cmamae") +CmaMaeSampler = module.CmaMaeSampler + +plot_pyribs = optunahub.load_module(package="visualization/plot_pyribs") +plot_grid_archive_heatmap = plot_pyribs.plot_grid_archive_heatmap + + +def objective(trial: optuna.trial.Trial) -> float: + """Returns an objective followed by two measures.""" + x = trial.suggest_float("x", -10, 10) + y = trial.suggest_float("y", -10, 10) + trial.set_user_attr("m0", 2 * x) + trial.set_user_attr("m1", x + y) + return x**2 + y**2 + + +if __name__ == "__main__": + sampler = CmaMaeSampler( + param_names=["x", "y"], + measure_names=["m0", "m1"], + archive_dims=[20, 20], + archive_ranges=[(-1, 1), (-1, 1)], + archive_learning_rate=0.1, + archive_threshold_min=-10, + n_emitters=1, + emitter_x0={ + "x": 0, + "y": 0, + }, + emitter_sigma0=0.1, + emitter_batch_size=20, + ) + study = optuna.create_study(sampler=sampler) + study.optimize(objective, n_trials=10000) + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_grid_archive_heatmap(study, ax=ax) + plt.savefig("archive.png") + plt.show() diff --git a/package/visualization/plot_pyribs/images/archive.png b/package/visualization/plot_pyribs/images/archive.png new file mode 100644 index 00000000..96a615e0 Binary files /dev/null and b/package/visualization/plot_pyribs/images/archive.png differ