Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add FISTA solver #91

Merged
merged 31 commits into from
Oct 22, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0868b0f
POC FISTA
PABannier Oct 12, 2022
8584299
CLN
PABannier Oct 14, 2022
c82e32e
changed obj_freq from 100 to 10
PABannier Oct 14, 2022
4940a0d
WIP Lipschitz
PABannier Oct 14, 2022
e47c68a
ADD global lipschitz constants
PABannier Oct 14, 2022
3635f24
FISTA with global lipschitz
PABannier Oct 14, 2022
4880112
writing tests
PABannier Oct 14, 2022
46a9a76
better tests
PABannier Oct 14, 2022
9f0653a
support sparse matrices
PABannier Oct 14, 2022
fe159be
fix mistake
PABannier Oct 14, 2022
8e74e8a
RM toy_fista
PABannier Oct 14, 2022
a24ed9c
green
PABannier Oct 14, 2022
4362c2c
mv `_prox_vec` to utils
PABannier Oct 16, 2022
2665d5d
rm `opt_freq`
PABannier Oct 16, 2022
2e408bc
fix tests
PABannier Oct 16, 2022
8524cf7
Update skglm/solvers/fista.py
PABannier Oct 16, 2022
dd658f8
huber comment
PABannier Oct 16, 2022
7c9fbe1
Merge branch 'fista' of https://github.com/PABannier/skglm into fista
PABannier Oct 16, 2022
cbc5418
WIP
PABannier Oct 16, 2022
b6c664c
Merge branch 'main' of https://github.com/scikit-learn-contrib/skglm …
Badr-MOUFAD Oct 20, 2022
e76dfb1
implement power method
Badr-MOUFAD Oct 20, 2022
2a4bce3
private ``prox_vec``
Badr-MOUFAD Oct 20, 2022
cd39a62
random init in power method && default args
Badr-MOUFAD Oct 21, 2022
0e4d42a
use power method for ``global_lipschitz``
Badr-MOUFAD Oct 21, 2022
2bbc8f5
fix && refactor unittest
Badr-MOUFAD Oct 21, 2022
ed3686a
add docs for tol and max_iter && clean ups
Badr-MOUFAD Oct 21, 2022
aa15c46
remove square form spectral norm
Badr-MOUFAD Oct 21, 2022
27b918d
refactor ``_prox_vec`` function
Badr-MOUFAD Oct 21, 2022
9d8e3c0
fix bug segmentation fault
Badr-MOUFAD Oct 21, 2022
e5ce21b
add Fista to docs && fix unittest
Badr-MOUFAD Oct 21, 2022
5d2dbaf
cosmetic changes
mathurinm Oct 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion skglm/solvers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .anderson_cd import AndersonCD
from .base import BaseSolver
from .fista import FISTA
from .gram_cd import GramCD
from .group_bcd import GroupBCD
from .multitask_bcd import MultiTaskBCD
from .prox_newton import ProxNewton


__all__ = [AndersonCD, BaseSolver, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
__all__ = [AndersonCD, BaseSolver, FISTA, GramCD, GroupBCD, MultiTaskBCD, ProxNewton]
67 changes: 67 additions & 0 deletions skglm/solvers/fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
from numba import njit
from skglm.solvers.base import BaseSolver


@njit
PABannier marked this conversation as resolved.
Show resolved Hide resolved
def _prox_vec(w, z, penalty, lipschitz):
PABannier marked this conversation as resolved.
Show resolved Hide resolved
# XXX: TO DISCUSS: should add a vectorized prox update
n_features = w.shape[0]
for j in range(n_features):
w[j] = penalty.prox_1d(z[j], 1 / lipschitz, j)
return w


class FISTA(BaseSolver):
PABannier marked this conversation as resolved.
Show resolved Hide resolved
r"""ISTA solver with Nesterov acceleration (FISTA)."""
PABannier marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, max_iter=100, tol=1e-4, fit_intercept=False, warm_start=False,
opt_freq=100, verbose=0):
PABannier marked this conversation as resolved.
Show resolved Hide resolved
self.max_iter = max_iter
self.tol = tol
self.fit_intercept = fit_intercept
self.warm_start = warm_start
self.opt_freq = opt_freq
self.verbose = verbose

def solve(self, X, y, penalty, w_init=None, weights=None):
# needs a quadratic datafit, but works with L1, WeightedL1, SLOPE
n_samples, n_features = X.shape
all_features = np.arange(n_features)
t_new = 1
PABannier marked this conversation as resolved.
Show resolved Hide resolved

w = w_init.copy() if w_init is not None else np.zeros(n_features)
z = w_init.copy() if w_init is not None else np.zeros(n_features)
weights = weights if weights is not None else np.ones(n_features)

# FISTA with Gram update
PABannier marked this conversation as resolved.
Show resolved Hide resolved
G = X.T @ X
Xty = X.T @ y
lipschitz = np.linalg.norm(X, ord=2) ** 2 / n_samples
PABannier marked this conversation as resolved.
Show resolved Hide resolved

for n_iter in range(self.max_iter):
t_old = t_new
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
w_old = w.copy()
grad = (G @ z - Xty) / n_samples
z -= grad / lipschitz
w = _prox_vec(w, z, penalty, lipschitz)
z = w + (t_old - 1.) / t_new * (w - w_old)
PABannier marked this conversation as resolved.
Show resolved Hide resolved
mathurinm marked this conversation as resolved.
Show resolved Hide resolved

if n_iter % self.opt_freq == 0:
opt = penalty.subdiff_distance(w, grad, all_features)
stop_crit = np.max(opt)

if self.verbose:
p_obj = (np.sum((y - X @ w) ** 2) / (2 * n_samples)
+ penalty.value(w))
print(
f"Iteration {n_iter+1}: {p_obj:.10f}, "
f"stopping crit: {stop_crit:.2e}"
)

if stop_crit < self.tol:
if self.verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
break
return w
PABannier marked this conversation as resolved.
Show resolved Hide resolved
27 changes: 27 additions & 0 deletions toy_fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
from numpy.linalg import norm
from skglm.solvers import FISTA
from skglm.penalties import L1
from skglm.estimators import Lasso
from skglm.utils import make_correlated_data, compiled_clone


X, y, _ = make_correlated_data(n_samples=200, n_features=100, random_state=24)

n_samples, n_features = X.shape
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples

alpha = alpha_max / 10

max_iter = 1000
obj_freq = 100
tol = 1e-10

solver = FISTA(max_iter=max_iter, tol=tol, opt_freq=obj_freq, verbose=1)
penalty = compiled_clone(L1(alpha))
w = solver.solve(X, y, penalty)

clf = Lasso(alpha=alpha, tol=tol, fit_intercept=False)
clf.fit(X, y)

np.testing.assert_allclose(w, clf.coef_, rtol=1e-5)