From 8f70bb1ebfa27af402df0141afc53acc1598c2d4 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 19 Feb 2024 21:41:06 +0100 Subject: [PATCH 01/50] ENH plot_quadratic and update config file --- config/quadratics_021424_best_params.yml | 4 +- figures/plot_quadratics.py | 201 +++++++++++++++++++++++ 2 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 figures/plot_quadratics.py diff --git a/config/quadratics_021424_best_params.yml b/config/quadratics_021424_best_params.yml index caedaca..15bf7a4 100644 --- a/config/quadratics_021424_best_params.yml +++ b/config/quadratics_021424_best_params.yml @@ -3,10 +3,10 @@ objective: dataset: - quadratic[L_cross_inner=0.1,L_cross_outer=0.1,mu_inner=[.1],n_samples_inner=[32768],n_samples_outer=[1024],dim_inner=100,dim_outer=10] solver: - - AmIGO[batch_size=64,eval_freq=16,framework=none,n_inner_steps=10,outer_ratio=1.0,step_size=0.01,random_state=[1,2,3,4,5,6,7,8,9,10]] + - AmIGO[batch_size=64,eval_freq=16,framework=none,n_inner_steps=10,outer_ratio=0.1,step_size=0.01,random_state=[1,2,3,4,5,6,7,8,9,10]] - MRBO[batch_size=64,eta=0.5,eval_freq=16,framework=none,n_shia_steps=10,outer_ratio=0.1,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]] - SABA[batch_size=64,eval_freq=64,framework=none,mode_init_memory=zero,outer_ratio=1.0,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]] - - SRBA[batch_size=64,eval_freq=64,framework=none,outer_ratio=0.1,period_frac=0.5,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]] + - SRBA[batch_size=64,eval_freq=64,framework=none,outer_ratio=1.0,period_frac=0.5,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]] - StocBiO[batch_size=64,eval_freq=16,framework=none,n_inner_steps=10,n_shia_steps=10,outer_ratio=1.0,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]] - VRBO[batch_size=64,eval_freq=2,framework=none,n_inner_steps=10,n_shia_steps=10,outer_ratio=1.0,period_frac=0.01,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]] - F2SA[batch_size=64,delta_lmbda=0.01,eval_freq=16,framework=none,lmbda0=1,n_inner_steps=10,outer_ratio=1.0,step_size=0.1,random_state=[1,2,3,4,5,6,7,8,9,10]] diff --git a/figures/plot_quadratics.py b/figures/plot_quadratics.py new file mode 100644 index 0000000..b4b640c --- /dev/null +++ b/figures/plot_quadratics.py @@ -0,0 +1,201 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import matplotlib as mpl +import matplotlib.pyplot as plt + +mpl.rc('text', usetex=True) + +FILE_NAME = Path(__file__).with_suffix('') +METRIC = 'objective_value' + +# DEFAULT_WIDTH = 3.25 +DEFAULT_WIDTH = 3 +DEFAULT_HEIGHT = 2 +LEGEND_RATIO = 0.1 + +N_POINTS = 500 +X_LIM = 250 + +# Utils to get common STYLES object and setup matplotlib +# for all plots + +mpl.rcParams.update({ + 'font.size': 10, + 'legend.fontsize': 'small', + 'axes.labelsize': 'small', + 'xtick.labelsize': 'small', + 'ytick.labelsize': 'small' +}) + +STYLES = { + '*': dict(lw=1.5), + + 'amigo': dict(color='#5778a4', label=r'AmIGO'), + 'mrbo': dict(color='#e49444', label=r'MRBO'), + 'vrbo': dict(color='#e7ca60', label=r'VRBO'), + 'saba': dict(color='#d1615d', label=r'SABA'), + 'stocbio': dict(color='#85b6b2', label=r'StocBiO'), + 'srba': dict(color='#6a9f58', label=r'\textbf{SRBA}', lw=2), + 'f2sa': dict(color='#bcbd22', label=r'F2SA'), +} + + +def get_param(name, param='period_frac'): + params = {} + for vals in name.split("[", maxsplit=1)[1][:-1].split(","): + k, v = vals.split("=") + if v.replace(".", "").isnumeric(): + params[k] = float(v) + else: + params[k] = v + return params[param] + + +def drop_param(name, param='period_frac'): + new_name = name.split("[", maxsplit=1)[0] + '[' + for vals in name.split("[", maxsplit=1)[1][:-1].split(","): + k, v = vals.split("=") + if k != param: + new_name += f'{k}={v},' + return new_name[:-1] + ']' + + +if __name__ == "__main__": + fname = "quadratic.parquet" + fname = FILE_NAME.parent / fname + + if Path(f'{fname.stem}_stable.parquet').is_file(): + df = pd.read_parquet(f'{fname.stem}_stable.parquet') + print(f'{fname.stem}_stable.parquet') + else: + df = pd.read_parquet(fname) + print(fname) + + # normalize names + df['solver'] = df['solver_name'].apply( + lambda x: x.split('[')[0].lower() + ) + df['seed_solver'] = df['solver_name'].apply( + lambda x: get_param(x, 'random_state') + ) + df['seed_data'] = df['data_name'].apply( + lambda x: get_param(x, 'random_state') + ) + + df['solver_name'] = df['solver_name'].apply( + lambda x: drop_param(x, 'random_state') + ) + df['data_name'] = df['data_name'].apply( + lambda x: drop_param(x, 'random_state') + ) + df['cond'] = df['data_name'].apply( + lambda x: get_param(x, 'L_inner_inner')/get_param(x, 'mu_inner') + ) + df['n_inner'] = df['data_name'].apply( + lambda x: get_param(x, 'n_samples_inner') + ) + df['n_outer'] = df['data_name'].apply( + lambda x: get_param(x, 'n_samples_outer') + ) + df['n_tot'] = df['n_inner'] + df['n_outer'] + + # keep only runs all the random seeds + df['full'] = False + n_seeds = df.groupby('solver_name')['seed_data'].nunique() + n_seeds *= df.groupby('solver_name')['seed_solver'].nunique() + for s in n_seeds.index: + if n_seeds[s] == 10: + df.loc[df['solver_name'] == s, 'full'] = True + df = df.query('full == True') + df.to_parquet(f'{fname.stem}_stable.parquet') + + fig = plt.figure( + figsize=(DEFAULT_WIDTH, DEFAULT_HEIGHT * (1 + LEGEND_RATIO)) + ) + + gs = plt.GridSpec( + len(df['n_tot'].unique()), len(df['cond'].unique()), + height_ratios=[1] * len(df['n_tot'].unique()), + width_ratios=[1] * len(df['cond'].unique()), + hspace=0.5, wspace=0.3 + ) + + lines = [] + for i, n_tot in enumerate(df['n_tot'].unique()): + for j, cond in enumerate(df['cond'].unique()): + df_pb = df.query("cond == @cond & n_tot == @n_tot") + print(f"Cond: {cond}, n: {df_pb['n_inner'].iloc[0]}, " + + f"m: {df_pb['n_outer'].iloc[0]}") + to_plot = ( + df.query("cond == @cond & n_tot == @n_tot & stop_val <= 100") + .groupby(['solver', 'solver_name', 'data_name', 'stop_val']) + .median(METRIC) + .reset_index().sort_values(METRIC) + .groupby('solver').first()[['solver_name']] + ) + ( + df.query("solver_name in @to_plot.values.ravel()") + .to_parquet(f'{fname.stem}_best_params.parquet') + ) + print("Chosen parameters:") + for s in to_plot['solver_name']: + print(f"- {s}") + ax = fig.add_subplot(gs[i, j]) + for solver_name in to_plot['solver_name']: + df_solver = df_pb.query("solver_name == @solver_name") + solver = df_solver['solver'].iloc[0] + style = STYLES['*'].copy() + style.update(STYLES[solver]) + curves = [data[['time', METRIC]].values + for _, data in df_solver.groupby(['seed_data', + 'seed_solver'])] + vals = [c[:, 1] for c in curves] + times = [c[:, 0] for c in curves] + tmin = np.min([np.min(t) for t in times]) + tmax = np.max([np.max(t) for t in times]) + time_grid = np.linspace(np.log(tmin), np.log(tmax + 1), + N_POINTS) + interp_vals = np.zeros((len(times), N_POINTS)) + for k, (t, val) in enumerate(zip(times, vals)): + interp_vals[k] = np.exp(np.interp(time_grid, np.log(t), + np.log(val))) + time_grid = np.exp(time_grid) + medval = np.quantile(interp_vals, .5, axis=0) + q1 = np.quantile(interp_vals, .2, axis=0) + q2 = np.quantile(interp_vals, .8, axis=0) + if i == 0 and j == 0: + lines.append(ax.semilogy( + time_grid, np.sqrt(medval), + **style + )[0]) + else: + ax.semilogy( + time_grid, np.sqrt(medval), + **style + ) + ax.fill_between( + time_grid, + np.sqrt(q1), + np.sqrt(q2), + color=style['color'], alpha=0.3 + ) + ax.set_xlabel('Time (s)') + ax.set_ylabel(r'$\|\nabla h(x^t)\|$') + print(f"Min score ({solver}):", df_solver[METRIC].min()) + ax.grid() + ax.set_xlim([0, X_LIM]) + + if i == 0 and j == 0: + ax_legend = ax.legend( + handles=lines, + ncol=2, + prop={'size': 6.5} + ) + print(f"Saving {fname.with_suffix('.pdf')}") + fig.savefig( + fname.with_suffix('.pdf'), + bbox_inches='tight', + bbox_extra_artists=[ax_legend] + ) From 61806e02ac81db13475ff8e6655ec8c7aeaeb8fc Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 19 Feb 2024 21:41:26 +0100 Subject: [PATCH 02/50] ENH add quadratics to readme --- README.rst | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 485707f..cd777a5 100644 --- a/README.rst +++ b/README.rst @@ -15,9 +15,23 @@ where $g$ and $f$ are two functions of two variables. Different problems ------------------ -This benchmark currently implements two bilevel optimization problems: regularization selection, and hyper data cleaning. +This benchmark currently implements three bilevel optimization problems: quadratic problem, regularization selection, and hyper data cleaning. -1 - Regularization selection +1 - Quadratic bilevel problem +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In this problem, the inner and the outer functions are quadritics functions defined of $\mathbb{R}^{d\times p} + +$$g(x, z) = \frac1n \sum_{i=1}^n \frac12 z^\top H_i^z z + \frac12 x^\top H_i^x x + x^\top C_i z + c_i^\top z + d_i^\top x$$ + +and + +$$f(x, z) = \frac1m \sum_{j=1}^m \frac12 z^\top \tilde H_j^z z + \frac12 x^\top \tilde H_j^x x + x^\top \tilde C_j z + \tilde c_j^\top z + \tilde d_j^\top x$$ + +where $H_i^z, \tilde H_j^z$ are symmetric positive definite matrices of size $p\times p$, H_j^x, \tilde H_j^x$are symmetric positive definite matrices of size $d\times d$, $C_i, \tilde C_j$ are matrices of size $d\times p$, $c_i, \tilde c_j$ are vectors of size $d$ and $d_i, \tilde d_j$ are vectors of size $p$. + + +2 - Regularization selection ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ In this problem, the inner function $g$ is defined by @@ -57,7 +71,7 @@ $$\\mathcal{R}(x, z) = \\frac12\\sum_{j=1}^k\\exp(x_j)\\|z_j\\|^2,$$ each line in $z$ is independently regularized with the strength $\\exp(x_j)$. -2 - Hyper data cleaning +3 - Hyper data cleaning ^^^^^^^^^^^^^^^^^^^^^^^ This problem was first introduced by [Fra2017]_ . From 0ecc679538b7897f98ca8655c1cb3fc53b443f82 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 19 Feb 2024 21:43:33 +0100 Subject: [PATCH 03/50] FIX latex readme --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index cd777a5..7756e34 100644 --- a/README.rst +++ b/README.rst @@ -22,11 +22,11 @@ This benchmark currently implements three bilevel optimization problems: quadrat In this problem, the inner and the outer functions are quadritics functions defined of $\mathbb{R}^{d\times p} -$$g(x, z) = \frac1n \sum_{i=1}^n \frac12 z^\top H_i^z z + \frac12 x^\top H_i^x x + x^\top C_i z + c_i^\top z + d_i^\top x$$ +$$g(x, z) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2} z^\\top H_i^z z + \\frac{1}{2} x^\\top H_i^x x + x^\\top C_i z + c_i^\\top z + d_i^\\top x$$ and -$$f(x, z) = \frac1m \sum_{j=1}^m \frac12 z^\top \tilde H_j^z z + \frac12 x^\top \tilde H_j^x x + x^\top \tilde C_j z + \tilde c_j^\top z + \tilde d_j^\top x$$ +$$f(x, z) = \\frac{1}{m} \sum_{j=1}^m \\frac{1}{2} z^\\top \tilde H_j^z z + \\frac{1}{2} x^\\top \tilde H_j^x x + x^\\top \tilde C_j z + \tilde c_j^\\top z + \tilde d_j^\\top x$$ where $H_i^z, \tilde H_j^z$ are symmetric positive definite matrices of size $p\times p$, H_j^x, \tilde H_j^x$are symmetric positive definite matrices of size $d\times d$, $C_i, \tilde C_j$ are matrices of size $d\times p$, $c_i, \tilde c_j$ are vectors of size $d$ and $d_i, \tilde d_j$ are vectors of size $p$. From 2307cccd0695ccbb0b900158ecbd55abd85b397c Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 19 Feb 2024 21:44:44 +0100 Subject: [PATCH 04/50] FIX latex readme --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 7756e34..d8c75f3 100644 --- a/README.rst +++ b/README.rst @@ -26,9 +26,9 @@ $$g(x, z) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2} z^\\top H_i^z z + \\frac{1}{2 and -$$f(x, z) = \\frac{1}{m} \sum_{j=1}^m \\frac{1}{2} z^\\top \tilde H_j^z z + \\frac{1}{2} x^\\top \tilde H_j^x x + x^\\top \tilde C_j z + \tilde c_j^\\top z + \tilde d_j^\\top x$$ +$$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\frac{1}{2} x^\\top \\tilde H_j^x x + x^\\top \\tilde C_j z + \\tilde c_j^\\top z + \\tilde d_j^\\top x$$ -where $H_i^z, \tilde H_j^z$ are symmetric positive definite matrices of size $p\times p$, H_j^x, \tilde H_j^x$are symmetric positive definite matrices of size $d\times d$, $C_i, \tilde C_j$ are matrices of size $d\times p$, $c_i, \tilde c_j$ are vectors of size $d$ and $d_i, \tilde d_j$ are vectors of size $p$. +where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, H_j^x, \\tilde H_j^x$are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i, \\tilde c_j$ are vectors of size $d$ and $d_i, \\tilde d_j$ are vectors of size $p$. 2 - Regularization selection From f394a6d5b472aca2064bb7a70fd65110a7f84507 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 11 Oct 2024 10:30:12 +0200 Subject: [PATCH 05/50] FIX typo --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index d8c75f3..0969bda 100644 --- a/README.rst +++ b/README.rst @@ -28,7 +28,7 @@ and $$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\frac{1}{2} x^\\top \\tilde H_j^x x + x^\\top \\tilde C_j z + \\tilde c_j^\\top z + \\tilde d_j^\\top x$$ -where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, H_j^x, \\tilde H_j^x$are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i, \\tilde c_j$ are vectors of size $d$ and $d_i, \\tilde d_j$ are vectors of size $p$. +where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i, \\tilde c_j$ are vectors of size $d$ and $d_i, \\tilde d_j$ are vectors of size $p$. 2 - Regularization selection From 8da9ac56bb69656fff0a321b875b63bca69ecb95 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 11 Oct 2024 10:31:35 +0200 Subject: [PATCH 06/50] FIX typo --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 0969bda..cb5eb24 100644 --- a/README.rst +++ b/README.rst @@ -20,7 +20,7 @@ This benchmark currently implements three bilevel optimization problems: quadrat 1 - Quadratic bilevel problem ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In this problem, the inner and the outer functions are quadritics functions defined of $\mathbb{R}^{d\times p} +In this problem, the inner and the outer functions are quadritics functions defined of $\mathbb{R}^{d\times p}$ $$g(x, z) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2} z^\\top H_i^z z + \\frac{1}{2} x^\\top H_i^x x + x^\\top C_i z + c_i^\\top z + d_i^\\top x$$ From 047c83eef5803d838d8edffeb0304692d0bd8e7a Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 11 Oct 2024 10:32:09 +0200 Subject: [PATCH 07/50] FIX double backslash --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index cb5eb24..956dbdf 100644 --- a/README.rst +++ b/README.rst @@ -20,7 +20,7 @@ This benchmark currently implements three bilevel optimization problems: quadrat 1 - Quadratic bilevel problem ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In this problem, the inner and the outer functions are quadritics functions defined of $\mathbb{R}^{d\times p}$ +In this problem, the inner and the outer functions are quadritics functions defined of $\\mathbb{R}^{d\\times p}$ $$g(x, z) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2} z^\\top H_i^z z + \\frac{1}{2} x^\\top H_i^x x + x^\\top C_i z + c_i^\\top z + d_i^\\top x$$ From 0d8b5562d00e7e2d6da5618bee85d85a22154163 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 11 Oct 2024 10:34:04 +0200 Subject: [PATCH 08/50] FIX typo --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 956dbdf..b264240 100644 --- a/README.rst +++ b/README.rst @@ -28,7 +28,7 @@ and $$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\frac{1}{2} x^\\top \\tilde H_j^x x + x^\\top \\tilde C_j z + \\tilde c_j^\\top z + \\tilde d_j^\\top x$$ -where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i, \\tilde c_j$ are vectors of size $d$ and $d_i, \\tilde d_j$ are vectors of size $p$. +where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i, \\tilde c_j$ are vectors of size $d$ and $d_i, \\tilde d_j$ are vectors of size $p$. 2 - Regularization selection From 17e226797bf1245c641a70702b7ecea71fafb2d4 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 11 Oct 2024 11:34:02 +0200 Subject: [PATCH 09/50] ENH doc eigenvalues of matrices --- README.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index b264240..e29c200 100644 --- a/README.rst +++ b/README.rst @@ -17,7 +17,7 @@ Different problems This benchmark currently implements three bilevel optimization problems: quadratic problem, regularization selection, and hyper data cleaning. -1 - Quadratic bilevel problem +1 - Simulated bilevel problem ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ In this problem, the inner and the outer functions are quadritics functions defined of $\\mathbb{R}^{d\\times p}$ @@ -30,6 +30,10 @@ $$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\ where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i, \\tilde c_j$ are vectors of size $d$ and $d_i, \\tilde d_j$ are vectors of size $p$. +The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are generated randomly such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between `mu_inner` and `L_inner_inner`, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between `mu_inner` and `L_inner_outer`, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between `mu_inner` and `L_outer_inner`, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between `mu_inner` and `L_outer_outer`. + +The matrices $C_i, \\tilde C_j$ are generated randomly such that the spectral norm of $\\frac1n\\sum_i C_i$ is lower than `L_cross_inner`, and the spectral norm of $\\frac1m\\sum_j \\tilde C_j$ is lower than `L_cross_outer`. + 2 - Regularization selection ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 89f096681361e60e6115c96ad734ab4ca04260db Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 11 Oct 2024 11:36:03 +0200 Subject: [PATCH 10/50] FIX typo --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index e29c200..352f293 100644 --- a/README.rst +++ b/README.rst @@ -28,7 +28,7 @@ and $$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\frac{1}{2} x^\\top \\tilde H_j^x x + x^\\top \\tilde C_j z + \\tilde c_j^\\top z + \\tilde d_j^\\top x$$ -where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i, \\tilde c_j$ are vectors of size $d$ and $d_i, \\tilde d_j$ are vectors of size $p$. +where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i$, $\\tilde c_j$ are vectors of size $d$ and $d_i, \\tilde d_j$ are vectors of size $p$. The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are generated randomly such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between `mu_inner` and `L_inner_inner`, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between `mu_inner` and `L_inner_outer`, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between `mu_inner` and `L_outer_inner`, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between `mu_inner` and `L_outer_outer`. From 479e1d52f36911f36cdd20babbad23ac1d445c49 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 11 Oct 2024 11:45:42 +0200 Subject: [PATCH 11/50] ENH value function evaluation nit that expensive --- README.rst | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/README.rst b/README.rst index 352f293..c287505 100644 --- a/README.rst +++ b/README.rst @@ -4,13 +4,13 @@ Bilevel Optimization Benchmark *Results can be consulted on https://benchopt.github.io/results/benchmark_bilevel.html* -BenchOpt is a package to simplify and make more transparent and +BenchOpt is a package to simplify, and make more transparent, and reproducible the comparisons of optimization algorithms. This benchmark is dedicated to solvers for bilevel optimization: $$\\min_{x} f(x, z^*(x)) \\quad \\text{with} \\quad z^*(x) = \\arg\\min_z g(x, z), $$ -where $g$ and $f$ are two functions of two variables. +where $g$, and $f$ are two functions of two variables. Different problems ------------------ @@ -20,7 +20,7 @@ This benchmark currently implements three bilevel optimization problems: quadrat 1 - Simulated bilevel problem ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In this problem, the inner and the outer functions are quadritics functions defined of $\\mathbb{R}^{d\\times p}$ +In this problem, the inner, and the outer functions are quadritics functions defined of $\\mathbb{R}^{d\\times p}$ $$g(x, z) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2} z^\\top H_i^z z + \\frac{1}{2} x^\\top H_i^x x + x^\\top C_i z + c_i^\\top z + d_i^\\top x$$ @@ -28,12 +28,14 @@ and $$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\frac{1}{2} x^\\top \\tilde H_j^x x + x^\\top \\tilde C_j z + \\tilde c_j^\\top z + \\tilde d_j^\\top x$$ -where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i$, $\\tilde c_j$ are vectors of size $d$ and $d_i, \\tilde d_j$ are vectors of size $p$. +where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i$, $\\tilde c_j$ are vectors of size $d$, and $d_i, \\tilde d_j$ are vectors of size $p$. -The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are generated randomly such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between `mu_inner` and `L_inner_inner`, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between `mu_inner` and `L_inner_outer`, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between `mu_inner` and `L_outer_inner`, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between `mu_inner` and `L_outer_outer`. +The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are generated randomly such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between `mu_inner`, and `L_inner_inner`, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between `mu_inner`, and `L_inner_outer`, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between `mu_inner`, and `L_outer_inner`, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between `mu_inner`, and `L_outer_outer`. The matrices $C_i, \\tilde C_j$ are generated randomly such that the spectral norm of $\\frac1n\\sum_i C_i$ is lower than `L_cross_inner`, and the spectral norm of $\\frac1m\\sum_j \\tilde C_j$ is lower than `L_cross_outer`. +Note that in this setting, the solution of the inner problem is a linear system. Moreover, the full batch inner and outer functions can be cheaply computed by storing the average of the Hessian matrices. Thus, the value function can be cheaply evaluated in closed form in medium dimension. + 2 - Regularization selection ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -59,7 +61,7 @@ Covtype *Homepage : https://archive.ics.uci.edu/dataset/31/covertype* -This is a logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features and $y_i=\\pm1$ is the binary target. +This is a logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features, and $y_i=\\pm1$ is the binary target. For this problem, the loss is $\\ell(d_i, z) = \\log(1+\\exp(-y_i a_i^T z))$, and the regularization is simply given by $$\\mathcal{R}(x, z) = \\frac12\\sum_{j=1}^p\\exp(x_j)z_j^2,$$ each coefficient in $z$ is independently regularized with the strength $\\exp(x_j)$. @@ -69,7 +71,7 @@ Ijcnn1 *Homepage : https://www.openml.org/search?type=data&sort=runs&id=1575&status=active* -This is a multicalss logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features and $y_i\\in \\{1,\\dots, k\\}$ is the integer target, with k the number of classes. +This is a multicalss logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features, and $y_i\\in \\{1,\\dots, k\\}$ is the integer target, with k the number of classes. For this problem, the loss is $\\ell(d_i, z) = \\text{CrossEntropy}(za_i, y_i)$ where $z$ is now a k x p matrix. The regularization is given by $$\\mathcal{R}(x, z) = \\frac12\\sum_{j=1}^k\\exp(x_j)\\|z_j\\|^2,$$ each line in $z$ is independently regularized with the strength $\\exp(x_j)$. @@ -80,7 +82,7 @@ each line in $z$ is independently regularized with the strength $\\exp(x_j)$. This problem was first introduced by [Fra2017]_ . In this problem, the data is the MNIST dataset. -The training set has been corrupted: with a probability $p$, the label of the image $y\\in\\{1,\\dots,10\\}$ is replaced by another random label between 1 and 10. +The training set has been corrupted: with a probability $p$, the label of the image $y\\in\\{1,\\dots,10\\}$ is replaced by another random label between 1, and 10. We do not know beforehand which data has been corrupted. We have a clean testing set, which has not been corrupted. The goal is to fit a model on the corrupted training data that has good performances on the test set. From 76ac1fe30c30ffc72d9dc2aa7c3bf5d5e658fbcc Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 11 Oct 2024 11:51:01 +0200 Subject: [PATCH 12/50] FIX double backquotes --- README.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.rst b/README.rst index c287505..06702fa 100644 --- a/README.rst +++ b/README.rst @@ -30,9 +30,9 @@ $$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\ where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i$, $\\tilde c_j$ are vectors of size $d$, and $d_i, \\tilde d_j$ are vectors of size $p$. -The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are generated randomly such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between `mu_inner`, and `L_inner_inner`, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between `mu_inner`, and `L_inner_outer`, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between `mu_inner`, and `L_outer_inner`, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between `mu_inner`, and `L_outer_outer`. +The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are generated randomly such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between ``mu_inner``, and ``L_inner_inner``, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between ``mu_inner``, and ``L_inner_outer``, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between ``mu_inner``, and ``L_outer_inner``, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between ``mu_inner``, and ``L_outer_outer``. -The matrices $C_i, \\tilde C_j$ are generated randomly such that the spectral norm of $\\frac1n\\sum_i C_i$ is lower than `L_cross_inner`, and the spectral norm of $\\frac1m\\sum_j \\tilde C_j$ is lower than `L_cross_outer`. +The matrices $C_i, \\tilde C_j$ are generated randomly such that the spectral norm of $\\frac1n\\sum_i C_i$ is lower than ``L_cross_inner``, and the spectral norm of $\\frac1m\\sum_j \\tilde C_j$ is lower than ``L_cross_outer``. Note that in this setting, the solution of the inner problem is a linear system. Moreover, the full batch inner and outer functions can be cheaply computed by storing the average of the Hessian matrices. Thus, the value function can be cheaply evaluated in closed form in medium dimension. @@ -111,7 +111,7 @@ This benchmark can be run using the following commands: $ git clone https://github.com/benchopt/benchmark_bilevel $ benchopt run benchmark_bilevel -Apart from the problem, options can be passed to `benchopt run`, to restrict the benchmarks to some solvers or datasets, e.g.: +Apart from the problem, options can be passed to ``benchopt run``, to restrict the benchmarks to some solvers or datasets, e.g.: .. code-block:: @@ -123,9 +123,9 @@ You can also use config files to setup the benchmark run: $ benchopt run benchmark_bilevel --config config/X.yml -where `X.yml` is a config file. See https://benchopt.github.io/index.html#run-a-benchmark for an example of a config file. This will possibly launch a huge grid search. When available, you can rather use the file `X_best_params.yml` in order to launch an experiment with a single set of parameters for each solver. +where ``X.yml`` is a config file. See https://benchopt.github.io/index.html#run-a-benchmark for an example of a config file. This will possibly launch a huge grid search. When available, you can rather use the file ``X_best_params.yml`` in order to launch an experiment with a single set of parameters for each solver. -Use `benchopt run -h` for more details about these options, or visit https://benchopt.github.io/api.html. +Use ``benchopt run -h`` for more details about these options, or visit https://benchopt.github.io/api.html. Cite From 8dc98c2025e6579edc95acbf4f421aa035deea5b Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 14 Oct 2024 10:20:05 +0200 Subject: [PATCH 13/50] WIP doc --- README.rst | 8 ++++++++ benchmark_utils/stochastic_jax_solver.py | 18 ++++++++++++++---- solvers/amigo.py | 21 +++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index 06702fa..ce551cd 100644 --- a/README.rst +++ b/README.rst @@ -127,6 +127,14 @@ where ``X.yml`` is a config file. See https://benchopt.github.io/index.html#run- Use ``benchopt run -h`` for more details about these options, or visit https://benchopt.github.io/api.html. +How to contribute to the benchmark? +----------------------------------- + +If you think that a solver is missing, or if you want to add a new problem, feel free to open a pull request or an issue! + +1 - How to add a new solvers? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* Stochastic solver: see the detailed explanations in the [AmIGO solver](solvers/amigo.py). Cite ---- diff --git a/benchmark_utils/stochastic_jax_solver.py b/benchmark_utils/stochastic_jax_solver.py index bb48ab8..cd09d61 100644 --- a/benchmark_utils/stochastic_jax_solver.py +++ b/benchmark_utils/stochastic_jax_solver.py @@ -98,10 +98,20 @@ def set_objective(self, f_inner, f_outer, n_inner_samples, n_outer_samples, inner_var0, outer_var0: array-like, shape (dim_inner,) (dim_outer,) - f_inner_fb, f_outer_fb: callable - Full batch version of f_inner and f_outer. Should take as input: - * inner_var: array-like, shape (dim_inner,) - * outer_var: array-like, shape (dim_outer,) + Attributes + ---------- + f_inner, f_outer: callable + Inner and outer objective function for the bilevel optimization + problem. + + n_inner_samples, n_outer_samples: int + Number of samples to draw for the inner and outer objective + functions. + + inner_var0, outer_var0: array-like, shape (dim_inner,) (dim_outer,) + + batch_size_inner, batch_size_outer: int + """ self.f_inner = f_inner diff --git a/solvers/amigo.py b/solvers/amigo.py index 87bd350..8a35f5c 100644 --- a/solvers/amigo.py +++ b/solvers/amigo.py @@ -20,6 +20,27 @@ class Solver(StochasticJaxSolver): Bilevel Optimization". ICLR 2022""" name = 'AmIGO' + """How to add a new stochastic solver to the benchmark? + + Stochastic solvers are Solver classes that inherit from the + `StochasticJaxSolver` class. They should implement the `init` and the + `get_step_methods` and the class variable `parameters`. + + * The variable `parameters` is a dictionary that contains the solver's + parameters. In the case of AmIGO, it contains + - step_size: the step_size of the inner and linear system solvers + - outer_ratio: the ratio between the step sizes of the inner and the + outer updates + - n_inner_steps: the number of steps of the inner and the linear system + solvers + - batch_size: the size of the minibatch (assumed to be the same for the + inner and outer functions) + - **StochasticJaxSolver.parameters: the parameters shared by all the + stochastic solvers based on the StochasticJaxSolver class + + * + """ + # any parameter defined here is accessible as a class attribute parameters = { 'step_size': [.1], From df744278b9ec252507ae9e7204a41ef9346f8ccd Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 14 Oct 2024 11:35:28 +0200 Subject: [PATCH 14/50] WIP complete doc how to create a solver --- solvers/amigo.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/solvers/amigo.py b/solvers/amigo.py index 8a35f5c..8b6073e 100644 --- a/solvers/amigo.py +++ b/solvers/amigo.py @@ -26,19 +26,31 @@ class Solver(StochasticJaxSolver): `StochasticJaxSolver` class. They should implement the `init` and the `get_step_methods` and the class variable `parameters`. - * The variable `parameters` is a dictionary that contains the solver's + * The variable `parameters` is a dictionary that contains the solver's parameters. In the case of AmIGO, it contains - step_size: the step_size of the inner and linear system solvers - outer_ratio: the ratio between the step sizes of the inner and the outer updates - n_inner_steps: the number of steps of the inner and the linear system solvers - - batch_size: the size of the minibatch (assumed to be the same for the + - batch_size: the size of the minibatch (assumed to be the same for the inner and outer functions) - **StochasticJaxSolver.parameters: the parameters shared by all the stochastic solvers based on the StochasticJaxSolver class - * + * The `init` methods initializes variables that are udapted during the + optimization process. In the case of AmIGO, it initializes the inner and + outer variables, the linear system variable v and the learning rate + scheduler. It returns a dictionary containing these variables and the + initial state of the samplers. Those ones are already provided by the + attributes `state_inner_sampler` and `state_outer_sampler`. + + * The `get_step` method returns a function that performs one iteration of + the optimization algorithm. This function should be jittable by JAX. In + this function are also initialized the eventual subroutines such as the + inner SGD and the linear system solver in the case of AmIGO. Note that the + variable updated during the process are stored in the `carry` dictionary, + whose initial state is the output of the `init` method. """ # any parameter defined here is accessible as a class attribute From dd3380cfd3c31448f5a06da733bf79bbbfe55474 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 14 Oct 2024 11:38:31 +0200 Subject: [PATCH 15/50] ENH add comments amigo --- solvers/amigo.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/solvers/amigo.py b/solvers/amigo.py index 8b6073e..9aec79f 100644 --- a/solvers/amigo.py +++ b/solvers/amigo.py @@ -76,6 +76,11 @@ def init(self): ) exponents = jnp.zeros(3) state_lr = init_lr_scheduler(step_sizes, exponents) + + # The return dictionary should contain all the variables that are + # updated during the optimization process. The state of the samplers + # are already provided by the attributes `state_inner_sampler` and + # `state_outer_sampler`. return dict( inner_var=self.inner_var, outer_var=self.outer_var, v=v, state_lr=state_lr, @@ -96,6 +101,9 @@ def get_step(self, inner_sampler, outer_sampler): sampler=inner_sampler, n_steps=self.n_inner_steps ) + # This function should be jittable by JAX. It returns the output of + # one iteration of the optimization algorithm (one iteration = one + # outer vairable update in this case). def amigo_one_iter(carry, _): (inner_lr, v_lr, outer_lr), carry['state_lr'] = update_lr( From beb83f1b8fbfe2b77aa6d80733c0d1d325d1a422 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 14 Oct 2024 11:41:50 +0200 Subject: [PATCH 16/50] ENH readme --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index ce551cd..852c2aa 100644 --- a/README.rst +++ b/README.rst @@ -135,6 +135,7 @@ If you think that a solver is missing, or if you want to add a new problem, feel 1 - How to add a new solvers? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * Stochastic solver: see the detailed explanations in the [AmIGO solver](solvers/amigo.py). +* Other solver: see the detailed explanation in the [Benchopt documentation](https://benchopt.github.io/tutorials/add_solver.html). Cite ---- From c77ba62d1ae2421ab9c3ae3557c7d7fa2cd4929e Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 14 Oct 2024 11:47:34 +0200 Subject: [PATCH 17/50] ENH docstring StochasticJaxSolver --- benchmark_utils/stochastic_jax_solver.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/benchmark_utils/stochastic_jax_solver.py b/benchmark_utils/stochastic_jax_solver.py index cd09d61..f9debf3 100644 --- a/benchmark_utils/stochastic_jax_solver.py +++ b/benchmark_utils/stochastic_jax_solver.py @@ -111,7 +111,15 @@ def set_objective(self, f_inner, f_outer, n_inner_samples, n_outer_samples, inner_var0, outer_var0: array-like, shape (dim_inner,) (dim_outer,) batch_size_inner, batch_size_outer: int - + Size of the minibatch to use for the inner and outer objective + functions. + + state_inner_sampler, state_outer_sampler: dict + State of the minibatch samplers for the inner and outer objectives. + + one_epoch: callable + Jitted function that runs the solver for one epoch. One epoch is + defined as `eval_freq` iterations of the solver. """ self.f_inner = f_inner From 05dfa21bae292595819d5d3c243269a930387e33 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 14 Oct 2024 11:50:53 +0200 Subject: [PATCH 18/50] ENH comment amigo --- solvers/amigo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/solvers/amigo.py b/solvers/amigo.py index 9aec79f..d2bc996 100644 --- a/solvers/amigo.py +++ b/solvers/amigo.py @@ -59,7 +59,8 @@ class Solver(StochasticJaxSolver): 'outer_ratio': [1.], 'n_inner_steps': [10], 'batch_size': [64], - **StochasticJaxSolver.parameters + **StochasticJaxSolver.parameters + # Contains the `eval_freq` and `random_state`parameters } def init(self): From 38f219784b0d632890158e0c17e6e68b4cb43ad6 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Mon, 14 Oct 2024 11:51:11 +0200 Subject: [PATCH 19/50] FIX flake8 --- solvers/amigo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/solvers/amigo.py b/solvers/amigo.py index d2bc996..730f979 100644 --- a/solvers/amigo.py +++ b/solvers/amigo.py @@ -59,7 +59,7 @@ class Solver(StochasticJaxSolver): 'outer_ratio': [1.], 'n_inner_steps': [10], 'batch_size': [64], - **StochasticJaxSolver.parameters + **StochasticJaxSolver.parameters # Contains the `eval_freq` and `random_state`parameters } From d4b1b35678ef9c6a5683dba94a3453cbd850922a Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Wed, 16 Oct 2024 16:49:42 +0200 Subject: [PATCH 20/50] FIX review suggestions README.rst --- README.rst | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.rst b/README.rst index 852c2aa..2e03606 100644 --- a/README.rst +++ b/README.rst @@ -4,7 +4,7 @@ Bilevel Optimization Benchmark *Results can be consulted on https://benchopt.github.io/results/benchmark_bilevel.html* -BenchOpt is a package to simplify, and make more transparent, and +BenchOpt is a package to simplify, make more transparent, and reproducible the comparisons of optimization algorithms. This benchmark is dedicated to solvers for bilevel optimization: @@ -17,10 +17,10 @@ Different problems This benchmark currently implements three bilevel optimization problems: quadratic problem, regularization selection, and hyper data cleaning. -1 - Simulated bilevel problem +1 - Simulated quadratic bilevel problem ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In this problem, the inner, and the outer functions are quadritics functions defined of $\\mathbb{R}^{d\\times p}$ +In this problem, the inner and the outer functions are quadritics functions defined of $\\mathbb{R}^{d\\times p}$ $$g(x, z) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2} z^\\top H_i^z z + \\frac{1}{2} x^\\top H_i^x x + x^\\top C_i z + c_i^\\top z + d_i^\\top x$$ @@ -34,7 +34,8 @@ The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are generated randomly The matrices $C_i, \\tilde C_j$ are generated randomly such that the spectral norm of $\\frac1n\\sum_i C_i$ is lower than ``L_cross_inner``, and the spectral norm of $\\frac1m\\sum_j \\tilde C_j$ is lower than ``L_cross_outer``. -Note that in this setting, the solution of the inner problem is a linear system. Moreover, the full batch inner and outer functions can be cheaply computed by storing the average of the Hessian matrices. Thus, the value function can be cheaply evaluated in closed form in medium dimension. +Note that in this setting, the solution of the inner problem is a linear system. +As, the full batch inner and outer functions can be computed efficiently directly with the average Hessian matrices, the value function can be evaluated in closed form. 2 - Regularization selection @@ -130,7 +131,7 @@ Use ``benchopt run -h`` for more details about these options, or visit https://b How to contribute to the benchmark? ----------------------------------- -If you think that a solver is missing, or if you want to add a new problem, feel free to open a pull request or an issue! +If you want to add a solver or a new problem, you are welcome to open an issue or submit a pull request! 1 - How to add a new solvers? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From d093b8bbb0097ced7f5c0e0106028673cfb43d32 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Wed, 16 Oct 2024 16:58:14 +0200 Subject: [PATCH 21/50] CLN create template_stochastic_solver and moove explanation from AmIGO --- solvers/amigo.py | 33 ------- solvers/template_stochastic_solver.py | 124 ++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 33 deletions(-) create mode 100644 solvers/template_stochastic_solver.py diff --git a/solvers/amigo.py b/solvers/amigo.py index 730f979..28e3fc0 100644 --- a/solvers/amigo.py +++ b/solvers/amigo.py @@ -20,39 +20,6 @@ class Solver(StochasticJaxSolver): Bilevel Optimization". ICLR 2022""" name = 'AmIGO' - """How to add a new stochastic solver to the benchmark? - - Stochastic solvers are Solver classes that inherit from the - `StochasticJaxSolver` class. They should implement the `init` and the - `get_step_methods` and the class variable `parameters`. - - * The variable `parameters` is a dictionary that contains the solver's - parameters. In the case of AmIGO, it contains - - step_size: the step_size of the inner and linear system solvers - - outer_ratio: the ratio between the step sizes of the inner and the - outer updates - - n_inner_steps: the number of steps of the inner and the linear system - solvers - - batch_size: the size of the minibatch (assumed to be the same for the - inner and outer functions) - - **StochasticJaxSolver.parameters: the parameters shared by all the - stochastic solvers based on the StochasticJaxSolver class - - * The `init` methods initializes variables that are udapted during the - optimization process. In the case of AmIGO, it initializes the inner and - outer variables, the linear system variable v and the learning rate - scheduler. It returns a dictionary containing these variables and the - initial state of the samplers. Those ones are already provided by the - attributes `state_inner_sampler` and `state_outer_sampler`. - - * The `get_step` method returns a function that performs one iteration of - the optimization algorithm. This function should be jittable by JAX. In - this function are also initialized the eventual subroutines such as the - inner SGD and the linear system solver in the case of AmIGO. Note that the - variable updated during the process are stored in the `carry` dictionary, - whose initial state is the output of the `init` method. - """ - # any parameter defined here is accessible as a class attribute parameters = { 'step_size': [.1], diff --git a/solvers/template_stochastic_solver.py b/solvers/template_stochastic_solver.py new file mode 100644 index 0000000..c22fdb9 --- /dev/null +++ b/solvers/template_stochastic_solver.py @@ -0,0 +1,124 @@ +from benchmark_utils.stochastic_jax_solver import StochasticJaxSolver + +from benchopt import safe_import_context + +with safe_import_context() as import_ctx: + from benchmark_utils.learning_rate_scheduler import update_lr + from benchmark_utils.learning_rate_scheduler import init_lr_scheduler + + import jax + import jax.numpy as jnp + + +class Solver(StochasticJaxSolver): + # The docstring should contain the solver's name and a reference to the + # paper where it is introduced. This will be displayed in the HTML result + # page. + """Stochastic Bilevel Algorithm (SOBA). + + M. Dagréou, P. Ablin, S. Vaiter and T. Moreau, "A framework for bilevel + optimization that enables stochastic and global variance reduction + algorithms", NeurIPS 2022.""" + name = 'Template Stochastic Solver' + + """How to add a new stochastic solver to the benchmark? + + Stochastic solvers are Solver classes that inherit from the + `StochasticJaxSolver` class. They should implement the `init` and the + `get_step_methods` and the class variable `parameters`. One epoch of + StochasticJaxSolver corresponds to `eval_freq` outer iterations of the + solver. The epochs of these solvers are jitted by JAX to get fast + stochastic iterations. + + * The variable `parameters` is a dictionary that contains the solver's + parameters. Here, it contains + - step_size: the step_size of the inner and linear system solvers + - outer_ratio: the ratio between the step sizes of the inner and the + outer updates + - n_inner_steps: the number of steps of the inner and the linear system + solvers + - batch_size: the size of the minibatch (assumed to be the same for the + inner and outer functions) + - **StochasticJaxSolver.parameters: the parameters shared by all the + stochastic solvers based on the StochasticJaxSolver class + + * The `init` methods initializes variables that are udapted during the + optimization process. Here, it initializes the inner and + outer variables, the linear system variable v and the learning rate + scheduler. It returns a dictionary containing these variables and the + initial state of the samplers. Those ones are already provided by the + attributes `state_inner_sampler` and `state_outer_sampler`. + + * The `get_step` method returns a function that performs one iteration of + the optimization algorithm. This function should be jittable by JAX. In + this function are also initialized the eventual subroutines such as the + inner SGD and the linear system solver in the case of AmIGO. Note that the + variable updated during the process are stored in the `carry` dictionary, + whose initial state is the output of the `init` method. + """ + + # any parameter defined here is accessible as a class attribute + parameters = { + 'step_size': [.1], + 'outer_ratio': [1.], + 'batch_size': [64], + **StochasticJaxSolver.parameters + } + + def init(self): + # Init variables + self.inner_var = self.inner_var0.copy() + self.outer_var = self.outer_var0.copy() + v = jnp.zeros_like(self.inner_var) + + # Init lr scheduler + step_sizes = jnp.array( + [self.step_size, self.step_size / self.outer_ratio] + ) + exponents = jnp.array( + [.5, .5] + ) + state_lr = init_lr_scheduler(step_sizes, exponents) + return dict( + inner_var=self.inner_var, outer_var=self.outer_var, v=v, + state_lr=state_lr, + state_inner_sampler=self.state_inner_sampler, + state_outer_sampler=self.state_outer_sampler, + ) + + def get_step(self, inner_sampler, outer_sampler): + + grad_inner = jax.grad(self.f_inner, argnums=0) + grad_outer = jax.grad(self.f_outer, argnums=(0, 1)) + + def soba_one_iter(carry, _): + + (inner_step_size, outer_step_size), carry['state_lr'] = update_lr( + carry['state_lr'] + ) + + # Step.1 - get all gradients and compute the implicit gradient. + start_inner, *_, carry['state_inner_sampler'] = inner_sampler( + carry['state_inner_sampler'] + ) + grad_inner_var, vjp_train = jax.vjp( + lambda z, x: grad_inner(z, x, start_inner), carry['inner_var'], + carry['outer_var'] + ) + hvp, cross_v = vjp_train(carry['v']) + + start_outer, *_, carry['state_outer_sampler'] = outer_sampler( + carry['state_outer_sampler'] + ) + grad_in_outer, grad_out_outer = grad_outer( + carry['inner_var'], carry['outer_var'], start_outer + ) + + # Step.2 - update inner variable with SGD. + carry['inner_var'] -= inner_step_size * grad_inner_var + carry['v'] -= inner_step_size * (hvp + grad_in_outer) + carry['outer_var'] -= outer_step_size * (cross_v + grad_out_outer) + + return carry, _ + + return soba_one_iter From e0d785cffbc863333ebaa771dea0a5d04d41d361 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Thu, 17 Oct 2024 18:05:03 +0200 Subject: [PATCH 22/50] ENH add template_solver.py --- README.rst | 9 +- solvers/template_solver.py | 125 ++++++++++++++++++++++++++ solvers/template_stochastic_solver.py | 5 +- 3 files changed, 136 insertions(+), 3 deletions(-) create mode 100644 solvers/template_solver.py diff --git a/README.rst b/README.rst index 2e03606..9bd95c3 100644 --- a/README.rst +++ b/README.rst @@ -135,8 +135,13 @@ If you want to add a solver or a new problem, you are welcome to open an issue o 1 - How to add a new solvers? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -* Stochastic solver: see the detailed explanations in the [AmIGO solver](solvers/amigo.py). -* Other solver: see the detailed explanation in the [Benchopt documentation](https://benchopt.github.io/tutorials/add_solver.html). +Each solver derive from the [`benchopt.BaseSolver` class](https://benchopt.github.io/user_guide/generated/benchopt.BaseSolver.html) in the [solvers](solvers) folder. The solvers are separated among the stochastic JAX solvers and the others: +* Stochastic Jax solver: these solvers inherit from the [`StochasticJaxSolver` class](benchmark_utils/stochastic_jax_solver.py) see the detailed explanations in the [template stochastic solver](solvers/template_stochastic_solver.py). +* Other solver: see the detailed explanation in the [Benchopt documentation](https://benchopt.github.io/tutorials/add_solver.html). An example is provided in the [template solver](solvers/template_solver.py). + +2 - How to add a new problem? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +In this benchmark, each problem is defined by a [Dataset class](https://benchopt.github.io/user_guide/generated/benchopt.BaseDataset.html) in the [datasets](datasets) folder. Cite ---- diff --git a/solvers/template_solver.py b/solvers/template_solver.py new file mode 100644 index 0000000..ace7770 --- /dev/null +++ b/solvers/template_solver.py @@ -0,0 +1,125 @@ +from benchopt import BaseSolver +from benchopt.stopping_criterion import SufficientProgressCriterion + +from benchopt import safe_import_context + +with safe_import_context() as import_ctx: + # import your reusable functions here + from benchmark_utils import constants + from benchmark_utils.learning_rate_scheduler import update_lr + from benchmark_utils.learning_rate_scheduler import init_lr_scheduler + + import jax + import jax.numpy as jnp + from functools import partial + + import jaxopt + + +class Solver(BaseSolver): + """Gradient descent with JAXopt solvers. + + M. Blondel, Q. Berthet, M. Cuturi, R. Frosting, S. Hoyer, F. + Llinares-Lopez, F. Pedregosa and J.-P. Vert. "Efficient and Modular + Implicit Differentiation". NeurIPS 2022""" + # Name to select the solver in the CLI and to display the results. + name = 'jaxopt_GD' + + """How to add a new stochastic solver to the benchmark? + + This template solver is an adaptation of the solver from the benchopt + template benchmark (https://github.com/benchopt/template_benchmark/) to + the bilevel setting. Other explanations can be found in + https://benchopt.github.io/tutorials/add_solver.html. + """ + + # List of packages needed to run the solver. + requirements = ["pip:jaxopt"] + + # Stopping criterion for the solver. + # See https://benchopt.github.io/user_guide/performance_curves.html for + # more information on benchopt stopping criteria. + stopping_criterion = SufficientProgressCriterion( + patience=constants.PATIENCE, strategy='callback' + ) + + # List of parameters for the solver. The benchmark will consider + # the cross product for each key in the dictionary. + # All parameters 'p' defined here are available as 'self.p'. + parameters = { + 'step_size_outer': [10], + 'n_inner_steps': [100], + } + + @staticmethod + def get_next(stop_val): + return stop_val + 1 + + def set_objective(self, f_inner, f_outer, n_inner_samples, n_outer_samples, + inner_var0, outer_var0): + # Define the information received by each solver from the objective. + # The arguments of this function are the results of the + # `Objective.get_objective`. For the bilevel benchmark, these + # informations are the inner and outer objective functions, the number + # of samples to draw for the inner and outer objective functions, the + # initial values of the inner and outer variables. + self.f_inner = partial(f_inner, start=0, batch_size=n_inner_samples) + self.f_outer = partial(f_outer, start=0, batch_size=n_outer_samples) + inner_solver = jaxopt.GradientDescent( + fun=self.f_inner, maxiter=self.n_inner_steps, + implicit_diff=True, acceleration=False + ) + + # The value function is defined for this specific solver, but it is + # not mandatory in general. + def value_fun(inner_var, outer_var): + """Solver used to solve the inner problem. + + The output of this function is differentiable w.r.t. the + outer_variable. The Jacobian is computed using implicit + differentiation with a conjugate gradient solver. + """ + inner_var = inner_solver.run(inner_var, outer_var).params + return self.f_outer(inner_var, outer_var), inner_var + + self.value_grad = jax.jit(jax.value_and_grad( + value_fun, argnums=1, has_aux=True + )) + + self.inner_var0 = inner_var0 + self.outer_var0 = outer_var0 + + # Run the solver for 2 iterations for the JAX compilation if + # applicable. + self.run_once(2) + + def run(self, callback): + # This is the function that is called to evaluate the solver. + # It runs the algorithm for a given a number of iterations `n_iter`. + # You can also use a `tolerance` or a `callback`, as described in + # https://benchopt.github.io/performance_curves.html + + # Init variables + self.inner_var = self.inner_var0.copy() + self.outer_var = self.outer_var0.copy() + + step_sizes = jnp.array( + [self.step_size_outer] + ) + exponents = jnp.zeros(1) + state_lr = init_lr_scheduler(step_sizes, exponents) + + while callback(): + outer_lr, state_lr = update_lr(state_lr) + (_, self.inner_var), implicit_grad = self.value_grad( + self.inner_var, self.outer_var + ) + self.outer_var -= outer_lr * implicit_grad + + def get_result(self): + # Return the result from one optimization run. + # The outputs of this function is a dictionary which defines the + # keyword arguments for `Objective.evaluate_result` + # This defines the benchmark's API for solvers' results. + # it is customizable for each benchmark. + return dict(inner_var=self.inner_var, outer_var=self.outer_var) diff --git a/solvers/template_stochastic_solver.py b/solvers/template_stochastic_solver.py index c22fdb9..86c3ad7 100644 --- a/solvers/template_stochastic_solver.py +++ b/solvers/template_stochastic_solver.py @@ -19,6 +19,7 @@ class Solver(StochasticJaxSolver): M. Dagréou, P. Ablin, S. Vaiter and T. Moreau, "A framework for bilevel optimization that enables stochastic and global variance reduction algorithms", NeurIPS 2022.""" + # Name to select the solver in the CLI and to display the results. name = 'Template Stochastic Solver' """How to add a new stochastic solver to the benchmark? @@ -57,7 +58,9 @@ class Solver(StochasticJaxSolver): whose initial state is the output of the `init` method. """ - # any parameter defined here is accessible as a class attribute + # List of parameters for the solver. The benchmark will consider + # the cross product for each key in the dictionary. + # All parameters 'p' defined here are available as 'self.p'. parameters = { 'step_size': [.1], 'outer_ratio': [1.], From f333a7c0bae2a5168ec0a707291df4e6a5c481c3 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Thu, 17 Oct 2024 18:58:47 +0200 Subject: [PATCH 23/50] ENH add template_dataset.py --- datasets/template_dataset.py | 153 +++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 datasets/template_dataset.py diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py new file mode 100644 index 0000000..cd9fdf6 --- /dev/null +++ b/datasets/template_dataset.py @@ -0,0 +1,153 @@ +from benchopt import BaseDataset +from benchopt import safe_import_context + +# Protect the import with `safe_import_context()`. This allows: +# - skipping import to speed up autocompletion in CLI. +# - getting requirements info when all dependencies are not installed. +with safe_import_context() as import_ctx: + import numpy as np + from libsvmdata import fetch_libsvm + + import jax + import jax.numpy as jnp + from functools import partial + + from jaxopt import LBFGS + + +def loss_sample(inner_var, outer_var, x, y): + return -jax.nn.log_sigmoid(y*jnp.dot(inner_var, x)) + + +def loss(inner_var, outer_var, X, y): + batched_loss = jax.vmap(loss_sample, in_axes=(None, None, 0, 0)) + return jnp.mean(batched_loss(inner_var, outer_var, X, y), axis=0) + + +# All datasets must be named `Dataset` and inherit from `BaseDataset` +class Dataset(BaseDataset): + """Hyperparameter optimization with IJCNN1 dataset.""" + # Name to select the dataset in the CLI and to display the results. + name = "ijcnn1" + + install_cmd = 'conda' + # List of packages needed to run the dataset. See the corresponding + # section in objective.py + requirements = ['pip:libsvmdata', 'scikit-learn'] + + # List of parameters to generate the datasets. The benchmark will consider + # the cross product for each key in the dictionary. + # Any parameters 'param' defined here is available as `self.param`. + parameters = { + 'reg_parametrization': ['exp'], + } + + def get_data(self): + # The return arguments of this function are passed as keyword arguments + # to `Objective.set_data`. This defines the benchmark's + # API to pass data. + assert self.reg_parametrization in ['lin', 'exp'], ( + f"unknown reg parameter '{self.reg_parametrization}'. " + "Should be 'lin' or 'exp'." + ) + + X_train, y_train = fetch_libsvm('ijcnn1') + X_val, y_val = fetch_libsvm('ijcnn1_test') + + X_train, y_train = jnp.array(X_train), jnp.array(y_train) + X_val, y_val = jnp.array(X_val), jnp.array(y_val) + + self.n_samples_inner = X_train.shape[0] + self.dim_inner = X_train.shape[1] + self.n_samples_outer = X_val.shape[0] + self.dim_outer = X_val.shape[1] + + @partial(jax.jit, static_argnames=('batch_size')) + def f_inner(inner_var, outer_var, start=0, batch_size=1): + x = jax.lax.dynamic_slice( + X_train, (start, 0), (batch_size, X_train.shape[1]) + ) + y = jax.lax.dynamic_slice( + y_train, (start, ), (batch_size, ) + ) + res = loss(inner_var, outer_var, x, y) + + if self.reg_parametrization == 'exp': + res += jnp.dot(jnp.exp(outer_var) * inner_var, inner_var)/2 + elif self.reg_parametrization == 'lin': + res += jnp.dot(outer_var * inner_var, inner_var)/2 + return res + + @partial(jax.jit, static_argnames=('batch_size')) + def f_outer(inner_var, outer_var, start=0, batch_size=1): + x = jax.lax.dynamic_slice( + X_val, (start, 0), (batch_size, X_val.shape[1]) + ) + y = jax.lax.dynamic_slice( + y_val, (start, ), (batch_size, ) + ) + res = loss(inner_var, outer_var, x, y) + return res + + f_inner_fb = partial( + f_inner, batch_size=X_train.shape[0], start=0 + ) + f_outer_fb = partial( + f_outer, batch_size=X_val.shape[0], start=0 + ) + + solver_inner = LBFGS(fun=f_inner_fb) + + def value_function(outer_var): + inner_var_star = solver_inner.run( + jnp.zeros(X_train.shape[1]), outer_var + ).params + + return f_outer_fb(inner_var_star, outer_var), inner_var_star + + value_and_grad = jax.jit( + jax.value_and_grad(value_function, has_aux=True) + ) + + def metrics(inner_var, outer_var): + # Defines the metrics that are computed when calling the method + # Objective.evaluating_results(inner_var, outer_var) and saved + # in the result file. The output is a dictionary that contains at + # least the key `value`. The keyword arguments of this function are + # the keys of the dictionary returned by `Solver.get_result`. + (value_fun, inner_star), grad_value = value_and_grad(outer_var) + return dict( + value_func=float(value_fun), + value=float(jnp.linalg.norm(grad_value)**2), + inner_distance=float(jnp.linalg.norm(inner_star-inner_var)**2), + norm_outer_var=float(jnp.linalg.norm(outer_var)**2), + norm_regul=float(jnp.linalg.norm(np.exp(outer_var))**2), + ) + + def init_var(key): + # Provides an initialization of inner_var and outer_var. + keys = jax.random.split(key, 2) + inner_var0 = jax.random.normal(keys[0], (self.dim_inner,)) + outer_var0 = jax.random.uniform(keys[1], (self.dim_outer,)) + if self.reg_parametrization == 'exp': + outer_var0 = jnp.log(outer_var0) + return inner_var0, outer_var0 + + data = dict( + pb_inner=(f_inner, self.n_samples_inner, self.dim_inner, + f_inner_fb), + pb_outer=(f_outer, self.n_samples_outer, self.dim_outer, + f_outer_fb), + metrics=metrics, + init_var=init_var, + ) + + # The output should be a dict that contains the keys `pb_inner`, + # `pb_outer`, `metrics`, and optionnally `init_var`. + # `pb_inner`` is a tuple that contains the inner function, the number + # of inner samples, the dimension of the inner variable and the full + # batch version of the inner version. + # `pb_outer` in analogous. + # The key `metrics` contains the function `metrics`. + # The key `init_var` contains the function `init_var` when applicable. + return data From c31f5cce180a84ac52351f6ffda6d8884aa44166 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Thu, 17 Oct 2024 19:01:04 +0200 Subject: [PATCH 24/50] ENH ref to benchopt template --- README.rst | 2 +- datasets/template_dataset.py | 6 ++++++ solvers/template_solver.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 9bd95c3..4434c6d 100644 --- a/README.rst +++ b/README.rst @@ -141,7 +141,7 @@ Each solver derive from the [`benchopt.BaseSolver` class](https://benchopt.githu 2 - How to add a new problem? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In this benchmark, each problem is defined by a [Dataset class](https://benchopt.github.io/user_guide/generated/benchopt.BaseDataset.html) in the [datasets](datasets) folder. +In this benchmark, each problem is defined by a [Dataset class](https://benchopt.github.io/user_guide/generated/benchopt.BaseDataset.html) in the [datasets](datasets) folder. A [template](datasets/template_dataset.py) is provided. Cite ---- diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index cd9fdf6..a126bbf 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -29,6 +29,12 @@ class Dataset(BaseDataset): """Hyperparameter optimization with IJCNN1 dataset.""" # Name to select the dataset in the CLI and to display the results. name = "ijcnn1" + """How to add a new problem to the benchmark? + + This template dataset is an adaptation of the dataset from the benchopt + template benchmark (https://github.com/benchopt/template_benchmark/) to + the bilevel setting. + """ install_cmd = 'conda' # List of packages needed to run the dataset. See the corresponding diff --git a/solvers/template_solver.py b/solvers/template_solver.py index ace7770..e57af64 100644 --- a/solvers/template_solver.py +++ b/solvers/template_solver.py @@ -25,7 +25,7 @@ class Solver(BaseSolver): # Name to select the solver in the CLI and to display the results. name = 'jaxopt_GD' - """How to add a new stochastic solver to the benchmark? + """How to add a new solver to the benchmark? This template solver is an adaptation of the solver from the benchopt template benchmark (https://github.com/benchopt/template_benchmark/) to From 931e095413461507bfb573378ba513b2d650077d Mon Sep 17 00:00:00 2001 From: Thomas Moreau <thomas.moreau.2010@gmail.com> Date: Fri, 18 Oct 2024 10:26:52 +0200 Subject: [PATCH 25/50] Update README.rst --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 4434c6d..61a342e 100644 --- a/README.rst +++ b/README.rst @@ -10,7 +10,7 @@ This benchmark is dedicated to solvers for bilevel optimization: $$\\min_{x} f(x, z^*(x)) \\quad \\text{with} \\quad z^*(x) = \\arg\\min_z g(x, z), $$ -where $g$, and $f$ are two functions of two variables. +where $g$ and $f$ are two functions of two variables. Different problems ------------------ From 143ca617d0f3f8106df7df12181626ed1ccf33b0 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 18 Oct 2024 10:48:05 +0200 Subject: [PATCH 26/50] ENH apply suggestion readme --- README.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.rst b/README.rst index 61a342e..b1f4a08 100644 --- a/README.rst +++ b/README.rst @@ -20,7 +20,7 @@ This benchmark currently implements three bilevel optimization problems: quadrat 1 - Simulated quadratic bilevel problem ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In this problem, the inner and the outer functions are quadritics functions defined of $\\mathbb{R}^{d\\times p}$ +In this problem, the inner and the outer functions are quadritic functions defined on $\\mathbb{R}^{d\\times p}$ $$g(x, z) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2} z^\\top H_i^z z + \\frac{1}{2} x^\\top H_i^x x + x^\\top C_i z + c_i^\\top z + d_i^\\top x$$ @@ -30,12 +30,12 @@ $$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\ where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i$, $\\tilde c_j$ are vectors of size $d$, and $d_i, \\tilde d_j$ are vectors of size $p$. -The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are generated randomly such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between ``mu_inner``, and ``L_inner_inner``, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between ``mu_inner``, and ``L_inner_outer``, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between ``mu_inner``, and ``L_outer_inner``, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between ``mu_inner``, and ``L_outer_outer``. +The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are randomly generated such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between ``mu_inner``, and ``L_inner_inner``, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between ``mu_inner``, and ``L_inner_outer``, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between ``mu_inner``, and ``L_outer_inner``, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between ``mu_inner``, and ``L_outer_outer``. The matrices $C_i, \\tilde C_j$ are generated randomly such that the spectral norm of $\\frac1n\\sum_i C_i$ is lower than ``L_cross_inner``, and the spectral norm of $\\frac1m\\sum_j \\tilde C_j$ is lower than ``L_cross_outer``. Note that in this setting, the solution of the inner problem is a linear system. -As, the full batch inner and outer functions can be computed efficiently directly with the average Hessian matrices, the value function can be evaluated in closed form. +As, the full batch inner and outer functions can be computed efficiently with the average Hessian matrices, the value function is evaluated in closed form. 2 - Regularization selection @@ -62,7 +62,7 @@ Covtype *Homepage : https://archive.ics.uci.edu/dataset/31/covertype* -This is a logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features, and $y_i=\\pm1$ is the binary target. +This is a logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features and $y_i=\\pm1$ is the binary target. For this problem, the loss is $\\ell(d_i, z) = \\log(1+\\exp(-y_i a_i^T z))$, and the regularization is simply given by $$\\mathcal{R}(x, z) = \\frac12\\sum_{j=1}^p\\exp(x_j)z_j^2,$$ each coefficient in $z$ is independently regularized with the strength $\\exp(x_j)$. @@ -72,7 +72,7 @@ Ijcnn1 *Homepage : https://www.openml.org/search?type=data&sort=runs&id=1575&status=active* -This is a multicalss logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features, and $y_i\\in \\{1,\\dots, k\\}$ is the integer target, with k the number of classes. +This is a multicalss logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features and $y_i\\in \\{1,\\dots, k\\}$ is the integer target, with k the number of classes. For this problem, the loss is $\\ell(d_i, z) = \\text{CrossEntropy}(za_i, y_i)$ where $z$ is now a k x p matrix. The regularization is given by $$\\mathcal{R}(x, z) = \\frac12\\sum_{j=1}^k\\exp(x_j)\\|z_j\\|^2,$$ each line in $z$ is independently regularized with the strength $\\exp(x_j)$. @@ -83,7 +83,7 @@ each line in $z$ is independently regularized with the strength $\\exp(x_j)$. This problem was first introduced by [Fra2017]_ . In this problem, the data is the MNIST dataset. -The training set has been corrupted: with a probability $p$, the label of the image $y\\in\\{1,\\dots,10\\}$ is replaced by another random label between 1, and 10. +The training set has been corrupted: with a probability $p$, the label of the image $y\\in\\{1,\\dots,10\\}$ is replaced by another random label between 1 and 10. We do not know beforehand which data has been corrupted. We have a clean testing set, which has not been corrupted. The goal is to fit a model on the corrupted training data that has good performances on the test set. From 481880f50c95d71ab65d5554c73046235fdbda34 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 18 Oct 2024 11:02:38 +0200 Subject: [PATCH 27/50] ENH replace rst by md --- README.rst => README.md | 107 ++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 60 deletions(-) rename README.rst => README.md (52%) diff --git a/README.rst b/README.md similarity index 52% rename from README.rst rename to README.md index b1f4a08..ed8c7e1 100644 --- a/README.rst +++ b/README.md @@ -1,6 +1,7 @@ Bilevel Optimization Benchmark =============================== -|Build Status| |Python 3.6+| +[](https://github.com/benchopt/benchmark_bilevel/actions) +[](https://www.python.org/downloads/release/python-360/) *Results can be consulted on https://benchopt.github.io/results/benchmark_bilevel.html* @@ -8,7 +9,7 @@ BenchOpt is a package to simplify, make more transparent, and reproducible the comparisons of optimization algorithms. This benchmark is dedicated to solvers for bilevel optimization: -$$\\min_{x} f(x, z^*(x)) \\quad \\text{with} \\quad z^*(x) = \\arg\\min_z g(x, z), $$ +$$\min_{x} f(x, z^*(x)) \quad \text{with} \quad z^*(x) = \arg\min_z g(x, z), $$ where $g$ and $f$ are two functions of two variables. @@ -17,73 +18,69 @@ Different problems This benchmark currently implements three bilevel optimization problems: quadratic problem, regularization selection, and hyper data cleaning. -1 - Simulated quadratic bilevel problem -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +### 1 - Simulated quadratic bilevel problem -In this problem, the inner and the outer functions are quadritic functions defined on $\\mathbb{R}^{d\\times p}$ -$$g(x, z) = \\frac{1}{n}\\sum_{i=1}^n \\frac{1}{2} z^\\top H_i^z z + \\frac{1}{2} x^\\top H_i^x x + x^\\top C_i z + c_i^\\top z + d_i^\\top x$$ +In this problem, the inner and the outer functions are quadritic functions defined on $\mathbb{R}^{d\times p}$ + +$$g(x, z) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2} z^\top H_i^z z + \frac{1}{2} x^\top H_i^x x + x^\top C_i z + c_i^\top z + d_i^\top x$$ and -$$f(x, z) = \\frac{1}{m} \\sum_{j=1}^m \\frac{1}{2} z^\\top \\tilde H_j^z z + \\frac{1}{2} x^\\top \\tilde H_j^x x + x^\\top \\tilde C_j z + \\tilde c_j^\\top z + \\tilde d_j^\\top x$$ +$$f(x, z) = \frac{1}{m} \sum_{j=1}^m \frac{1}{2} z^\top \tilde H_j^z z + \frac{1}{2} x^\top \tilde H_j^x x + x^\top \tilde C_j z + \tilde c_j^\top z + \tilde d_j^\top x$$ -where $H_i^z, \\tilde H_j^z$ are symmetric positive definite matrices of size $p\\times p$, $H_j^x, \\tilde H_j^x$ are symmetric positive definite matrices of size $d\\times d$, $C_i, \\tilde C_j$ are matrices of size $d\\times p$, $c_i$, $\\tilde c_j$ are vectors of size $d$, and $d_i, \\tilde d_j$ are vectors of size $p$. +where $H_i^z, \tilde H_j^z$ are symmetric positive definite matrices of size $p\times p$, $H_j^x, \tilde H_j^x$ are symmetric positive definite matrices of size $d\times d$, $C_i, \tilde C_j$ are matrices of size $d\times p$, $c_i$, $\tilde c_j$ are vectors of size $d$, and $d_i, \tilde d_j$ are vectors of size $p$. -The matrices $H_i^z, H_i^x, \\tilde H_j^z, \\tilde H_j^x$ are randomly generated such that the eigenvalues of $\\frac1n\\sum_i H_i^z$ are between ``mu_inner``, and ``L_inner_inner``, the eigenvalues of $\\frac1n\\sum_i H_i^x$ are between ``mu_inner``, and ``L_inner_outer``, the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^z$ are between ``mu_inner``, and ``L_outer_inner``, and the eigenvalues of $\\frac1m\\sum_j \\tilde H_j^x$ are between ``mu_inner``, and ``L_outer_outer``. +The matrices $H_i^z, H_i^x, \tilde H_j^z, \tilde H_j^x$ are randomly generated such that the eigenvalues of $\frac1n\sum_i H_i^z$ are between ``mu_inner``, and ``L_inner_inner``, the eigenvalues of $\frac1n\sum_i H_i^x$ are between ``mu_inner``, and ``L_inner_outer``, the eigenvalues of $\frac1m\sum_j \tilde H_j^z$ are between ``mu_inner``, and ``L_outer_inner``, and the eigenvalues of $\frac1m\sum_j \tilde H_j^x$ are between ``mu_inner``, and ``L_outer_outer``. -The matrices $C_i, \\tilde C_j$ are generated randomly such that the spectral norm of $\\frac1n\\sum_i C_i$ is lower than ``L_cross_inner``, and the spectral norm of $\\frac1m\\sum_j \\tilde C_j$ is lower than ``L_cross_outer``. +The matrices $C_i, \tilde C_j$ are generated randomly such that the spectral norm of $\frac1n\sum_i C_i$ is lower than ``L_cross_inner``, and the spectral norm of $\frac1m\sum_j \tilde C_j$ is lower than ``L_cross_outer``. Note that in this setting, the solution of the inner problem is a linear system. As, the full batch inner and outer functions can be computed efficiently with the average Hessian matrices, the value function is evaluated in closed form. -2 - Regularization selection -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +### 2 - Regularization selection In this problem, the inner function $g$ is defined by -$$g(x, z) = \\frac{1}{n} \\sum_{i=1}^{n} \\ell(d_i; z) + \\mathcal{R}(x, z)$$ +$$g(x, z) = \frac{1}{n} \sum_{i=1}^{n} \ell(d_i; z) + \mathcal{R}(x, z)$$ -where $d_1, \\dots, d_n$ are training data samples, $z$ are the parameters of the machine learning model, and the loss function $\\ell$ measures how well the model parameters $z$ predict the data $d_i$. -There is also a regularization $\\mathcal{R}$ that is parametrized by the regularization strengths $x$, which aims at promoting a certain structure on the parameters $z$. +where $d_1, \dots, d_n$ are training data samples, $z$ are the parameters of the machine learning model, and the loss function $\ell$ measures how well the model parameters $z$ predict the data $d_i$. +There is also a regularization $\mathcal{R}$ that is parametrized by the regularization strengths $x$, which aims at promoting a certain structure on the parameters $z$. The outer function $f$ is defined as the unregularized loss on unseen data -$$f(x, z) = \\frac{1}{m} \\sum_{j=1}^{m} \\ell(d'_j; z)$$ +$$f(x, z) = \frac{1}{m} \sum_{j=1}^{m} \ell(d'_j; z)$$ -where the $d'_1, \\dots, d'_m$ are new samples from the same dataset as above. +where the $d'_1, \dots, d'_m$ are new samples from the same dataset as above. There are currently two datasets for this regularization selection problem. -Covtype -+++++++ +#### Covtype *Homepage : https://archive.ics.uci.edu/dataset/31/covertype* -This is a logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features and $y_i=\\pm1$ is the binary target. -For this problem, the loss is $\\ell(d_i, z) = \\log(1+\\exp(-y_i a_i^T z))$, and the regularization is simply given by -$$\\mathcal{R}(x, z) = \\frac12\\sum_{j=1}^p\\exp(x_j)z_j^2,$$ -each coefficient in $z$ is independently regularized with the strength $\\exp(x_j)$. +This is a logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\in\mathbb{R}^p$ are the features and $y_i=\pm1$ is the binary target. +For this problem, the loss is $\ell(d_i, z) = \log(1+\exp(-y_i a_i^T z))$, and the regularization is simply given by +$$\mathcal{R}(x, z) = \frac12\sum_{j=1}^p\exp(x_j)z_j^2,$$ +each coefficient in $z$ is independently regularized with the strength $\exp(x_j)$. -Ijcnn1 -++++++ +#### Ijcnn1 *Homepage : https://www.openml.org/search?type=data&sort=runs&id=1575&status=active* -This is a multicalss logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\\in\\mathbb{R}^p$ are the features and $y_i\\in \\{1,\\dots, k\\}$ is the integer target, with k the number of classes. -For this problem, the loss is $\\ell(d_i, z) = \\text{CrossEntropy}(za_i, y_i)$ where $z$ is now a k x p matrix. The regularization is given by -$$\\mathcal{R}(x, z) = \\frac12\\sum_{j=1}^k\\exp(x_j)\\|z_j\\|^2,$$ -each line in $z$ is independently regularized with the strength $\\exp(x_j)$. +This is a multicalss logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\in\mathbb{R}^p$ are the features and $y_i\in \{1,\dots, k\}$ is the integer target, with k the number of classes. +For this problem, the loss is $\ell(d_i, z) = \text{CrossEntropy}(za_i, y_i)$ where $z$ is now a k x p matrix. The regularization is given by +$$\mathcal{R}(x, z) = \frac12\sum_{j=1}^k\exp(x_j)\|z_j\|^2,$$ +each line in $z$ is independently regularized with the strength $\exp(x_j)$. -3 - Hyper data cleaning -^^^^^^^^^^^^^^^^^^^^^^^ +### 3 - Hyper data cleaning -This problem was first introduced by [Fra2017]_ . +This problem was first introduced by [Franceschi et al., 2017](https://arxiv.org/abs/1703.01785). In this problem, the data is the MNIST dataset. -The training set has been corrupted: with a probability $p$, the label of the image $y\\in\\{1,\\dots,10\\}$ is replaced by another random label between 1 and 10. +The training set has been corrupted: with a probability $p$, the label of the image $y\in\{1,\dots,10\}$ is replaced by another random label between 1 and 10. We do not know beforehand which data has been corrupted. We have a clean testing set, which has not been corrupted. The goal is to fit a model on the corrupted training data that has good performances on the test set. @@ -91,13 +88,13 @@ To do so, a set of weights -- one per train sample -- is learned as well as the Ideally, we would want a weight of 0 for data that has been corrupted, and a weight of 1 for uncorrupted data. The problem is cast as a bilevel problem with $g$ given by -$$g(x, z) =\\frac1n \\sum_{i=1}^n \\sigma(x_i)\\ell(d_i, z) + \\frac C 2 \\|z\\|^2$$ +$$g(x, z) =\frac1n \sum_{i=1}^n \sigma(x_i)\ell(d_i, z) + \frac C 2 \|z\|^2$$ -where the $d_i$ are the corrupted training data, $\\ell$ is the loss of a CNN parameterized by $z$, $\\sigma$ is a sigmoid function, and C is a small regularization constant. -Here the outer variable $x$ is a vector of dimension $n$, and the weight of data $i$ is given by $\\sigma(x_i)$. +where the $d_i$ are the corrupted training data, $\ell$ is the loss of a CNN parameterized by $z$, $\sigma$ is a sigmoid function, and C is a small regularization constant. +Here the outer variable $x$ is a vector of dimension $n$, and the weight of data $i$ is given by $\sigma(x_i)$. The test function is -$$f(x, z) =\\frac1m \\sum_{j=1}^n \\ell(d'_j, z)$$ +$$f(x, z) =\frac1m \sum_{j=1}^n \ell(d'_j, z)$$ where the $d_j$ are uncorrupted testing data. @@ -106,41 +103,40 @@ Install This benchmark can be run using the following commands: -.. code-block:: - +```bash $ pip install -U benchopt $ git clone https://github.com/benchopt/benchmark_bilevel $ benchopt run benchmark_bilevel +``` Apart from the problem, options can be passed to ``benchopt run``, to restrict the benchmarks to some solvers or datasets, e.g.: -.. code-block:: - +```bash $ benchopt run benchmark_bilevel -s solver1 -d dataset2 --max-runs 10 --n-repetitions 10 +```` You can also use config files to setup the benchmark run: -.. code-block:: - +```bash $ benchopt run benchmark_bilevel --config config/X.yml +``` where ``X.yml`` is a config file. See https://benchopt.github.io/index.html#run-a-benchmark for an example of a config file. This will possibly launch a huge grid search. When available, you can rather use the file ``X_best_params.yml`` in order to launch an experiment with a single set of parameters for each solver. Use ``benchopt run -h`` for more details about these options, or visit https://benchopt.github.io/api.html. -How to contribute to the benchmark? ------------------------------------ +### How to contribute to the benchmark? If you want to add a solver or a new problem, you are welcome to open an issue or submit a pull request! -1 - How to add a new solvers? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +#### 1 - How to add a new solvers? + Each solver derive from the [`benchopt.BaseSolver` class](https://benchopt.github.io/user_guide/generated/benchopt.BaseSolver.html) in the [solvers](solvers) folder. The solvers are separated among the stochastic JAX solvers and the others: * Stochastic Jax solver: these solvers inherit from the [`StochasticJaxSolver` class](benchmark_utils/stochastic_jax_solver.py) see the detailed explanations in the [template stochastic solver](solvers/template_stochastic_solver.py). * Other solver: see the detailed explanation in the [Benchopt documentation](https://benchopt.github.io/tutorials/add_solver.html). An example is provided in the [template solver](solvers/template_solver.py). -2 - How to add a new problem? -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +#### 2 - How to add a new problem? + In this benchmark, each problem is defined by a [Dataset class](https://benchopt.github.io/user_guide/generated/benchopt.BaseDataset.html) in the [datasets](datasets) folder. A [template](datasets/template_dataset.py) is provided. Cite @@ -148,20 +144,11 @@ Cite If you use this benchmark in your research project, please cite the following paper: -.. code-block:: - +``` @inproceedings{saba, title = {A Framework for Bilevel Optimization That Enables Stochastic and Global Variance Reduction Algorithms}, booktitle = {Advances in {{Neural Information Processing Systems}} ({{NeurIPS}})}, author = {Dagr{\'e}ou, Mathieu and Ablin, Pierre and Vaiter, Samuel and Moreau, Thomas}, year = {2022} } - - -References ----------- -.. [Fra2017] Franceschi, Luca, et al. "Forward and reverse gradient-based hyperparameter optimization." International Conference on Machine Learning. PMLR, 2017. -.. |Build Status| image:: https://github.com/benchopt/benchmark_bilevel/workflows/Tests/badge.svg - :target: https://github.com/benchopt/benchmark_bilevel/actions -.. |Python 3.6+| image:: https://img.shields.io/badge/python-3.6%2B-blue - :target: https://www.python.org/downloads/release/python-360/ +``` From ae2ce5de4d0b636f344eb15a4b600d5979150d00 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 18 Oct 2024 11:05:03 +0200 Subject: [PATCH 28/50] FIX brackets --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ed8c7e1..0ee79c2 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,9 @@ BenchOpt is a package to simplify, make more transparent, and reproducible the comparisons of optimization algorithms. This benchmark is dedicated to solvers for bilevel optimization: -$$\min_{x} f(x, z^*(x)) \quad \text{with} \quad z^*(x) = \arg\min_z g(x, z), $$ +$$ +\min_{x} f(x, z^*(x)) \quad \text{with} \quad z^*(x) = \arg\min_z g(x, z), +$$ where $g$ and $f$ are two functions of two variables. @@ -80,7 +82,7 @@ each line in $z$ is independently regularized with the strength $\exp(x_j)$. This problem was first introduced by [Franceschi et al., 2017](https://arxiv.org/abs/1703.01785). In this problem, the data is the MNIST dataset. -The training set has been corrupted: with a probability $p$, the label of the image $y\in\{1,\dots,10\}$ is replaced by another random label between 1 and 10. +The training set has been corrupted: with a probability $p$, the label of the image $y\in {{1,\dots,10}}$ is replaced by another random label between 1 and 10. We do not know beforehand which data has been corrupted. We have a clean testing set, which has not been corrupted. The goal is to fit a model on the corrupted training data that has good performances on the test set. From b0cb39af5e104400259a7e3ca3ded6bcc2cbb82b Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 18 Oct 2024 11:05:36 +0200 Subject: [PATCH 29/50] FIX brackets --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0ee79c2..336875a 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ each line in $z$ is independently regularized with the strength $\exp(x_j)$. This problem was first introduced by [Franceschi et al., 2017](https://arxiv.org/abs/1703.01785). In this problem, the data is the MNIST dataset. -The training set has been corrupted: with a probability $p$, the label of the image $y\in {{1,\dots,10}}$ is replaced by another random label between 1 and 10. +The training set has been corrupted: with a probability $p$, the label of the image $y\in \{1,\dots,10\}$ is replaced by another random label between 1 and 10. We do not know beforehand which data has been corrupted. We have a clean testing set, which has not been corrupted. The goal is to fit a model on the corrupted training data that has good performances on the test set. From 928b70afcbf6c6d3bf6d8a021b413b0a9ffd95ad Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 18 Oct 2024 11:06:10 +0200 Subject: [PATCH 30/50] FIX brackets --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 336875a..09ee74d 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ each line in $z$ is independently regularized with the strength $\exp(x_j)$. This problem was first introduced by [Franceschi et al., 2017](https://arxiv.org/abs/1703.01785). In this problem, the data is the MNIST dataset. -The training set has been corrupted: with a probability $p$, the label of the image $y\in \{1,\dots,10\}$ is replaced by another random label between 1 and 10. +The training set has been corrupted: with a probability $p$, the label of the image $y\in `\{1,\dots,10`\}$ is replaced by another random label between 1 and 10. We do not know beforehand which data has been corrupted. We have a clean testing set, which has not been corrupted. The goal is to fit a model on the corrupted training data that has good performances on the test set. From a29f39355e5f397f98b3b8cce16bda7e7cfece44 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 18 Oct 2024 11:07:08 +0200 Subject: [PATCH 31/50] FIX brackets --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 09ee74d..7ff76ac 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ each line in $z$ is independently regularized with the strength $\exp(x_j)$. This problem was first introduced by [Franceschi et al., 2017](https://arxiv.org/abs/1703.01785). In this problem, the data is the MNIST dataset. -The training set has been corrupted: with a probability $p$, the label of the image $y\in `\{1,\dots,10`\}$ is replaced by another random label between 1 and 10. +The training set has been corrupted: with a probability $p$, the label of the image $y\in `\{1,\dots,10\}`$ is replaced by another random label between 1 and 10. We do not know beforehand which data has been corrupted. We have a clean testing set, which has not been corrupted. The goal is to fit a model on the corrupted training data that has good performances on the test set. From 10383590e6d5f74935b9661c8d7485432207c793 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 18 Oct 2024 11:07:24 +0200 Subject: [PATCH 32/50] FIX brackets --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7ff76ac..0f626c9 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ each line in $z$ is independently regularized with the strength $\exp(x_j)$. This problem was first introduced by [Franceschi et al., 2017](https://arxiv.org/abs/1703.01785). In this problem, the data is the MNIST dataset. -The training set has been corrupted: with a probability $p$, the label of the image $y\in `\{1,\dots,10\}`$ is replaced by another random label between 1 and 10. +The training set has been corrupted: with a probability $p$, the label of the image $`y\in\{1,\dots,10\}`$ is replaced by another random label between 1 and 10. We do not know beforehand which data has been corrupted. We have a clean testing set, which has not been corrupted. The goal is to fit a model on the corrupted training data that has good performances on the test set. From 50ca4aac828583f72855e0ffa799faec706c0501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Dagr=C3=A9ou?= <77896657+MatDag@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:15:08 +0200 Subject: [PATCH 33/50] Update README.md --- README.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 0f626c9..264ced1 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,8 @@ Bilevel Optimization Benchmark *Results can be consulted on https://benchopt.github.io/results/benchmark_bilevel.html* -BenchOpt is a package to simplify, make more transparent, and -reproducible the comparisons of optimization algorithms. +BenchOpt is a package to simplify, to make more transparent, and +reproducible the comparison of optimization algorithms. This benchmark is dedicated to solvers for bilevel optimization: $$ @@ -18,12 +18,12 @@ where $g$ and $f$ are two functions of two variables. Different problems ------------------ -This benchmark currently implements three bilevel optimization problems: quadratic problem, regularization selection, and hyper data cleaning. +This benchmark currently implements three bilevel optimization problems: quadratic problem, regularization selection, and data cleaning. ### 1 - Simulated quadratic bilevel problem -In this problem, the inner and the outer functions are quadritic functions defined on $\mathbb{R}^{d\times p}$ +In this problem, the inner and the outer functions are quadratic functions defined on $\mathbb{R}^{d\times p}$ $$g(x, z) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2} z^\top H_i^z z + \frac{1}{2} x^\top H_i^x x + x^\top C_i z + c_i^\top z + d_i^\top x$$ @@ -38,7 +38,7 @@ The matrices $H_i^z, H_i^x, \tilde H_j^z, \tilde H_j^x$ are randomly generated s The matrices $C_i, \tilde C_j$ are generated randomly such that the spectral norm of $\frac1n\sum_i C_i$ is lower than ``L_cross_inner``, and the spectral norm of $\frac1m\sum_j \tilde C_j$ is lower than ``L_cross_outer``. Note that in this setting, the solution of the inner problem is a linear system. -As, the full batch inner and outer functions can be computed efficiently with the average Hessian matrices, the value function is evaluated in closed form. +As the full batch inner and outer functions can be computed efficiently with the average Hessian matrices, the value function is evaluated in closed form. ### 2 - Regularization selection @@ -72,13 +72,13 @@ each coefficient in $z$ is independently regularized with the strength $\exp(x_j *Homepage : https://www.openml.org/search?type=data&sort=runs&id=1575&status=active* -This is a multicalss logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\in\mathbb{R}^p$ are the features and $y_i\in \{1,\dots, k\}$ is the integer target, with k the number of classes. +This is a multiclass logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\in\mathbb{R}^p$ are the features and $y_i\in \{1,\dots, k\}$ is the integer target, with k the number of classes. For this problem, the loss is $\ell(d_i, z) = \text{CrossEntropy}(za_i, y_i)$ where $z$ is now a k x p matrix. The regularization is given by $$\mathcal{R}(x, z) = \frac12\sum_{j=1}^k\exp(x_j)\|z_j\|^2,$$ each line in $z$ is independently regularized with the strength $\exp(x_j)$. -### 3 - Hyper data cleaning +### 3 - Data cleaning This problem was first introduced by [Franceschi et al., 2017](https://arxiv.org/abs/1703.01785). In this problem, the data is the MNIST dataset. @@ -111,19 +111,19 @@ This benchmark can be run using the following commands: $ benchopt run benchmark_bilevel ``` -Apart from the problem, options can be passed to ``benchopt run``, to restrict the benchmarks to some solvers or datasets, e.g.: +Apart from the problem, options can be passed to ``benchopt run`` to restrict the benchmarks to some solvers or datasets, e.g.: ```bash $ benchopt run benchmark_bilevel -s solver1 -d dataset2 --max-runs 10 --n-repetitions 10 ```` -You can also use config files to setup the benchmark run: +You can also use config files to set the benchmark run: ```bash $ benchopt run benchmark_bilevel --config config/X.yml ``` -where ``X.yml`` is a config file. See https://benchopt.github.io/index.html#run-a-benchmark for an example of a config file. This will possibly launch a huge grid search. When available, you can rather use the file ``X_best_params.yml`` in order to launch an experiment with a single set of parameters for each solver. +where ``X.yml`` is a config file. See https://benchopt.github.io/index.html#run-a-benchmark for an example of a config file. This will possibly launch a huge grid search. When available, you can rather use the file ``X_best_params.yml``to launch an experiment with a single set of parameters for each solver. Use ``benchopt run -h`` for more details about these options, or visit https://benchopt.github.io/api.html. @@ -131,9 +131,9 @@ Use ``benchopt run -h`` for more details about these options, or visit https://b If you want to add a solver or a new problem, you are welcome to open an issue or submit a pull request! -#### 1 - How to add a new solvers? +#### 1 - How to add a new solver? -Each solver derive from the [`benchopt.BaseSolver` class](https://benchopt.github.io/user_guide/generated/benchopt.BaseSolver.html) in the [solvers](solvers) folder. The solvers are separated among the stochastic JAX solvers and the others: +Each solver derives from the [`benchopt.BaseSolver` class](https://benchopt.github.io/user_guide/generated/benchopt.BaseSolver.html) in the [solvers](solvers) folder. The solvers are separated among the stochastic JAX solvers and the others: * Stochastic Jax solver: these solvers inherit from the [`StochasticJaxSolver` class](benchmark_utils/stochastic_jax_solver.py) see the detailed explanations in the [template stochastic solver](solvers/template_stochastic_solver.py). * Other solver: see the detailed explanation in the [Benchopt documentation](https://benchopt.github.io/tutorials/add_solver.html). An example is provided in the [template solver](solvers/template_solver.py). @@ -147,7 +147,7 @@ Cite If you use this benchmark in your research project, please cite the following paper: ``` - @inproceedings{saba, + @inproceedings{dagreou2022, title = {A Framework for Bilevel Optimization That Enables Stochastic and Global Variance Reduction Algorithms}, booktitle = {Advances in {{Neural Information Processing Systems}} ({{NeurIPS}})}, author = {Dagr{\'e}ou, Mathieu and Ablin, Pierre and Vaiter, Samuel and Moreau, Thomas}, From cd580c1a78bc93ed1b47326bbc4668af9ba70d36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Dagr=C3=A9ou?= <77896657+MatDag@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:23:06 +0200 Subject: [PATCH 34/50] Update README.md --- README.md | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 264ced1..379bf45 100644 --- a/README.md +++ b/README.md @@ -9,16 +9,14 @@ BenchOpt is a package to simplify, to make more transparent, and reproducible the comparison of optimization algorithms. This benchmark is dedicated to solvers for bilevel optimization: -$$ -\min_{x} f(x, z^*(x)) \quad \text{with} \quad z^*(x) = \arg\min_z g(x, z), -$$ +$$\min_{x} f(x, z^* (x)) \quad \text{with} \quad z^*(x) = \arg\min_z g(x, z),$$ where $g$ and $f$ are two functions of two variables. Different problems ------------------ -This benchmark currently implements three bilevel optimization problems: quadratic problem, regularization selection, and data cleaning. +This benchmark implements three bilevel optimization problems: quadratic problem, regularization selection, and data cleaning. ### 1 - Simulated quadratic bilevel problem @@ -51,7 +49,7 @@ $$g(x, z) = \frac{1}{n} \sum_{i=1}^{n} \ell(d_i; z) + \mathcal{R}(x, z)$$ where $d_1, \dots, d_n$ are training data samples, $z$ are the parameters of the machine learning model, and the loss function $\ell$ measures how well the model parameters $z$ predict the data $d_i$. There is also a regularization $\mathcal{R}$ that is parametrized by the regularization strengths $x$, which aims at promoting a certain structure on the parameters $z$. -The outer function $f$ is defined as the unregularized loss on unseen data +The outer function $f$ is defined as the unregularized loss on unseen data $$f(x, z) = \frac{1}{m} \sum_{j=1}^{m} \ell(d'_j; z)$$ @@ -63,7 +61,7 @@ There are currently two datasets for this regularization selection problem. *Homepage : https://archive.ics.uci.edu/dataset/31/covertype* -This is a logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\in\mathbb{R}^p$ are the features and $y_i=\pm1$ is the binary target. +This is a logistic regression problem, where the data have the form $d_i = (a_i, y_i)$ with $a_i\in\mathbb{R}^p$ the features and $y_i=\pm1$ the binary target. For this problem, the loss is $\ell(d_i, z) = \log(1+\exp(-y_i a_i^T z))$, and the regularization is simply given by $$\mathcal{R}(x, z) = \frac12\sum_{j=1}^p\exp(x_j)z_j^2,$$ each coefficient in $z$ is independently regularized with the strength $\exp(x_j)$. @@ -87,7 +85,7 @@ We do not know beforehand which data has been corrupted. We have a clean testing set, which has not been corrupted. The goal is to fit a model on the corrupted training data that has good performances on the test set. To do so, a set of weights -- one per train sample -- is learned as well as the model parameters. -Ideally, we would want a weight of 0 for data that has been corrupted, and a weight of 1 for uncorrupted data. +Ideally, we would want a weight of 0 for data that has been corrupted and a weight of 1 for uncorrupted data. The problem is cast as a bilevel problem with $g$ given by $$g(x, z) =\frac1n \sum_{i=1}^n \sigma(x_i)\ell(d_i, z) + \frac C 2 \|z\|^2$$ @@ -123,7 +121,7 @@ You can also use config files to set the benchmark run: $ benchopt run benchmark_bilevel --config config/X.yml ``` -where ``X.yml`` is a config file. See https://benchopt.github.io/index.html#run-a-benchmark for an example of a config file. This will possibly launch a huge grid search. When available, you can rather use the file ``X_best_params.yml``to launch an experiment with a single set of parameters for each solver. +where ``X.yml`` is a config file. See https://benchopt.github.io/index.html#run-a-benchmark for an example of a config file. This will launch a huge grid search. When available, you can rather use the file ``X_best_params.yml`` to launch an experiment with a single set of parameters for each solver. Use ``benchopt run -h`` for more details about these options, or visit https://benchopt.github.io/api.html. From 5f68c116d39e5701683ef6c29e89dc30eca8ef56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Dagr=C3=A9ou?= <77896657+MatDag@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:42:17 +0200 Subject: [PATCH 35/50] CLN remove tilde --- README.md | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 379bf45..9c8484e 100644 --- a/README.md +++ b/README.md @@ -23,17 +23,17 @@ This benchmark implements three bilevel optimization problems: quadratic problem In this problem, the inner and the outer functions are quadratic functions defined on $\mathbb{R}^{d\times p}$ -$$g(x, z) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2} z^\top H_i^z z + \frac{1}{2} x^\top H_i^x x + x^\top C_i z + c_i^\top z + d_i^\top x$$ +$$g(x, z) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2} z^\top A_i z + \frac{1}{2} x^\top B_i x + x^\top C_i z + a_i^\top z + b_i^\top x$$ and -$$f(x, z) = \frac{1}{m} \sum_{j=1}^m \frac{1}{2} z^\top \tilde H_j^z z + \frac{1}{2} x^\top \tilde H_j^x x + x^\top \tilde C_j z + \tilde c_j^\top z + \tilde d_j^\top x$$ +$$f(x, z) = \frac{1}{m} \sum_{j=1}^m \frac{1}{2} z^\top F_j z + \frac{1}{2} x^\top H_j x + x^\top K_j z + f_j^\top z + h_j^\top x$$ -where $H_i^z, \tilde H_j^z$ are symmetric positive definite matrices of size $p\times p$, $H_j^x, \tilde H_j^x$ are symmetric positive definite matrices of size $d\times d$, $C_i, \tilde C_j$ are matrices of size $d\times p$, $c_i$, $\tilde c_j$ are vectors of size $d$, and $d_i, \tilde d_j$ are vectors of size $p$. +where $A_i, F_j$ are symmetric positive definite matrices of size $p\times p$, $B_i, F_j$ are symmetric positive definite matrices of size $d\times d$, $C_i, K_j$ are matrices of size $d\times p$, $a_i$, $f_j$ are vectors of size $d$, and $b_i, h_j$ are vectors of size $p$. -The matrices $H_i^z, H_i^x, \tilde H_j^z, \tilde H_j^x$ are randomly generated such that the eigenvalues of $\frac1n\sum_i H_i^z$ are between ``mu_inner``, and ``L_inner_inner``, the eigenvalues of $\frac1n\sum_i H_i^x$ are between ``mu_inner``, and ``L_inner_outer``, the eigenvalues of $\frac1m\sum_j \tilde H_j^z$ are between ``mu_inner``, and ``L_outer_inner``, and the eigenvalues of $\frac1m\sum_j \tilde H_j^x$ are between ``mu_inner``, and ``L_outer_outer``. +The matrices $A_i, B_i, F_j, H_j$ are randomly generated such that the eigenvalues of $\frac1n\sum_i A_i$ are between ``mu_inner``, and ``L_inner_inner``, the eigenvalues of $\frac1n\sum_i B_i$ are between ``mu_inner``, and ``L_inner_outer``, the eigenvalues of $\frac1m\sum_j F_j$ are between ``mu_inner``, and ``L_outer_inner``, and the eigenvalues of $\frac1m\sum_j H_j$ are between ``mu_inner``, and ``L_outer_outer``. -The matrices $C_i, \tilde C_j$ are generated randomly such that the spectral norm of $\frac1n\sum_i C_i$ is lower than ``L_cross_inner``, and the spectral norm of $\frac1m\sum_j \tilde C_j$ is lower than ``L_cross_outer``. +The matrices $C_i, K_j$ are generated randomly such that the spectral norm of $\frac1n\sum_i C_i$ is lower than ``L_cross_inner``, and the spectral norm of $\frac1m\sum_j K_j$ is lower than ``L_cross_outer``. Note that in this setting, the solution of the inner problem is a linear system. As the full batch inner and outer functions can be computed efficiently with the average Hessian matrices, the value function is evaluated in closed form. @@ -57,18 +57,14 @@ where the $d'_1, \dots, d'_m$ are new samples from the same dataset as above. There are currently two datasets for this regularization selection problem. -#### Covtype - -*Homepage : https://archive.ics.uci.edu/dataset/31/covertype* +#### Covtype - [*Homepage*](https://archive.ics.uci.edu/dataset/31/covertype*) This is a logistic regression problem, where the data have the form $d_i = (a_i, y_i)$ with $a_i\in\mathbb{R}^p$ the features and $y_i=\pm1$ the binary target. For this problem, the loss is $\ell(d_i, z) = \log(1+\exp(-y_i a_i^T z))$, and the regularization is simply given by $$\mathcal{R}(x, z) = \frac12\sum_{j=1}^p\exp(x_j)z_j^2,$$ each coefficient in $z$ is independently regularized with the strength $\exp(x_j)$. -#### Ijcnn1 - -*Homepage : https://www.openml.org/search?type=data&sort=runs&id=1575&status=active* +#### Ijcnn1 - [*Homepage*](https://www.openml.org/search?type=data&sort=runs&id=1575&status=active) This is a multiclass logistic regression problem, where the data is of the form $d_i = (a_i, y_i)$ with $a_i\in\mathbb{R}^p$ are the features and $y_i\in \{1,\dots, k\}$ is the integer target, with k the number of classes. For this problem, the loss is $\ell(d_i, z) = \text{CrossEntropy}(za_i, y_i)$ where $z$ is now a k x p matrix. The regularization is given by @@ -112,7 +108,7 @@ This benchmark can be run using the following commands: Apart from the problem, options can be passed to ``benchopt run`` to restrict the benchmarks to some solvers or datasets, e.g.: ```bash - $ benchopt run benchmark_bilevel -s solver1 -d dataset2 --max-runs 10 --n-repetitions 10 + $ benchopt run benchmark_bilevel -s solver1 -d dataset2 --max-runs 10 --n-repetitions 10 ```` You can also use config files to set the benchmark run: From 0d1ab038615460f175a63588607be8aeff2a932e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Dagr=C3=A9ou?= <77896657+MatDag@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:43:30 +0200 Subject: [PATCH 36/50] FIX ref --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9c8484e..cab6e44 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ each line in $z$ is independently regularized with the strength $\exp(x_j)$. ### 3 - Data cleaning -This problem was first introduced by [Franceschi et al., 2017](https://arxiv.org/abs/1703.01785). +This problem was first introduced by [Franceschi et al. (2017)](https://arxiv.org/abs/1703.01785). In this problem, the data is the MNIST dataset. The training set has been corrupted: with a probability $p$, the label of the image $`y\in\{1,\dots,10\}`$ is replaced by another random label between 1 and 10. We do not know beforehand which data has been corrupted. From 7d9508e0c75d3f29e328b71d26e239d811ff7705 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 18 Oct 2024 11:16:02 +0200 Subject: [PATCH 37/50] CLN remove useless params --- datasets/template_dataset.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index a126bbf..dc52a66 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -52,10 +52,6 @@ def get_data(self): # The return arguments of this function are passed as keyword arguments # to `Objective.set_data`. This defines the benchmark's # API to pass data. - assert self.reg_parametrization in ['lin', 'exp'], ( - f"unknown reg parameter '{self.reg_parametrization}'. " - "Should be 'lin' or 'exp'." - ) X_train, y_train = fetch_libsvm('ijcnn1') X_val, y_val = fetch_libsvm('ijcnn1_test') From 59001b3da70c1057adea54149bf3de725e2f8b8e Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Fri, 18 Oct 2024 13:53:04 +0200 Subject: [PATCH 38/50] WIP --- datasets/simulated.py | 9 ---- datasets/template_dataset.py | 99 +++++++++++++++++++----------------- 2 files changed, 53 insertions(+), 55 deletions(-) diff --git a/datasets/simulated.py b/datasets/simulated.py index 64cea26..9a7166c 100644 --- a/datasets/simulated.py +++ b/datasets/simulated.py @@ -31,15 +31,6 @@ def quadratic(inner_var, outer_var, hess_inner, hess_outer, cross, return res -def batched_quadratic(inner_var, outer_var, hess_inner, hess_outer, cross, - linear_inner, linear_outer): - batched_loss = jax.vmap(quadratic, in_axes=(None, None, 0, 0, 0, 0, 0)) - return jnp.mean( - batched_loss(inner_var, outer_var, hess_inner, hess_outer, - cross, linear_inner, linear_outer) - ) - - def get_function(hess_inner, hess_outer, cross, linear_inner, linear_outer): @partial(jax.jit, static_argnames=('batch_size')) diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index dc52a66..ebf9e6a 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -6,7 +6,6 @@ # - getting requirements info when all dependencies are not installed. with safe_import_context() as import_ctx: import numpy as np - from libsvmdata import fetch_libsvm import jax import jax.numpy as jnp @@ -15,37 +14,65 @@ from jaxopt import LBFGS -def loss_sample(inner_var, outer_var, x, y): - return -jax.nn.log_sigmoid(y*jnp.dot(inner_var, x)) +def generate_matrices(dim_inner, dim_outer, key=jax.random.PRNGKey(0)): + """Generates the different matrices of the inner and outer quadratic + functions.""" + keys = jax.random.split(key, 4) + eig_inner = jnp.logpsace(-1, 0, dim_inner) + eig_outer = jnp.logpsace(-1, 0, dim_inner) + eig_cross = jnp.logpsace(-1, 0, min(dim_inner, dim_outer)) + # Matrix generation for the inner function + # Generate a PSD matrix with eigenvalues `eig_inner` + hess_inner_inner = jax.random.normal(keys[0], (dim_inner, dim_inner)) + U, _, _ = jnp.linalg.svd(hess_inner_inner) + hess_inner_inner = U @ jnp.diag(eig_inner) @ U.T -def loss(inner_var, outer_var, X, y): - batched_loss = jax.vmap(loss_sample, in_axes=(None, None, 0, 0)) - return jnp.mean(batched_loss(inner_var, outer_var, X, y), axis=0) + # Generate a PSD matrix with eigenvalues `eig_outer` + hess_outer_inner = jax.random.normal(keys[1], (dim_outer, dim_outer)) + U, _, _ = jnp.linalg.svd(hess_outer_inner) + hess_outer_inner = U @ jnp.diag(eig_outer) @ U.T + + # Generate a PSD matrix with eigenvalues `eig_outer` + cross_inner = jax.random.normal(keys[2], (dim_outer, dim_inner)) + D = jnp.zeros((dim_outer, dim_inner)) + D = D.at[:min(dim_outer, dim_inner), :min(dim_outer, dim_inner)].set( + jnp.diag(eig_cross) + ) + U, _, V = jnp.linalg.svd(cross_inner) + cross_inner = U @ D @ V.T + + hess_inner_outer = jax.random.normal(keys[3], (dim_inner, dim_inner)) + U, _, _ = jnp.linalg.svd(hess_inner_outer) + hess_inner_outer = U @ jnp.diag(eig_inner) @ U.T + + return hess_inner_inner, hess_outer_inner, cross_inner, hess_inner_outer + + +def quadratic(inner_var, outer_var, hess_inner, hess_outer, cross): + res = .5 * inner_var @ (hess_inner @ inner_var) + res += .5 * outer_var @ (hess_outer @ outer_var) + res += outer_var @ cross @ inner_var + return res # All datasets must be named `Dataset` and inherit from `BaseDataset` class Dataset(BaseDataset): - """Hyperparameter optimization with IJCNN1 dataset.""" - # Name to select the dataset in the CLI and to display the results. - name = "ijcnn1" """How to add a new problem to the benchmark? This template dataset is an adaptation of the dataset from the benchopt template benchmark (https://github.com/benchopt/template_benchmark/) to the bilevel setting. """ - - install_cmd = 'conda' - # List of packages needed to run the dataset. See the corresponding - # section in objective.py - requirements = ['pip:libsvmdata', 'scikit-learn'] + # Name to select the dataset in the CLI and to display the results. + name = "Template dataset" # List of parameters to generate the datasets. The benchmark will consider # the cross product for each key in the dictionary. # Any parameters 'param' defined here is available as `self.param`. parameters = { - 'reg_parametrization': ['exp'], + 'dim_inner': [10], + 'dim_outer': [10], } def get_data(self): @@ -53,43 +80,23 @@ def get_data(self): # to `Objective.set_data`. This defines the benchmark's # API to pass data. - X_train, y_train = fetch_libsvm('ijcnn1') - X_val, y_val = fetch_libsvm('ijcnn1_test') - - X_train, y_train = jnp.array(X_train), jnp.array(y_train) - X_val, y_val = jnp.array(X_val), jnp.array(y_val) - - self.n_samples_inner = X_train.shape[0] - self.dim_inner = X_train.shape[1] - self.n_samples_outer = X_val.shape[0] - self.dim_outer = X_val.shape[1] + hess_inner_inner, hess_outer_inner, cross, hess_inner_outer = ( + generate_matrices( + self.dim_inner, self.dim_outer + ) + ) @partial(jax.jit, static_argnames=('batch_size')) def f_inner(inner_var, outer_var, start=0, batch_size=1): - x = jax.lax.dynamic_slice( - X_train, (start, 0), (batch_size, X_train.shape[1]) - ) - y = jax.lax.dynamic_slice( - y_train, (start, ), (batch_size, ) - ) - res = loss(inner_var, outer_var, x, y) - - if self.reg_parametrization == 'exp': - res += jnp.dot(jnp.exp(outer_var) * inner_var, inner_var)/2 - elif self.reg_parametrization == 'lin': - res += jnp.dot(outer_var * inner_var, inner_var)/2 - return res + return quadratic(inner_var, outer_var, + hess_inner_inner, hess_outer_inner, + cross) @partial(jax.jit, static_argnames=('batch_size')) def f_outer(inner_var, outer_var, start=0, batch_size=1): - x = jax.lax.dynamic_slice( - X_val, (start, 0), (batch_size, X_val.shape[1]) - ) - y = jax.lax.dynamic_slice( - y_val, (start, ), (batch_size, ) - ) - res = loss(inner_var, outer_var, x, y) - return res + return quadratic(inner_var, outer_var, hess_inner_outer, + jnp.zeros_like(hess_outer_inner), + jnp.zeros_like(cross)) f_inner_fb = partial( f_inner, batch_size=X_train.shape[0], start=0 From 2fac0da614c41f82485f92806059ac605970da44 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Tue, 22 Oct 2024 10:17:49 +0200 Subject: [PATCH 39/50] ENH simplify template_dataset --- datasets/template_dataset.py | 136 ++++++++++++++++++++++------------- 1 file changed, 86 insertions(+), 50 deletions(-) diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index ebf9e6a..939c21f 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -5,53 +5,46 @@ # - skipping import to speed up autocompletion in CLI. # - getting requirements info when all dependencies are not installed. with safe_import_context() as import_ctx: - import numpy as np - import jax import jax.numpy as jnp - from functools import partial + from functools import partial # useful for just-in-time compilation - from jaxopt import LBFGS + from jaxopt import LBFGS # useful to define the value function def generate_matrices(dim_inner, dim_outer, key=jax.random.PRNGKey(0)): """Generates the different matrices of the inner and outer quadratic functions.""" - keys = jax.random.split(key, 4) + keys = jax.random.split(key, 3) eig_inner = jnp.logpsace(-1, 0, dim_inner) - eig_outer = jnp.logpsace(-1, 0, dim_inner) - eig_cross = jnp.logpsace(-1, 0, min(dim_inner, dim_outer)) + sing_cross = jnp.logpsace(-1, 0, min(dim_inner, dim_outer)) # Matrix generation for the inner function # Generate a PSD matrix with eigenvalues `eig_inner` - hess_inner_inner = jax.random.normal(keys[0], (dim_inner, dim_inner)) - U, _, _ = jnp.linalg.svd(hess_inner_inner) - hess_inner_inner = U @ jnp.diag(eig_inner) @ U.T - - # Generate a PSD matrix with eigenvalues `eig_outer` - hess_outer_inner = jax.random.normal(keys[1], (dim_outer, dim_outer)) - U, _, _ = jnp.linalg.svd(hess_outer_inner) - hess_outer_inner = U @ jnp.diag(eig_outer) @ U.T + hess_inner = jax.random.normal(keys[0], (dim_inner, dim_inner)) + U, _, _ = jnp.linalg.svd(hess_inner) + hess_inner = U @ jnp.diag(eig_inner) @ U.T - # Generate a PSD matrix with eigenvalues `eig_outer` - cross_inner = jax.random.normal(keys[2], (dim_outer, dim_inner)) + # Generate a matrix with singular values `sing_cross` + cross_inner = jax.random.normal(keys[1], (dim_outer, dim_inner)) D = jnp.zeros((dim_outer, dim_inner)) D = D.at[:min(dim_outer, dim_inner), :min(dim_outer, dim_inner)].set( - jnp.diag(eig_cross) + jnp.diag(sing_cross) ) U, _, V = jnp.linalg.svd(cross_inner) cross_inner = U @ D @ V.T - hess_inner_outer = jax.random.normal(keys[3], (dim_inner, dim_inner)) - U, _, _ = jnp.linalg.svd(hess_inner_outer) - hess_inner_outer = U @ jnp.diag(eig_inner) @ U.T + hess_outer = jax.random.normal(keys[2], (dim_inner, dim_inner)) + U, _, _ = jnp.linalg.svd(hess_outer) + hess_outer = U @ jnp.diag(eig_inner) @ U.T - return hess_inner_inner, hess_outer_inner, cross_inner, hess_inner_outer + return hess_inner, cross_inner, hess_outer -def quadratic(inner_var, outer_var, hess_inner, hess_outer, cross): +def quadratic(inner_var, outer_var, hess_inner, cross): + """Defines a quadratic function for given hessian and cross + derivative matrices.""" res = .5 * inner_var @ (hess_inner @ inner_var) - res += .5 * outer_var @ (hess_outer @ outer_var) res += outer_var @ cross @ inner_var return res @@ -76,40 +69,93 @@ class Dataset(BaseDataset): } def get_data(self): - # The return arguments of this function are passed as keyword arguments - # to `Objective.set_data`. This defines the benchmark's - # API to pass data. - - hess_inner_inner, hess_outer_inner, cross, hess_inner_outer = ( + """This method retrieves/simulated the data, defines the inner and + outer objectives and the metrics to evaluate the results. It is + mandatory for each dataset. he return arguments of this function are + passed as keyword arguments to `Objective.set_data`. + + Returns + ------- + data: dict + A dictionary containing the keys `pb_inner`, `pb_outer`, `metrics` + and optionnally `init_var`. + + The entries of the dictionary are: + - `pb_inner`: tuple + Contains the inner function, the number of inner samples, the + dimension of the inner variable and the full batch version of the + inner objective. + + - `pb_outer`: tuple + Contains the outer function, the number of outer samples, the + dimension of the outer variable and the full batch version of the + outer objective. + + - `metrics`: function + Function that computes the metrics of the problem. + + - `init_var`: function, optional + Function that initializes the inner and outer variables. + """ + + hess_inner, cross, hess_outer = ( generate_matrices( self.dim_inner, self.dim_outer ) ) + # This decorator is used to jit the inner and the outer objective. + # static_argnames=('batch_size') means that the function is recompiled + # each time it is used with a new batch size. @partial(jax.jit, static_argnames=('batch_size')) def f_inner(inner_var, outer_var, start=0, batch_size=1): + """Defines the inner objective function. It should be a pure jax + function so that it can be jitted. + + Parameters + ---------- + inner_var: pytree + Inner variable. + + outer_var: pytree + Outer variable. + + start: int, default=0 + For stochastic problems, index of the first sample of the + batch. + + batch_size: int, default=1 + For stochastic problems, size of the batch. + + Returns + ------- + float + Value of the inner objective function. + """ return quadratic(inner_var, outer_var, - hess_inner_inner, hess_outer_inner, - cross) + hess_inner, cross) + # This is similar to f_inner @partial(jax.jit, static_argnames=('batch_size')) def f_outer(inner_var, outer_var, start=0, batch_size=1): - return quadratic(inner_var, outer_var, hess_inner_outer, - jnp.zeros_like(hess_outer_inner), + return quadratic(inner_var, outer_var, hess_outer, jnp.zeros_like(cross)) - f_inner_fb = partial( - f_inner, batch_size=X_train.shape[0], start=0 - ) - f_outer_fb = partial( - f_outer, batch_size=X_val.shape[0], start=0 - ) + # For stochastic problems, it is useful to define the full batch + # version of f_inner and f_outer, for instance to compute metrics + # or to be used in some solvers (e.g. SRBA). For non-stochastic + # problems, just define f_inner_fb = f_inner and f_outer_fb = f_outer. + f_inner_fb = f_inner + f_outer_fb = f_outer solver_inner = LBFGS(fun=f_inner_fb) + # The value function is useful for the metrics. Note that it is not + # mandatory to define it. In particular, for large scale problems, + # evaluating it can be cumbersome. def value_function(outer_var): inner_var_star = solver_inner.run( - jnp.zeros(X_train.shape[1]), outer_var + jnp.zeros(self.dim_inner), outer_var ).params return f_outer_fb(inner_var_star, outer_var), inner_var_star @@ -129,8 +175,6 @@ def metrics(inner_var, outer_var): value_func=float(value_fun), value=float(jnp.linalg.norm(grad_value)**2), inner_distance=float(jnp.linalg.norm(inner_star-inner_var)**2), - norm_outer_var=float(jnp.linalg.norm(outer_var)**2), - norm_regul=float(jnp.linalg.norm(np.exp(outer_var))**2), ) def init_var(key): @@ -151,12 +195,4 @@ def init_var(key): init_var=init_var, ) - # The output should be a dict that contains the keys `pb_inner`, - # `pb_outer`, `metrics`, and optionnally `init_var`. - # `pb_inner`` is a tuple that contains the inner function, the number - # of inner samples, the dimension of the inner variable and the full - # batch version of the inner version. - # `pb_outer` in analogous. - # The key `metrics` contains the function `metrics`. - # The key `init_var` contains the function `init_var` when applicable. return data From 3d7bded5c2df86cd9adb97cf7ef9d4535743e5fa Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Tue, 22 Oct 2024 10:21:42 +0200 Subject: [PATCH 40/50] FIX typo --- datasets/template_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index 939c21f..83d851a 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -69,9 +69,9 @@ class Dataset(BaseDataset): } def get_data(self): - """This method retrieves/simulated the data, defines the inner and + """This method retrieves/simulates the data, defines the inner and outer objectives and the metrics to evaluate the results. It is - mandatory for each dataset. he return arguments of this function are + mandatory for each dataset. The return arguments of this function are passed as keyword arguments to `Objective.set_data`. Returns From 10b3bc0040048b4248552f16c94b87035b8fdbf8 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Tue, 22 Oct 2024 10:26:21 +0200 Subject: [PATCH 41/50] FIX batched_quadratics disappeared in simulated.py... --- datasets/simulated.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/datasets/simulated.py b/datasets/simulated.py index 9a7166c..64cea26 100644 --- a/datasets/simulated.py +++ b/datasets/simulated.py @@ -31,6 +31,15 @@ def quadratic(inner_var, outer_var, hess_inner, hess_outer, cross, return res +def batched_quadratic(inner_var, outer_var, hess_inner, hess_outer, cross, + linear_inner, linear_outer): + batched_loss = jax.vmap(quadratic, in_axes=(None, None, 0, 0, 0, 0, 0)) + return jnp.mean( + batched_loss(inner_var, outer_var, hess_inner, hess_outer, + cross, linear_inner, linear_outer) + ) + + def get_function(hess_inner, hess_outer, cross, linear_inner, linear_outer): @partial(jax.jit, static_argnames=('batch_size')) From 5b80320781eb9b48066fa8baa931353e9a533766 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Tue, 22 Oct 2024 10:28:55 +0200 Subject: [PATCH 42/50] CLN remove plot_quadratics.py --- figures/plot_quadratics.py | 201 ------------------------------------- 1 file changed, 201 deletions(-) delete mode 100644 figures/plot_quadratics.py diff --git a/figures/plot_quadratics.py b/figures/plot_quadratics.py deleted file mode 100644 index b4b640c..0000000 --- a/figures/plot_quadratics.py +++ /dev/null @@ -1,201 +0,0 @@ -from pathlib import Path - -import numpy as np -import pandas as pd -import matplotlib as mpl -import matplotlib.pyplot as plt - -mpl.rc('text', usetex=True) - -FILE_NAME = Path(__file__).with_suffix('') -METRIC = 'objective_value' - -# DEFAULT_WIDTH = 3.25 -DEFAULT_WIDTH = 3 -DEFAULT_HEIGHT = 2 -LEGEND_RATIO = 0.1 - -N_POINTS = 500 -X_LIM = 250 - -# Utils to get common STYLES object and setup matplotlib -# for all plots - -mpl.rcParams.update({ - 'font.size': 10, - 'legend.fontsize': 'small', - 'axes.labelsize': 'small', - 'xtick.labelsize': 'small', - 'ytick.labelsize': 'small' -}) - -STYLES = { - '*': dict(lw=1.5), - - 'amigo': dict(color='#5778a4', label=r'AmIGO'), - 'mrbo': dict(color='#e49444', label=r'MRBO'), - 'vrbo': dict(color='#e7ca60', label=r'VRBO'), - 'saba': dict(color='#d1615d', label=r'SABA'), - 'stocbio': dict(color='#85b6b2', label=r'StocBiO'), - 'srba': dict(color='#6a9f58', label=r'\textbf{SRBA}', lw=2), - 'f2sa': dict(color='#bcbd22', label=r'F2SA'), -} - - -def get_param(name, param='period_frac'): - params = {} - for vals in name.split("[", maxsplit=1)[1][:-1].split(","): - k, v = vals.split("=") - if v.replace(".", "").isnumeric(): - params[k] = float(v) - else: - params[k] = v - return params[param] - - -def drop_param(name, param='period_frac'): - new_name = name.split("[", maxsplit=1)[0] + '[' - for vals in name.split("[", maxsplit=1)[1][:-1].split(","): - k, v = vals.split("=") - if k != param: - new_name += f'{k}={v},' - return new_name[:-1] + ']' - - -if __name__ == "__main__": - fname = "quadratic.parquet" - fname = FILE_NAME.parent / fname - - if Path(f'{fname.stem}_stable.parquet').is_file(): - df = pd.read_parquet(f'{fname.stem}_stable.parquet') - print(f'{fname.stem}_stable.parquet') - else: - df = pd.read_parquet(fname) - print(fname) - - # normalize names - df['solver'] = df['solver_name'].apply( - lambda x: x.split('[')[0].lower() - ) - df['seed_solver'] = df['solver_name'].apply( - lambda x: get_param(x, 'random_state') - ) - df['seed_data'] = df['data_name'].apply( - lambda x: get_param(x, 'random_state') - ) - - df['solver_name'] = df['solver_name'].apply( - lambda x: drop_param(x, 'random_state') - ) - df['data_name'] = df['data_name'].apply( - lambda x: drop_param(x, 'random_state') - ) - df['cond'] = df['data_name'].apply( - lambda x: get_param(x, 'L_inner_inner')/get_param(x, 'mu_inner') - ) - df['n_inner'] = df['data_name'].apply( - lambda x: get_param(x, 'n_samples_inner') - ) - df['n_outer'] = df['data_name'].apply( - lambda x: get_param(x, 'n_samples_outer') - ) - df['n_tot'] = df['n_inner'] + df['n_outer'] - - # keep only runs all the random seeds - df['full'] = False - n_seeds = df.groupby('solver_name')['seed_data'].nunique() - n_seeds *= df.groupby('solver_name')['seed_solver'].nunique() - for s in n_seeds.index: - if n_seeds[s] == 10: - df.loc[df['solver_name'] == s, 'full'] = True - df = df.query('full == True') - df.to_parquet(f'{fname.stem}_stable.parquet') - - fig = plt.figure( - figsize=(DEFAULT_WIDTH, DEFAULT_HEIGHT * (1 + LEGEND_RATIO)) - ) - - gs = plt.GridSpec( - len(df['n_tot'].unique()), len(df['cond'].unique()), - height_ratios=[1] * len(df['n_tot'].unique()), - width_ratios=[1] * len(df['cond'].unique()), - hspace=0.5, wspace=0.3 - ) - - lines = [] - for i, n_tot in enumerate(df['n_tot'].unique()): - for j, cond in enumerate(df['cond'].unique()): - df_pb = df.query("cond == @cond & n_tot == @n_tot") - print(f"Cond: {cond}, n: {df_pb['n_inner'].iloc[0]}, " - + f"m: {df_pb['n_outer'].iloc[0]}") - to_plot = ( - df.query("cond == @cond & n_tot == @n_tot & stop_val <= 100") - .groupby(['solver', 'solver_name', 'data_name', 'stop_val']) - .median(METRIC) - .reset_index().sort_values(METRIC) - .groupby('solver').first()[['solver_name']] - ) - ( - df.query("solver_name in @to_plot.values.ravel()") - .to_parquet(f'{fname.stem}_best_params.parquet') - ) - print("Chosen parameters:") - for s in to_plot['solver_name']: - print(f"- {s}") - ax = fig.add_subplot(gs[i, j]) - for solver_name in to_plot['solver_name']: - df_solver = df_pb.query("solver_name == @solver_name") - solver = df_solver['solver'].iloc[0] - style = STYLES['*'].copy() - style.update(STYLES[solver]) - curves = [data[['time', METRIC]].values - for _, data in df_solver.groupby(['seed_data', - 'seed_solver'])] - vals = [c[:, 1] for c in curves] - times = [c[:, 0] for c in curves] - tmin = np.min([np.min(t) for t in times]) - tmax = np.max([np.max(t) for t in times]) - time_grid = np.linspace(np.log(tmin), np.log(tmax + 1), - N_POINTS) - interp_vals = np.zeros((len(times), N_POINTS)) - for k, (t, val) in enumerate(zip(times, vals)): - interp_vals[k] = np.exp(np.interp(time_grid, np.log(t), - np.log(val))) - time_grid = np.exp(time_grid) - medval = np.quantile(interp_vals, .5, axis=0) - q1 = np.quantile(interp_vals, .2, axis=0) - q2 = np.quantile(interp_vals, .8, axis=0) - if i == 0 and j == 0: - lines.append(ax.semilogy( - time_grid, np.sqrt(medval), - **style - )[0]) - else: - ax.semilogy( - time_grid, np.sqrt(medval), - **style - ) - ax.fill_between( - time_grid, - np.sqrt(q1), - np.sqrt(q2), - color=style['color'], alpha=0.3 - ) - ax.set_xlabel('Time (s)') - ax.set_ylabel(r'$\|\nabla h(x^t)\|$') - print(f"Min score ({solver}):", df_solver[METRIC].min()) - ax.grid() - ax.set_xlim([0, X_LIM]) - - if i == 0 and j == 0: - ax_legend = ax.legend( - handles=lines, - ncol=2, - prop={'size': 6.5} - ) - print(f"Saving {fname.with_suffix('.pdf')}") - fig.savefig( - fname.with_suffix('.pdf'), - bbox_inches='tight', - bbox_extra_artists=[ax_legend] - ) From 302d8afc5ea031ae862805333f1d5a67818e739e Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Wed, 23 Oct 2024 11:18:18 +0200 Subject: [PATCH 43/50] ENH rm generate_matrices --- datasets/template_dataset.py | 49 ++---------------------------------- 1 file changed, 2 insertions(+), 47 deletions(-) diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index 83d851a..3cf15b6 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -12,43 +12,6 @@ from jaxopt import LBFGS # useful to define the value function -def generate_matrices(dim_inner, dim_outer, key=jax.random.PRNGKey(0)): - """Generates the different matrices of the inner and outer quadratic - functions.""" - keys = jax.random.split(key, 3) - eig_inner = jnp.logpsace(-1, 0, dim_inner) - sing_cross = jnp.logpsace(-1, 0, min(dim_inner, dim_outer)) - - # Matrix generation for the inner function - # Generate a PSD matrix with eigenvalues `eig_inner` - hess_inner = jax.random.normal(keys[0], (dim_inner, dim_inner)) - U, _, _ = jnp.linalg.svd(hess_inner) - hess_inner = U @ jnp.diag(eig_inner) @ U.T - - # Generate a matrix with singular values `sing_cross` - cross_inner = jax.random.normal(keys[1], (dim_outer, dim_inner)) - D = jnp.zeros((dim_outer, dim_inner)) - D = D.at[:min(dim_outer, dim_inner), :min(dim_outer, dim_inner)].set( - jnp.diag(sing_cross) - ) - U, _, V = jnp.linalg.svd(cross_inner) - cross_inner = U @ D @ V.T - - hess_outer = jax.random.normal(keys[2], (dim_inner, dim_inner)) - U, _, _ = jnp.linalg.svd(hess_outer) - hess_outer = U @ jnp.diag(eig_inner) @ U.T - - return hess_inner, cross_inner, hess_outer - - -def quadratic(inner_var, outer_var, hess_inner, cross): - """Defines a quadratic function for given hessian and cross - derivative matrices.""" - res = .5 * inner_var @ (hess_inner @ inner_var) - res += outer_var @ cross @ inner_var - return res - - # All datasets must be named `Dataset` and inherit from `BaseDataset` class Dataset(BaseDataset): """How to add a new problem to the benchmark? @@ -98,12 +61,6 @@ def get_data(self): Function that initializes the inner and outer variables. """ - hess_inner, cross, hess_outer = ( - generate_matrices( - self.dim_inner, self.dim_outer - ) - ) - # This decorator is used to jit the inner and the outer objective. # static_argnames=('batch_size') means that the function is recompiled # each time it is used with a new batch size. @@ -132,14 +89,12 @@ def f_inner(inner_var, outer_var, start=0, batch_size=1): float Value of the inner objective function. """ - return quadratic(inner_var, outer_var, - hess_inner, cross) + return inner_var ** 2 + 2 * inner_var * outer_var # This is similar to f_inner @partial(jax.jit, static_argnames=('batch_size')) def f_outer(inner_var, outer_var, start=0, batch_size=1): - return quadratic(inner_var, outer_var, hess_outer, - jnp.zeros_like(cross)) + return inner_var ** 2 # For stochastic problems, it is useful to define the full batch # version of f_inner and f_outer, for instance to compute metrics From a7769ec3c2d263eb76cc3af653a0b7ec6764f0e4 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Wed, 23 Oct 2024 11:21:02 +0200 Subject: [PATCH 44/50] CLN docstring --- datasets/template_dataset.py | 51 +++++++++++++----------------------- 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index 3cf15b6..f5b7e89 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -23,14 +23,6 @@ class Dataset(BaseDataset): # Name to select the dataset in the CLI and to display the results. name = "Template dataset" - # List of parameters to generate the datasets. The benchmark will consider - # the cross product for each key in the dictionary. - # Any parameters 'param' defined here is available as `self.param`. - parameters = { - 'dim_inner': [10], - 'dim_outer': [10], - } - def get_data(self): """This method retrieves/simulates the data, defines the inner and outer objectives and the metrics to evaluate the results. It is @@ -41,24 +33,22 @@ def get_data(self): ------- data: dict A dictionary containing the keys `pb_inner`, `pb_outer`, `metrics` - and optionnally `init_var`. - - The entries of the dictionary are: - - `pb_inner`: tuple - Contains the inner function, the number of inner samples, the - dimension of the inner variable and the full batch version of the - inner objective. - - - `pb_outer`: tuple - Contains the outer function, the number of outer samples, the - dimension of the outer variable and the full batch version of the - outer objective. - - - `metrics`: function - Function that computes the metrics of the problem. - - - `init_var`: function, optional - Function that initializes the inner and outer variables. + and optionnally `init_var`. The entries of the dictionary are: + - `pb_inner`: tuple + Contains the inner function, the number of inner samples, the + dimension of the inner variable and the full batch version of the + inner objective. + + - `pb_outer`: tuple + Contains the outer function, the number of outer samples, the + dimension of the outer variable and the full batch version of the + outer objective. + + - `metrics`: function + Function that computes the metrics of the problem. + + - `init_var`: function, optional + Function that initializes the inner and outer variables. """ # This decorator is used to jit the inner and the outer objective. @@ -89,7 +79,7 @@ def f_inner(inner_var, outer_var, start=0, batch_size=1): float Value of the inner objective function. """ - return inner_var ** 2 + 2 * inner_var * outer_var + return .5 * inner_var ** 2 + inner_var * outer_var # This is similar to f_inner @partial(jax.jit, static_argnames=('batch_size')) @@ -103,16 +93,11 @@ def f_outer(inner_var, outer_var, start=0, batch_size=1): f_inner_fb = f_inner f_outer_fb = f_outer - solver_inner = LBFGS(fun=f_inner_fb) - # The value function is useful for the metrics. Note that it is not # mandatory to define it. In particular, for large scale problems, # evaluating it can be cumbersome. def value_function(outer_var): - inner_var_star = solver_inner.run( - jnp.zeros(self.dim_inner), outer_var - ).params - + inner_var_star = - outer_var return f_outer_fb(inner_var_star, outer_var), inner_var_star value_and_grad = jax.jit( From 7a881a856411ac021790e9825b1cbdb90476a15f Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Wed, 23 Oct 2024 11:24:01 +0200 Subject: [PATCH 45/50] FIX flake8 --- datasets/template_dataset.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/datasets/template_dataset.py b/datasets/template_dataset.py index f5b7e89..6bd2371 100644 --- a/datasets/template_dataset.py +++ b/datasets/template_dataset.py @@ -9,8 +9,6 @@ import jax.numpy as jnp from functools import partial # useful for just-in-time compilation - from jaxopt import LBFGS # useful to define the value function - # All datasets must be named `Dataset` and inherit from `BaseDataset` class Dataset(BaseDataset): @@ -35,14 +33,14 @@ def get_data(self): A dictionary containing the keys `pb_inner`, `pb_outer`, `metrics` and optionnally `init_var`. The entries of the dictionary are: - `pb_inner`: tuple - Contains the inner function, the number of inner samples, the - dimension of the inner variable and the full batch version of the - inner objective. + Contains the inner function, the number of inner samples, + the dimension of the inner variable and the full batch + version of the inner objective. - `pb_outer`: tuple - Contains the outer function, the number of outer samples, the - dimension of the outer variable and the full batch version of the - outer objective. + Contains the outer function, the number of outer samples, + the dimension of the outer variable and the full batch + version of the outer objective. - `metrics`: function Function that computes the metrics of the problem. @@ -69,10 +67,13 @@ def f_inner(inner_var, outer_var, start=0, batch_size=1): start: int, default=0 For stochastic problems, index of the first sample of the - batch. + batch. Note that for this specific instance, it is not used + since the problem is deterministic. batch_size: int, default=1 - For stochastic problems, size of the batch. + For stochastic problems, size of the batch. Note that for this + specific instance, it is not used since the problem is + deterministic. Returns ------- From 9b4c556f4e69eb41266bc4f38156d30f6b24c0c1 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Thu, 24 Oct 2024 11:01:49 +0200 Subject: [PATCH 46/50] ENH callback info template dataset --- solvers/template_solver.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/solvers/template_solver.py b/solvers/template_solver.py index e57af64..eb4226a 100644 --- a/solvers/template_solver.py +++ b/solvers/template_solver.py @@ -6,6 +6,8 @@ with safe_import_context() as import_ctx: # import your reusable functions here from benchmark_utils import constants + + # step size scheduler from benchmark_utils.learning_rate_scheduler import update_lr from benchmark_utils.learning_rate_scheduler import init_lr_scheduler @@ -17,14 +19,6 @@ class Solver(BaseSolver): - """Gradient descent with JAXopt solvers. - - M. Blondel, Q. Berthet, M. Cuturi, R. Frosting, S. Hoyer, F. - Llinares-Lopez, F. Pedregosa and J.-P. Vert. "Efficient and Modular - Implicit Differentiation". NeurIPS 2022""" - # Name to select the solver in the CLI and to display the results. - name = 'jaxopt_GD' - """How to add a new solver to the benchmark? This template solver is an adaptation of the solver from the benchopt @@ -32,6 +26,8 @@ class Solver(BaseSolver): the bilevel setting. Other explanations can be found in https://benchopt.github.io/tutorials/add_solver.html. """ + # Name to select the solver in the CLI and to display the results. + name = 'Template solver' # List of packages needed to run the solver. requirements = ["pip:jaxopt"] @@ -73,7 +69,7 @@ def set_objective(self, f_inner, f_outer, n_inner_samples, n_outer_samples, # The value function is defined for this specific solver, but it is # not mandatory in general. def value_fun(inner_var, outer_var): - """Solver used to solve the inner problem. + """Value function for the bilevel optimization problem. The output of this function is differentiable w.r.t. the outer_variable. The Jacobian is computed using implicit @@ -82,6 +78,8 @@ def value_fun(inner_var, outer_var): inner_var = inner_solver.run(inner_var, outer_var).params return self.f_outer(inner_var, outer_var), inner_var + # The value function and its gradient are jitted by JAX for fast oracle + # computations. self.value_grad = jax.jit(jax.value_and_grad( value_fun, argnums=1, has_aux=True )) @@ -107,13 +105,29 @@ def run(self, callback): [self.step_size_outer] ) exponents = jnp.zeros(1) + # The step size scheduler provides step sizes that have the form + # `a / t**b`, where t is the iteration number, `a` is a constant + # that comes from the first argument of init_lr_scheduler and `b` is + # the exponent that comes from the second argument of + # init_lr_scheduler. state_lr = init_lr_scheduler(step_sizes, exponents) + # The function `callback` calls `Objective.get_result` to check if the + # stopping criterion is met. More informations on the callback + # asmpling strategy can be found here: + # https://benchopt.github.io/user_guide/performance_curves.html#using-a + # -callback while callback(): + # update_lr provides a step_size and a new state of the scheduler. outer_lr, state_lr = update_lr(state_lr) + + # Compute the value and the gradient of the value function. Also + # provide the inner solution. (_, self.inner_var), implicit_grad = self.value_grad( self.inner_var, self.outer_var ) + + # Update the outer variable by a gradient step. self.outer_var -= outer_lr * implicit_grad def get_result(self): From 510c81228c4fe39c04f9dba982531e829cd7d07e Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Thu, 24 Oct 2024 11:13:43 +0200 Subject: [PATCH 47/50] ENH lr_scheduler template_stochastic_solver --- solvers/template_stochastic_solver.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/solvers/template_stochastic_solver.py b/solvers/template_stochastic_solver.py index 86c3ad7..8697d6b 100644 --- a/solvers/template_stochastic_solver.py +++ b/solvers/template_stochastic_solver.py @@ -81,6 +81,12 @@ def init(self): exponents = jnp.array( [.5, .5] ) + + # The step size scheduler provides step sizes that have the form + # `a / t**b`, where t is the iteration number, `a` is a constant + # that comes from the first argument of init_lr_scheduler and `b` is + # the exponent that comes from the second argument of + # init_lr_scheduler. state_lr = init_lr_scheduler(step_sizes, exponents) return dict( inner_var=self.inner_var, outer_var=self.outer_var, v=v, From e918f5a7c705ef62865b287999b11bee1414d940 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Thu, 21 Nov 2024 16:57:45 +0100 Subject: [PATCH 48/50] ENH add comments oracles --- solvers/template_stochastic_solver.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/solvers/template_stochastic_solver.py b/solvers/template_stochastic_solver.py index 8697d6b..fffbd75 100644 --- a/solvers/template_stochastic_solver.py +++ b/solvers/template_stochastic_solver.py @@ -25,9 +25,10 @@ class Solver(StochasticJaxSolver): """How to add a new stochastic solver to the benchmark? Stochastic solvers are Solver classes that inherit from the - `StochasticJaxSolver` class. They should implement the `init` and the + `StochasticJaxSolver` class which aims at making easy the JIT compilation + of the code. Stochastic solvers should implement the `init` and the `get_step_methods` and the class variable `parameters`. One epoch of - StochasticJaxSolver corresponds to `eval_freq` outer iterations of the + `StochasticJaxSolver` corresponds to `eval_freq` outer iterations of the solver. The epochs of these solvers are jitted by JAX to get fast stochastic iterations. @@ -110,22 +111,35 @@ def soba_one_iter(carry, _): start_inner, *_, carry['state_inner_sampler'] = inner_sampler( carry['state_inner_sampler'] ) + + # The gradient of the inner function w.r.t. the inner variable + # and a function that takes as input a vector v and returns + # the product between the Hessian of the inner function w.r.t. grad_inner_var, vjp_train = jax.vjp( lambda z, x: grad_inner(z, x, start_inner), carry['inner_var'], carry['outer_var'] ) + + # Product between the Hessian of the inner function w.r.t. the + # inner variable and the vector v and product between the cross + # derivatives matrix and the vector v. hvp, cross_v = vjp_train(carry['v']) start_outer, *_, carry['state_outer_sampler'] = outer_sampler( carry['state_outer_sampler'] ) + + # Gradient of the outer function grad_in_outer, grad_out_outer = grad_outer( carry['inner_var'], carry['outer_var'], start_outer ) # Step.2 - update inner variable with SGD. + # Gradient step for the inner variable step carry['inner_var'] -= inner_step_size * grad_inner_var + # Gradient step for the linear system variable carry['v'] -= inner_step_size * (hvp + grad_in_outer) + # Approximate gradient step for the outer variable carry['outer_var'] -= outer_step_size * (cross_v + grad_out_outer) return carry, _ From f65755e45648fd4a0e63a33114105b174c3fb285 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Thu, 21 Nov 2024 17:02:23 +0100 Subject: [PATCH 49/50] ENH docstring init --- solvers/template_stochastic_solver.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/solvers/template_stochastic_solver.py b/solvers/template_stochastic_solver.py index fffbd75..a1a0819 100644 --- a/solvers/template_stochastic_solver.py +++ b/solvers/template_stochastic_solver.py @@ -70,6 +70,15 @@ class Solver(StochasticJaxSolver): } def init(self): + """ + Initializes the stochastic solver. + It returns a dictionary which is the initial state of the `carry` + dictionary in the `get_step` method. It contains at least `inner_var` + and `outer_var`. In this specific case, it also contains the initial + value of the linear system variable `v`, the initial state of the + lr scheduler, and the initial state of the samplers. + """ + # Init variables self.inner_var = self.inner_var0.copy() self.outer_var = self.outer_var0.copy() @@ -108,13 +117,15 @@ def soba_one_iter(carry, _): ) # Step.1 - get all gradients and compute the implicit gradient. + + # First index for the samples of the inner function start_inner, *_, carry['state_inner_sampler'] = inner_sampler( carry['state_inner_sampler'] ) # The gradient of the inner function w.r.t. the inner variable # and a function that takes as input a vector v and returns - # the product between the Hessian of the inner function w.r.t. + # the different Hessian-vector products with v. (c.f. l.128) grad_inner_var, vjp_train = jax.vjp( lambda z, x: grad_inner(z, x, start_inner), carry['inner_var'], carry['outer_var'] @@ -125,6 +136,7 @@ def soba_one_iter(carry, _): # derivatives matrix and the vector v. hvp, cross_v = vjp_train(carry['v']) + # First index for the samples of the outer function start_outer, *_, carry['state_outer_sampler'] = outer_sampler( carry['state_outer_sampler'] ) From 24c08fa60ad2a3b8582e98ecd1546b98357736a1 Mon Sep 17 00:00:00 2001 From: MatDag <mathieu.dagreou@inria.fr> Date: Thu, 21 Nov 2024 17:05:11 +0100 Subject: [PATCH 50/50] ENH docstring get_step --- solvers/template_stochastic_solver.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/solvers/template_stochastic_solver.py b/solvers/template_stochastic_solver.py index a1a0819..ddd4c7d 100644 --- a/solvers/template_stochastic_solver.py +++ b/solvers/template_stochastic_solver.py @@ -78,7 +78,7 @@ def init(self): value of the linear system variable `v`, the initial state of the lr scheduler, and the initial state of the samplers. """ - + # Init variables self.inner_var = self.inner_var0.copy() self.outer_var = self.outer_var0.copy() @@ -106,7 +106,26 @@ def init(self): ) def get_step(self, inner_sampler, outer_sampler): - + """Returns a function that compute one iteration of the stochastic + algorithm. + + Parameters + ---------- + inner_sampler: callable + Function that returns the initial index of a batch of samples for + the inner function and the update state of the sampler. + + outer_sampler: callable + Function that returns the initial index of a batch of samples for + the outer function and the update state of the sampler. + + Returns + ------- + soba_one_iter: callable + Function that computes one iteration of the SOBA algorithm. It + takes as input the carry dictionary and an unused argument and + returns the updated carry and the unused argument. + """ grad_inner = jax.grad(self.f_inner, argnums=0) grad_outer = jax.grad(self.f_outer, argnums=(0, 1))