Skip to content

Commit

Permalink
ENH Add FISTA solver (#91)
Browse files Browse the repository at this point in the history
Co-authored-by: Badr MOUFAD <[email protected]>
  • Loading branch information
PABannier and Badr-MOUFAD authored Oct 22, 2022
1 parent f9ee2e5 commit 359f4da
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 9 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Solvers
:toctree: generated/

AndersonCD
FISTA
GramCD
GroupBCD
MultiTaskBCD
Expand Down
60 changes: 52 additions & 8 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numba import float64

from skglm.datafits.base import BaseDatafit
from skglm.utils import spectral_norm


class Quadratic(BaseDatafit):
Expand All @@ -22,6 +23,10 @@ class Quadratic(BaseDatafit):
The coordinatewise gradient Lipschitz constants. Equal to
norm(X, axis=0) ** 2 / n_samples.
global_lipschitz : float
Global Lipschitz constant. Equal to
norm(X, ord=2) ** 2 / n_samples.
Note
----
The class is jit compiled at fit time using Numba compiler.
Expand All @@ -35,6 +40,7 @@ def get_spec(self):
spec = (
('Xty', float64[:]),
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -44,14 +50,18 @@ def params_to_dict(self):
def initialize(self, X, y):
self.Xty = X.T @ y
n_features = X.shape[1]
self.global_lipschitz = norm(X, ord=2) ** 2 / len(y)
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
for j in range(n_features):
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)

def initialize_sparse(
self, X_data, X_indptr, X_indices, y):
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1
self.Xty = np.zeros(n_features, dtype=X_data.dtype)

self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
self.global_lipschitz /= len(y)

self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
for j in range(n_features):
nrm2 = 0.
Expand Down Expand Up @@ -111,6 +121,10 @@ class Logistic(BaseDatafit):
The coordinatewise gradient Lipschitz constants. Equal to
norm(X, axis=0) ** 2 / (4 * n_samples).
global_lipschitz : float
Global Lipschitz constant. Equal to
norm(X, ord=2) ** 2 / (4 * n_samples).
Note
----
The class is jit compiled at fit time using Numba compiler.
Expand All @@ -123,6 +137,7 @@ def __init__(self):
def get_spec(self):
spec = (
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -140,9 +155,14 @@ def raw_hessian(self, y, Xw):

def initialize(self, X, y):
self.lipschitz = (X ** 2).sum(axis=0) / (len(y) * 4)
self.global_lipschitz = norm(X, ord=2) ** 2 / (len(y) * 4)

def initialize_sparse(self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1

self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
self.global_lipschitz /= 4 * len(y)

self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
for j in range(n_features):
Xj = X_data[X_indptr[j]:X_indptr[j+1]]
Expand Down Expand Up @@ -187,6 +207,11 @@ class QuadraticSVC(BaseDatafit):
----------
lipschitz : array, shape (n_features,)
The coordinatewise gradient Lipschitz constants.
Equal to norm(yXT, axis=0) ** 2.
global_lipschitz : float
Global Lipschitz constant. Equal to
norm(yXT, ord=2) ** 2.
Note
----
Expand All @@ -200,6 +225,7 @@ def __init__(self):
def get_spec(self):
spec = (
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -209,12 +235,16 @@ def params_to_dict(self):
def initialize(self, yXT, y):
n_features = yXT.shape[1]
self.lipschitz = np.zeros(n_features, dtype=yXT.dtype)
self.global_lipschitz = norm(yXT, ord=2) ** 2
for j in range(n_features):
self.lipschitz[j] = norm(yXT[:, j]) ** 2

def initialize_sparse(
self, yXT_data, yXT_indptr, yXT_indices, y):
def initialize_sparse(self, yXT_data, yXT_indptr, yXT_indices, y):
n_features = len(yXT_indptr) - 1

self.global_lipschitz = spectral_norm(
yXT_data, yXT_indptr, yXT_indices, max(yXT_indices)+1) ** 2

self.lipschitz = np.zeros(n_features, dtype=yXT_data.dtype)
for j in range(n_features):
nrm2 = 0.
Expand Down Expand Up @@ -264,8 +294,16 @@ class Huber(BaseDatafit):
Attributes
----------
delta : float
Threshold hyperparameter.
lipschitz : array, shape (n_features,)
The coordinatewise gradient Lipschitz constants.
The coordinatewise gradient Lipschitz constants. Equal to
norm(X, axis=0) ** 2 / n_samples.
global_lipschitz : float
Global Lipschitz constant. Equal to
norm(X, ord=2) ** 2 / n_samples.
Note
----
Expand All @@ -279,7 +317,8 @@ def __init__(self, delta):
def get_spec(self):
spec = (
('delta', float64),
('lipschitz', float64[:])
('lipschitz', float64[:]),
('global_lipschitz', float64),
)
return spec

Expand All @@ -289,12 +328,17 @@ def params_to_dict(self):
def initialize(self, X, y):
n_features = X.shape[1]
self.lipschitz = np.zeros(n_features, dtype=X.dtype)
self.global_lipschitz = 0.
for j in range(n_features):
self.lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
self.global_lipschitz += (X[:, j] ** 2).sum() / len(y)

def initialize_sparse(
self, X_data, X_indptr, X_indices, y):
def initialize_sparse(self, X_data, X_indptr, X_indices, y):
n_features = len(X_indptr) - 1

self.global_lipschitz = spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2
self.global_lipschitz /= len(y)

self.lipschitz = np.zeros(n_features, dtype=X_data.dtype)
for j in range(n_features):
nrm2 = 0.
Expand Down
3 changes: 2 additions & 1 deletion skglm/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .anderson_cd import AndersonCD
from .base import BaseSolver
from .fista import FISTA
from .gram_cd import GramCD
from .group_bcd import GroupBCD
from .multitask_bcd import MultiTaskBCD
from .prox_newton import ProxNewton


__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
82 changes: 82 additions & 0 deletions skglm/solvers/fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import numpy as np
from scipy.sparse import issparse
from skglm.solvers.base import BaseSolver
from skglm.solvers.common import construct_grad, construct_grad_sparse
from skglm.utils import _prox_vec


class FISTA(BaseSolver):
r"""ISTA solver with Nesterov acceleration (FISTA).
Attributes
----------
max_iter : int, default 100
Maximum number of iterations.
tol : float, default 1e-4
Tolerance for convergence.
verbose : bool, default False
Amount of verbosity. 0/False is silent.
References
----------
.. [1] Beck, A. and Teboulle M.
"A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
problems", 2009, SIAM J. Imaging Sci.
https://epubs.siam.org/doi/10.1137/080716542
"""

def __init__(self, max_iter=100, tol=1e-4, verbose=0):
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
p_objs_out = []
n_samples, n_features = X.shape
all_features = np.arange(n_features)
t_new = 1.

w = w_init.copy() if w_init is not None else np.zeros(n_features)
z = w_init.copy() if w_init is not None else np.zeros(n_features)
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)

if hasattr(datafit, "global_lipschitz"):
lipschitz = datafit.global_lipschitz
else:
# TODO: OR line search
raise Exception("Line search is not yet implemented for FISTA solver.")

for n_iter in range(self.max_iter):
t_old = t_new
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
w_old = w.copy()
if issparse(X):
grad = construct_grad_sparse(
X.data, X.indptr, X.indices, y, z, X @ z, datafit, all_features)
else:
grad = construct_grad(X, y, z, X @ z, datafit, all_features)

step = 1 / lipschitz
z -= step * grad
w = _prox_vec(w, z, penalty, step)
Xw = X @ w
z = w + (t_old - 1.) / t_new * (w - w_old)

opt = penalty.subdiff_distance(w, grad, all_features)
stop_crit = np.max(opt)

p_obj = datafit.value(y, w, Xw) + penalty.value(w)
p_objs_out.append(p_obj)
if self.verbose:
print(
f"Iteration {n_iter+1}: {p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}"
)

if stop_crit < self.tol:
if self.verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
break
return w, np.array(p_objs_out), stop_crit
69 changes: 69 additions & 0 deletions skglm/tests/test_fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

import numpy as np
from numpy.linalg import norm

import scipy.sparse
import scipy.sparse.linalg
from scipy.sparse import csc_matrix, issparse

from skglm.penalties import L1, IndicatorBox
from skglm.solvers import FISTA, AndersonCD
from skglm.datafits import Quadratic, Logistic, QuadraticSVC
from skglm.utils import make_correlated_data, compiled_clone, spectral_norm


random_state = 113
n_samples, n_features = 50, 60

rng = np.random.RandomState(random_state)
X, y, _ = make_correlated_data(n_samples, n_features, random_state=rng)
rng.seed(random_state)
X_sparse = csc_matrix(X * np.random.binomial(1, 0.5, X.shape))
y_classif = np.sign(y)

alpha_max = norm(X.T @ y, ord=np.inf) / len(y)
alpha = alpha_max / 10

tol = 1e-10


@pytest.mark.parametrize("X", [X, X_sparse])
@pytest.mark.parametrize("Datafit, Penalty", [
(Quadratic, L1),
(Logistic, L1),
(QuadraticSVC, IndicatorBox),
])
def test_fista_solver(X, Datafit, Penalty):
_y = y if isinstance(Datafit, Quadratic) else y_classif
datafit = compiled_clone(Datafit())
_init = y @ X.T if isinstance(Datafit, QuadraticSVC) else X
if issparse(X):
datafit.initialize_sparse(_init.data, _init.indptr, _init.indices, _y)
else:
datafit.initialize(_init, _y)
penalty = compiled_clone(Penalty(alpha))

solver = FISTA(max_iter=1000, tol=tol)
w_fista = solver.solve(X, _y, datafit, penalty)[0]

solver_cd = AndersonCD(tol=tol, fit_intercept=False)
w_cd = solver_cd.solve(X, _y, datafit, penalty)[0]

np.testing.assert_allclose(w_fista, w_cd, atol=1e-7)


def test_spectral_norm():
n_samples, n_features = 50, 60
A_sparse = scipy.sparse.random(n_samples, n_features, density=0.7, format='csc',
random_state=random_state)

A_bundles = (A_sparse.data, A_sparse.indptr, A_sparse.indices)
spectral_norm_our = spectral_norm(*A_bundles, n_samples=len(y))
spectral_norm_sp = scipy.sparse.linalg.svds(A_sparse, k=1)[1]

np.testing.assert_allclose(spectral_norm_our, spectral_norm_sp)


if __name__ == '__main__':
pass
Loading

0 comments on commit 359f4da

Please sign in to comment.