Skip to content

Commit

Permalink
dist_fix_point ---> dist_fix_point_bcd
Browse files Browse the repository at this point in the history
  • Loading branch information
Badr-MOUFAD committed Oct 26, 2023
1 parent c52e50c commit 041407d
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions skglm/solvers/multitask_bcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
if self.ws_strategy == "subdiff":
opt = penalty.subdiff_distance(W, grad, all_feats)
elif self.ws_strategy == "fixpoint":
opt = dist_fix_point(W, grad, datafit, penalty, all_feats)
opt = dist_fix_point_bcd(W, grad, datafit, penalty, all_feats)
stop_crit = np.max(opt)
if self.verbose:
print(f"Stopping criterion max violation: {stop_crit:.2e}")
Expand Down Expand Up @@ -150,7 +150,7 @@ def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None):
if self.ws_strategy == "subdiff":
opt_ws = penalty.subdiff_distance(W, grad_ws, ws)
elif self.ws_strategy == "fixpoint":
opt_ws = dist_fix_point(
opt_ws = dist_fix_point_bcd(
W, grad_ws, lipschitz, datafit, penalty, ws
)

Expand Down Expand Up @@ -231,7 +231,7 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False)


@njit
def dist_fix_point(W, grad_ws, lipschitz, datafit, penalty, ws):
def dist_fix_point_bcd(W, grad_ws, lipschitz, datafit, penalty, ws):
"""Compute the violation of the fixed point iterate schema.
Parameters
Expand All @@ -256,17 +256,20 @@ def dist_fix_point(W, grad_ws, lipschitz, datafit, penalty, ws):
Returns
-------
dist_fix_point : array, shape (ws_size,)
dist : array, shape (ws_size,)
Contain the violation score for every feature.
"""
dist_fix_point = np.zeros(ws.shape[0])
dist = np.zeros(ws.shape[0])

for idx, j in enumerate(ws):
lcj = lipschitz[j]
if lcj:
dist_fix_point[idx] = norm(
W[j] - penalty.prox_1feat(W[j] - grad_ws[idx] / lcj, 1. / lcj, j)
)
return dist_fix_point
if lipschitz[j] == 0.:
continue

step_j = 1 / lipschitz[j]
dist[idx] = norm(
W[j] - penalty.prox_1feat(W[j] - step_j * grad_ws[idx], step_j, j)
)
return dist


@njit
Expand Down

0 comments on commit 041407d

Please sign in to comment.