Skip to content

Commit

Permalink
ENH - check datafit + penalty compatibility with solver (#137)
Browse files Browse the repository at this point in the history
Co-authored-by: Badr-MOUFAD <[email protected]>
Co-authored-by: Badr MOUFAD <[email protected]>
Co-authored-by: Quentin Bertrand <[email protected]>
Co-authored-by: mathurinm <[email protected]>
Co-authored-by: mathurinm <[email protected]>
  • Loading branch information
6 people authored Jul 15, 2024
1 parent d6ab8c2 commit ed7bf2d
Show file tree
Hide file tree
Showing 19 changed files with 408 additions and 84 deletions.
1 change: 1 addition & 0 deletions doc/changes/0.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Version 0.4 (in progress)
-------------------------
- Add :ref:`GroupLasso Estimator <skglm.GroupLasso>` (PR: :gh:`228`)
- Add support and tutorial for positive coefficients to :ref:`Group Lasso Penalty <skglm.penalties.WeightedGroupL2>` (PR: :gh:`221`)
- Check compatibility with datafit and penalty in solver (PR :gh:`137`)
- Add support to weight samples in the quadratic datafit :ref:`Weighted Quadratic Datafit <skglm.datafit.WeightedQuadratic>` (PR: :gh:`258`)


Expand Down
2 changes: 1 addition & 1 deletion skglm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.3.2dev'
__version__ = '0.4dev'

from skglm.estimators import ( # noqa F401
Lasso, WeightedLasso, ElasticNet, MCPRegression, MultiTaskLasso, LinearSVC,
Expand Down
3 changes: 0 additions & 3 deletions skglm/datafits/single_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,9 +661,6 @@ def value(self, y, w, Xw):
def gradient_scalar(self, X, y, w, Xw, j):
return X[:, j] @ (1 - y * np.exp(-Xw)) / len(y)

def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j):
pass

def intercept_update_step(self, y, Xw):
return np.sum(self.raw_grad(y, Xw))

Expand Down
37 changes: 11 additions & 26 deletions skglm/experimental/pdcd_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from scipy.sparse import issparse

from numba import njit
from skglm.utils.jit_compilation import compiled_clone
from skglm.solvers import BaseSolver

from sklearn.exceptions import ConvergenceWarning


class PDCD_WS:
class PDCD_WS(BaseSolver):
r"""Primal-Dual Coordinate Descent solver with working sets.
It solves
Expand Down Expand Up @@ -78,6 +79,9 @@ class PDCD_WS:
https://arxiv.org/abs/2204.07826
"""

_datafit_required_attr = ('prox_conjugate',)
_penalty_required_attr = ("prox_1d",)

def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
p0=100, tol=1e-6, verbose=False):
self.max_iter = max_iter
Expand All @@ -87,11 +91,7 @@ def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None,
self.tol = tol
self.verbose = verbose

def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None):
if issparse(X):
raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.")

datafit, penalty = PDCD_WS._validate_init(datafit_, penalty_)
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
n_samples, n_features = X.shape

# init steps
Expand Down Expand Up @@ -196,27 +196,12 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty,
if stop_crit_in <= tol_in:
break

@staticmethod
def _validate_init(datafit_, penalty_):
# validate datafit
missing_attrs = []
for attr in ('prox_conjugate', 'subdiff_distance'):
if not hasattr(datafit_, attr):
missing_attrs.append(f"`{attr}`")

if len(missing_attrs):
raise AttributeError(
"Datafit is not compatible with PDCD_WS solver.\n"
"Datafit must implement `prox_conjugate` and `subdiff_distance`.\n"
f"Missing {' and '.join(missing_attrs)}."
def custom_checks(self, X, y, datafit, penalty):
if issparse(X):
raise ValueError(
"Sparse matrices are not yet supported in `PDCD_WS` solver."
)

# jit compile classes
compiled_datafit = compiled_clone(datafit_)
compiled_penalty = compiled_clone(penalty_)

return compiled_datafit, compiled_penalty


@njit
def _scores_primal(X, w, z, penalty, primal_steps, ws):
Expand Down
6 changes: 5 additions & 1 deletion skglm/experimental/tests/test_quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from skglm.penalties import L1
from skglm.experimental.pdcd_ws import PDCD_WS
from skglm.experimental.quantile_regression import Pinball
from skglm.utils.jit_compilation import compiled_clone

from skglm.utils.data import make_correlated_data
from sklearn.linear_model import QuantileRegressor
Expand All @@ -21,9 +22,12 @@ def test_PDCD_WS(quantile_level):
alpha_max = norm(X.T @ (np.sign(y)/2 + (quantile_level - 0.5)), ord=np.inf)
alpha = alpha_max / 5

datafit = compiled_clone(Pinball(quantile_level))
penalty = compiled_clone(L1(alpha))

w = PDCD_WS(
dual_init=np.sign(y)/2 + (quantile_level - 0.5)
).solve(X, y, Pinball(quantile_level), L1(alpha))[0]
).solve(X, y, datafit, penalty)[0]

clf = QuantileRegressor(
quantile=quantile_level,
Expand Down
6 changes: 5 additions & 1 deletion skglm/experimental/tests/test_sqrt_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic,
_chambolle_pock_sqrt)
from skglm.experimental.pdcd_ws import PDCD_WS
from skglm.utils.jit_compilation import compiled_clone


def test_alpha_max():
Expand Down Expand Up @@ -69,7 +70,10 @@ def test_PDCD_WS(with_dual_init):

dual_init = y / norm(y) if with_dual_init else None

w = PDCD_WS(dual_init=dual_init).solve(X, y, SqrtQuadratic(), L1(alpha))[0]
datafit = compiled_clone(SqrtQuadratic())
penalty = compiled_clone(L1(alpha))

w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0]
clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y)
np.testing.assert_allclose(clf.coef_, w, atol=1e-6)

Expand Down
11 changes: 0 additions & 11 deletions skglm/penalties/block_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,17 +458,6 @@ def prox_1group(self, value, stepsize, g):
res = ST_vec(value, self.alpha * stepsize * self.weights_features[g])
return BST(res, self.alpha * stepsize * self.weights_groups[g])

def subdiff_distance(self, w, grad_ws, ws):
"""Compute distance to the subdifferential at ``w`` of negative gradient.
Refer to :ref:`subdiff_positive_group_lasso` for details of the derivation.
Note:
----
``grad_ws`` is a stacked array of gradients ``[grad_ws_1, grad_ws_2, ...]``.
"""
raise NotImplementedError("Too hard for now")

def is_penalized(self, n_groups):
return np.ones(n_groups, dtype=np.bool_)

Expand Down
10 changes: 0 additions & 10 deletions skglm/penalties/non_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,3 @@ def prox_vec(self, x, stepsize):
prox[sorted_indices] = prox_SLOPE(abs_x[sorted_indices], alphas * stepsize)

return np.sign(x) * prox

def prox_1d(self, value, stepsize, j):
raise ValueError(
"No coordinate-wise proximal operator for SLOPE. Use `prox_vec` instead."
)

def subdiff_distance(self, w, grad, ws):
return ValueError(
"No subdifferential distance for SLOPE. Use `opt_strategy='fixpoint'`"
)
21 changes: 20 additions & 1 deletion skglm/solvers/anderson_cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from skglm.solvers.base import BaseSolver
from skglm.utils.anderson import AndersonAcceleration
from skglm.utils.validation import check_attrs


class AndersonCD(BaseSolver):
Expand Down Expand Up @@ -47,6 +48,9 @@ class AndersonCD(BaseSolver):
code: https://github.com/mathurinm/andersoncd
"""

_datafit_required_attr = ("get_lipschitz", "gradient_scalar")
_penalty_required_attr = ("prox_1d",)

def __init__(self, max_iter=50, max_epochs=50_000, p0=10,
tol=1e-4, ws_strategy="subdiff", fit_intercept=True,
warm_start=False, verbose=0):
Expand All @@ -59,7 +63,7 @@ def __init__(self, max_iter=50, max_epochs=50_000, p0=10,
self.warm_start = warm_start
self.verbose = verbose

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
if self.ws_strategy not in ("subdiff", "fixpoint"):
raise ValueError(
'Unsupported value for self.ws_strategy:', self.ws_strategy)
Expand Down Expand Up @@ -269,6 +273,21 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None,
results += (n_iters,)
return results

def custom_checks(self, X, y, datafit, penalty):
# check datafit support sparse data
check_attrs(
datafit, solver=self,
required_attr=self._datafit_required_attr,
support_sparse=sparse.issparse(X)
)

# ws strategy
if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"):
raise AttributeError(
"Penalty must implement `subdiff_distance` "
"to use ws_strategy='subdiff' in solver AndersonCD."
)


@njit
def _cd_epoch(X, y, w, Xw, lc, datafit, penalty, ws):
Expand Down
83 changes: 79 additions & 4 deletions skglm/solvers/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
from abc import abstractmethod
from abc import abstractmethod, ABC
from skglm.utils.validation import check_attrs


class BaseSolver():
"""Base class for solvers."""
class BaseSolver(ABC):
"""Base class for solvers.
Attributes
----------
_datafit_required_attr : list
List of attributes that must be implemented in Datafit.
_penalty_required_attr : list
List of attributes that must be implemented in Penalty.
Notes
-----
For required attributes, if an attribute is given as a list of attributes
it means at least one of them should be implemented.
For instance, if
_datafit_required_attr = (
"get_global_lipschitz",
("gradient", "gradient_scalar")
)
it mean datafit must implement the methods ``get_global_lipschitz``
and (``gradient`` or ``gradient_scaler``).
"""

_datafit_required_attr: list
_penalty_required_attr: list

@abstractmethod
def solve(self, X, y, datafit, penalty, w_init, Xw_init):
def _solve(self, X, y, datafit, penalty, w_init, Xw_init):
"""Solve an optimization problem.
Parameters
Expand Down Expand Up @@ -39,3 +66,51 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init):
stop_crit : float
Value of stopping criterion at convergence.
"""

def custom_checks(self, X, y, datafit, penalty):
"""Ensure the solver is suited for the `datafit` + `penalty` problem.
This method includes extra checks to perform
aside from checking attributes compatibility.
Parameters
----------
X : array, shape (n_samples, n_features)
Training data.
y : array, shape (n_samples,)
Target values.
datafit : instance of BaseDatafit
Datafit.
penalty : instance of BasePenalty
Penalty.
"""
pass

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None,
*, run_checks=True):
"""Solve the optimization problem after validating its compatibility.
A proxy of ``_solve`` method that implicitly ensures the compatibility
of ``datafit`` and ``penalty`` with the solver.
Examples
--------
>>> ...
>>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty)
"""
if run_checks:
self._validate(X, y, datafit, penalty)

return self._solve(X, y, datafit, penalty, w_init, Xw_init)

def _validate(self, X, y, datafit, penalty):
# execute: `custom_checks` then check attributes
self.custom_checks(X, y, datafit, penalty)

# do not check for sparse support here, make the check at the solver level
# some solvers like ProxNewton don't require methods for sparse support
check_attrs(datafit, self, self._datafit_required_attr)
check_attrs(penalty, self, self._penalty_required_attr)
40 changes: 26 additions & 14 deletions skglm/solvers/fista.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from skglm.solvers.base import BaseSolver
from skglm.solvers.common import construct_grad, construct_grad_sparse
from skglm.utils.prox_funcs import _prox_vec
from skglm.utils.validation import check_attrs


class FISTA(BaseSolver):
Expand All @@ -27,6 +28,9 @@ class FISTA(BaseSolver):
https://epubs.siam.org/doi/10.1137/080716542
"""

_datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar"))
_penalty_required_attr = (("prox_1d", "prox_vec"),)

def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0):
self.max_iter = max_iter
self.tol = tol
Expand All @@ -35,7 +39,7 @@ def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0):
self.fit_intercept = False # needed to be passed to GeneralizedLinearEstimator
self.warm_start = False

def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
p_objs_out = []
n_samples, n_features = X.shape
all_features = np.arange(n_features)
Expand All @@ -46,19 +50,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
z = w_init.copy() if w_init is not None else np.zeros(n_features)
Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples)

try:
if X_is_sparse:
lipschitz = datafit.get_global_lipschitz_sparse(
X.data, X.indptr, X.indices, y
)
else:
lipschitz = datafit.get_global_lipschitz(X, y)
except AttributeError as e:
sparse_suffix = '_sparse' if X_is_sparse else ''

raise Exception(
"Datafit is not compatible with FISTA solver.\n Datafit must "
f"implement `get_global_lipschitz{sparse_suffix}` method") from e
if X_is_sparse:
lipschitz = datafit.get_global_lipschitz_sparse(
X.data, X.indptr, X.indices, y
)
else:
lipschitz = datafit.get_global_lipschitz(X, y)

for n_iter in range(self.max_iter):
t_old = t_new
Expand Down Expand Up @@ -111,3 +108,18 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
print(f"Stopping criterion max violation: {stop_crit:.2e}")
break
return w, np.array(p_objs_out), stop_crit

def custom_checks(self, X, y, datafit, penalty):
# check datafit support sparse data
check_attrs(
datafit, solver=self,
required_attr=self._datafit_required_attr,
support_sparse=issparse(X)
)

# optimality check
if self.opt_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"):
raise AttributeError(
"Penalty must implement `subdiff_distance` "
"to use `opt_strategy='subdiff'` in Fista solver."
)
Loading

0 comments on commit ed7bf2d

Please sign in to comment.