diff --git a/gram_test.py b/gram_test.py new file mode 100644 index 000000000..1fcab4585 --- /dev/null +++ b/gram_test.py @@ -0,0 +1,56 @@ +# data available at https://www.dropbox.com/sh/32b3mr3xghi496g/AACNRS_NOsUXU-hrSLixNg0ja?dl=0 + + +import time +import numpy as np +from celer import GroupLasso +from skglm.solvers.gram import gram_fista_group_lasso, gram_group_lasso + +X = np.load("design_matrix.npy") +y = np.load("target.npy") +groups = np.load("groups.npy") +weights = np.load("weights.npy") +grps = [list(np.where(groups == i)[0]) for i in range(1, 33)] + + +alpha_ratio = 1e-2 +n_alphas = 10 +tol = 1e-8 + +# Case 1: slower runtime for small alphas +# alpha_max = 0.003471727067743962 +alpha_max = np.max(np.linalg.norm((X.T @ y).reshape(-1, 5), axis=1)) / len(y) +alpha = alpha_max / 100 +clf = GroupLasso(fit_intercept=False, tol=tol, + groups=grps, weights=weights, alpha=alpha, verbose=1) + +t0 = time.time() +clf.fit(X, y) +t1 = time.time() + +print(f"Celer: {t1 - t0:.3f} s") + +t0 = time.time() +res = gram_group_lasso(X, y, alpha, groups=grps, tol=tol, weights=weights, max_iter=10_000, + check_freq=50) +t1 = time.time() + +print(f"skglm gram: {t1 - t0:.3f} s") + + +# FISTA Gram for very small alphas +alpha = alpha_max / 1e-4 +clf = GroupLasso(fit_intercept=False, tol=tol, groups=grps, weights=weights, alpha=alpha, + verbose=1) + +t0 = time.time() +clf.fit(X, y) +t1 = time.time() + +print(f"Celer: {t1 - t0:.3f} s") + +t0 = time.time() +res = gram_fista_group_lasso(X, y, alpha, groups=grps, tol=tol, weights=weights, max_iter=10_000, + check_freq=50) +t1 = time.time() +print(f"skglm fista gram: {t1 - t0:.3f} s") \ No newline at end of file diff --git a/skglm/gram_solver.py b/skglm/gram_solver.py new file mode 100644 index 000000000..57301a4f5 --- /dev/null +++ b/skglm/gram_solver.py @@ -0,0 +1,73 @@ +from time import time +import numpy as np +from numpy.linalg import norm +from celer import Lasso, GroupLasso +from benchopt.datasets.simulated import make_correlated_data +from skglm.solvers.gram import gram_fista_group_lasso, gram_fista_lasso, gram_lasso, gram_group_lasso + + +n_samples, n_features = 100, 300 +X, y, w_star = make_correlated_data( + n_samples=n_samples, n_features=n_features, random_state=0) +alpha_max = norm(X.T @ y, ord=np.inf) + +# Hyperparameters +max_iter = 10_000 +tol = 1e-8 +reg = 0.1 +group_size = 3 + +alpha = alpha_max * reg / n_samples + +weights = np.random.normal(2, 0.4, n_features) +weights_grp = np.random.normal(2, 0.4, n_features // group_size) + +# Lasso +print("#" * 15) +print("Lasso") +print("#" * 15) +start = time() +w = gram_lasso(X, y, alpha, max_iter, tol, weights=weights) +gram_lasso_time = time() - start +clf_sk = Lasso(alpha, weights=weights, tol=tol, fit_intercept=False) +start = time() +clf_sk.fit(X, y) +celer_lasso_time = time() - start +start = time() +w_fista = gram_fista_lasso(X, y, alpha, max_iter, tol, weights=weights) +gram_fista_lasso_time = time() - start +np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-4) +np.testing.assert_allclose(w, w_fista, rtol=1e-4) + +print("\n") +print("Celer: %.2f" % celer_lasso_time) +print("CD Gram: %.2f" % gram_lasso_time) +print("FISTA Gram: %.2f" % gram_fista_lasso_time) +print("\n") + +# Group Lasso +print("#" * 15) +print("Group Lasso") +print("#" * 15) +start = time() +w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp) +gram_group_lasso_time = time() - start +start = time() +w_fista = gram_fista_group_lasso(X, y, alpha, group_size, max_iter, tol, + weights=weights_grp) +gram_fista_group_lasso_time = time() - start + +np.testing.assert_allclose(w, w_fista, rtol=1e-4) + +clf_celer = GroupLasso(group_size, alpha, tol=tol, weights=weights_grp, + fit_intercept=False) +start = time() +clf_celer.fit(X, y) +celer_group_lasso_time = time() - start +np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-4) + +print("\n") +print("Celer: %.2f" % celer_group_lasso_time) +print("BCD Gram: %.2f" % gram_group_lasso_time) +print("FISTA Gram: %.2f" % gram_fista_group_lasso_time) +print("\n") diff --git a/skglm/solvers/gram.py b/skglm/solvers/gram.py new file mode 100644 index 000000000..79a5a8929 --- /dev/null +++ b/skglm/solvers/gram.py @@ -0,0 +1,254 @@ +import numpy as np +from numba import njit +from numpy.linalg import norm +from celer.homotopy import _grp_converter + +from skglm.utils import BST, ST, ST_vec + + +@njit +def primal(alpha, r, w, weights): + n_features = len(weights) + p_obj = (r @ r) / (2 * len(r)) + pen = 0. + for j in range(n_features): + if weights[j] == np.inf: + continue + pen += np.abs(w[j] * weights[j]) + return p_obj + alpha * pen + + +@njit +def primal_grp(alpha, norm_r2, r, w, grp_ptr, grp_indices, weights): + p_obj = norm_r2 / (2 * len(r)) + for g in range(len(grp_ptr) - 1): + if weights[g] == np.inf: + continue + w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] + p_obj += alpha * norm(w_g * weights[g], ord=2) + return p_obj + + +@njit +def dual(alpha, norm_y2, theta, y): + d_obj = - np.sum((y / (alpha * len(y)) - theta) ** 2) + d_obj *= 0.5 * alpha ** 2 * len(y) + d_obj += norm_y2 / (2 * len(y)) + return d_obj + + +@njit +def dnorm_l1(theta, X, weights): + n_features = X.shape[1] + scal = 0. + for j in range(n_features): + Xj_theta = X[:, j] @ theta + scal = max(scal, Xj_theta / weights[j]) + return scal + + +@njit +def dnorm_l21(theta, grp_ptr, grp_indices, X, weights): + scal = 0. + n_groups = len(grp_ptr) - 1 + for g in range(n_groups): + if weights[g] == np.inf: + continue + tmp = 0. + for k in range(grp_ptr[g], grp_ptr[g + 1]): + j = grp_indices[k] + Xj_theta = X[:, j] @ theta + tmp += Xj_theta ** 2 + scal = max(scal, np.sqrt(tmp) / weights[g]) + return scal + + +@njit +def create_dual_point(r, alpha, X, y, weights): + theta = r / (alpha * len(y)) + scal = dnorm_l1(theta, X, weights) + if scal > 1.: + theta /= scal + return theta + + +@njit +def create_dual_point_grp(r, alpha, y, X, grp_ptr, grp_indices, weights): + theta = r / (alpha * len(y)) + scal = dnorm_l21(theta, grp_ptr, grp_indices, X, weights) + if scal > 1.: + theta /= scal + return theta + + +@njit +def dual_gap(alpha, norm_y2, y, X, w, weights): + r = y - X @ w + p_obj = primal(alpha, r, w, weights) + theta = create_dual_point(r, alpha, X, y, weights) + d_obj = dual(alpha, norm_y2, theta, y) + return p_obj, d_obj, p_obj - d_obj + + +@njit +def dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, grp_indices, weights): + r = y - X @ w + norm_r2 = r @ r + p_obj = primal_grp(alpha, norm_r2, r, w, grp_ptr, grp_indices, weights) + theta = create_dual_point_grp(r, alpha, y, X, grp_ptr, grp_indices, weights) + d_obj = dual(alpha, norm_y2, theta, y) + return p_obj, d_obj, p_obj - d_obj + + +@njit +def compute_lipschitz(X, y): + n_features = X.shape[1] + lipschitz = np.zeros(n_features, dtype=X.dtype) + for j in range(n_features): + lipschitz[j] = (X[:, j] ** 2).sum() / len(y) + return lipschitz + + +@njit +def prox_l21(w, u, weights, grp_ptr, grp_indices): + n_groups = len(grp_ptr) - 1 + out = w.copy() + for g in range(n_groups): + idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]] + grp_nrm = norm(w[idx], ord=2) + scaling = np.maximum(1 - u / grp_nrm * weights[g], 0) + out[idx] *= scaling + return out + + +def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=100): + n_features = X.shape[1] + norm_y2 = y @ y + grads = X.T @ y / len(y) + G = X.T @ X + lipschitz = compute_lipschitz(X, y) + w = w_init.copy() if w_init is not None else np.zeros(n_features) + weights = weights if weights is not None else np.ones(n_features) + for n_iter in range(max_iter): + cd_epoch(X, G, grads, w, alpha, lipschitz, weights) + if n_iter % check_freq == 0: + p_obj, d_obj, d_gap = dual_gap(alpha, norm_y2, y, X, w, weights) + print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f}" + + f" :: gap {d_gap:.5f}") + if d_gap < tol: + print("Convergence reached!") + break + return w + + +def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, + check_freq=100): + n_samples, n_features = X.shape + norm_y2 = y @ y + t_new = 1 + w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) + weights = weights if weights is not None else np.ones(n_features) + G = X.T @ X + Xty = X.T @ y + L = np.linalg.norm(X, ord=2) ** 2 / n_samples + for n_iter in range(max_iter): + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + z -= (G @ z - Xty) / L / n_samples + w = ST_vec(z, alpha / L * weights) + z = w + (t_old - 1.) / t_new * (w - w_old) + if n_iter % check_freq == 0: + p_obj, d_obj, d_gap = dual_gap(alpha, norm_y2, y, X, w, weights) + print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + + f":: gap {d_gap:.5f}") + if d_gap < tol: + print("Convergence reached!") + break + return w + + +def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None, + check_freq=100): + n_features = X.shape[1] + grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) + n_groups = len(grp_ptr) - 1 + norm_y2 = y @ y + grads = X.T @ y / len(y) + G = X.T @ X + lipschitz = np.zeros(n_groups, dtype=X.dtype) + for g in range(n_groups): + X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]] + lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y) + w = w_init.copy() if w_init is not None else np.zeros(n_features) + weights = weights if weights is not None else np.ones(n_groups) + for n_iter in range(max_iter): + bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights) + if n_iter % check_freq == 0: + p_obj, d_obj, d_gap = dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, + grp_indices, weights) + print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + + f":: gap {d_gap:.5f}") + if d_gap < tol: + print("Convergence reached!") + break + return w + + +def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, + weights=None, check_freq=100): + n_features = X.shape[1] + norm_y2 = y @ y + grp_ptr, grp_indices = _grp_converter(groups, X.shape[1]) + n_groups = len(grp_ptr) - 1 + t_new = 1 + w = w_init.copy() if w_init is not None else np.zeros(n_features) + z = w_init.copy() if w_init is not None else np.zeros(n_features) + weights = weights if weights is not None else np.ones(n_groups) + G = X.T @ X + Xty = X.T @ y + L = np.linalg.norm(X, ord=2) ** 2 / len(y) + for n_iter in range(max_iter): + t_old = t_new + t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2 + w_old = w.copy() + z -= (G @ z - Xty) / L / len(y) + w = prox_l21(z, alpha / L, weights, grp_ptr, grp_indices) + z = w + (t_old - 1.) / t_new * (w - w_old) + if n_iter % check_freq == 0: + p_obj, d_obj, d_gap = dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, + grp_indices, weights) + print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " + + f":: gap {d_gap:.5f}") + if d_gap < tol: + print("Convergence reached!") + break + return w + + +@njit +def cd_epoch(X, G, grads, w, alpha, lipschitz, weights): + n_features = X.shape[1] + for j in range(n_features): + if lipschitz[j] == 0. or weights[j] == np.inf: + continue + old_w_j = w[j] + w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j]) + if old_w_j != w[j]: + grads += G[j, :] * (old_w_j - w[j]) / len(X) + + +@njit +def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights): + n_groups = len(grp_ptr) - 1 + for g in range(n_groups): + if lipschitz[g] == 0. and weights[g] == np.inf: + continue + idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]] + old_w_g = w[idx].copy() + w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g] + * weights[g]) + diff = old_w_g - w[idx] + if np.any(diff != 0.): + grads += diff @ G[idx, :] / len(X)