Skip to content

Commit

Permalink
WIP FISTA
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Apr 22, 2022
1 parent 51b4cfe commit a7db68b
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions skglm/solvers/gram.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq
lipschitz = np.zeros(n_features, dtype=X.dtype)
for j in range(n_features):
lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
w = w_init if w_init is not None else np.zeros(n_features)
w = w_init.copy() if w_init is not None else np.zeros(n_features)
z = w_init.copy() if w_init is not None else np.zeros(n_features)
beta_0 = beta_1 = 1
weights = weights if weights is not None else np.ones(n_features)
# CD
for n_iter in range(max_iter):
cd_epoch(X, G, grads, w, alpha, lipschitz, weights)
beta_1 = (1 + np.sqrt(1 + 4 * beta_0 ** 2)) / 2

This comment has been minimized.

Copy link
@mathurinm

mathurinm Apr 22, 2022

Collaborator

You need to move it apart from CD solver and do a gradient step instead

cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights)
beta_0 = beta_1
if n_iter % check_freq == 0:
p_obj = primal(alpha, y, X, w, weights)
if p_obj_prev - p_obj < tol:
Expand All @@ -58,7 +62,7 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No
for g in range(n_groups):
X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
w = w_init if w_init is not None else np.zeros(n_features)
w = w_init.copy() if w_init is not None else np.zeros(n_features)
weights = weights if weights is not None else np.ones(n_groups)
# BCD
for n_iter in range(max_iter):
Expand All @@ -74,15 +78,17 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No


@njit
def cd_epoch(X, G, grads, w, alpha, lipschitz, weights):
def cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights):
n_features = X.shape[1]
for j in range(n_features):
if lipschitz[j] == 0. or weights[j] == np.inf:
continue
old_w_j = w[j]
w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j])
if old_w_j != w[j]:
grads += G[j, :] * (old_w_j - w[j]) / len(X)
old_z_j = z[j]
w[j] = ST(z[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j])
z[j] = w[j] + ((beta_0 - 1) / beta_1) * (w[j] - old_w_j)
if old_z_j != z[j]:
grads += G[j, :] * (old_z_j - z[j]) / len(X)


@njit
Expand Down

0 comments on commit a7db68b

Please sign in to comment.