-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into solver_dispatcher
- Loading branch information
Showing
19 changed files
with
650 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,6 +69,7 @@ Datafits | |
Quadratic | ||
QuadraticGroup | ||
QuadraticSVC | ||
WeightedQuadratic | ||
|
||
|
||
Solvers | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
.. _alpha_max: | ||
|
||
========================================================== | ||
Critical regularization strength above which solution is 0 | ||
========================================================== | ||
|
||
This tutorial shows that for :math:`\lambda \geq \lambda_{\text{max}} = || \nabla f(0) ||_{\infty}`, the solution to | ||
:math:`\min f(x) + \lambda || x ||_1` is 0. | ||
|
||
In skglm, we thus frequently use | ||
|
||
.. code-block:: | ||
alpha_max = np.max(np.abs(gradient0)) | ||
and choose for the regularization strength :\math:`\alpha` a fraction of this critical value, e.g. ``alpha = 0.01 * alpha_max``. | ||
|
||
Problem setup | ||
============= | ||
|
||
Consider the optimization problem: | ||
|
||
.. math:: | ||
\min_x f(x) + \lambda || x||_1 | ||
where: | ||
|
||
- :math:`f: \mathbb{R}^d \to \mathbb{R}` is a convex differentiable function, | ||
- :math:`|| x ||_1` is the L1 norm of :math:`x`, | ||
- :math:`\lambda > 0` is the regularization parameter. | ||
|
||
We aim to determine the conditions under which the solution to this problem is :math:`x = 0`. | ||
|
||
Theoretical background | ||
====================== | ||
|
||
|
||
Let | ||
|
||
.. math:: | ||
g(x) = f(x) + \lambda || x||_1 | ||
According to Fermat's rule, 0 is the minimizer of :math:`g` if and only if 0 is in the subdifferential of :math:`g` at 0. | ||
The subdifferential of :math:`|| x ||_1` at 0 is the L-infinity unit ball: | ||
|
||
.. math:: | ||
\partial || \cdot ||_1 (0) = \{ u \in \mathbb{R}^d : ||u||_{\infty} \leq 1 \} | ||
Thus, | ||
|
||
.. math:: | ||
:nowrap: | ||
\begin{equation} | ||
\begin{aligned} | ||
0 \in \text{argmin} ~ g(x) | ||
&\Leftrightarrow 0 \in \partial g(0) \\ | ||
&\Leftrightarrow | ||
0 \in \nabla f(0) + \lambda \partial || \cdot ||_1 (0) \\ | ||
&\Leftrightarrow - \nabla f(0) \in \lambda \{ u \in \mathbb{R}^d : ||u||_{\infty} \leq 1 \} \\ | ||
&\Leftrightarrow || \nabla f(0) ||_\infty \leq \lambda | ||
\end{aligned} | ||
\end{equation} | ||
We have just shown that the minimizer of :math:`g = f + \lambda || \cdot ||_1` is 0 if and only if :math:`\lambda \geq ||\nabla f(0)||_{\infty}`. | ||
|
||
Example | ||
======= | ||
|
||
Consider the loss function for Ordinary Least Squares :math:`f(x) = \frac{1}{2n} ||Ax - b||_2^2`, where :math:`n` is the number of samples. We have: | ||
|
||
.. math:: | ||
\nabla f(x) = \frac{1}{n}A^T (Ax - b) | ||
At :math:`x=0`: | ||
|
||
.. math:: | ||
\nabla f(0) = -\frac{1}{n}A^T b | ||
The infinity norm of the gradient at 0 is: | ||
|
||
.. math:: | ||
||\nabla f(0)||_{\infty} = \frac{1}{n}||A^T b||_{\infty} | ||
For :math:`\lambda \geq \frac{1}{n}||A^T b||_{\infty}`, the solution to :math:`\min_x \frac{1}{2n} ||Ax - b||_2^2 + \lambda || x||_1` is :math:`x=0`. | ||
|
||
|
||
|
||
References | ||
========== | ||
|
||
Refer to Section 3.1 and Proposition 4 in particular of [1] for more details. | ||
|
||
.. _1: | ||
|
||
[1] Eugene Ndiaye, Olivier Fercoq, Alexandre Gramfort, and Joseph Salmon. 2017. Gap safe screening rules for sparsity enforcing penalties. J. Mach. Learn. Res. 18, 1 (January 2017), 4671–4703. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
================================= | ||
Fast Sparse Group Lasso in python | ||
================================= | ||
Scikit-learn is missing a Sparse Group Lasso regression estimator. We show how to | ||
implement one with ``skglm``. | ||
""" | ||
|
||
# Author: Mathurin Massias | ||
|
||
# %% | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from skglm.solvers import GroupBCD | ||
from skglm.datafits import QuadraticGroup | ||
from skglm import GeneralizedLinearEstimator | ||
from skglm.penalties import WeightedL1GroupL2 | ||
from skglm.utils.data import make_correlated_data, grp_converter | ||
|
||
n_features = 30 | ||
X, y, _ = make_correlated_data( | ||
n_samples=10, n_features=30, random_state=0) | ||
|
||
|
||
# %% | ||
# Model creation: combination of penalty, datafit and solver. | ||
# | ||
# penalty: | ||
grp_size = 10 # take groups of 10 consecutive features | ||
n_groups = n_features // grp_size | ||
grp_indices, grp_ptr = grp_converter(grp_size, n_features) | ||
n_groups = len(grp_ptr) - 1 | ||
weights_g = np.ones(n_groups, dtype=np.float64) | ||
weights_f = 0.5 * np.ones(n_features) | ||
penalty = WeightedL1GroupL2( | ||
alpha=0.5, weights_groups=weights_g, | ||
weights_features=weights_f, grp_indices=grp_indices, grp_ptr=grp_ptr) | ||
|
||
# %% Datafit and solver | ||
datafit = QuadraticGroup(grp_ptr, grp_indices) | ||
solver = GroupBCD(ws_strategy="fixpoint", verbose=1, fit_intercept=False, tol=1e-10) | ||
|
||
model = GeneralizedLinearEstimator(datafit, penalty, solver=solver) | ||
|
||
# %% | ||
# Train the model | ||
clf = GeneralizedLinearEstimator(datafit, penalty, solver) | ||
clf.fit(X, y) | ||
|
||
# %% | ||
# Some groups are fully 0, and inside non zero groups, | ||
# some values are 0 too | ||
plt.imshow(clf.coef_.reshape(-1, grp_size) != 0, cmap='Greys') | ||
plt.title("Non zero values (in black) in model coefficients") | ||
plt.ylabel('Group index') | ||
plt.xlabel('Feature index inside group') | ||
plt.xticks(np.arange(grp_size)) | ||
plt.yticks(np.arange(n_groups)); | ||
|
||
# %% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.