From a7db68b7459e79afbb4efa72078cb6e0e5c2913c Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Fri, 22 Apr 2022 19:13:49 +0200 Subject: [PATCH] WIP FISTA --- skglm/solvers/gram.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py index 915ae8741..317cea99b 100644 --- a/skglm/solvers/gram.py +++ b/skglm/solvers/gram.py @@ -31,11 +31,15 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq lipschitz = np.zeros(n_features, dtype=X.dtype) for j in range(n_features): lipschitz[j] = (X[:, j] ** 2).sum() / len(y) - w = w_init if w_init is not None else np.zeros(n_features) + w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) + beta_0 = beta_1 = 1 weights = weights if weights is not None else np.ones(n_features) # CD for n_iter in range(max_iter): - cd_epoch(X, G, grads, w, alpha, lipschitz, weights) + beta_1 = (1 + np.sqrt(1 + 4 * beta_0 ** 2)) / 2 + cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights) + beta_0 = beta_1 if n_iter % check_freq == 0: p_obj = primal(alpha, y, X, w, weights) if p_obj_prev - p_obj < tol: @@ -58,7 +62,7 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No for g in range(n_groups): X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) - w = w_init if w_init is not None else np.zeros(n_features) + w = w_init.copy() if w_init is not None else np.zeros(n_features) weights = weights if weights is not None else np.ones(n_groups) # BCD for n_iter in range(max_iter): @@ -74,15 +78,17 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No @njit -def cd_epoch(X, G, grads, w, alpha, lipschitz, weights): +def cd_epoch(X, G, grads, w, z, alpha, beta_1, beta_0, lipschitz, weights): n_features = X.shape[1] for j in range(n_features): if lipschitz[j] == 0. or weights[j] == np.inf: continue old_w_j = w[j] - w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j]) - if old_w_j != w[j]: - grads += G[j, :] * (old_w_j - w[j]) / len(X) + old_z_j = z[j] + w[j] = ST(z[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j]) + z[j] = w[j] + ((beta_0 - 1) / beta_1) * (w[j] - old_w_j) + if old_z_j != z[j]: + grads += G[j, :] * (old_z_j - z[j]) / len(X) @njit