diff --git a/pygam/tests/test_utils.py b/pygam/tests/test_utils.py index 19027995..80d1f420 100644 --- a/pygam/tests/test_utils.py +++ b/pygam/tests/test_utils.py @@ -7,7 +7,6 @@ from pygam import * from pygam.utils import check_X, check_y, check_X_y -from pygam.datasets import wage, default, mcycle # TODO check dtypes works as expected @@ -187,3 +186,12 @@ def test_input_data_after_fitting(mcycle_X_y): with pytest.raises(ValueError): gam.sample(X, y, weights=weights_nan, n_bootstraps=2) # # def test_b_spline_basis_clamped_what_we_want(): + +def test_catch_chol_pos_def_error(default_X_y): + """ + regresion test + + doing a gridsearch with a poorly conditioned penalty matrix should not crash + """ + X, y = default_X_y + gam = LogisticGAM().gridsearch(X, y, lam=np.logspace(10, 12, 3)) diff --git a/pygam/utils.py b/pygam/utils.py index 99b1d8d4..e22b5942 100644 --- a/pygam/utils.py +++ b/pygam/utils.py @@ -49,15 +49,15 @@ def cholesky(A, sparse=True, verbose=True): """ if SKSPIMPORT: A = sp.sparse.csc_matrix(A) - F = spcholesky(A) - - # permutation matrix P - P = sp.sparse.lil_matrix(A.shape) - p = F.P() - P[np.arange(len(p)), p] = 1 - - # permute try: + F = spcholesky(A) + + # permutation matrix P + P = sp.sparse.lil_matrix(A.shape) + p = F.P() + P[np.arange(len(p)), p] = 1 + + # permute L = F.L() L = P.T.dot(L) except CholmodNotPositiveDefiniteError as e: