Skip to content

Commit

Permalink
ENH - Add weights and positivity constraint to MCP (#184)
Browse files Browse the repository at this point in the history
Co-authored-by: Badr MOUFAD <[email protected]>
Co-authored-by: Badr-MOUFAD <[email protected]>
  • Loading branch information
3 people authored Jul 18, 2023
1 parent 562f42b commit a9f88f6
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 39 deletions.
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Penalties
PositiveConstraint
WeightedL1
WeightedGroupL2
WeightedMCPenalty
SCAD
BlockSCAD
SLOPE
Expand Down
5 changes: 5 additions & 0 deletions doc/changes/0.4.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.. _changes_0_4:

Version 0.4 (in progress)
---------------------------
- Add support for weights and positive coefficients to :ref:`MCPRegression Estimator <skglm.MCPRegression>` (PR: :gh:`184`)
3 changes: 2 additions & 1 deletion doc/changes/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ What's new

.. currentmodule:: skglm

.. include:: 0.4.rst

.. include:: 0.3.rst

.. include:: 0.2.rst

.. include:: 0.1.rst

50 changes: 40 additions & 10 deletions skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from skglm.solvers import AndersonCD, MultiTaskBCD
from skglm.datafits import Cox, Quadratic, Logistic, QuadraticSVC, QuadraticMultiTask
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2,
MCPenalty, IndicatorBox, L2_1)
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)


def _glm_fit(X, y, model, datafit, penalty, solver):
Expand Down Expand Up @@ -792,6 +792,10 @@ class MCPRegression(LinearModel, RegressorMixin):
If ``gamma = np.inf`` it is a soft thresholding.
Should be larger than (or equal to) 1.
weights : array, shape (n_features,), optional (default=None)
Positive weights used in the L1 penalty part of the Lasso
objective. If ``None``, weights equal to 1 are used.
max_iter : int, optional
The maximum number of iterations (subproblem definitions).
Expand All @@ -807,6 +811,9 @@ class MCPRegression(LinearModel, RegressorMixin):
tol : float, optional
Stopping criterion for the optimization.
positive : bool, optional
When set to ``True``, forces the coefficient vector to be positive.
fit_intercept : bool, optional (default=True)
Whether or not to fit an intercept.
Expand Down Expand Up @@ -836,20 +843,22 @@ class MCPRegression(LinearModel, RegressorMixin):
Lasso : Lasso regularization.
"""

def __init__(self, alpha=1., gamma=3, max_iter=50, max_epochs=50_000, p0=10,
verbose=0, tol=1e-4, fit_intercept=True, warm_start=False,
ws_strategy="subdiff"):
def __init__(self, alpha=1., gamma=3, weights=None, max_iter=50, max_epochs=50_000,
p0=10, verbose=0, tol=1e-4, positive=False, fit_intercept=True,
warm_start=False, ws_strategy="subdiff"):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.tol = tol
self.weights = weights
self.max_iter = max_iter
self.max_epochs = max_epochs
self.p0 = p0
self.ws_strategy = ws_strategy
self.verbose = verbose
self.tol = tol
self.positive = positive
self.fit_intercept = fit_intercept
self.warm_start = warm_start
self.verbose = verbose
self.ws_strategy = ws_strategy

def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
"""Compute MCPRegression path.
Expand Down Expand Up @@ -890,7 +899,19 @@ def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **params):
The number of iterations along the path. If return_n_iter is set to
``True``.
"""
penalty = compiled_clone(MCPenalty(self.alpha, self.gamma))
if self.weights is None:
penalty = compiled_clone(
MCPenalty(self.alpha, self.gamma, self.positive)
)
else:
if X.shape[1] != len(self.weights):
raise ValueError(
"The number of weights must match the number of features. "
f"Got {len(self.weights)}, expected {X.shape[1]}."
)
penalty = compiled_clone(
WeightedMCPenalty(self.alpha, self.gamma, self.weights, self.positive)
)
datafit = compiled_clone(Quadratic(), to_float32=X.dtype == np.float32)
solver = AndersonCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
Expand All @@ -914,12 +935,21 @@ def fit(self, X, y):
self :
Fitted estimator.
"""
if self.weights is None:
penalty = MCPenalty(self.alpha, self.gamma, self.positive)
else:
if X.shape[1] != len(self.weights):
raise ValueError(
"The number of weights must match the number of features. "
f"Got {len(self.weights)}, expected {X.shape[1]}."
)
penalty = WeightedMCPenalty(
self.alpha, self.gamma, self.weights, self.positive)
solver = AndersonCD(
self.max_iter, self.max_epochs, self.p0, tol=self.tol,
ws_strategy=self.ws_strategy, fit_intercept=self.fit_intercept,
warm_start=self.warm_start, verbose=self.verbose)
return _glm_fit(X, y, self, Quadratic(), MCPenalty(self.alpha, self.gamma),
solver)
return _glm_fit(X, y, self, Quadratic(), penalty, solver)


class SparseLogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator):
Expand Down
9 changes: 5 additions & 4 deletions skglm/penalties/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .base import BasePenalty
from .separable import (
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
PositiveConstraint
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, WeightedMCPenalty, SCAD,
WeightedL1, IndicatorBox, PositiveConstraint
)
from .block_separable import (
L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2
Expand All @@ -12,6 +12,7 @@

__all__ = [
BasePenalty,
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, SCAD, WeightedL1, IndicatorBox,
PositiveConstraint, L2_05, L2_1, BlockMCPenalty, BlockSCAD, WeightedGroupL2, SLOPE
L1_plus_L2, L0_5, L1, L2, L2_3, MCPenalty, WeightedMCPenalty, SCAD, WeightedL1,
IndicatorBox, PositiveConstraint, L2_05, L2_1, BlockMCPenalty, BlockSCAD,
WeightedGroupL2, SLOPE
]
121 changes: 107 additions & 14 deletions skglm/penalties/separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from skglm.penalties.base import BasePenalty
from skglm.utils.prox_funcs import (
ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP, value_MCP)
ST, box_proj, prox_05, prox_2_3, prox_SCAD, value_SCAD, prox_MCP,
value_MCP, value_weighted_MCP)


class L1(BasePenalty):
Expand Down Expand Up @@ -216,48 +217,57 @@ class MCPenalty(BasePenalty):
With :math:`x >= 0`:
.. math::
"pen"(x) = {(alpha x - x^2 / (2 gamma), if x =< alpha gamma),
"pen"(x) = {(alpha x - x^2 / (2 gamma), if x <= alpha gamma),
(gamma alpha^2 / 2 , if x > alpha gamma):}
.. math::
"value" = sum_(j=1)^(n_"features") "pen"(abs(w_j))
"""

def __init__(self, alpha, gamma):
def __init__(self, alpha, gamma, positive=False):
self.alpha = alpha
self.gamma = gamma
self.positive = positive

def get_spec(self):
spec = (
('alpha', float64),
('gamma', float64),
('positive', bool_)
)
return spec

def params_to_dict(self):
return dict(alpha=self.alpha,
gamma=self.gamma)
gamma=self.gamma,
positive=self.positive)

def value(self, w):
return value_MCP(w, self.alpha, self.gamma)

def prox_1d(self, value, stepsize, j):
"""Compute the proximal operator of MCP."""
return prox_MCP(value, stepsize, self.alpha, self.gamma)
return prox_MCP(value, stepsize, self.alpha, self.gamma, self.positive)

def subdiff_distance(self, w, grad, ws):
"""Compute distance of negative gradient to the subdifferential at w."""
subdiff_dist = np.zeros_like(grad)
for idx, j in enumerate(ws):
if w[j] == 0:
# distance of -grad to alpha * [-1, 1]
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - self.alpha)
elif np.abs(w[j]) < self.alpha * self.gamma:
# distance of -grad_j to (alpha * sign(w[j]) - w[j] / gamma)
subdiff_dist[idx] = np.abs(
grad[idx] + self.alpha * np.sign(w[j]) - w[j] / self.gamma)
if self.positive and w[j] < 0:
subdiff_dist[idx] = np.inf
elif self.positive and w[j] == 0:
# distance of -grad to (-infty, alpha]
subdiff_dist[idx] = max(0, - grad[idx] - self.alpha)
else:
# distance of grad to 0
subdiff_dist[idx] = np.abs(grad[idx])
if w[j] == 0:
# distance of -grad to [-alpha, alpha]
subdiff_dist[idx] = max(0, np.abs(grad[idx]) - self.alpha)
elif np.abs(w[j]) < self.alpha * self.gamma:
# distance of -grad to {alpha * sign(w[j]) - w[j] / gamma}
subdiff_dist[idx] = np.abs(
grad[idx] + self.alpha * np.sign(w[j]) - w[j] / self.gamma)
else:
# distance of grad to 0
subdiff_dist[idx] = np.abs(grad[idx])
return subdiff_dist

def is_penalized(self, n_features):
Expand All @@ -273,6 +283,89 @@ def alpha_max(self, gradient0):
return np.max(np.abs(gradient0))


class WeightedMCPenalty(BasePenalty):
"""Weighted Minimax Concave Penalty (MCP), a non-convex sparse penalty.
Notes
-----
With :math:`x >= 0`:
.. math::
"pen"(x) = {(alpha x - x^2 / (2 gamma), if x <= alpha gamma),
(gamma alpha^2 / 2 , if x > alpha gamma):}
.. math::
"value" = sum_(j=1)^(n_"features") "weights"_j xx "pen"(abs(w_j))
"""

def __init__(self, alpha, gamma, weights, positive=False):
self.alpha = alpha
self.gamma = gamma
self.weights = weights.astype(np.float64)
self.positive = positive

def get_spec(self):
spec = (
('alpha', float64),
('gamma', float64),
('weights', float64[:]),
('positive', bool_)
)
return spec

def params_to_dict(self):
return dict(alpha=self.alpha,
gamma=self.gamma,
weights=self.weights,
positive=self.positive)

def value(self, w):
return value_weighted_MCP(w, self.alpha, self.gamma, self.weights)

def prox_1d(self, value, stepsize, j):
"""Compute the proximal operator of the weighted MCP."""
return prox_MCP(
value, stepsize, self.alpha, self.gamma, self.positive, self.weights[j])

def subdiff_distance(self, w, grad, ws):
"""Compute distance of negative gradient to the subdifferential at w."""
subdiff_dist = np.zeros_like(grad)
for idx, j in enumerate(ws):
if self.positive and w[j] < 0:
subdiff_dist[idx] = np.inf
elif self.positive and w[j] == 0:
# distance of -grad to (-infty, alpha * weights[j]]
subdiff_dist[idx] = max(
0, - grad[idx] - self.alpha * self.weights[j])
else:
if w[j] == 0:
# distance of -grad to weights[j] * [-alpha, alpha]
subdiff_dist[idx] = max(
0, np.abs(grad[idx]) - self.alpha * self.weights[j])
elif np.abs(w[j]) < self.alpha * self.gamma:
# distance of -grad to
# {weights[j] * alpha * sign(w[j]) - w[j] / gamma}
subdiff_dist[idx] = np.abs(
grad[idx] + self.alpha * self.weights[j] * np.sign(w[j])
- self.weights[j] * w[j] / self.gamma)
else:
# distance of grad to 0
subdiff_dist[idx] = np.abs(grad[idx])
return subdiff_dist

def is_penalized(self, n_features):
"""Return a binary mask with the penalized features."""
return np.ones(n_features, bool_)

def generalized_support(self, w):
"""Return a mask with non-zero coefficients."""
return w != 0

def alpha_max(self, gradient0):
"""Return penalization value for which 0 is solution."""
nnz_weights = self.weights != 0
return np.max(np.abs(gradient0[nnz_weights] / self.weights[nnz_weights]))


class SCAD(BasePenalty):
r"""Smoothly Clipped Absolute Deviation.
Expand Down
12 changes: 9 additions & 3 deletions skglm/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
dict_estimators_ours["MCP"] = MCPRegression(
alpha=alpha, gamma=np.inf, tol=tol)

dict_estimators_sk["wMCP"] = Lasso_sklearn(
alpha=alpha, tol=tol)
dict_estimators_ours["wMCP"] = MCPRegression(
alpha=alpha, gamma=np.inf, tol=tol, weights=np.ones(n_features))

dict_estimators_sk["LogisticRegression"] = LogReg_sklearn(
C=1/(alpha * n_samples), tol=tol, penalty='l1',
solver='liblinear')
Expand All @@ -88,7 +93,7 @@

@pytest.mark.parametrize(
"estimator_name",
["Lasso", "wLasso", "ElasticNet", "MCP", "LogisticRegression", "SVC"])
["Lasso", "wLasso", "ElasticNet", "MCP", "wMCP", "LogisticRegression", "SVC"])
def test_check_estimator(estimator_name):
if estimator_name == "SVC":
pytest.xfail("SVC check_estimator is too slow due to bug.")
Expand All @@ -97,7 +102,7 @@ def test_check_estimator(estimator_name):
pytest.xfail("ProxNewton does not yet support intercept fitting")
clf = clone(dict_estimators_ours[estimator_name])
clf.tol = 1e-6 # failure in float32 computation otherwise
if isinstance(clf, WeightedLasso):
if isinstance(clf, (WeightedLasso, MCPRegression)):
clf.weights = None
check_estimator(clf)

Expand All @@ -113,7 +118,8 @@ def test_estimator(estimator_name, X, fit_intercept, positive):
pytest.xfail("sklearn LogisticRegression does not support intercept.")
if fit_intercept and estimator_name == "SVC":
pytest.xfail("Intercept is not supported for SVC.")
if positive and estimator_name not in ("Lasso", "ElasticNet", "WeightedLasso"):
if positive and estimator_name not in (
"Lasso", "ElasticNet", "wLasso", "MCP", "wMCP"):
pytest.xfail("`positive` option is only supported by L1, L1_plus_L2 and wL1.")

estimator_sk = clone(dict_estimators_sk[estimator_name])
Expand Down
22 changes: 15 additions & 7 deletions skglm/utils/prox_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,23 @@ def value_MCP(w, alpha, gamma):


@njit
def prox_MCP(value, stepsize, alpha, gamma):
"""Compute the proximal operator of stepsize * MCP penalty."""
tau = alpha * stepsize
g = gamma / stepsize # what does g stand for ?
if np.abs(value) <= tau:
def value_weighted_MCP(w, alpha, gamma, weights):
"""Compute the value of the weighted MCP."""
s0 = np.abs(w) < gamma * alpha
value = np.full_like(w, gamma * alpha ** 2 / 2.)
value[s0] = alpha * np.abs(w[s0]) - w[s0]**2 / (2 * gamma)
return np.sum(weights * value)


@njit
def prox_MCP(value, stepsize, alpha, gamma, positive=False, weight=1.):
"""Compute the proximal operator of stepsize * weight MCP penalty."""
wstepsize = weight * stepsize # weighted stepsize
if (np.abs(value) <= alpha * wstepsize) or (positive and value <= 0.):
return 0.
if np.abs(value) > g * tau:
if np.abs(value) > alpha * gamma:
return value
return np.sign(value) * (np.abs(value) - tau) / (1. - 1./g)
return np.sign(value) * (np.abs(value) - alpha * wstepsize) / (1. - wstepsize/gamma)


@njit
Expand Down

0 comments on commit a9f88f6

Please sign in to comment.