From ccc634487c28dd9db48a2bd7ef2f21f8092d7251 Mon Sep 17 00:00:00 2001 From: mathurinm Date: Mon, 3 Jun 2024 07:46:57 +0200 Subject: [PATCH] FIX ProxNewton solver with fixpoint strategy (#259) Co-authored-by: Badr-MOUFAD --- skglm/solvers/anderson_cd.py | 2 +- skglm/solvers/common.py | 12 ++++++------ skglm/solvers/multitask_bcd.py | 20 +++++++++++--------- skglm/solvers/prox_newton.py | 29 ++++++++++++++++------------- 4 files changed, 34 insertions(+), 29 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index cc4efbc3b..8b3437de4 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -184,7 +184,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): opt_ws = penalty.subdiff_distance(w[:n_features], grad_ws, ws) elif self.ws_strategy == "fixpoint": opt_ws = dist_fix_point_cd( - w[:n_features], grad_ws, lipschitz, datafit, penalty, ws + w[:n_features], grad_ws, lipschitz[ws], datafit, penalty, ws ) stop_crit_in = np.max(opt_ws) diff --git a/skglm/solvers/common.py b/skglm/solvers/common.py index a8a3f4ec3..a5e7216f2 100644 --- a/skglm/solvers/common.py +++ b/skglm/solvers/common.py @@ -3,7 +3,7 @@ @njit -def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws): +def dist_fix_point_cd(w, grad_ws, lipschitz_ws, datafit, penalty, ws): """Compute the violation of the fixed point iterate scheme. Parameters @@ -14,8 +14,8 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws): grad_ws : array, shape (ws_size,) Gradient restricted to the working set. - lipschitz : array, shape (n_features,) - Coordinatewise gradient Lipschitz constants. + lipschitz_ws : array, shape (len(ws),) + Coordinatewise gradient Lipschitz constants, restricted to working set. datafit: instance of BaseDatafit Datafit. @@ -23,7 +23,7 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws): penalty: instance of BasePenalty Penalty. - ws : array, shape (ws_size,) + ws : array, shape (len(ws),) The working set. Returns @@ -34,10 +34,10 @@ def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws): dist = np.zeros(ws.shape[0], dtype=w.dtype) for idx, j in enumerate(ws): - if lipschitz[j] == 0.: + if lipschitz_ws[idx] == 0.: continue - step_j = 1 / lipschitz[j] + step_j = 1 / lipschitz_ws[idx] dist[idx] = np.abs( w[j] - penalty.prox_1d(w[j] - step_j * grad_ws[idx], step_j, j) ) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 5a68d0c5e..16301ac4a 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -66,7 +66,9 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None): if self.ws_strategy == "subdiff": opt = penalty.subdiff_distance(W, grad, all_feats) elif self.ws_strategy == "fixpoint": - opt = dist_fix_point_bcd(W, grad, datafit, penalty, all_feats) + opt = dist_fix_point_bcd( + W, grad, lipschitz, datafit, penalty, all_feats + ) stop_crit = np.max(opt) if self.verbose: print(f"Stopping criterion max violation: {stop_crit:.2e}") @@ -151,7 +153,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None): opt_ws = penalty.subdiff_distance(W, grad_ws, ws) elif self.ws_strategy == "fixpoint": opt_ws = dist_fix_point_bcd( - W, grad_ws, lipschitz, datafit, penalty, ws + W, grad_ws, lipschitz[ws], datafit, penalty, ws ) stop_crit_in = np.max(opt_ws) @@ -231,7 +233,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) @njit -def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws): +def dist_fix_point_bcd(W, grad_ws, lipschitz_ws, datafit, penalty, ws): """Compute the violation of the fixed point iterate schema. Parameters @@ -239,19 +241,19 @@ def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws): W : array, shape (n_features, n_tasks) Coefficient matrix. - grad_ws : array, shape (ws_size, n_tasks) + grad_ws : array, shape (len(ws), n_tasks) Gradient restricted to the working set. datafit: instance of BaseMultiTaskDatafit Datafit. - lipschitz : array, shape (n_features,) - Blockwise gradient Lipschitz constants. + lipschitz_ws : array, shape (len(ws),) + Blockwise gradient Lipschitz constants, restricted to working set. penalty: instance of BasePenalty Penalty. - ws : array, shape (ws_size,) + ws : array, shape (len(ws),) The working set. Returns @@ -262,10 +264,10 @@ def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws): dist = np.zeros(ws.shape[0]) for idx, j in enumerate(ws): - if lipschitz[j] == 0.: + if lipschitz_ws[idx] == 0.: continue - step_j = 1 / lipschitz[j] + step_j = 1 / lipschitz_ws[idx] dist[idx] = norm( W[j] - penalty.prox_1feat(W[j] - step_j * grad_ws[idx], step_j, j) ) diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index b077ea567..4b8e0aaf7 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -65,6 +65,9 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, self.verbose = verbose def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + if self.ws_strategy not in ("subdiff", "fixpoint"): + raise ValueError("ws_strategy must be `subdiff` or `fixpoint`, " + f"got {self.ws_strategy}.") dtype = X.dtype n_samples, n_features = X.shape fit_intercept = self.fit_intercept @@ -206,9 +209,9 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, dtype = X.dtype raw_hess = datafit.raw_hessian(y, Xw_epoch) - lipschitz = np.zeros(len(ws), dtype) + lipschitz_ws = np.zeros(len(ws), dtype) for idx, j in enumerate(ws): - lipschitz[idx] = raw_hess @ X[:, j] ** 2 + lipschitz_ws[idx] = raw_hess @ X[:, j] ** 2 # for a less costly stopping criterion, we do not compute the exact gradient, # but store each coordinate-wise gradient every time we update one coordinate @@ -224,12 +227,12 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, for cd_iter in range(MAX_CD_ITER): for idx, j in enumerate(ws): # skip when X[:, j] == 0 - if lipschitz[idx] == 0: + if lipschitz_ws[idx] == 0: continue past_grads[idx] = grad_ws[idx] + X[:, j] @ (raw_hess * X_delta_w_ws) old_w_idx = w_ws[idx] - stepsize = 1 / lipschitz[idx] + stepsize = 1 / lipschitz_ws[idx] w_ws[idx] = penalty.prox_1d( old_w_idx - stepsize * past_grads[idx], stepsize, j) @@ -253,7 +256,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, opt = penalty.subdiff_distance(current_w, past_grads, ws) elif ws_strategy == "fixpoint": opt = dist_fix_point_cd( - current_w, past_grads, lipschitz, datafit, penalty, ws + current_w, past_grads, lipschitz_ws, datafit, penalty, ws ) stop_crit = np.max(opt) @@ -264,7 +267,7 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, break # descent direction - return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz + return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz_ws # sparse version of _descent_direction @@ -275,10 +278,10 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch, dtype = X_data.dtype raw_hess = datafit.raw_hessian(y, Xw_epoch) - lipschitz = np.zeros(len(ws), dtype) + lipschitz_ws = np.zeros(len(ws), dtype) for idx, j in enumerate(ws): - # equivalent to: lipschitz[idx] += raw_hess * X[:, j] ** 2 - lipschitz[idx] = _sparse_squared_weighted_norm( + # equivalent to: lipschitz_ws[idx] += raw_hess * X[:, j] ** 2 + lipschitz_ws[idx] = _sparse_squared_weighted_norm( X_data, X_indptr, X_indices, j, raw_hess) # see _descent_direction() comment @@ -294,7 +297,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch, for cd_iter in range(MAX_CD_ITER): for idx, j in enumerate(ws): # skip when X[:, j] == 0 - if lipschitz[idx] == 0: + if lipschitz_ws[idx] == 0: continue past_grads[idx] = grad_ws[idx] @@ -303,7 +306,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch, X_data, X_indptr, X_indices, j, X_delta_w_ws, raw_hess) old_w_idx = w_ws[idx] - stepsize = 1 / lipschitz[idx] + stepsize = 1 / lipschitz_ws[idx] w_ws[idx] = penalty.prox_1d( old_w_idx - stepsize * past_grads[idx], stepsize, j) @@ -328,7 +331,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch, opt = penalty.subdiff_distance(current_w, past_grads, ws) elif ws_strategy == "fixpoint": opt = dist_fix_point_cd( - current_w, past_grads, lipschitz, datafit, penalty, ws + current_w, past_grads, lipschitz_ws, datafit, penalty, ws ) stop_crit = np.max(opt) @@ -339,7 +342,7 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch, break # descent direction - return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz + return w_ws - w_epoch[ws_intercept], X_delta_w_ws, lipschitz_ws @njit