Skip to content

Commit

Permalink
MNT - Refactor dist_fix_point function (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
Badr-MOUFAD authored Oct 31, 2023
1 parent 920cd41 commit 59f91f7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 22 deletions.
8 changes: 5 additions & 3 deletions skglm/solvers/anderson_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from numba import njit
from scipy import sparse
from sklearn.utils import check_array
from skglm.solvers.common import construct_grad, construct_grad_sparse, dist_fix_point
from skglm.solvers.common import (
construct_grad, construct_grad_sparse, dist_fix_point_cd
)
from skglm.solvers.base import BaseSolver
from skglm.utils.anderson import AndersonAcceleration

Expand Down Expand Up @@ -104,7 +106,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
if self.ws_strategy == "subdiff":
opt = penalty.subdiff_distance(w[:n_features], grad, all_feats)
elif self.ws_strategy == "fixpoint":
opt = dist_fix_point(
opt = dist_fix_point_cd(
w[:n_features], grad, lipschitz, datafit, penalty, all_feats
)

Expand Down Expand Up @@ -181,7 +183,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
if self.ws_strategy == "subdiff":
opt_ws = penalty.subdiff_distance(w[:n_features], grad_ws, ws)
elif self.ws_strategy == "fixpoint":
opt_ws = dist_fix_point(
opt_ws = dist_fix_point_cd(
w[:n_features], grad_ws, lipschitz, datafit, penalty, ws
)

Expand Down
20 changes: 12 additions & 8 deletions skglm/solvers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


@njit
def dist_fix_point(w, grad_ws, lipschitz, datafit, penalty, ws):
def dist_fix_point_cd(w, grad_ws, lipschitz, datafit, penalty, ws):
"""Compute the violation of the fixed point iterate scheme.
Parameters
Expand All @@ -28,16 +28,20 @@ def dist_fix_point(w, grad_ws, lipschitz, datafit, penalty, ws):
Returns
-------
dist_fix_point : array, shape (n_features,)
dist : array, shape (n_features,)
Violation score for every feature.
"""
dist_fix_point = np.zeros(ws.shape[0])
dist = np.zeros(ws.shape[0])

for idx, j in enumerate(ws):
lcj = lipschitz[j]
if lcj != 0:
dist_fix_point[idx] = np.abs(
w[j] - penalty.prox_1d(w[j] - grad_ws[idx] / lcj, 1. / lcj, j))
return dist_fix_point
if lipschitz[j] == 0.:
continue

step_j = 1 / lipschitz[j]
dist[idx] = np.abs(
w[j] - penalty.prox_1d(w[j] - step_j * grad_ws[idx], step_j, j)
)
return dist


@njit
Expand Down
25 changes: 14 additions & 11 deletions skglm/solvers/multitask_bcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
if self.ws_strategy == "subdiff":
opt = penalty.subdiff_distance(W, grad, all_feats)
elif self.ws_strategy == "fixpoint":
opt = dist_fix_point(W, grad, datafit, penalty, all_feats)
opt = dist_fix_point_bcd(W, grad, datafit, penalty, all_feats)
stop_crit = np.max(opt)
if self.verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
Expand Down Expand Up @@ -150,7 +150,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
if self.ws_strategy == "subdiff":
opt_ws = penalty.subdiff_distance(W, grad_ws, ws)
elif self.ws_strategy == "fixpoint":
opt_ws = dist_fix_point(
opt_ws = dist_fix_point_bcd(
W, grad_ws, lipschitz, datafit, penalty, ws
)

Expand Down Expand Up @@ -231,7 +231,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False)


@njit
def dist_fix_point(W, grad_ws, lipschitz, datafit, penalty, ws):
def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws):
"""Compute the violation of the fixed point iterate schema.
Parameters
Expand All @@ -256,17 +256,20 @@ def dist_fix_point(W, grad_ws, lipschitz, datafit, penalty, ws):
Returns
-------
dist_fix_point : array, shape (ws_size,)
dist : array, shape (ws_size,)
Contain the violation score for every feature.
"""
dist_fix_point = np.zeros(ws.shape[0])
dist = np.zeros(ws.shape[0])

for idx, j in enumerate(ws):
lcj = lipschitz[j]
if lcj:
dist_fix_point[idx] = norm(
W[j] - penalty.prox_1feat(W[j] - grad_ws[idx] / lcj, 1. / lcj, j)
)
return dist_fix_point
if lipschitz[j] == 0.:
continue

step_j = 1 / lipschitz[j]
dist[idx] = norm(
W[j] - penalty.prox_1feat(W[j] - step_j * grad_ws[idx], step_j, j)
)
return dist


@njit
Expand Down

0 comments on commit 59f91f7

Please sign in to comment.