diff --git a/doc/api.rst b/doc/api.rst index 4d003891e..d05ed8a57 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -70,6 +70,7 @@ Solvers :toctree: generated/ AndersonCD + FISTA GramCD GroupBCD MultiTaskBCD diff --git a/skglm/datafits/single_task.py b/skglm/datafits/single_task.py index 1f4d39503..948cd1454 100644 --- a/skglm/datafits/single_task.py +++ b/skglm/datafits/single_task.py @@ -4,6 +4,7 @@ from numba import float64 from skglm.datafits.base import BaseDatafit +from skglm.utils import spectral_norm class Quadratic(BaseDatafit): @@ -22,6 +23,10 @@ class Quadratic(BaseDatafit): The coordinatewise gradient Lipschitz constants. Equal to norm(X, axis=0) ** 2 / n_samples. + global_lipschitz : float + Global Lipschitz constant. Equal to + norm(X, ord=2) ** 2 / n_samples. + Note ---- The class is jit compiled at fit time using Numba compiler. @@ -35,6 +40,7 @@ def get_spec(self): spec = ( ('Xty', float64[:]), ('lipschitz', float64[:]), + ('global_lipschitz', float64), ) return spec @@ -44,14 +50,18 @@ def params_to_dict(self): def initialize(self, X, y): self.Xty = X.T @ y n_features = X.shape[1] + self.global_lipschitz = norm(X, ord=2) ** 2 / len(y) self.lipschitz = np.zeros(n_features, dtype=X.dtype) for j in range(n_features): self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y) - def initialize_sparse( - self, X_data, X_indptr, X_indices, y): + def initialize_sparse(self, X_data, X_indptr, X_indices, y): n_features = len(X_indptr) - 1 self.Xty = np.zeros(n_features, dtype=X_data.dtype) + + self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2 + self.global_lipschitz /= len(y) + self.lipschitz = np.zeros(n_features, dtype=X_data.dtype) for j in range(n_features): nrm2 = 0. @@ -111,6 +121,10 @@ class Logistic(BaseDatafit): The coordinatewise gradient Lipschitz constants. Equal to norm(X, axis=0) ** 2 / (4 * n_samples). + global_lipschitz : float + Global Lipschitz constant. Equal to + norm(X, ord=2) ** 2 / (4 * n_samples). + Note ---- The class is jit compiled at fit time using Numba compiler. @@ -123,6 +137,7 @@ def __init__(self): def get_spec(self): spec = ( ('lipschitz', float64[:]), + ('global_lipschitz', float64), ) return spec @@ -140,9 +155,14 @@ def raw_hessian(self, y, Xw): def initialize(self, X, y): self.lipschitz = (X ** 2).sum(axis=0) / (len(y) * 4) + self.global_lipschitz = norm(X, ord=2) ** 2 / (len(y) * 4) def initialize_sparse(self, X_data, X_indptr, X_indices, y): n_features = len(X_indptr) - 1 + + self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2 + self.global_lipschitz /= 4 * len(y) + self.lipschitz = np.zeros(n_features, dtype=X_data.dtype) for j in range(n_features): Xj = X_data[X_indptr[j]:X_indptr[j+1]] @@ -187,6 +207,11 @@ class QuadraticSVC(BaseDatafit): ---------- lipschitz : array, shape (n_features,) The coordinatewise gradient Lipschitz constants. + Equal to norm(yXT, axis=0) ** 2. + + global_lipschitz : float + Global Lipschitz constant. Equal to + norm(yXT, ord=2) ** 2. Note ---- @@ -200,6 +225,7 @@ def __init__(self): def get_spec(self): spec = ( ('lipschitz', float64[:]), + ('global_lipschitz', float64), ) return spec @@ -209,12 +235,16 @@ def params_to_dict(self): def initialize(self, yXT, y): n_features = yXT.shape[1] self.lipschitz = np.zeros(n_features, dtype=yXT.dtype) + self.global_lipschitz = norm(yXT, ord=2) ** 2 for j in range(n_features): self.lipschitz[j] = norm(yXT[:, j]) ** 2 - def initialize_sparse( - self, yXT_data, yXT_indptr, yXT_indices, y): + def initialize_sparse(self, yXT_data, yXT_indptr, yXT_indices, y): n_features = len(yXT_indptr) - 1 + + self.global_lipschitz = spectral_norm( + yXT_data, yXT_indptr, yXT_indices, max(yXT_indices)+1) ** 2 + self.lipschitz = np.zeros(n_features, dtype=yXT_data.dtype) for j in range(n_features): nrm2 = 0. @@ -264,8 +294,16 @@ class Huber(BaseDatafit): Attributes ---------- + delta : float + Threshold hyperparameter. + lipschitz : array, shape (n_features,) - The coordinatewise gradient Lipschitz constants. + The coordinatewise gradient Lipschitz constants. Equal to + norm(X, axis=0) ** 2 / n_samples. + + global_lipschitz : float + Global Lipschitz constant. Equal to + norm(X, ord=2) ** 2 / n_samples. Note ---- @@ -279,7 +317,8 @@ def __init__(self, delta): def get_spec(self): spec = ( ('delta', float64), - ('lipschitz', float64[:]) + ('lipschitz', float64[:]), + ('global_lipschitz', float64), ) return spec @@ -289,12 +328,17 @@ def params_to_dict(self): def initialize(self, X, y): n_features = X.shape[1] self.lipschitz = np.zeros(n_features, dtype=X.dtype) + self.global_lipschitz = 0. for j in range(n_features): self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y) + self.global_lipschitz += (X[:, j] ** 2).sum() / len(y) - def initialize_sparse( - self, X_data, X_indptr, X_indices, y): + def initialize_sparse(self, X_data, X_indptr, X_indices, y): n_features = len(X_indptr) - 1 + + self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2 + self.global_lipschitz /= len(y) + self.lipschitz = np.zeros(n_features, dtype=X_data.dtype) for j in range(n_features): nrm2 = 0. diff --git a/skglm/solvers/__init__.py b/skglm/solvers/__init__.py index 0f8016f40..d86eb275a 100644 --- a/skglm/solvers/__init__.py +++ b/skglm/solvers/__init__.py @@ -1,9 +1,10 @@ from .anderson_cd import AndersonCD from .base import BaseSolver +from .fista import FISTA from .gram_cd import GramCD from .group_bcd import GroupBCD from .multitask_bcd import MultiTaskBCD from .prox_newton import ProxNewton -__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton] +__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton] diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py new file mode 100644 index 000000000..bbdc451b7 --- /dev/null +++ b/skglm/solvers/fista.py @@ -0,0 +1,82 @@ +import numpy as np +from scipy.sparse import issparse +from skglm.solvers.base import BaseSolver +from skglm.solvers.common import construct_grad, construct_grad_sparse +from skglm.utils import _prox_vec + + +class FISTA(BaseSolver): + r"""ISTA solver with Nesterov acceleration (FISTA). + + Attributes + ---------- + max_iter : int, default 100 + Maximum number of iterations. + + tol : float, default 1e-4 + Tolerance for convergence. + + verbose : bool, default False + Amount of verbosity. 0/False is silent. + + References + ---------- + .. [1] Beck, A. and Teboulle M. + "A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse + problems", 2009, SIAM J. Imaging Sci. + https://epubs.siam.org/doi/10.1137/080716542 + """ + + def __init__(self, max_iter=100, tol=1e-4, verbose=0): + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + + def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + p_objs_out = [] + n_samples, n_features = X.shape + all_features = np.arange(n_features) + t_new = 1. + + w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) + Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples) + + if hasattr(datafit, "global_lipschitz"): + lipschitz = datafit.global_lipschitz + else: + # TODO: OR line search + raise Exception("Line search is not yet implemented for FISTA solver.") + + for n_iter in range(self.max_iter): + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + if issparse(X): + grad = construct_grad_sparse( + X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features) + else: + grad = construct_grad(X, y, z, X @ z, datafit, all_features) + + step = 1 / lipschitz + z -= step * grad + w = _prox_vec(w, z, penalty, step) + Xw = X @ w + z = w + (t_old - 1.) / t_new * (w - w_old) + + opt = penalty.subdiff_distance(w, grad, all_features) + stop_crit = np.max(opt) + + p_obj = datafit.value(y, w, Xw) + penalty.value(w) + p_objs_out.append(p_obj) + if self.verbose: + print( + f"Iteration {n_iter+1}: {p_obj:.10f}, " + f"stopping crit: {stop_crit:.2e}" + ) + + if stop_crit < self.tol: + if self.verbose: + print(f"Stopping criterion max violation: {stop_crit:.2e}") + break + return w, np.array(p_objs_out), stop_crit diff --git a/skglm/tests/test_fista.py b/skglm/tests/test_fista.py new file mode 100644 index 000000000..559c5d9ca --- /dev/null +++ b/skglm/tests/test_fista.py @@ -0,0 +1,69 @@ +import pytest + +import numpy as np +from numpy.linalg import norm + +import scipy.sparse +import scipy.sparse.linalg +from scipy.sparse import csc_matrix, issparse + +from skglm.penalties import L1, IndicatorBox +from skglm.solvers import FISTA, AndersonCD +from skglm.datafits import Quadratic, Logistic, QuadraticSVC +from skglm.utils import make_correlated_data, compiled_clone, spectral_norm + + +random_state = 113 +n_samples, n_features = 50, 60 + +rng = np.random.RandomState(random_state) +X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng) +rng.seed(random_state) +X_sparse = csc_matrix(X * np.random.binomial(1, 0.5, X.shape)) +y_classif = np.sign(y) + +alpha_max = norm(X.T @ y, ord=np.inf) / len(y) +alpha = alpha_max / 10 + +tol = 1e-10 + + +@pytest.mark.parametrize("X", [X, X_sparse]) +@pytest.mark.parametrize("Datafit, Penalty", [ + (Quadratic, L1), + (Logistic, L1), + (QuadraticSVC, IndicatorBox), +]) +def test_fista_solver(X, Datafit, Penalty): + _y = y if isinstance(Datafit, Quadratic) else y_classif + datafit = compiled_clone(Datafit()) + _init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X + if issparse(X): + datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y) + else: + datafit.initialize(_init, _y) + penalty = compiled_clone(Penalty(alpha)) + + solver = FISTA(max_iter=1000, tol=tol) + w_fista = solver.solve(X, _y, datafit, penalty)[0] + + solver_cd = AndersonCD(tol=tol, fit_intercept=False) + w_cd = solver_cd.solve(X, _y, datafit, penalty)[0] + + np.testing.assert_allclose(w_fista, w_cd, atol=1e-7) + + +def test_spectral_norm(): + n_samples, n_features = 50, 60 + A_sparse = scipy.sparse.random(n_samples, n_features, density=0.7, format='csc', + random_state=random_state) + + A_bundles = (A_sparse.data, A_sparse.indptr, A_sparse.indices) + spectral_norm_our = spectral_norm(*A_bundles, n_samples=len(y)) + spectral_norm_sp = scipy.sparse.linalg.svds(A_sparse, k=1)[1] + + np.testing.assert_allclose(spectral_norm_our, spectral_norm_sp) + + +if __name__ == '__main__': + pass diff --git a/skglm/utils.py b/skglm/utils.py index 55c2fa462..2932c1822 100644 --- a/skglm/utils.py +++ b/skglm/utils.py @@ -457,3 +457,107 @@ def extrapolate(self, w, Xw): C = inv_UTU_ones / np.sum(inv_UTU_ones) # floating point errors may cause w and Xw to disagree return self.arr_w_[:, 1:] @ C, self.arr_Xw_[:, 1:] @ C, True + + +@njit +def _prox_vec(w, z, penalty, step): + # evaluate the full proximal operator + n_features = w.shape[0] + for j in range(n_features): + w[j] = penalty.prox_1d(z[j], step, j) + return w + + +@njit +def spectral_norm(X_data, X_indptr, X_indices, n_samples, + max_iter=100, tol=1e-6): + """Compute the spectral norm of sparse matrix ``X`` with power method. + + Parameters + ---------- + X_data : array, shape (n_elements,) + ``data`` attribute of the sparse CSC matrix ``X``. + + X_indptr : array, shape (n_features + 1,) + ``indptr`` attribute of the sparse CSC matrix ``X``. + + X_indices : array, shape (n_elements,) + ``indices`` attribute of the sparse CSC matrix ``X``. + + n_samples : int + number of rows of ``X``. + + max_iter : int, default 20 + Maximum number of power method iterations. + + tol : float, default 1e-6 + Tolerance for convergence. + + Returns + ------- + eigenvalue : float + The largest singular value of ``X``. + + References + ---------- + .. [1] Alfio Quarteroni, Riccardo Sacco, Fausto Saleri "Numerical Mathematics", + chapter 5, page 192-195. + """ + # init vec with norm(vec) == 1. + eigenvector = np.random.randn(n_samples) + eigenvector /= norm(eigenvector) + eigenvalue = 1. + + for _ in range(max_iter): + vec = _XXT_dot_vec(X_data, X_indptr, X_indices, eigenvector, n_samples) + norm_vec = norm(vec) + eigenvalue = vec @ eigenvector + + # norm(X @ X.T @ eigenvector - eigenvalue * eigenvector) <= tol + # inequality (5.25) in ref [1] is squared + if norm_vec ** 2 - eigenvalue ** 2 <= tol ** 2: + break + + eigenvector = vec / norm_vec + + return np.sqrt(eigenvalue) + + +@njit +def _XXT_dot_vec(X_data, X_indptr, X_indices, vec, n_samples): + # computes X @ X.T @ vec, with X csc encoded + return _X_dot_vec(X_data, X_indptr, X_indices, + _XT_dot_vec(X_data, X_indptr, X_indices, vec), n_samples) + + +@njit +def _X_dot_vec(X_data, X_indptr, X_indices, vec, n_samples): + # compute X @ vec, with X csc encoded + result = np.zeros(n_samples) + + # loop over features + for j in range(len(X_indptr) - 1): + if vec[j] == 0: + continue + + col_j_rows_idx = slice(X_indptr[j], X_indptr[j+1]) + result[X_indices[col_j_rows_idx]] += vec[j] * X_data[col_j_rows_idx] + + return result + + +@njit +def _XT_dot_vec(X_data, X_indptr, X_indices, vec): + # compute X.T @ vec, with X csc encoded + n_features = len(X_indptr) - 1 + result = np.zeros(n_features) + + for j in range(n_features): + for idx in range(X_indptr[j], X_indptr[j+1]): + result[j] += X_data[idx] * vec[X_indices[idx]] + + return result + + +if __name__ == '__main__': + pass