Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MTN update readme #40

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
8f70bb1
ENH plot_quadratic and update config file
MatDag Feb 19, 2024
61806e0
ENH add quadratics to readme
MatDag Feb 19, 2024
0ecc679
FIX latex readme
MatDag Feb 19, 2024
2307ccc
FIX latex readme
MatDag Feb 19, 2024
391f6cb
Merge branch 'main' into update_readme
MatDag Oct 11, 2024
f394a6d
FIX typo
MatDag Oct 11, 2024
8da9ac5
FIX typo
MatDag Oct 11, 2024
047c83e
FIX double backslash
MatDag Oct 11, 2024
0d8b556
FIX typo
MatDag Oct 11, 2024
17e2267
ENH doc eigenvalues of matrices
MatDag Oct 11, 2024
89f0966
FIX typo
MatDag Oct 11, 2024
479e1d5
ENH value function evaluation nit that expensive
MatDag Oct 11, 2024
76ac1fe
FIX double backquotes
MatDag Oct 11, 2024
8dc98c2
WIP doc
MatDag Oct 14, 2024
df74427
WIP complete doc how to create a solver
MatDag Oct 14, 2024
dd3380c
ENH add comments amigo
MatDag Oct 14, 2024
beb83f1
ENH readme
MatDag Oct 14, 2024
c77ba62
ENH docstring StochasticJaxSolver
MatDag Oct 14, 2024
05dfa21
ENH comment amigo
MatDag Oct 14, 2024
38f2197
FIX flake8
MatDag Oct 14, 2024
d4b1b35
FIX review suggestions README.rst
MatDag Oct 16, 2024
d093b8b
CLN create template_stochastic_solver and moove explanation from AmIGO
MatDag Oct 16, 2024
e0d785c
ENH add template_solver.py
MatDag Oct 17, 2024
f333a7c
ENH add template_dataset.py
MatDag Oct 17, 2024
c31f5cc
ENH ref to benchopt template
MatDag Oct 17, 2024
931e095
Update README.rst
tomMoral Oct 18, 2024
143ca61
ENH apply suggestion readme
MatDag Oct 18, 2024
481880f
ENH replace rst by md
MatDag Oct 18, 2024
ae2ce5d
FIX brackets
MatDag Oct 18, 2024
b0cb39a
FIX brackets
MatDag Oct 18, 2024
928b70a
FIX brackets
MatDag Oct 18, 2024
a29f393
FIX brackets
MatDag Oct 18, 2024
1038359
FIX brackets
MatDag Oct 18, 2024
50ca4aa
Update README.md
MatDag Oct 18, 2024
cd580c1
Update README.md
MatDag Oct 18, 2024
5f68c11
CLN remove tilde
MatDag Oct 18, 2024
0d1ab03
FIX ref
MatDag Oct 18, 2024
7d9508e
CLN remove useless params
MatDag Oct 18, 2024
59001b3
WIP
MatDag Oct 18, 2024
2fac0da
ENH simplify template_dataset
MatDag Oct 22, 2024
3d7bded
FIX typo
MatDag Oct 22, 2024
10b3bc0
FIX batched_quadratics disappeared in simulated.py...
MatDag Oct 22, 2024
5b80320
CLN remove plot_quadratics.py
MatDag Oct 22, 2024
302d8af
ENH rm generate_matrices
MatDag Oct 23, 2024
a7769ec
CLN docstring
MatDag Oct 23, 2024
7a881a8
FIX flake8
MatDag Oct 23, 2024
9b4c556
ENH callback info template dataset
MatDag Oct 24, 2024
510c812
ENH lr_scheduler template_stochastic_solver
MatDag Oct 24, 2024
e918f5a
ENH add comments oracles
MatDag Nov 21, 2024
f65755e
ENH docstring init
MatDag Nov 21, 2024
24c08fa
ENH docstring get_step
MatDag Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Bilevel Optimization Benchmark

*Results can be consulted on https://benchopt.github.io/results/benchmark_bilevel.html*

BenchOpt is a package to simplify and make more transparent and
BenchOpt is a package to simplify, make more transparent, and
reproducible the comparisons of optimization algorithms.
This benchmark is dedicated to solvers for bilevel optimization:

Expand All @@ -15,9 +15,30 @@ where $g$ and $f$ are two functions of two variables.
Different problems
------------------

This benchmark currently implements two bilevel optimization problems: regularization selection, and hyper data cleaning.
This benchmark currently implements three bilevel optimization problems: quadratic problem, regularization selection, and hyper data cleaning.

1 - Regularization selection
1 - Simulated quadratic bilevel problem
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In this problem, the inner and the outer functions are quadritics functions defined of $\\mathbb{R}^{d\\times p}$
MatDag marked this conversation as resolved.
Show resolved Hide resolved

$$g(x, z) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2} z^\\top H_i^z z + \\frac{1}{2} x^\\top H_i^x x + x^\\top C_i z + c_i^\\top z + d_i^\\top x$$

and

$$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\frac{1}{2} x^\\top \\tilde H_j^x x + x^\\top \\tilde C_j z + \\tilde c_j^\\top z + \\tilde d_j^\\top x$$

where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i$, $\\tilde c_j$ are vectors of size $d$, and $d_i, \\tilde d_j$ are vectors of size $p$.

The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are generated randomly such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between ``mu_inner``, and ``L_inner_inner``, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between ``mu_inner``, and ``L_inner_outer``, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between ``mu_inner``, and ``L_outer_inner``, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between ``mu_inner``, and ``L_outer_outer``.
MatDag marked this conversation as resolved.
Show resolved Hide resolved

The matrices $C_i, \\tilde C_j$ are generated randomly such that the spectral norm of $\\frac1n\\sum_i C_i$ is lower than ``L_cross_inner``, and the spectral norm of $\\frac1m\\sum_j \\tilde C_j$ is lower than ``L_cross_outer``.

Note that in this setting, the solution of the inner problem is a linear system.
As, the full batch inner and outer functions can be computed efficiently directly with the average Hessian matrices, the value function can be evaluated in closed form.
MatDag marked this conversation as resolved.
Show resolved Hide resolved

MatDag marked this conversation as resolved.
Show resolved Hide resolved

2 - Regularization selection
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In this problem, the inner function $g$ is defined by
Expand All @@ -41,7 +62,7 @@ Covtype

*Homepage : https://archive.ics.uci.edu/dataset/31/covertype*

This is a logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features and $y_i=\\pm1$ is the binary target.
This is a logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features, and $y_i=\\pm1$ is the binary target.
MatDag marked this conversation as resolved.
Show resolved Hide resolved
For this problem, the loss is $\\ell(d_i, z) = \\log(1+\\exp(-y_i a_i^T z))$, and the regularization is simply given by
$$\\mathcal{R}(x, z) = \\frac12\\sum_{j=1}^p\\exp(x_j)z_j^2,$$
each coefficient in $z$ is independently regularized with the strength $\\exp(x_j)$.
Expand All @@ -51,18 +72,18 @@ Ijcnn1

*Homepage : https://www.openml.org/search?type=data&sort=runs&id=1575&status=active*

This is a multicalss logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features and $y_i\\in \\{1,\\dots, k\\}$ is the integer target, with k the number of classes.
This is a multicalss logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features, and $y_i\\in \\{1,\\dots, k\\}$ is the integer target, with k the number of classes.
MatDag marked this conversation as resolved.
Show resolved Hide resolved
For this problem, the loss is $\\ell(d_i, z) = \\text{CrossEntropy}(za_i, y_i)$ where $z$ is now a k x p matrix. The regularization is given by
$$\\mathcal{R}(x, z) = \\frac12\\sum_{j=1}^k\\exp(x_j)\\|z_j\\|^2,$$
each line in $z$ is independently regularized with the strength $\\exp(x_j)$.


2 - Hyper data cleaning
3 - Hyper data cleaning
^^^^^^^^^^^^^^^^^^^^^^^

This problem was first introduced by [Fra2017]_ .
In this problem, the data is the MNIST dataset.
The training set has been corrupted: with a probability $p$, the label of the image $y\\in\\{1,\\dots,10\\}$ is replaced by another random label between 1 and 10.
The training set has been corrupted: with a probability $p$, the label of the image $y\\in\\{1,\\dots,10\\}$ is replaced by another random label between 1, and 10.
MatDag marked this conversation as resolved.
Show resolved Hide resolved
We do not know beforehand which data has been corrupted.
We have a clean testing set, which has not been corrupted.
The goal is to fit a model on the corrupted training data that has good performances on the test set.
Expand Down Expand Up @@ -91,7 +112,7 @@ This benchmark can be run using the following commands:
$ git clone https://github.com/benchopt/benchmark_bilevel
$ benchopt run benchmark_bilevel

Apart from the problem, options can be passed to `benchopt run`, to restrict the benchmarks to some solvers or datasets, e.g.:
Apart from the problem, options can be passed to ``benchopt run``, to restrict the benchmarks to some solvers or datasets, e.g.:

.. code-block::

Expand All @@ -103,10 +124,24 @@ You can also use config files to setup the benchmark run:

$ benchopt run benchmark_bilevel --config config/X.yml

where `X.yml` is a config file. See https://benchopt.github.io/index.html#run-a-benchmark for an example of a config file. This will possibly launch a huge grid search. When available, you can rather use the file `X_best_params.yml` in order to launch an experiment with a single set of parameters for each solver.
where ``X.yml`` is a config file. See https://benchopt.github.io/index.html#run-a-benchmark for an example of a config file. This will possibly launch a huge grid search. When available, you can rather use the file ``X_best_params.yml`` in order to launch an experiment with a single set of parameters for each solver.

Use ``benchopt run -h`` for more details about these options, or visit https://benchopt.github.io/api.html.

How to contribute to the benchmark?
-----------------------------------

Use `benchopt run -h` for more details about these options, or visit https://benchopt.github.io/api.html.
If you want to add a solver or a new problem, you are welcome to open an issue or submit a pull request!

1 - How to add a new solvers?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Each solver derive from the [`benchopt.BaseSolver` class](https://benchopt.github.io/user_guide/generated/benchopt.BaseSolver.html) in the [solvers](solvers) folder. The solvers are separated among the stochastic JAX solvers and the others:
* Stochastic Jax solver: these solvers inherit from the [`StochasticJaxSolver` class](benchmark_utils/stochastic_jax_solver.py) see the detailed explanations in the [template stochastic solver](solvers/template_stochastic_solver.py).
* Other solver: see the detailed explanation in the [Benchopt documentation](https://benchopt.github.io/tutorials/add_solver.html). An example is provided in the [template solver](solvers/template_solver.py).

2 - How to add a new problem?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In this benchmark, each problem is defined by a [Dataset class](https://benchopt.github.io/user_guide/generated/benchopt.BaseDataset.html) in the [datasets](datasets) folder. A [template](datasets/template_dataset.py) is provided.

Cite
----
Expand Down
26 changes: 22 additions & 4 deletions benchmark_utils/stochastic_jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,28 @@ def set_objective(self, f_inner, f_outer, n_inner_samples, n_outer_samples,

inner_var0, outer_var0: array-like, shape (dim_inner,) (dim_outer,)

f_inner_fb, f_outer_fb: callable
Full batch version of f_inner and f_outer. Should take as input:
* inner_var: array-like, shape (dim_inner,)
* outer_var: array-like, shape (dim_outer,)
Attributes
----------
f_inner, f_outer: callable
Inner and outer objective function for the bilevel optimization
problem.

n_inner_samples, n_outer_samples: int
Number of samples to draw for the inner and outer objective
functions.

inner_var0, outer_var0: array-like, shape (dim_inner,) (dim_outer,)

batch_size_inner, batch_size_outer: int
Size of the minibatch to use for the inner and outer objective
functions.

state_inner_sampler, state_outer_sampler: dict
State of the minibatch samplers for the inner and outer objectives.

one_epoch: callable
Jitted function that runs the solver for one epoch. One epoch is
defined as `eval_freq` iterations of the solver.
"""

self.f_inner = f_inner
Expand Down
4 changes: 2 additions & 2 deletions config/quadratics_021424_best_params.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ objective:
dataset:
- quadratic[L_cross_inner=0.1,L_cross_outer=0.1,mu_inner=[.1],n_samples_inner=[32768],n_samples_outer=[1024],dim_inner=100,dim_outer=10]
solver:
- AmIGO[batch_size=64,eval_freq=16,framework=none,n_inner_steps=10,outer_ratio=1.0,step_size=0.01,random_state=[1,2,3,4,5,6,7,8,9,10]]
- AmIGO[batch_size=64,eval_freq=16,framework=none,n_inner_steps=10,outer_ratio=0.1,step_size=0.01,random_state=[1,2,3,4,5,6,7,8,9,10]]
- MRBO[batch_size=64,eta=0.5,eval_freq=16,framework=none,n_shia_steps=10,outer_ratio=0.1,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]]
- SABA[batch_size=64,eval_freq=64,framework=none,mode_init_memory=zero,outer_ratio=1.0,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]]
- SRBA[batch_size=64,eval_freq=64,framework=none,outer_ratio=0.1,period_frac=0.5,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]]
- SRBA[batch_size=64,eval_freq=64,framework=none,outer_ratio=1.0,period_frac=0.5,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]]
- StocBiO[batch_size=64,eval_freq=16,framework=none,n_inner_steps=10,n_shia_steps=10,outer_ratio=1.0,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]]
- VRBO[batch_size=64,eval_freq=2,framework=none,n_inner_steps=10,n_shia_steps=10,outer_ratio=1.0,period_frac=0.01,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]]
- F2SA[batch_size=64,delta_lmbda=0.01,eval_freq=16,framework=none,lmbda0=1,n_inner_steps=10,outer_ratio=1.0,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]]
Expand Down
159 changes: 159 additions & 0 deletions datasets/template_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from benchopt import BaseDataset
from benchopt import safe_import_context

# Protect the import with `safe_import_context()`. This allows:
# - skipping import to speed up autocompletion in CLI.
# - getting requirements info when all dependencies are not installed.
with safe_import_context() as import_ctx:
import numpy as np
from libsvmdata import fetch_libsvm
MatDag marked this conversation as resolved.
Show resolved Hide resolved

import jax
import jax.numpy as jnp
from functools import partial

from jaxopt import LBFGS


def loss_sample(inner_var, outer_var, x, y):
return -jax.nn.log_sigmoid(y*jnp.dot(inner_var, x))


def loss(inner_var, outer_var, X, y):
batched_loss = jax.vmap(loss_sample, in_axes=(None, None, 0, 0))
return jnp.mean(batched_loss(inner_var, outer_var, X, y), axis=0)
MatDag marked this conversation as resolved.
Show resolved Hide resolved


# All datasets must be named `Dataset` and inherit from `BaseDataset`
class Dataset(BaseDataset):
"""Hyperparameter optimization with IJCNN1 dataset."""
# Name to select the dataset in the CLI and to display the results.
name = "ijcnn1"
"""How to add a new problem to the benchmark?

This template dataset is an adaptation of the dataset from the benchopt
template benchmark (https://github.com/benchopt/template_benchmark/) to
the bilevel setting.
"""
MatDag marked this conversation as resolved.
Show resolved Hide resolved

install_cmd = 'conda'
# List of packages needed to run the dataset. See the corresponding
# section in objective.py
MatDag marked this conversation as resolved.
Show resolved Hide resolved
requirements = ['pip:libsvmdata', 'scikit-learn']
MatDag marked this conversation as resolved.
Show resolved Hide resolved

# List of parameters to generate the datasets. The benchmark will consider
# the cross product for each key in the dictionary.
# Any parameters 'param' defined here is available as `self.param`.
parameters = {
'reg_parametrization': ['exp'],
}
MatDag marked this conversation as resolved.
Show resolved Hide resolved

def get_data(self):
# The return arguments of this function are passed as keyword arguments
# to `Objective.set_data`. This defines the benchmark's
# API to pass data.
MatDag marked this conversation as resolved.
Show resolved Hide resolved
assert self.reg_parametrization in ['lin', 'exp'], (
f"unknown reg parameter '{self.reg_parametrization}'. "
"Should be 'lin' or 'exp'."
)
MatDag marked this conversation as resolved.
Show resolved Hide resolved

X_train, y_train = fetch_libsvm('ijcnn1')
X_val, y_val = fetch_libsvm('ijcnn1_test')

X_train, y_train = jnp.array(X_train), jnp.array(y_train)
X_val, y_val = jnp.array(X_val), jnp.array(y_val)

self.n_samples_inner = X_train.shape[0]
self.dim_inner = X_train.shape[1]
self.n_samples_outer = X_val.shape[0]
self.dim_outer = X_val.shape[1]

@partial(jax.jit, static_argnames=('batch_size'))
def f_inner(inner_var, outer_var, start=0, batch_size=1):
x = jax.lax.dynamic_slice(
X_train, (start, 0), (batch_size, X_train.shape[1])
)
y = jax.lax.dynamic_slice(
y_train, (start, ), (batch_size, )
)
res = loss(inner_var, outer_var, x, y)

if self.reg_parametrization == 'exp':
res += jnp.dot(jnp.exp(outer_var) * inner_var, inner_var)/2
elif self.reg_parametrization == 'lin':
res += jnp.dot(outer_var * inner_var, inner_var)/2
return res

@partial(jax.jit, static_argnames=('batch_size'))
def f_outer(inner_var, outer_var, start=0, batch_size=1):
x = jax.lax.dynamic_slice(
X_val, (start, 0), (batch_size, X_val.shape[1])
)
y = jax.lax.dynamic_slice(
y_val, (start, ), (batch_size, )
)
res = loss(inner_var, outer_var, x, y)
return res

f_inner_fb = partial(
f_inner, batch_size=X_train.shape[0], start=0
)
f_outer_fb = partial(
f_outer, batch_size=X_val.shape[0], start=0
)

solver_inner = LBFGS(fun=f_inner_fb)

def value_function(outer_var):
inner_var_star = solver_inner.run(
jnp.zeros(X_train.shape[1]), outer_var
).params

return f_outer_fb(inner_var_star, outer_var), inner_var_star

value_and_grad = jax.jit(
jax.value_and_grad(value_function, has_aux=True)
)

def metrics(inner_var, outer_var):
# Defines the metrics that are computed when calling the method
# Objective.evaluating_results(inner_var, outer_var) and saved
# in the result file. The output is a dictionary that contains at
# least the key `value`. The keyword arguments of this function are
# the keys of the dictionary returned by `Solver.get_result`.
(value_fun, inner_star), grad_value = value_and_grad(outer_var)
return dict(
value_func=float(value_fun),
value=float(jnp.linalg.norm(grad_value)**2),
inner_distance=float(jnp.linalg.norm(inner_star-inner_var)**2),
norm_outer_var=float(jnp.linalg.norm(outer_var)**2),
norm_regul=float(jnp.linalg.norm(np.exp(outer_var))**2),
)

def init_var(key):
# Provides an initialization of inner_var and outer_var.
keys = jax.random.split(key, 2)
inner_var0 = jax.random.normal(keys[0], (self.dim_inner,))
outer_var0 = jax.random.uniform(keys[1], (self.dim_outer,))
if self.reg_parametrization == 'exp':
outer_var0 = jnp.log(outer_var0)
return inner_var0, outer_var0

data = dict(
pb_inner=(f_inner, self.n_samples_inner, self.dim_inner,
f_inner_fb),
pb_outer=(f_outer, self.n_samples_outer, self.dim_outer,
f_outer_fb),
metrics=metrics,
init_var=init_var,
)

# The output should be a dict that contains the keys `pb_inner`,
# `pb_outer`, `metrics`, and optionnally `init_var`.
# `pb_inner`` is a tuple that contains the inner function, the number
# of inner samples, the dimension of the inner variable and the full
# batch version of the inner version.
# `pb_outer` in analogous.
# The key `metrics` contains the function `metrics`.
# The key `init_var` contains the function `init_var` when applicable.
MatDag marked this conversation as resolved.
Show resolved Hide resolved
return data
Loading
Loading