-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Badr MOUFAD <[email protected]>
- Loading branch information
1 parent
f9ee2e5
commit 359f4da
Showing
6 changed files
with
310 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,7 @@ Solvers | |
:toctree: generated/ | ||
|
||
AndersonCD | ||
FISTA | ||
GramCD | ||
GroupBCD | ||
MultiTaskBCD | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
from .anderson_cd import AndersonCD | ||
from .base import BaseSolver | ||
from .fista import FISTA | ||
from .gram_cd import GramCD | ||
from .group_bcd import GroupBCD | ||
from .multitask_bcd import MultiTaskBCD | ||
from .prox_newton import ProxNewton | ||
|
||
|
||
__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton] | ||
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import numpy as np | ||
from scipy.sparse import issparse | ||
from skglm.solvers.base import BaseSolver | ||
from skglm.solvers.common import construct_grad, construct_grad_sparse | ||
from skglm.utils import _prox_vec | ||
|
||
|
||
class FISTA(BaseSolver): | ||
r"""ISTA solver with Nesterov acceleration (FISTA). | ||
Attributes | ||
---------- | ||
max_iter : int, default 100 | ||
Maximum number of iterations. | ||
tol : float, default 1e-4 | ||
Tolerance for convergence. | ||
verbose : bool, default False | ||
Amount of verbosity. 0/False is silent. | ||
References | ||
---------- | ||
.. [1] Beck, A. and Teboulle M. | ||
"A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse | ||
problems", 2009, SIAM J. Imaging Sci. | ||
https://epubs.siam.org/doi/10.1137/080716542 | ||
""" | ||
|
||
def __init__(self, max_iter=100, tol=1e-4, verbose=0): | ||
self.max_iter = max_iter | ||
self.tol = tol | ||
self.verbose = verbose | ||
|
||
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): | ||
p_objs_out = [] | ||
n_samples, n_features = X.shape | ||
all_features = np.arange(n_features) | ||
t_new = 1. | ||
|
||
w = w_init.copy() if w_init is not None else np.zeros(n_features) | ||
z = w_init.copy() if w_init is not None else np.zeros(n_features) | ||
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples) | ||
|
||
if hasattr(datafit, "global_lipschitz"): | ||
lipschitz = datafit.global_lipschitz | ||
else: | ||
# TODO: OR line search | ||
raise Exception("Line search is not yet implemented for FISTA solver.") | ||
|
||
for n_iter in range(self.max_iter): | ||
t_old = t_new | ||
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 | ||
w_old = w.copy() | ||
if issparse(X): | ||
grad = construct_grad_sparse( | ||
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features) | ||
else: | ||
grad = construct_grad(X, y, z, X @ z, datafit, all_features) | ||
|
||
step = 1 / lipschitz | ||
z -= step * grad | ||
w = _prox_vec(w, z, penalty, step) | ||
Xw = X @ w | ||
z = w + (t_old - 1.) / t_new * (w - w_old) | ||
|
||
opt = penalty.subdiff_distance(w, grad, all_features) | ||
stop_crit = np.max(opt) | ||
|
||
p_obj = datafit.value(y, w, Xw) + penalty.value(w) | ||
p_objs_out.append(p_obj) | ||
if self.verbose: | ||
print( | ||
f"Iteration {n_iter+1}: {p_obj:.10f}, " | ||
f"stopping crit: {stop_crit:.2e}" | ||
) | ||
|
||
if stop_crit < self.tol: | ||
if self.verbose: | ||
print(f"Stopping criterion max violation: {stop_crit:.2e}") | ||
break | ||
return w, np.array(p_objs_out), stop_crit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import pytest | ||
|
||
import numpy as np | ||
from numpy.linalg import norm | ||
|
||
import scipy.sparse | ||
import scipy.sparse.linalg | ||
from scipy.sparse import csc_matrix, issparse | ||
|
||
from skglm.penalties import L1, IndicatorBox | ||
from skglm.solvers import FISTA, AndersonCD | ||
from skglm.datafits import Quadratic, Logistic, QuadraticSVC | ||
from skglm.utils import make_correlated_data, compiled_clone, spectral_norm | ||
|
||
|
||
random_state = 113 | ||
n_samples, n_features = 50, 60 | ||
|
||
rng = np.random.RandomState(random_state) | ||
X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng) | ||
rng.seed(random_state) | ||
X_sparse = csc_matrix(X * np.random.binomial(1, 0.5, X.shape)) | ||
y_classif = np.sign(y) | ||
|
||
alpha_max = norm(X.T @ y, ord=np.inf) / len(y) | ||
alpha = alpha_max / 10 | ||
|
||
tol = 1e-10 | ||
|
||
|
||
@pytest.mark.parametrize("X", [X, X_sparse]) | ||
@pytest.mark.parametrize("Datafit, Penalty", [ | ||
(Quadratic, L1), | ||
(Logistic, L1), | ||
(QuadraticSVC, IndicatorBox), | ||
]) | ||
def test_fista_solver(X, Datafit, Penalty): | ||
_y = y if isinstance(Datafit, Quadratic) else y_classif | ||
datafit = compiled_clone(Datafit()) | ||
_init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X | ||
if issparse(X): | ||
datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y) | ||
else: | ||
datafit.initialize(_init, _y) | ||
penalty = compiled_clone(Penalty(alpha)) | ||
|
||
solver = FISTA(max_iter=1000, tol=tol) | ||
w_fista = solver.solve(X, _y, datafit, penalty)[0] | ||
|
||
solver_cd = AndersonCD(tol=tol, fit_intercept=False) | ||
w_cd = solver_cd.solve(X, _y, datafit, penalty)[0] | ||
|
||
np.testing.assert_allclose(w_fista, w_cd, atol=1e-7) | ||
|
||
|
||
def test_spectral_norm(): | ||
n_samples, n_features = 50, 60 | ||
A_sparse = scipy.sparse.random(n_samples, n_features, density=0.7, format='csc', | ||
random_state=random_state) | ||
|
||
A_bundles = (A_sparse.data, A_sparse.indptr, A_sparse.indices) | ||
spectral_norm_our = spectral_norm(*A_bundles, n_samples=len(y)) | ||
spectral_norm_sp = scipy.sparse.linalg.svds(A_sparse, k=1)[1] | ||
|
||
np.testing.assert_allclose(spectral_norm_our, spectral_norm_sp) | ||
|
||
|
||
if __name__ == '__main__': | ||
pass |
Oops, something went wrong.