diff --git a/skglm/tests/test_validation.py b/skglm/tests/test_validation.py index 25c23c9a7..7e998bfb8 100644 --- a/skglm/tests/test_validation.py +++ b/skglm/tests/test_validation.py @@ -1,35 +1,89 @@ import pytest +import numpy as np +from scipy import sparse -from skglm.penalties import L1 -from skglm.datafits import Poisson, Huber -from skglm.solvers import FISTA, ProxNewton +from skglm.penalties import L1, WeightedL1GroupL2, WeightedGroupL2 +from skglm.datafits import Poisson, Huber, QuadraticGroup, LogisticGroup +from skglm.solvers import FISTA, ProxNewton, GroupBCD, GramCD, GroupProxNewton + +from skglm.utils.data import grp_converter from skglm.utils.data import make_correlated_data from skglm.utils.jit_compilation import compiled_clone def test_datafit_penalty_solver_compatibility(): - X, y, _ = make_correlated_data(n_samples=3, n_features=5) + grp_size, n_features = 3, 9 + n_samples = 10 + X, y, _ = make_correlated_data(n_samples, n_features) + X_sparse = sparse.csc_array(X) + + n_groups = n_features // grp_size + weights_groups = np.ones(n_groups) + weights_features = np.ones(n_features) + grp_indices, grp_ptr = grp_converter(grp_size, n_features) + # basic compatibility checks with pytest.raises( AttributeError, match="Missing `raw_grad` and `raw_hessian`" ): ProxNewton()._validate( X, y, compiled_clone(Huber(1.)), compiled_clone(L1(1.)) ) - with pytest.raises( AttributeError, match="Missing `get_global_lipschitz`" ): FISTA()._validate( X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) ) - with pytest.raises( AttributeError, match="Missing `get_global_lipschitz`" ): FISTA()._validate( X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) ) + # check Gram Solver + with pytest.raises( + AttributeError, match="`GramCD` supports only `Quadratic` datafit" + ): + GramCD()._validate( + X, y, compiled_clone(Poisson()), compiled_clone(L1(1.)) + ) + # check working set strategy subdiff + with pytest.raises( + AttributeError, match="Penalty must implement `subdiff_distance`" + ): + GroupBCD()._validate( + X, y, + datafit=compiled_clone(QuadraticGroup(grp_ptr, grp_indices)), + penalty=compiled_clone( + WeightedL1GroupL2( + 1., weights_groups, weights_features, grp_ptr, grp_indices) + ) + ) + # checks for sparsity + with pytest.raises( + ValueError, + match="Sparse matrices are not yet supported in `GroupProxNewton` solver." + ): + GroupProxNewton()._validate( + X_sparse, y, + datafit=compiled_clone(QuadraticGroup(grp_ptr, grp_indices)), + penalty=compiled_clone( + WeightedL1GroupL2( + 1., weights_groups, weights_features, grp_ptr, grp_indices) + ) + ) + with pytest.raises( + AttributeError, + match="LogisticGroup is not compatible with solver GroupBCD with sparse data." + ): + GroupBCD()._validate( + X_sparse, y, + datafit=compiled_clone(LogisticGroup(grp_ptr, grp_indices)), + penalty=compiled_clone( + WeightedGroupL2(1., weights_groups, grp_ptr, grp_indices) + ) + ) if __name__ == "__main__":