diff --git a/doc/changes/0.4.rst b/doc/changes/0.4.rst index c131143b6..904401fd9 100644 --- a/doc/changes/0.4.rst +++ b/doc/changes/0.4.rst @@ -4,6 +4,7 @@ Version 0.4 (in progress) ------------------------- - Add :ref:`GroupLasso Estimator ` (PR: :gh:`228`) - Add support and tutorial for positive coefficients to :ref:`Group Lasso Penalty ` (PR: :gh:`221`) +- Check compatibility with datafit and penalty in solver (PR :gh:`137`) - Add support to weight samples in the quadratic datafit :ref:`Weighted Quadratic Datafit ` (PR: :gh:`258`) diff --git a/skglm/__init__.py b/skglm/__init__.py index c134c98f2..d80de3c17 100644 --- a/skglm/__init__.py +++ b/skglm/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.3.2dev' +__version__ = '0.4dev' from skglm.estimators import ( # noqa F401 Lasso, WeightedLasso, ElasticNet, MCPRegression, MultiTaskLasso, LinearSVC, diff --git a/skglm/datafits/single_task.py b/skglm/datafits/single_task.py index 5750ea295..1ccb218aa 100644 --- a/skglm/datafits/single_task.py +++ b/skglm/datafits/single_task.py @@ -661,9 +661,6 @@ def value(self, y, w, Xw): def gradient_scalar(self, X, y, w, Xw, j): return X[:, j] @ (1 - y * np.exp(-Xw)) / len(y) - def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): - pass - def intercept_update_step(self, y, Xw): return np.sum(self.raw_grad(y, Xw)) diff --git a/skglm/experimental/pdcd_ws.py b/skglm/experimental/pdcd_ws.py index b81a68f5f..81e72da8c 100644 --- a/skglm/experimental/pdcd_ws.py +++ b/skglm/experimental/pdcd_ws.py @@ -5,11 +5,12 @@ from scipy.sparse import issparse from numba import njit -from skglm.utils.jit_compilation import compiled_clone +from skglm.solvers import BaseSolver + from sklearn.exceptions import ConvergenceWarning -class PDCD_WS: +class PDCD_WS(BaseSolver): r"""Primal-Dual Coordinate Descent solver with working sets. It solves @@ -78,6 +79,9 @@ class PDCD_WS: https://arxiv.org/abs/2204.07826 """ + _datafit_required_attr = ('prox_conjugate',) + _penalty_required_attr = ("prox_1d",) + def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None, p0=100, tol=1e-6, verbose=False): self.max_iter = max_iter @@ -87,11 +91,7 @@ def __init__(self, max_iter=1000, max_epochs=1000, dual_init=None, self.tol = tol self.verbose = verbose - def solve(self, X, y, datafit_, penalty_, w_init=None, Xw_init=None): - if issparse(X): - raise ValueError("Sparse matrices are not yet support in PDCD_WS solver.") - - datafit, penalty = PDCD_WS._validate_init(datafit_, penalty_) + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): n_samples, n_features = X.shape # init steps @@ -196,27 +196,12 @@ def _solve_subproblem(y, X, w, Xw, z, z_bar, datafit, penalty, if stop_crit_in <= tol_in: break - @staticmethod - def _validate_init(datafit_, penalty_): - # validate datafit - missing_attrs = [] - for attr in ('prox_conjugate', 'subdiff_distance'): - if not hasattr(datafit_, attr): - missing_attrs.append(f"`{attr}`") - - if len(missing_attrs): - raise AttributeError( - "Datafit is not compatible with PDCD_WS solver.\n" - "Datafit must implement `prox_conjugate` and `subdiff_distance`.\n" - f"Missing {' and '.join(missing_attrs)}." + def custom_checks(self, X, y, datafit, penalty): + if issparse(X): + raise ValueError( + "Sparse matrices are not yet supported in `PDCD_WS` solver." ) - # jit compile classes - compiled_datafit = compiled_clone(datafit_) - compiled_penalty = compiled_clone(penalty_) - - return compiled_datafit, compiled_penalty - @njit def _scores_primal(X, w, z, penalty, primal_steps, ws): diff --git a/skglm/experimental/tests/test_quantile_regression.py b/skglm/experimental/tests/test_quantile_regression.py index 509b7079c..65e0c1e65 100644 --- a/skglm/experimental/tests/test_quantile_regression.py +++ b/skglm/experimental/tests/test_quantile_regression.py @@ -5,6 +5,7 @@ from skglm.penalties import L1 from skglm.experimental.pdcd_ws import PDCD_WS from skglm.experimental.quantile_regression import Pinball +from skglm.utils.jit_compilation import compiled_clone from skglm.utils.data import make_correlated_data from sklearn.linear_model import QuantileRegressor @@ -21,9 +22,12 @@ def test_PDCD_WS(quantile_level): alpha_max = norm(X.T @ (np.sign(y)/2 + (quantile_level - 0.5)), ord=np.inf) alpha = alpha_max / 5 + datafit = compiled_clone(Pinball(quantile_level)) + penalty = compiled_clone(L1(alpha)) + w = PDCD_WS( dual_init=np.sign(y)/2 + (quantile_level - 0.5) - ).solve(X, y, Pinball(quantile_level), L1(alpha))[0] + ).solve(X, y, datafit, penalty)[0] clf = QuantileRegressor( quantile=quantile_level, diff --git a/skglm/experimental/tests/test_sqrt_lasso.py b/skglm/experimental/tests/test_sqrt_lasso.py index 91722abea..f5b044a86 100644 --- a/skglm/experimental/tests/test_sqrt_lasso.py +++ b/skglm/experimental/tests/test_sqrt_lasso.py @@ -7,6 +7,7 @@ from skglm.experimental.sqrt_lasso import (SqrtLasso, SqrtQuadratic, _chambolle_pock_sqrt) from skglm.experimental.pdcd_ws import PDCD_WS +from skglm.utils.jit_compilation import compiled_clone def test_alpha_max(): @@ -69,7 +70,10 @@ def test_PDCD_WS(with_dual_init): dual_init = y / norm(y) if with_dual_init else None - w = PDCD_WS(dual_init=dual_init).solve(X, y, SqrtQuadratic(), L1(alpha))[0] + datafit = compiled_clone(SqrtQuadratic()) + penalty = compiled_clone(L1(alpha)) + + w = PDCD_WS(dual_init=dual_init).solve(X, y, datafit, penalty)[0] clf = SqrtLasso(alpha=alpha, tol=1e-12).fit(X, y) np.testing.assert_allclose(clf.coef_, w, atol=1e-6) diff --git a/skglm/penalties/block_separable.py b/skglm/penalties/block_separable.py index 47161080e..091392601 100644 --- a/skglm/penalties/block_separable.py +++ b/skglm/penalties/block_separable.py @@ -458,17 +458,6 @@ def prox_1group(self, value, stepsize, g): res = ST_vec(value, self.alpha * stepsize * self.weights_features[g]) return BST(res, self.alpha * stepsize * self.weights_groups[g]) - def subdiff_distance(self, w, grad_ws, ws): - """Compute distance to the subdifferential at ``w`` of negative gradient. - - Refer to :ref:`subdiff_positive_group_lasso` for details of the derivation. - - Note: - ---- - ``grad_ws`` is a stacked array of gradients ``[grad_ws_1, grad_ws_2, ...]``. - """ - raise NotImplementedError("Too hard for now") - def is_penalized(self, n_groups): return np.ones(n_groups, dtype=np.bool_) diff --git a/skglm/penalties/non_separable.py b/skglm/penalties/non_separable.py index c27079323..58f0b8c2e 100644 --- a/skglm/penalties/non_separable.py +++ b/skglm/penalties/non_separable.py @@ -48,13 +48,3 @@ def prox_vec(self, x, stepsize): prox[sorted_indices] = prox_SLOPE(abs_x[sorted_indices], alphas * stepsize) return np.sign(x) * prox - - def prox_1d(self, value, stepsize, j): - raise ValueError( - "No coordinate-wise proximal operator for SLOPE. Use `prox_vec` instead." - ) - - def subdiff_distance(self, w, grad, ws): - return ValueError( - "No subdifferential distance for SLOPE. Use `opt_strategy='fixpoint'`" - ) diff --git a/skglm/solvers/anderson_cd.py b/skglm/solvers/anderson_cd.py index 87bd92008..d39a24086 100644 --- a/skglm/solvers/anderson_cd.py +++ b/skglm/solvers/anderson_cd.py @@ -7,6 +7,7 @@ ) from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration +from skglm.utils.validation import check_attrs class AndersonCD(BaseSolver): @@ -47,6 +48,9 @@ class AndersonCD(BaseSolver): code: https://github.com/mathurinm/andersoncd """ + _datafit_required_attr = ("get_lipschitz", "gradient_scalar") + _penalty_required_attr = ("prox_1d",) + def __init__(self, max_iter=50, max_epochs=50_000, p0=10, tol=1e-4, ws_strategy="subdiff", fit_intercept=True, warm_start=False, verbose=0): @@ -59,7 +63,7 @@ def __init__(self, max_iter=50, max_epochs=50_000, p0=10, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.ws_strategy not in ("subdiff", "fixpoint"): raise ValueError( 'Unsupported value for self.ws_strategy:', self.ws_strategy) @@ -269,6 +273,21 @@ def path(self, X, y, datafit, penalty, alphas=None, w_init=None, results += (n_iters,) return results + def custom_checks(self, X, y, datafit, penalty): + # check datafit support sparse data + check_attrs( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=sparse.issparse(X) + ) + + # ws strategy + if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use ws_strategy='subdiff' in solver AndersonCD." + ) + @njit def _cd_epoch(X, y, w, Xw, lc, datafit, penalty, ws): diff --git a/skglm/solvers/base.py b/skglm/solvers/base.py index 9b5c5b121..06a08a690 100644 --- a/skglm/solvers/base.py +++ b/skglm/solvers/base.py @@ -1,11 +1,38 @@ -from abc import abstractmethod +from abc import abstractmethod, ABC +from skglm.utils.validation import check_attrs -class BaseSolver(): - """Base class for solvers.""" +class BaseSolver(ABC): + """Base class for solvers. + + Attributes + ---------- + _datafit_required_attr : list + List of attributes that must be implemented in Datafit. + + _penalty_required_attr : list + List of attributes that must be implemented in Penalty. + + Notes + ----- + For required attributes, if an attribute is given as a list of attributes + it means at least one of them should be implemented. + For instance, if + + _datafit_required_attr = ( + "get_global_lipschitz", + ("gradient", "gradient_scalar") + ) + + it mean datafit must implement the methods ``get_global_lipschitz`` + and (``gradient`` or ``gradient_scaler``). + """ + + _datafit_required_attr: list + _penalty_required_attr: list @abstractmethod - def solve(self, X, y, datafit, penalty, w_init, Xw_init): + def _solve(self, X, y, datafit, penalty, w_init, Xw_init): """Solve an optimization problem. Parameters @@ -39,3 +66,51 @@ def solve(self, X, y, datafit, penalty, w_init, Xw_init): stop_crit : float Value of stopping criterion at convergence. """ + + def custom_checks(self, X, y, datafit, penalty): + """Ensure the solver is suited for the `datafit` + `penalty` problem. + + This method includes extra checks to perform + aside from checking attributes compatibility. + + Parameters + ---------- + X : array, shape (n_samples, n_features) + Training data. + + y : array, shape (n_samples,) + Target values. + + datafit : instance of BaseDatafit + Datafit. + + penalty : instance of BasePenalty + Penalty. + """ + pass + + def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None, + *, run_checks=True): + """Solve the optimization problem after validating its compatibility. + + A proxy of ``_solve`` method that implicitly ensures the compatibility + of ``datafit`` and ``penalty`` with the solver. + + Examples + -------- + >>> ... + >>> coefs, obj_out, stop_crit = solver.solve(X, y, datafit, penalty) + """ + if run_checks: + self._validate(X, y, datafit, penalty) + + return self._solve(X, y, datafit, penalty, w_init, Xw_init) + + def _validate(self, X, y, datafit, penalty): + # execute: `custom_checks` then check attributes + self.custom_checks(X, y, datafit, penalty) + + # do not check for sparse support here, make the check at the solver level + # some solvers like ProxNewton don't require methods for sparse support + check_attrs(datafit, self, self._datafit_required_attr) + check_attrs(penalty, self, self._penalty_required_attr) diff --git a/skglm/solvers/fista.py b/skglm/solvers/fista.py index b4653079a..e0933a111 100644 --- a/skglm/solvers/fista.py +++ b/skglm/solvers/fista.py @@ -3,6 +3,7 @@ from skglm.solvers.base import BaseSolver from skglm.solvers.common import construct_grad, construct_grad_sparse from skglm.utils.prox_funcs import _prox_vec +from skglm.utils.validation import check_attrs class FISTA(BaseSolver): @@ -27,6 +28,9 @@ class FISTA(BaseSolver): https://epubs.siam.org/doi/10.1137/080716542 """ + _datafit_required_attr = ("get_global_lipschitz", ("gradient", "gradient_scalar")) + _penalty_required_attr = (("prox_1d", "prox_vec"),) + def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0): self.max_iter = max_iter self.tol = tol @@ -35,7 +39,7 @@ def __init__(self, max_iter=100, tol=1e-4, opt_strategy="subdiff", verbose=0): self.fit_intercept = False # needed to be passed to GeneralizedLinearEstimator self.warm_start = False - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out = [] n_samples, n_features = X.shape all_features = np.arange(n_features) @@ -46,19 +50,12 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): z = w_init.copy() if w_init is not None else np.zeros(n_features) Xw = Xw_init.copy() if Xw_init is not None else np.zeros(n_samples) - try: - if X_is_sparse: - lipschitz = datafit.get_global_lipschitz_sparse( - X.data, X.indptr, X.indices, y - ) - else: - lipschitz = datafit.get_global_lipschitz(X, y) - except AttributeError as e: - sparse_suffix = '_sparse' if X_is_sparse else '' - - raise Exception( - "Datafit is not compatible with FISTA solver.\n Datafit must " - f"implement `get_global_lipschitz{sparse_suffix}` method") from e + if X_is_sparse: + lipschitz = datafit.get_global_lipschitz_sparse( + X.data, X.indptr, X.indices, y + ) + else: + lipschitz = datafit.get_global_lipschitz(X, y) for n_iter in range(self.max_iter): t_old = t_new @@ -111,3 +108,18 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): print(f"Stopping criterion max violation: {stop_crit:.2e}") break return w, np.array(p_objs_out), stop_crit + + def custom_checks(self, X, y, datafit, penalty): + # check datafit support sparse data + check_attrs( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=issparse(X) + ) + + # optimality check + if self.opt_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use `opt_strategy='subdiff'` in Fista solver." + ) diff --git a/skglm/solvers/gram_cd.py b/skglm/solvers/gram_cd.py index d18b165b3..9ecf42bfb 100644 --- a/skglm/solvers/gram_cd.py +++ b/skglm/solvers/gram_cd.py @@ -2,6 +2,7 @@ import numpy as np from numba import njit from scipy.sparse import issparse + from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration @@ -49,6 +50,9 @@ class GramCD(BaseSolver): Amount of verbosity. 0/False is silent. """ + _datafit_required_attr = () + _penalty_required_attr = ("prox_1d", "subdiff_distance") + def __init__(self, max_iter=100, use_acc=False, greedy_cd=True, tol=1e-4, fit_intercept=True, warm_start=False, verbose=0): self.max_iter = max_iter @@ -59,7 +63,7 @@ def __init__(self, max_iter=100, use_acc=False, greedy_cd=True, tol=1e-4, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): # we don't pass Xw_init as the solver uses Gram updates # to keep the gradient up-to-date instead of Xw n_samples, n_features = X.shape @@ -132,6 +136,13 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.array(p_objs_out), stop_crit + def custom_checks(self, X, y, datafit, penalty): + if datafit is not None: + raise AttributeError( + "`GramCD` supports only `Quadratic` datafit and fits it implicitly, " + f"argument `datafit` must be `None`, got {datafit.__class__.__name__}." + ) + @njit def _gram_cd_epoch(scaled_gram, w, grad, penalty, greedy_cd): diff --git a/skglm/solvers/group_bcd.py b/skglm/solvers/group_bcd.py index b2163a160..c7b515dad 100644 --- a/skglm/solvers/group_bcd.py +++ b/skglm/solvers/group_bcd.py @@ -4,7 +4,7 @@ from skglm.solvers.base import BaseSolver from skglm.utils.anderson import AndersonAcceleration -from skglm.utils.validation import check_group_compatible +from skglm.utils.validation import check_group_compatible, check_attrs from skglm.solvers.common import dist_fix_point_bcd @@ -37,6 +37,9 @@ class GroupBCD(BaseSolver): Amount of verbosity. 0/False is silent. """ + _datafit_required_attr = ("get_lipschitz", "gradient_g") + _penalty_required_attr = ("prox_1group",) + def __init__( self, max_iter=1000, max_epochs=100, p0=10, tol=1e-4, fit_intercept=False, warm_start=False, ws_strategy="subdiff", verbose=0): @@ -49,12 +52,10 @@ def __init__( self.ws_strategy = ws_strategy self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.ws_strategy not in ("subdiff", "fixpoint"): raise ValueError( 'Unsupported value for self.ws_strategy:', self.ws_strategy) - check_group_compatible(datafit) - check_group_compatible(penalty) n_samples, n_features = X.shape n_groups = len(penalty.grp_ptr) - 1 @@ -181,6 +182,24 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): return w, p_objs_out, stop_crit + def custom_checks(self, X, y, datafit, penalty): + check_group_compatible(datafit) + check_group_compatible(penalty) + + # check datafit support sparse data + check_attrs( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=issparse(X) + ) + + # ws strategy + if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use ws_strategy='subdiff'." + ) + @njit def _bcd_epoch(X, y, w, Xw, lipschitz, datafit, penalty, ws): diff --git a/skglm/solvers/group_prox_newton.py b/skglm/solvers/group_prox_newton.py index 929ab853b..1492651c3 100644 --- a/skglm/solvers/group_prox_newton.py +++ b/skglm/solvers/group_prox_newton.py @@ -1,9 +1,12 @@ import numpy as np from numba import njit from numpy.linalg import norm +from scipy.sparse import issparse + from skglm.solvers.base import BaseSolver from skglm.utils.validation import check_group_compatible + EPS_TOL = 0.3 MAX_CD_ITER = 20 MAX_BACKTRACK_ITER = 20 @@ -41,6 +44,9 @@ class GroupProxNewton(BaseSolver): code: https://github.com/tbjohns/BlitzL1 """ + _datafit_required_attr = ("raw_grad", "raw_hessian") + _penalty_required_attr = ("prox_1group", "subdiff_distance") + def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, fit_intercept=False, warm_start=False, verbose=0): self.p0 = p0 @@ -51,10 +57,7 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): - check_group_compatible(datafit) - check_group_compatible(penalty) - + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): fit_intercept = self.fit_intercept n_samples, n_features = X.shape grp_ptr, grp_indices = penalty.grp_ptr, penalty.grp_indices @@ -142,6 +145,15 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): p_objs_out.append(p_obj) return w, np.asarray(p_objs_out), stop_crit + def custom_checks(self, X, y, datafit, penalty): + check_group_compatible(datafit) + check_group_compatible(penalty) + + if issparse(X): + raise ValueError( + "Sparse matrices are not yet supported in `GroupProxNewton` solver." + ) + @njit def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, diff --git a/skglm/solvers/lbfgs.py b/skglm/solvers/lbfgs.py index 5e7e03051..438c8b97b 100644 --- a/skglm/solvers/lbfgs.py +++ b/skglm/solvers/lbfgs.py @@ -7,6 +7,7 @@ from scipy.sparse import issparse from skglm.solvers import BaseSolver +from skglm.utils.validation import check_attrs class LBFGS(BaseSolver): @@ -27,12 +28,15 @@ class LBFGS(BaseSolver): Amount of verbosity. 0/False is silent. """ + _datafit_required_attr = ("gradient",) + _penalty_required_attr = ("gradient",) + def __init__(self, max_iter=50, tol=1e-4, verbose=False): self.max_iter = max_iter self.tol = tol self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): def objective(w): Xw = X @ w @@ -102,3 +106,11 @@ def callback_post_iter(w_k): stop_crit = norm(result.jac, ord=np.inf) return w, np.asarray(p_objs_out), stop_crit + + def custom_checks(self, X, y, datafit, penalty): + # check datafit support sparse data + check_attrs( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=issparse(X) + ) diff --git a/skglm/solvers/multitask_bcd.py b/skglm/solvers/multitask_bcd.py index 16301ac4a..5a8dfa5e6 100644 --- a/skglm/solvers/multitask_bcd.py +++ b/skglm/solvers/multitask_bcd.py @@ -4,11 +4,15 @@ from numpy.linalg import norm from sklearn.utils import check_array from skglm.solvers.base import BaseSolver +from skglm.utils.validation import check_attrs class MultiTaskBCD(BaseSolver): """Block coordinate descent solver for multi-task problems.""" + _datafit_required_attr = ("get_lipschitz", "gradient_j") + _penalty_required_attr = ("prox_1feat",) + def __init__(self, max_iter=100, max_epochs=50_000, p0=10, tol=1e-6, use_acc=True, ws_strategy="subdiff", fit_intercept=True, warm_start=False, verbose=0): @@ -22,7 +26,7 @@ def __init__(self, max_iter=100, max_epochs=50_000, p0=10, tol=1e-6, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None): + def _solve(self, X, Y, datafit, penalty, W_init=None, XW_init=None): n_samples, n_features = X.shape n_tasks = Y.shape[1] pen = penalty.is_penalized(n_features) @@ -231,6 +235,21 @@ def path(self, X, Y, datafit, penalty, alphas, W_init=None, return_n_iter=False) return results + def custom_checks(self, X, y, datafit, penalty): + # check datafit support sparse data + check_attrs( + datafit, solver=self, + required_attr=self._datafit_required_attr, + support_sparse=sparse.issparse(X) + ) + + # ws strategy + if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use ws_strategy='subdiff'." + ) + @njit def dist_fix_point_bcd(W, grad_ws, lipschitz_ws, datafit, penalty, ws): diff --git a/skglm/solvers/prox_newton.py b/skglm/solvers/prox_newton.py index 4b8e0aaf7..76867c7d8 100644 --- a/skglm/solvers/prox_newton.py +++ b/skglm/solvers/prox_newton.py @@ -52,6 +52,9 @@ class ProxNewton(BaseSolver): code: https://github.com/tbjohns/BlitzL1 """ + _datafit_required_attr = ("raw_grad", "raw_hessian") + _penalty_required_attr = ("prox_1d",) + def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, ws_strategy="subdiff", fit_intercept=True, warm_start=False, verbose=0): @@ -64,7 +67,7 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4, self.warm_start = warm_start self.verbose = verbose - def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): + def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): if self.ws_strategy not in ("subdiff", "fixpoint"): raise ValueError("ws_strategy must be `subdiff` or `fixpoint`, " f"got {self.ws_strategy}.") @@ -196,6 +199,14 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None): ) return w, np.asarray(p_objs_out), stop_crit + def custom_checks(self, X, y, datafit, penalty): + # ws strategy + if self.ws_strategy == "subdiff" and not hasattr(penalty, "subdiff_distance"): + raise AttributeError( + "Penalty must implement `subdiff_distance` " + "to use ws_strategy='subdiff' in ProxNewton solver" + ) + @njit def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit, diff --git a/skglm/tests/test_validation.py b/skglm/tests/test_validation.py new file mode 100644 index 000000000..7e998bfb8 --- /dev/null +++ b/skglm/tests/test_validation.py @@ -0,0 +1,90 @@ +import pytest +import numpy as np +from scipy import sparse + +from skglm.penalties import L1, WeightedL1GroupL2, WeightedGroupL2 +from skglm.datafits import Poisson, Huber, QuadraticGroup, LogisticGroup +from skglm.solvers import FISTA, ProxNewton, GroupBCD, GramCD, GroupProxNewton + +from skglm.utils.data import grp_converter +from skglm.utils.data import make_correlated_data +from skglm.utils.jit_compilation import compiled_clone + + +def test_datafit_penalty_solver_compatibility(): + grp_size, n_features = 3, 9 + n_samples = 10 + X, y, _ = make_correlated_data(n_samples, n_features) + X_sparse = sparse.csc_array(X) + + n_groups = n_features // grp_size + weights_groups = np.ones(n_groups) + weights_features = np.ones(n_features) + grp_indices, grp_ptr = grp_converter(grp_size, n_features) + + # basic compatibility checks + with pytest.raises( + AttributeError, match="Missing `raw_grad` and `raw_hessian`" + ): + ProxNewton()._validate( + X, y, compiled_clone(Huber(1.)), compiled_clone(L1(1.)) + ) + with pytest.raises( + AttributeError, match="Missing `get_global_lipschitz`" + ): + FISTA()._validate( + X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) + ) + with pytest.raises( + AttributeError, match="Missing `get_global_lipschitz`" + ): + FISTA()._validate( + X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) + ) + # check Gram Solver + with pytest.raises( + AttributeError, match="`GramCD` supports only `Quadratic` datafit" + ): + GramCD()._validate( + X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) + ) + # check working set strategy subdiff + with pytest.raises( + AttributeError, match="Penalty must implement `subdiff_distance`" + ): + GroupBCD()._validate( + X, y, + datafit=compiled_clone(QuadraticGroup(grp_ptr, grp_indices)), + penalty=compiled_clone( + WeightedL1GroupL2( + 1., weights_groups, weights_features, grp_ptr, grp_indices) + ) + ) + # checks for sparsity + with pytest.raises( + ValueError, + match="Sparse matrices are not yet supported in `GroupProxNewton` solver." + ): + GroupProxNewton()._validate( + X_sparse, y, + datafit=compiled_clone(QuadraticGroup(grp_ptr, grp_indices)), + penalty=compiled_clone( + WeightedL1GroupL2( + 1., weights_groups, weights_features, grp_ptr, grp_indices) + ) + ) + with pytest.raises( + AttributeError, + match="LogisticGroup is not compatible with solver GroupBCD with sparse data." + ): + GroupBCD()._validate( + X_sparse, y, + datafit=compiled_clone(LogisticGroup(grp_ptr, grp_indices)), + penalty=compiled_clone( + WeightedGroupL2(1., weights_groups, grp_ptr, grp_indices) + ) + ) + + +if __name__ == "__main__": + pass diff --git a/skglm/utils/validation.py b/skglm/utils/validation.py index 0da22df40..264ad4bb7 100644 --- a/skglm/utils/validation.py +++ b/skglm/utils/validation.py @@ -1,3 +1,8 @@ +import re + + +SPARSE_SUFFIX = "_sparse" + def check_group_compatible(obj): """Check whether ``obj`` is compatible with ``bcd_solver``. @@ -23,3 +28,72 @@ def check_group_compatible(obj): f"'{obj_name}' is not block-separable. " f"Missing '{attr}' attribute." ) + + +def check_attrs(obj, solver, required_attr, support_sparse=False): + """Check whether datafit or penalty is compatible with solver. + + Parameters + ---------- + obj : Instance of Datafit or Penalty + The instance Datafit (or Penalty) to check. + + solver : Instance of Solver + The instance of Solver to check. + + required_attr : List or tuple of strings + The attributes that ``obj`` must have. + + support_sparse : bool, default False + If ``True`` adds a ``SPARSE_SUFFIX`` to check compatibility with sparse data. + + Raises + ------ + AttributeError + if any of the attribute in ``required_attr`` is missing + from ``obj`` attributes. + """ + missing_attrs = [] + suffix = SPARSE_SUFFIX if support_sparse else "" + + # if `attr` is a list, check that at least one of them + # is within `obj` attributes + for attr in required_attr: + attributes = attr if not isinstance(attr, str) else (attr,) + + for a in attributes: + if hasattr(obj, f"{a}{suffix}"): + break + else: + missing_attrs.append(_join_attrs_with_or(attributes, suffix)) + + if len(missing_attrs): + required_attr = [_join_attrs_with_or(attrs, suffix) for attrs in required_attr] + + # get name obj and solver + name_matcher = re.compile(r"\.(\w+)'>") + + obj_name = name_matcher.search(str(obj.__class__)).group(1) + solver_name = name_matcher.search(str(solver.__class__)).group(1) + + if not support_sparse: + err_message = f"{obj_name} is not compatible with solver {solver_name}." + else: + err_message = (f"{obj_name} is not compatible with solver {solver_name} " + "with sparse data.") + + err_message += (f" It must implement {' and '.join(required_attr)}.\n" + f"Missing {' and '.join(missing_attrs)}.") + + raise AttributeError(err_message) + + +def _join_attrs_with_or(attrs, suffix=""): + if isinstance(attrs, str): + return f"`{attrs}{suffix}`" + + if len(attrs) == 1: + return f"`{attrs[0]}{suffix}`" + + out = " or ".join([f"`{a}{suffix}`" for a in attrs]) + return f"({out})"