From 59f91f7e4c88734b5bf02f728ad36db8013483c4 Mon Sep 17 00:00:00 2001 From: Badr MOUFAD <65614794+Badr-MOUFAD@users.noreply.github.com> Date: Tue, 31 Oct 2023 11:39:52 +0100 Subject: [PATCH] MNT - Refactor ``dist_fix_point`` function (#194) --- skglm/solvers/anderson_cd.py | 8 +++++--- skglm/solvers/common.py | 20 ++++++++++++-------- skglm/solvers/multitask_bcd.py | 25 ++++++++++++++----------- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 4b8984b0b..cc4efbc3b 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -2,7 +2,9 @@ from numba import njit from scipy import sparse from sklearn.utils import check_array -from skglm.solvers.common import construct_grad, construct_grad_sparse, dist_fix_point +from skglm.solvers.common import ( + construct_grad, construct_grad_sparse, dist_fix_point_cd +) from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration @@ -104,7 +106,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.ws_strategy == "subdiff": opt = penalty.subdiff_distance(w[:n_features], grad, all_feats) elif self.ws_strategy == "fixpoint": - opt = dist_fix_point( + opt = dist_fix_point_cd( w[:n_features], grad, lipschitz, datafit, penalty, all_feats ) @@ -181,7 +183,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.ws_strategy == "subdiff": opt_ws = penalty.subdiff_distance(w[:n_features], grad_ws, ws) elif self.ws_strategy == "fixpoint": - opt_ws = dist_fix_point( + opt_ws = dist_fix_point_cd( w[:n_features], grad_ws, lipschitz, datafit, penalty, ws ) diff --git a/skglm/solvers/common.py b/skglm/solvers/common.py index b6cc37bba..3e933c597 100644 --- a/skglm/solvers/common.py +++ b/skglm/solvers/common.py @@ -3,7 +3,7 @@ @njit -def dist_fix_point(w, grad_ws, lipschitz, datafit, penalty, ws): +def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws): """Compute the violation of the fixed point iterate scheme. Parameters @@ -28,16 +28,20 @@ def dist_fix_point(w, grad_ws, lipschitz, datafit, penalty, ws): Returns ------- - dist_fix_point : array, shape (n_features,) + dist : array, shape (n_features,) Violation score for every feature. """ - dist_fix_point = np.zeros(ws.shape[0]) + dist = np.zeros(ws.shape[0]) + for idx, j in enumerate(ws): - lcj = lipschitz[j] - if lcj != 0: - dist_fix_point[idx] = np.abs( - w[j] - penalty.prox_1d(w[j] - grad_ws[idx] / lcj, 1. / lcj, j)) - return dist_fix_point + if lipschitz[j] == 0.: + continue + + step_j = 1 / lipschitz[j] + dist[idx] = np.abs( + w[j] - penalty.prox_1d(w[j] - step_j * grad_ws[idx], step_j, j) + ) + return dist @njit diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 0759cfa8c..5a68d0c5e 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -66,7 +66,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None): if self.ws_strategy == "subdiff": opt = penalty.subdiff_distance(W, grad, all_feats) elif self.ws_strategy == "fixpoint": - opt = dist_fix_point(W, grad, datafit, penalty, all_feats) + opt = dist_fix_point_bcd(W, grad, datafit, penalty, all_feats) stop_crit = np.max(opt) if self.verbose: print(f"Stopping criterion max violation: {stop_crit:.2e}") @@ -150,7 +150,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None): if self.ws_strategy == "subdiff": opt_ws = penalty.subdiff_distance(W, grad_ws, ws) elif self.ws_strategy == "fixpoint": - opt_ws = dist_fix_point( + opt_ws = dist_fix_point_bcd( W, grad_ws, lipschitz, datafit, penalty, ws ) @@ -231,7 +231,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) @njit -def dist_fix_point(W, grad_ws, lipschitz, datafit, penalty, ws): +def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws): """Compute the violation of the fixed point iterate schema. Parameters @@ -256,17 +256,20 @@ def dist_fix_point(W, grad_ws, lipschitz, datafit, penalty, ws): Returns ------- - dist_fix_point : array, shape (ws_size,) + dist : array, shape (ws_size,) Contain the violation score for every feature. """ - dist_fix_point = np.zeros(ws.shape[0]) + dist = np.zeros(ws.shape[0]) + for idx, j in enumerate(ws): - lcj = lipschitz[j] - if lcj: - dist_fix_point[idx] = norm( - W[j] - penalty.prox_1feat(W[j] - grad_ws[idx] / lcj, 1. / lcj, j) - ) - return dist_fix_point + if lipschitz[j] == 0.: + continue + + step_j = 1 / lipschitz[j] + dist[idx] = norm( + W[j] - penalty.prox_1feat(W[j] - step_j * grad_ws[idx], step_j, j) + ) + return dist @njit